diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index aee1a1483..8c6c0045d 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -2577,7 +2577,7 @@ - [ ] list_job_runs - [ ] list_tags_for_resource - [X] start_application -- [ ] start_job_run +- [X] start_job_run - [X] stop_application - [ ] tag_resource - [ ] untag_resource diff --git a/codecov.yml b/codecov.yml index 4fe85e33f..750eb3b17 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,7 +1,7 @@ codecov: notify: # Leave a GitHub comment after all builds have passed - after_n_builds: 8 + after_n_builds: 10 coverage: status: project: diff --git a/docs/docs/multi_account.rst b/docs/docs/multi_account.rst new file mode 100644 index 000000000..b0b4277b7 --- /dev/null +++ b/docs/docs/multi_account.rst @@ -0,0 +1,133 @@ +.. _multi_account: + +===================== +Multi-Account support +===================== + + +By default, Moto processes all requests in a default account: `12345678910`. The exact credentials provided are usually ignored to make the process of mocking requests as hassle-free as possible. + +If you want to mock resources in multiple accounts, or you want to change the default account ID, there are multiple ways to achieve this. + +Configure the default account +------------------------------ + +It is possible to configure the default account ID that will be used for all incoming requests, by setting the environment variable `MOTO_ACCOUNT_ID`. + +Here is an example of what this looks like in practice: + +.. sourcecode:: python + + # Create a bucket in the default account + client = boto3.client("s3", region_name="us-east-1") + client.create_bucket(Bucket="bucket-default-account") + + # Configure another account - all subsequent requests will use this account ID + os.environ["MOTO_ACCOUNT_ID"] = "111111111111" + client.create_bucket(Bucket="bucket-in-account-2") + + assert [b["Name"] for b in client2.list_buckets()["Buckets"]] == ["bucket-in-account-2"] + + # Now revert to the default account, by removing the environment variable + del os.environ["MOTO_ACCOUNT_ID"] + assert [b["Name"] for b in client2.list_buckets()["Buckets"]] == ["bucket-default-account"] + + + +Configure the account ID using a request header +--------------------------------------------------- + +If you are using Moto in ServerMode you can add a custom header to a request, to specify which account should be used. + +.. note:: + + Moto will only look at the request-header if the environment variable is not set. + +As an example, this is how you would create an S3-bucket in another account: + +.. sourcecode:: python + + headers ={"x-moto-account-id": "333344445555"} + requests.put("http://bucket.localhost:5000/", headers=headers) + + # This will return a list of all buckets in account 333344445555 + requests.get("http://localhost:5000", headers=headers) + + # This will return an empty list, as there are no buckets in the default account + requests.get("http://localhost:5000") + +Configure an account using STS +------------------------------ + +The `STS.assume_role()`-feature is useful if you want to temporarily use a different set of access credentials. +Passing in a role that belongs to a different account will return a set of credentials that give access to that account. + +.. note:: + + To avoid any chicken-and-egg problems trying to create roles in non-existing accounts, these Roles do not need to exist. + Moto will only extract the account ID from the role, and create access credentials for that account. + +.. note:: + + Moto will only look at the access credentials if the environment variable and request header is not set. + +Let's look at some examples. + + +.. sourcecode:: python + + # Create a bucket using the default access credentials + client1 = boto3.client("s3", region_name="us-east-1") + client1.create_bucket(Bucket="foobar") + + # Assume a role in our account + # Note that this Role does not need to exist + default_account = "123456789012" + sts = boto3.client("sts") + response = sts.assume_role( + RoleArn=f"arn:aws:iam::{default_account}:role/my-role", + RoleSessionName="test-session-name", + ExternalId="test-external-id", + ) + + # These access credentials give access to the default account + client2 = boto3.client( + "s3", + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + client2.list_buckets()["Buckets"].should.have.length_of(1) + +Because we assumed a role within the same account, we can see the bucket that we've just created. + +Things get interesting when assuming a role within a different account. + +.. sourcecode:: python + + # Create a bucket with default access credentials + client1 = boto3.client("s3", region_name="us-east-1") + client1.create_bucket(Bucket="foobar") + + # Assume a role in a different account + # Note that the Role does not need to exist + sts = boto3.client("sts") + response = sts.assume_role( + RoleArn="arn:aws:iam::111111111111:role/role-in-another-account", + RoleSessionName="test-session-name", + ExternalId="test-external-id", + ) + + # Retrieve all buckets in this new account - this will be completely empty + client2 = boto3.client( + "s3", + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + client2.list_buckets()["Buckets"].should.have.length_of(0) + +Because we've assumed a role in a different account, no buckets were found. The `foobar`-bucket only exists in the default account, not in `111111111111`. + diff --git a/docs/index.rst b/docs/index.rst index 3d93d26b8..0e35964aa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,6 +35,7 @@ Additional Resources docs/faq docs/iam docs/aws_config + docs/multi_account .. toctree:: :hidden: diff --git a/moto/acm/models.py b/moto/acm/models.py index 3a3f3e1e6..525f0b6f7 100644 --- a/moto/acm/models.py +++ b/moto/acm/models.py @@ -13,8 +13,6 @@ import cryptography.hazmat.primitives.asymmetric.rsa from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.backends import default_backend -from moto.core import get_account_id - AWS_ROOT_CA = b"""-----BEGIN CERTIFICATE----- MIIESTCCAzGgAwIBAgITBntQXCplJ7wevi2i0ZmY7bibLDANBgkqhkiG9w0BAQsF @@ -123,6 +121,7 @@ class TagHolder(dict): class CertBundle(BaseModel): def __init__( self, + account_id, certificate, private_key, chain=None, @@ -161,12 +160,12 @@ class CertBundle(BaseModel): # Used for when one wants to overwrite an arn if arn is None: - self.arn = make_arn_for_certificate(get_account_id(), region) + self.arn = make_arn_for_certificate(account_id, region) else: self.arn = arn @classmethod - def generate_cert(cls, domain_name, region, sans=None): + def generate_cert(cls, domain_name, account_id, region, sans=None): if sans is None: sans = set() else: @@ -235,10 +234,11 @@ class CertBundle(BaseModel): ) return cls( - cert_armored, - private_key, + certificate=cert_armored, + private_key=private_key, cert_type="AMAZON_ISSUED", cert_status="PENDING_VALIDATION", + account_id=account_id, region=region, ) @@ -435,11 +435,8 @@ class AWSCertificateManagerBackend(BaseBackend): service_region, zones, "acm-pca" ) - @staticmethod - def _arn_not_found(arn): - msg = "Certificate with arn {0} not found in account {1}".format( - arn, get_account_id() - ) + def _arn_not_found(self, arn): + msg = f"Certificate with arn {arn} not found in account {self.account_id}" return AWSResourceNotFoundException(msg) def set_certificate_in_use_by(self, arn, load_balancer_name): @@ -485,6 +482,7 @@ class AWSCertificateManagerBackend(BaseBackend): else: # Will reuse provided ARN bundle = CertBundle( + self.account_id, certificate, private_key, chain=chain, @@ -494,7 +492,11 @@ class AWSCertificateManagerBackend(BaseBackend): else: # Will generate a random ARN bundle = CertBundle( - certificate, private_key, chain=chain, region=self.region_name + self.account_id, + certificate, + private_key, + chain=chain, + region=self.region_name, ) self._certificates[bundle.arn] = bundle @@ -546,7 +548,10 @@ class AWSCertificateManagerBackend(BaseBackend): return arn cert = CertBundle.generate_cert( - domain_name, region=self.region_name, sans=subject_alt_names + domain_name, + account_id=self.account_id, + region=self.region_name, + sans=subject_alt_names, ) if idempotency_token is not None: self._set_idempotency_token_arn(idempotency_token, cert.arn) diff --git a/moto/acm/responses.py b/moto/acm/responses.py index d0696e5f4..8e6e534ef 100644 --- a/moto/acm/responses.py +++ b/moto/acm/responses.py @@ -6,6 +6,9 @@ from .models import acm_backends, AWSValidationException class AWSCertificateManagerResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="acm") + @property def acm_backend(self): """ @@ -14,7 +17,7 @@ class AWSCertificateManagerResponse(BaseResponse): :return: ACM Backend object :rtype: moto.acm.models.AWSCertificateManagerBackend """ - return acm_backends[self.region] + return acm_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 1d14262f1..c5d359f00 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -13,7 +13,7 @@ from urllib.parse import urlparse import responses from openapi_spec_validator.exceptions import OpenAPIValidationError -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from .utils import create_id, to_path from moto.core.utils import path_url, BackendDict from .integration_parsers.aws_parser import TypeAwsParser @@ -84,13 +84,13 @@ class Deployment(CloudFormationModel, dict): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] rest_api_id = properties["RestApiId"] name = properties["StageName"] desc = properties.get("Description", "") - backend = apigateway_backends[region_name] + backend = apigateway_backends[account_id][region_name] return backend.create_deployment( function_id=rest_api_id, name=name, description=desc ) @@ -209,7 +209,7 @@ class Method(CloudFormationModel, dict): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] rest_api_id = properties["RestApiId"] @@ -217,7 +217,7 @@ class Method(CloudFormationModel, dict): method_type = properties["HttpMethod"] auth_type = properties["AuthorizationType"] key_req = properties["ApiKeyRequired"] - backend = apigateway_backends[region_name] + backend = apigateway_backends[account_id][region_name] m = backend.put_method( function_id=rest_api_id, resource_id=resource_id, @@ -253,9 +253,12 @@ class Method(CloudFormationModel, dict): class Resource(CloudFormationModel): - def __init__(self, resource_id, region_name, api_id, path_part, parent_id): + def __init__( + self, resource_id, account_id, region_name, api_id, path_part, parent_id + ): super().__init__() self.id = resource_id + self.account_id = account_id self.region_name = region_name self.api_id = api_id self.path_part = path_part @@ -291,14 +294,14 @@ class Resource(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] api_id = properties["RestApiId"] parent = properties["ParentId"] path = properties["PathPart"] - backend = apigateway_backends[region_name] + backend = apigateway_backends[account_id][region_name] if parent == api_id: # A Root path (/) is automatically created. Any new paths should use this as their parent resources = backend.get_resources(function_id=api_id) @@ -315,7 +318,7 @@ class Resource(CloudFormationModel): def get_parent_path(self): if self.parent_id: - backend = apigateway_backends[self.region_name] + backend = apigateway_backends[self.account_id][self.region_name] parent = backend.get_resource(self.api_id, self.parent_id) parent_path = parent.get_path() if parent_path != "/": # Root parent @@ -780,9 +783,10 @@ class RestAPI(CloudFormationModel): OPERATION_VALUE = "value" OPERATION_OP = "op" - def __init__(self, api_id, region_name, name, description, **kwargs): + def __init__(self, api_id, account_id, region_name, name, description, **kwargs): super().__init__() self.id = api_id + self.account_id = account_id self.region_name = region_name self.name = name self.description = description @@ -883,13 +887,13 @@ class RestAPI(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] name = properties["Name"] desc = properties.get("Description", "") config = properties.get("EndpointConfiguration", None) - backend = apigateway_backends[region_name] + backend = apigateway_backends[account_id][region_name] return backend.create_rest_api( name=name, description=desc, endpoint_configuration=config ) @@ -898,6 +902,7 @@ class RestAPI(CloudFormationModel): child_id = create_id() child = Resource( resource_id=child_id, + account_id=self.account_id, region_name=self.region_name, api_id=self.id, path_part=path, @@ -1267,6 +1272,7 @@ class APIGatewayBackend(BaseBackend): api_id = create_id() rest_api = RestAPI( api_id, + self.account_id, self.region_name, name, description, @@ -1576,7 +1582,7 @@ class APIGatewayBackend(BaseBackend): ): resource = self.get_resource(function_id, resource_id) if credentials and not re.match( - "^arn:aws:iam::" + str(get_account_id()), credentials + "^arn:aws:iam::" + str(self.account_id), credentials ): raise CrossAccountNotAllowed() if not integration_method and integration_type in [ diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index 501c209e8..7312ca3b9 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -13,6 +13,9 @@ ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"] class APIGatewayResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="apigateway") + def error(self, type_, message, status=400): headers = self.response_headers or {} headers["X-Amzn-Errortype"] = type_ @@ -20,7 +23,7 @@ class APIGatewayResponse(BaseResponse): @property def backend(self): - return apigateway_backends[self.region] + return apigateway_backends[self.current_account][self.region] def __validate_api_key_source(self, api_key_source): if api_key_source and api_key_source not in API_KEY_SOURCES: diff --git a/moto/apigatewayv2/responses.py b/moto/apigatewayv2/responses.py index b025be40f..8f391e6ac 100644 --- a/moto/apigatewayv2/responses.py +++ b/moto/apigatewayv2/responses.py @@ -11,10 +11,13 @@ from .models import apigatewayv2_backends class ApiGatewayV2Response(BaseResponse): """Handler for ApiGatewayV2 requests and responses.""" + def __init__(self): + super().__init__(service_name="apigatewayv2") + @property def apigatewayv2_backend(self): """Return backend instance specific for this region.""" - return apigatewayv2_backends[self.region] + return apigatewayv2_backends[self.current_account][self.region] def apis(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/applicationautoscaling/models.py b/moto/applicationautoscaling/models.py index 8a6c9cdb5..b0f369de8 100644 --- a/moto/applicationautoscaling/models.py +++ b/moto/applicationautoscaling/models.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.ecs import ecs_backends from .exceptions import AWSValidationException @@ -65,7 +65,7 @@ class ScalableDimensionValueSet(Enum): class ApplicationAutoscalingBackend(BaseBackend): def __init__(self, region_name, account_id): super().__init__(region_name, account_id) - self.ecs_backend = ecs_backends[region_name] + self.ecs_backend = ecs_backends[account_id][region_name] self.targets = OrderedDict() self.policies = {} self.scheduled_actions = list() @@ -77,10 +77,6 @@ class ApplicationAutoscalingBackend(BaseBackend): service_region, zones, "application-autoscaling" ) - @property - def applicationautoscaling_backend(self): - return applicationautoscaling_backends[self.region_name] - def describe_scalable_targets(self, namespace, r_ids=None, dimension=None): """Describe scalable targets.""" if r_ids is None: @@ -305,6 +301,7 @@ class ApplicationAutoscalingBackend(BaseBackend): start_time, end_time, scalable_target_action, + self.account_id, self.region_name, ) self.scheduled_actions.append(action) @@ -450,9 +447,10 @@ class FakeScheduledAction(BaseModel): start_time, end_time, scalable_target_action, + account_id, region, ): - self.arn = f"arn:aws:autoscaling:{region}:{get_account_id()}:scheduledAction:{service_namespace}:scheduledActionName/{scheduled_action_name}" + self.arn = f"arn:aws:autoscaling:{region}:{account_id}:scheduledAction:{service_namespace}:scheduledActionName/{scheduled_action_name}" self.service_namespace = service_namespace self.schedule = schedule self.timezone = timezone diff --git a/moto/applicationautoscaling/responses.py b/moto/applicationautoscaling/responses.py index f25520647..6debfae85 100644 --- a/moto/applicationautoscaling/responses.py +++ b/moto/applicationautoscaling/responses.py @@ -9,9 +9,12 @@ from .exceptions import AWSValidationException class ApplicationAutoScalingResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="application-autoscaling") + @property def applicationautoscaling_backend(self): - return applicationautoscaling_backends[self.region] + return applicationautoscaling_backends[self.current_account][self.region] def describe_scalable_targets(self): self._validate_params() diff --git a/moto/appsync/models.py b/moto/appsync/models.py index 7a21b60e5..1b7b65654 100644 --- a/moto/appsync/models.py +++ b/moto/appsync/models.py @@ -1,6 +1,6 @@ import base64 from datetime import timedelta, datetime, timezone -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, unix_time from moto.utilities.tagging_service import TaggingService @@ -53,6 +53,7 @@ class GraphqlSchema(BaseModel): class GraphqlAPI(BaseModel): def __init__( self, + account_id, region, name, authentication_type, @@ -74,9 +75,7 @@ class GraphqlAPI(BaseModel): self.user_pool_config = user_pool_config self.xray_enabled = xray_enabled - self.arn = ( - f"arn:aws:appsync:{self.region}:{get_account_id()}:apis/{self.api_id}" - ) + self.arn = f"arn:aws:appsync:{self.region}:{account_id}:apis/{self.api_id}" self.graphql_schema = None self.api_keys = dict() @@ -205,6 +204,7 @@ class AppSyncBackend(BaseBackend): tags, ): graphql_api = GraphqlAPI( + account_id=self.account_id, region=self.region_name, name=name, authentication_type=authentication_type, diff --git a/moto/appsync/responses.py b/moto/appsync/responses.py index c9ae5f216..eef8da514 100644 --- a/moto/appsync/responses.py +++ b/moto/appsync/responses.py @@ -9,10 +9,13 @@ from .models import appsync_backends class AppSyncResponse(BaseResponse): """Handler for AppSync requests and responses.""" + def __init__(self): + super().__init__(service_name="appsync") + @property def appsync_backend(self): """Return backend instance specific for this region.""" - return appsync_backends[self.region] + return appsync_backends[self.current_account][self.region] def graph_ql(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -114,7 +117,6 @@ class AppSyncResponse(BaseResponse): log_config = params.get("logConfig") authentication_type = params.get("authenticationType") user_pool_config = params.get("userPoolConfig") - print(user_pool_config) open_id_connect_config = params.get("openIDConnectConfig") additional_authentication_providers = params.get( "additionalAuthenticationProviders" @@ -152,7 +154,6 @@ class AppSyncResponse(BaseResponse): api_key = self.appsync_backend.create_api_key( api_id=api_id, description=description, expires=expires ) - print(api_key.to_json()) return 200, {}, json.dumps(dict(apiKey=api_key.to_json())) def delete_api_key(self): diff --git a/moto/athena/models.py b/moto/athena/models.py index b3aef3f3c..7bf7e3f79 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -1,6 +1,6 @@ import time -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from uuid import uuid4 @@ -10,18 +10,11 @@ class TaggableResourceMixin(object): # This mixing was copied from Redshift when initially implementing # Athena. TBD if it's worth the overhead. - def __init__(self, region_name, resource_name, tags): + def __init__(self, account_id, region_name, resource_name, tags): self.region = region_name self.resource_name = resource_name self.tags = tags or [] - - @property - def arn(self): - return "arn:aws:athena:{region}:{account_id}:{resource_name}".format( - region=self.region, - account_id=get_account_id(), - resource_name=self.resource_name, - ) + self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}" def create_tags(self, tags): new_keys = [tag_set["Key"] for tag_set in tags] @@ -41,7 +34,12 @@ class WorkGroup(TaggableResourceMixin, BaseModel): def __init__(self, athena_backend, name, configuration, description, tags): self.region_name = athena_backend.region_name - super().__init__(self.region_name, "workgroup/{}".format(name), tags) + super().__init__( + athena_backend.account_id, + self.region_name, + "workgroup/{}".format(name), + tags, + ) self.athena_backend = athena_backend self.name = name self.description = description @@ -53,7 +51,12 @@ class DataCatalog(TaggableResourceMixin, BaseModel): self, athena_backend, name, catalog_type, description, parameters, tags ): self.region_name = athena_backend.region_name - super().__init__(self.region_name, "datacatalog/{}".format(name), tags) + super().__init__( + athena_backend.account_id, + self.region_name, + "datacatalog/{}".format(name), + tags, + ) self.athena_backend = athena_backend self.name = name self.type = catalog_type diff --git a/moto/athena/responses.py b/moto/athena/responses.py index 218e55471..b47f5ed7b 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -5,9 +5,12 @@ from .models import athena_backends class AthenaResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="athena") + @property def athena_backend(self): - return athena_backends[self.region] + return athena_backends[self.current_account][self.region] def create_work_group(self): name = self._get_param("Name") diff --git a/moto/autoscaling/models.py b/moto/autoscaling/models.py index 1bf45b1dc..e9659f289 100644 --- a/moto/autoscaling/models.py +++ b/moto/autoscaling/models.py @@ -9,7 +9,7 @@ from moto.packages.boto.ec2.blockdevicemapping import ( from moto.ec2.exceptions import InvalidInstanceIdError from collections import OrderedDict -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import camelcase_to_underscores, BackendDict from moto.ec2 import ec2_backends from moto.elb import elb_backends @@ -97,7 +97,7 @@ class FakeScalingPolicy(BaseModel): @property def arn(self): - return f"arn:aws:autoscaling:{self.autoscaling_backend.region_name}:{get_account_id()}:scalingPolicy:c322761b-3172-4d56-9a21-0ed9d6161d67:autoScalingGroupName/{self.as_name}:policyName/{self.name}" + return f"arn:aws:autoscaling:{self.autoscaling_backend.region_name}:{self.autoscaling_backend.account_id}:scalingPolicy:c322761b-3172-4d56-9a21-0ed9d6161d67:autoScalingGroupName/{self.as_name}:policyName/{self.name}" def execute(self): if self.adjustment_type == "ExactCapacity": @@ -131,6 +131,7 @@ class FakeLaunchConfiguration(CloudFormationModel): ebs_optimized, associate_public_ip_address, block_device_mapping_dict, + account_id, region_name, metadata_options, classic_link_vpc_id, @@ -157,7 +158,7 @@ class FakeLaunchConfiguration(CloudFormationModel): self.metadata_options = metadata_options self.classic_link_vpc_id = classic_link_vpc_id self.classic_link_vpc_security_groups = classic_link_vpc_security_groups - self.arn = f"arn:aws:autoscaling:{region_name}:{get_account_id()}:launchConfiguration:9dbbbf87-6141-428a-a409-0752edbe6cad:launchConfigurationName/{self.name}" + self.arn = f"arn:aws:autoscaling:{region_name}:{account_id}:launchConfiguration:9dbbbf87-6141-428a-a409-0752edbe6cad:launchConfigurationName/{self.name}" @classmethod def create_from_instance(cls, name, instance, backend): @@ -191,13 +192,13 @@ class FakeLaunchConfiguration(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] instance_profile_name = properties.get("IamInstanceProfile") - backend = autoscaling_backends[region_name] + backend = autoscaling_backends[account_id][region_name] config = backend.create_launch_configuration( name=resource_name, image_id=properties.get("ImageId"), @@ -218,27 +219,32 @@ class FakeLaunchConfiguration(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - backend = autoscaling_backends[region_name] + backend = autoscaling_backends[account_id][region_name] try: backend.delete_launch_configuration(resource_name) except KeyError: pass - def delete(self, region_name): - backend = autoscaling_backends[region_name] + def delete(self, account_id, region_name): + backend = autoscaling_backends[account_id][region_name] backend.delete_launch_configuration(self.name) @property @@ -315,12 +321,12 @@ class FakeScheduledAction(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - backend = autoscaling_backends[region_name] + backend = autoscaling_backends[account_id][region_name] scheduled_action_name = ( kwargs["LogicalId"] @@ -369,6 +375,7 @@ class FakeAutoScalingGroup(CloudFormationModel): self.name = name self._id = str(uuid4()) self.region = self.autoscaling_backend.region_name + self.account_id = self.autoscaling_backend.account_id self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier) @@ -415,7 +422,7 @@ class FakeAutoScalingGroup(CloudFormationModel): @property def arn(self): - return f"arn:aws:autoscaling:{self.region}:{get_account_id()}:autoScalingGroup:{self._id}:autoScalingGroupName/{self.name}" + return f"arn:aws:autoscaling:{self.region}:{self.account_id}:autoScalingGroup:{self._id}:autoScalingGroupName/{self.name}" def active_instances(self): return [x for x in self.instance_states if x.lifecycle_state == "InService"] @@ -498,7 +505,7 @@ class FakeAutoScalingGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] @@ -510,7 +517,7 @@ class FakeAutoScalingGroup(CloudFormationModel): load_balancer_names = properties.get("LoadBalancerNames", []) target_group_arns = properties.get("TargetGroupARNs", []) - backend = autoscaling_backends[region_name] + backend = autoscaling_backends[account_id][region_name] group = backend.create_auto_scaling_group( name=resource_name, availability_zones=properties.get("AvailabilityZones", []), @@ -540,27 +547,32 @@ class FakeAutoScalingGroup(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - backend = autoscaling_backends[region_name] + backend = autoscaling_backends[account_id][region_name] try: backend.delete_auto_scaling_group(resource_name) except KeyError: pass - def delete(self, region_name): - backend = autoscaling_backends[region_name] + def delete(self, account_id, region_name): + backend = autoscaling_backends[account_id][region_name] backend.delete_auto_scaling_group(self.name) @property @@ -740,9 +752,9 @@ class AutoScalingBackend(BaseBackend): self.scheduled_actions = OrderedDict() self.policies = {} self.lifecycle_hooks = {} - self.ec2_backend = ec2_backends[region_name] - self.elb_backend = elb_backends[region_name] - self.elbv2_backend = elbv2_backends[region_name] + self.ec2_backend = ec2_backends[self.account_id][region_name] + self.elb_backend = elb_backends[self.account_id][region_name] + self.elbv2_backend = elbv2_backends[self.account_id][region_name] @staticmethod def default_vpc_endpoint_service(service_region, zones): @@ -800,6 +812,7 @@ class AutoScalingBackend(BaseBackend): ebs_optimized=ebs_optimized, associate_public_ip_address=associate_public_ip_address, block_device_mapping_dict=block_device_mappings, + account_id=self.account_id, region_name=self.region_name, metadata_options=metadata_options, classic_link_vpc_id=classic_link_vpc_id, diff --git a/moto/autoscaling/responses.py b/moto/autoscaling/responses.py index ebc8901d7..93c401f3d 100644 --- a/moto/autoscaling/responses.py +++ b/moto/autoscaling/responses.py @@ -10,9 +10,12 @@ from .models import autoscaling_backends class AutoScalingResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="autoscaling") + @property def autoscaling_backend(self): - return autoscaling_backends[self.region] + return autoscaling_backends[self.current_account][self.region] def create_launch_configuration(self): instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled") diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index d9d82d717..819a5cd44 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -51,7 +51,6 @@ from .utils import ( from moto.sqs import sqs_backends from moto.dynamodb import dynamodb_backends from moto.dynamodbstreams import dynamodbstreams_backends -from moto.core import get_account_id from moto.utilities.docker_utilities import DockerModel, parse_image_ref from tempfile import TemporaryDirectory from uuid import uuid4 @@ -179,11 +178,13 @@ def _s3_content(key): return key.value, key.size, base64ed_sha, sha_hex_digest -def _validate_s3_bucket_and_key(data): +def _validate_s3_bucket_and_key(account_id, data): key = None try: # FIXME: does not validate bucket region - key = s3_backends["global"].get_object(data["S3Bucket"], data["S3Key"]) + key = s3_backends[account_id]["global"].get_object( + data["S3Bucket"], data["S3Key"] + ) except MissingBucket: if do_validate_s3(): raise InvalidParameterValueException( @@ -212,18 +213,19 @@ class Permission(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - backend = lambda_backends[region_name] + backend = lambda_backends[account_id][region_name] fn = backend.get_function(properties["FunctionName"]) fn.policy.add_statement(raw=json.dumps(properties)) return Permission(region=region_name) class LayerVersion(CloudFormationModel): - def __init__(self, spec, region): + def __init__(self, spec, account_id, region): # required + self.account_id = account_id self.region = region self.name = spec["LayerName"] self.content = spec["Content"] @@ -248,7 +250,7 @@ class LayerVersion(CloudFormationModel): self.code_digest, ) = _zipfile_content(self.content["ZipFile"]) else: - key = _validate_s3_bucket_and_key(self.content) + key = _validate_s3_bucket_and_key(account_id, data=self.content) if key: ( self.code_bytes, @@ -261,7 +263,7 @@ class LayerVersion(CloudFormationModel): def arn(self): if self.version: return make_layer_ver_arn( - self.region, get_account_id(), self.name, self.version + self.region, self.account_id, self.name, self.version ) raise ValueError("Layer version is not set") @@ -297,7 +299,7 @@ class LayerVersion(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] optional_properties = ("Description", "CompatibleRuntimes", "LicenseInfo") @@ -311,16 +313,25 @@ class LayerVersion(CloudFormationModel): if prop in properties: spec[prop] = properties[prop] - backend = lambda_backends[region_name] + backend = lambda_backends[account_id][region_name] layer_version = backend.publish_layer_version(spec) return layer_version class LambdaAlias(BaseModel): def __init__( - self, region, name, function_name, function_version, description, routing_config + self, + account_id, + region, + name, + function_name, + function_version, + description, + routing_config, ): - self.arn = f"arn:aws:lambda:{region}:{get_account_id()}:function:{function_name}:{name}" + self.arn = ( + f"arn:aws:lambda:{region}:{account_id}:function:{function_name}:{name}" + ) self.name = name self.function_version = function_version self.description = description @@ -347,11 +358,13 @@ class LambdaAlias(BaseModel): class Layer(object): - def __init__(self, name, region): - self.region = region - self.name = name + def __init__(self, layer_version: LayerVersion): + self.region = layer_version.region + self.name = layer_version.name - self.layer_arn = make_layer_arn(region, get_account_id(), self.name) + self.layer_arn = make_layer_arn( + self.region, layer_version.account_id, self.name + ) self._latest_version = 0 self.layer_versions = {} @@ -378,16 +391,17 @@ class Layer(object): class LambdaFunction(CloudFormationModel, DockerModel): - def __init__(self, spec, region, version=1): + def __init__(self, account_id, spec, region, version=1): DockerModel.__init__(self) # required + self.account_id = account_id self.region = region self.code = spec["Code"] self.function_name = spec["FunctionName"] self.handler = spec.get("Handler") self.role = spec["Role"] self.run_time = spec.get("Runtime") - self.logs_backend = logs_backends[self.region] + self.logs_backend = logs_backends[account_id][self.region] self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.policy = None self.state = "Active" @@ -428,7 +442,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.code["UUID"] = str(uuid.uuid4()) self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"]) elif "S3Bucket" in self.code: - key = _validate_s3_bucket_and_key(self.code) + key = _validate_s3_bucket_and_key(self.account_id, data=self.code) if key: ( self.code_bytes, @@ -447,7 +461,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.code_size = 0 self.function_arn = make_function_arn( - self.region, get_account_id(), self.function_name + self.region, self.account_id, self.function_name ) if spec.get("Tags"): @@ -459,7 +473,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): def set_version(self, version): self.function_arn = make_function_ver_arn( - self.region, get_account_id(), self.function_name, version + self.region, self.account_id, self.function_name, version ) self.version = version self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") @@ -479,7 +493,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): return json.dumps(self.get_configuration()) def _get_layers_data(self, layers_versions_arns): - backend = lambda_backends[self.region] + backend = lambda_backends[self.account_id][self.region] layer_versions = [ backend.layers_versions_by_arn(layer_version) for layer_version in layers_versions_arns @@ -602,7 +616,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): key = None try: # FIXME: does not validate bucket region - key = s3_backends["global"].get_object( + key = s3_backends[self.account_id]["global"].get_object( updated_spec["S3Bucket"], updated_spec["S3Key"] ) except MissingBucket: @@ -791,7 +805,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] optional_properties = ( @@ -827,7 +841,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): cls._create_zipfile_from_plaintext_code(spec["Code"]["ZipFile"]) ) - backend = lambda_backends[region_name] + backend = lambda_backends[account_id][region_name] fn = backend.create_function(spec) return fn @@ -839,12 +853,17 @@ class LambdaFunction(CloudFormationModel, DockerModel): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": - return make_function_arn(self.region, get_account_id(), self.function_name) + return make_function_arn(self.region, self.account_id, self.function_name) raise UnformattedGetAttTemplateException() @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): updated_props = cloudformation_json["Properties"] original_resource.update_configuration(updated_props) @@ -865,8 +884,8 @@ class LambdaFunction(CloudFormationModel, DockerModel): zip_output.seek(0) return zip_output.read() - def delete(self, region): - lambda_backends[region].delete_function(self.function_name) + def delete(self, account_id, region): + lambda_backends[account_id][region].delete_function(self.function_name) def delete_alias(self, name): self._aliases.pop(name, None) @@ -874,11 +893,12 @@ class LambdaFunction(CloudFormationModel, DockerModel): def get_alias(self, name): if name in self._aliases: return self._aliases[name] - arn = f"arn:aws:lambda:{self.region}:{get_account_id()}:function:{self.function_name}:{name}" + arn = f"arn:aws:lambda:{self.region}:{self.account_id}:function:{self.function_name}:{name}" raise UnknownAliasException(arn) def put_alias(self, name, description, function_version, routing_config): alias = LambdaAlias( + account_id=self.account_id, region=self.region, name=name, function_name=self.function_name, @@ -968,8 +988,8 @@ class EventSourceMapping(CloudFormationModel): "StateTransitionReason": "User initiated", } - def delete(self, region_name): - lambda_backend = lambda_backends[region_name] + def delete(self, account_id, region_name): + lambda_backend = lambda_backends[account_id][region_name] lambda_backend.delete_event_source_mapping(self.uuid) @staticmethod @@ -983,27 +1003,32 @@ class EventSourceMapping(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - lambda_backend = lambda_backends[region_name] + lambda_backend = lambda_backends[account_id][region_name] return lambda_backend.create_event_source_mapping(properties) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] event_source_uuid = original_resource.uuid - lambda_backend = lambda_backends[region_name] + lambda_backend = lambda_backends[account_id][region_name] return lambda_backend.update_event_source_mapping(event_source_uuid, properties) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): properties = cloudformation_json["Properties"] - lambda_backend = lambda_backends[region_name] + lambda_backend = lambda_backends[account_id][region_name] esms = lambda_backend.list_event_source_mappings( event_source_arn=properties["EventSourceArn"], function_name=properties["FunctionName"], @@ -1011,7 +1036,7 @@ class EventSourceMapping(CloudFormationModel): for esm in esms: if esm.uuid == resource_name: - esm.delete(region_name) + esm.delete(account_id, region_name) @property def physical_resource_id(self): @@ -1036,22 +1061,23 @@ class LambdaVersion(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] function_name = properties["FunctionName"] - func = lambda_backends[region_name].publish_function(function_name) + func = lambda_backends[account_id][region_name].publish_function(function_name) spec = {"Version": func.version} return LambdaVersion(spec) class LambdaStorage(object): - def __init__(self, region_name): + def __init__(self, region_name, account_id): # Format 'func_name' {'versions': []} self._functions = {} self._aliases = dict() self._arns = weakref.WeakValueDictionary() self.region_name = region_name + self.account_id = account_id def _get_latest(self, name): return self._functions[name]["latest"] @@ -1120,7 +1146,7 @@ class LambdaStorage(object): if name_or_arn.startswith("arn:aws"): arn = name_or_arn else: - arn = make_function_arn(self.region_name, get_account_id(), name_or_arn) + arn = make_function_arn(self.region_name, self.account_id, name_or_arn) if qualifier: arn = f"{arn}:{qualifier}" raise UnknownFunctionException(arn) @@ -1134,10 +1160,11 @@ class LambdaStorage(object): valid_role = re.match(InvalidRoleFormat.pattern, fn.role) if valid_role: account = valid_role.group(2) - if account != get_account_id(): + if account != self.account_id: raise CrossAccountNotAllowed() try: - iam_backends["global"].get_role_by_arn(fn.role) + iam_backend = iam_backends[self.account_id]["global"] + iam_backend.get_role_by_arn(fn.role) except IAMNotFoundException: raise InvalidParameterValueException( "The role defined for the function cannot be assumed by Lambda." @@ -1240,9 +1267,7 @@ class LayerStorage(object): :param layer_version: LayerVersion """ if layer_version.name not in self._layers: - self._layers[layer_version.name] = Layer( - layer_version.name, layer_version.region - ) + self._layers[layer_version.name] = Layer(layer_version) self._layers[layer_version.name].attach_version(layer_version) def list_layers(self): @@ -1328,7 +1353,7 @@ class LambdaBackend(BaseBackend): def __init__(self, region_name, account_id): super().__init__(region_name, account_id) - self._lambdas = LambdaStorage(region_name=region_name) + self._lambdas = LambdaStorage(region_name=region_name, account_id=account_id) self._event_source_mappings = {} self._layers = LayerStorage() @@ -1367,7 +1392,12 @@ class LambdaBackend(BaseBackend): if function_name is None: raise RESTError("InvalidParameterValueException", "Missing FunctionName") - fn = LambdaFunction(spec, self.region_name, version="$LATEST") + fn = LambdaFunction( + account_id=self.account_id, + spec=spec, + region=self.region_name, + version="$LATEST", + ) self._lambdas.put_function(fn) @@ -1393,7 +1423,8 @@ class LambdaBackend(BaseBackend): raise RESTError("ResourceNotFoundException", "Invalid FunctionName") # Validate queue - for queue in sqs_backends[self.region_name].queues.values(): + sqs_backend = sqs_backends[self.account_id][self.region_name] + for queue in sqs_backend.queues.values(): if queue.queue_arn == spec["EventSourceArn"]: if queue.lambda_event_source_mappings.get("func.function_arn"): # TODO: Correct exception? @@ -1414,15 +1445,15 @@ class LambdaBackend(BaseBackend): queue.lambda_event_source_mappings[esm.function_arn] = esm return esm - for stream in json.loads( - dynamodbstreams_backends[self.region_name].list_streams() - )["Streams"]: + ddbstream_backend = dynamodbstreams_backends[self.account_id][self.region_name] + ddb_backend = dynamodb_backends[self.account_id][self.region_name] + for stream in json.loads(ddbstream_backend.list_streams())["Streams"]: if stream["StreamArn"] == spec["EventSourceArn"]: spec.update({"FunctionArn": func.function_arn}) esm = EventSourceMapping(spec) self._event_source_mappings[esm.uuid] = esm table_name = stream["TableName"] - table = dynamodb_backends[self.region_name].get_table(table_name) + table = ddb_backend.get_table(table_name) table.lambda_event_source_mappings[esm.function_arn] = esm return esm raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn") @@ -1432,7 +1463,9 @@ class LambdaBackend(BaseBackend): for param in required: if not spec.get(param): raise InvalidParameterValueException("Missing {}".format(param)) - layer_version = LayerVersion(spec, self.region_name) + layer_version = LayerVersion( + spec, account_id=self.account_id, region=self.region_name + ) self._layers.put_layer_version(layer_version) return layer_version @@ -1592,7 +1625,7 @@ class LambdaBackend(BaseBackend): ): data = { "messageType": "DATA_MESSAGE", - "owner": get_account_id(), + "owner": self.account_id, "logGroup": log_group_name, "logStream": log_stream_name, "subscriptionFilters": [filter_name], diff --git a/moto/awslambda/responses.py b/moto/awslambda/responses.py index c1d2f4a0a..ca355d583 100644 --- a/moto/awslambda/responses.py +++ b/moto/awslambda/responses.py @@ -9,6 +9,9 @@ from .models import lambda_backends class LambdaResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="awslambda") + @property def json_body(self): """ @@ -18,13 +21,8 @@ class LambdaResponse(BaseResponse): return json.loads(self.body) @property - def lambda_backend(self): - """ - Get backend - :return: Lambda Backend - :rtype: moto.awslambda.models.LambdaBackend - """ - return lambda_backends[self.region] + def backend(self): + return lambda_backends[self.current_account][self.region] def root(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -195,13 +193,13 @@ class LambdaResponse(BaseResponse): function_name = unquote(path.split("/")[-2]) qualifier = self.querystring.get("Qualifier", [None])[0] statement = self.body - self.lambda_backend.add_permission(function_name, qualifier, statement) + self.backend.add_permission(function_name, qualifier, statement) return 200, {}, json.dumps({"Statement": statement}) def _get_policy(self, request): path = request.path if hasattr(request, "path") else path_url(request.url) function_name = unquote(path.split("/")[-2]) - out = self.lambda_backend.get_policy(function_name) + out = self.backend.get_policy(function_name) return 200, {}, out def _del_policy(self, request, querystring): @@ -209,8 +207,8 @@ class LambdaResponse(BaseResponse): function_name = unquote(path.split("/")[-3]) statement_id = path.split("/")[-1].split("?")[0] revision = querystring.get("RevisionId", "") - if self.lambda_backend.get_function(function_name): - self.lambda_backend.remove_permission(function_name, statement_id, revision) + if self.backend.get_function(function_name): + self.backend.remove_permission(function_name, statement_id, revision) return 204, {}, "{}" else: return 404, {}, "{}" @@ -222,7 +220,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("qualifier") - payload = self.lambda_backend.invoke( + payload = self.backend.invoke( function_name, qualifier, self.body, self.headers, response_headers ) if payload: @@ -254,7 +252,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.rsplit("/", 3)[-3]) - fn = self.lambda_backend.get_function(function_name, None) + fn = self.backend.get_function(function_name, None) payload = fn.invoke(self.body, self.headers, response_headers) response_headers["Content-Length"] = str(len(payload)) return 202, response_headers, payload @@ -264,7 +262,7 @@ class LambdaResponse(BaseResponse): func_version = querystring.get("FunctionVersion", [None])[0] result = {"Functions": []} - for fn in self.lambda_backend.list_functions(func_version): + for fn in self.backend.list_functions(func_version): json_data = fn.get_configuration() result["Functions"].append(json_data) @@ -273,7 +271,7 @@ class LambdaResponse(BaseResponse): def _list_versions_by_function(self, function_name): result = {"Versions": []} - functions = self.lambda_backend.list_versions_by_function(function_name) + functions = self.backend.list_versions_by_function(function_name) if functions: for fn in functions: json_data = fn.get_configuration() @@ -282,38 +280,36 @@ class LambdaResponse(BaseResponse): return 200, {}, json.dumps(result) def _create_function(self): - fn = self.lambda_backend.create_function(self.json_body) + fn = self.backend.create_function(self.json_body) config = fn.get_configuration(on_create=True) return 201, {}, json.dumps(config) def _create_event_source_mapping(self): - fn = self.lambda_backend.create_event_source_mapping(self.json_body) + fn = self.backend.create_event_source_mapping(self.json_body) config = fn.get_configuration() return 201, {}, json.dumps(config) def _list_event_source_mappings(self, event_source_arn, function_name): - esms = self.lambda_backend.list_event_source_mappings( - event_source_arn, function_name - ) + esms = self.backend.list_event_source_mappings(event_source_arn, function_name) result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]} return 200, {}, json.dumps(result) def _get_event_source_mapping(self, uuid): - result = self.lambda_backend.get_event_source_mapping(uuid) + result = self.backend.get_event_source_mapping(uuid) if result: return 200, {}, json.dumps(result.get_configuration()) else: return 404, {}, "{}" def _update_event_source_mapping(self, uuid): - result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body) + result = self.backend.update_event_source_mapping(uuid, self.json_body) if result: return 202, {}, json.dumps(result.get_configuration()) else: return 404, {}, "{}" def _delete_event_source_mapping(self, uuid): - esm = self.lambda_backend.delete_event_source_mapping(uuid) + esm = self.backend.delete_event_source_mapping(uuid) if esm: json_result = esm.get_configuration() json_result.update({"State": "Deleting"}) @@ -325,7 +321,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.split("/")[-2]) description = self._get_param("Description") - fn = self.lambda_backend.publish_function(function_name, description) + fn = self.backend.publish_function(function_name, description) config = fn.get_configuration() return 201, {}, json.dumps(config) @@ -333,7 +329,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.rsplit("/", 1)[-1]) qualifier = self._get_param("Qualifier", None) - self.lambda_backend.delete_function(function_name, qualifier) + self.backend.delete_function(function_name, qualifier) return 204, {}, "" @staticmethod @@ -348,7 +344,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.rsplit("/", 1)[-1]) qualifier = self._get_param("Qualifier", None) - fn = self.lambda_backend.get_function(function_name, qualifier) + fn = self.backend.get_function(function_name, qualifier) code = fn.get_code() code["Configuration"] = self._set_configuration_qualifier( @@ -360,7 +356,7 @@ class LambdaResponse(BaseResponse): function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) - fn = self.lambda_backend.get_function(function_name, qualifier) + fn = self.backend.get_function(function_name, qualifier) configuration = self._set_configuration_qualifier( fn.get_configuration(), qualifier @@ -377,26 +373,26 @@ class LambdaResponse(BaseResponse): def _list_tags(self): function_arn = unquote(self.path.rsplit("/", 1)[-1]) - tags = self.lambda_backend.list_tags(function_arn) + tags = self.backend.list_tags(function_arn) return 200, {}, json.dumps({"Tags": tags}) def _tag_resource(self): function_arn = unquote(self.path.rsplit("/", 1)[-1]) - self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]) + self.backend.tag_resource(function_arn, self.json_body["Tags"]) return 200, {}, "{}" def _untag_resource(self): function_arn = unquote(self.path.rsplit("/", 1)[-1]) tag_keys = self.querystring["tagKeys"] - self.lambda_backend.untag_resource(function_arn, tag_keys) + self.backend.untag_resource(function_arn, tag_keys) return 204, {}, "{}" def _put_configuration(self): function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) - resp = self.lambda_backend.update_function_configuration( + resp = self.backend.update_function_configuration( function_name, qualifier, body=self.json_body ) @@ -408,7 +404,7 @@ class LambdaResponse(BaseResponse): def _put_code(self): function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) - resp = self.lambda_backend.update_function_code( + resp = self.backend.update_function_code( function_name, qualifier, body=self.json_body ) @@ -419,65 +415,63 @@ class LambdaResponse(BaseResponse): def _get_code_signing_config(self): function_name = unquote(self.path.rsplit("/", 2)[-2]) - resp = self.lambda_backend.get_code_signing_config(function_name) + resp = self.backend.get_code_signing_config(function_name) return 200, {}, json.dumps(resp) def _get_function_concurrency(self): path_function_name = unquote(self.path.rsplit("/", 2)[-2]) - function_name = self.lambda_backend.get_function(path_function_name) + function_name = self.backend.get_function(path_function_name) if function_name is None: return 404, {}, "{}" - resp = self.lambda_backend.get_function_concurrency(path_function_name) + resp = self.backend.get_function_concurrency(path_function_name) return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp}) def _delete_function_concurrency(self): path_function_name = unquote(self.path.rsplit("/", 2)[-2]) - function_name = self.lambda_backend.get_function(path_function_name) + function_name = self.backend.get_function(path_function_name) if function_name is None: return 404, {}, "{}" - self.lambda_backend.delete_function_concurrency(path_function_name) + self.backend.delete_function_concurrency(path_function_name) return 204, {}, "{}" def _put_function_concurrency(self): path_function_name = unquote(self.path.rsplit("/", 2)[-2]) - function = self.lambda_backend.get_function(path_function_name) + function = self.backend.get_function(path_function_name) if function is None: return 404, {}, "{}" concurrency = self._get_param("ReservedConcurrentExecutions", None) - resp = self.lambda_backend.put_function_concurrency( - path_function_name, concurrency - ) + resp = self.backend.put_function_concurrency(path_function_name, concurrency) return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp}) def _list_layers(self): - layers = self.lambda_backend.list_layers() + layers = self.backend.list_layers() return 200, {}, json.dumps({"Layers": layers}) def _delete_layer_version(self): layer_name = self.path.split("/")[-3] layer_version = self.path.split("/")[-1] - self.lambda_backend.delete_layer_version(layer_name, layer_version) + self.backend.delete_layer_version(layer_name, layer_version) return 200, {}, "{}" def _get_layer_version(self): layer_name = self.path.split("/")[-3] layer_version = self.path.split("/")[-1] - layer = self.lambda_backend.get_layer_version(layer_name, layer_version) + layer = self.backend.get_layer_version(layer_name, layer_version) return 200, {}, json.dumps(layer.get_layer_version()) def _get_layer_versions(self): layer_name = self.path.rsplit("/", 2)[-2] - layer_versions = self.lambda_backend.get_layer_versions(layer_name) + layer_versions = self.backend.get_layer_versions(layer_name) return ( 200, {}, @@ -490,7 +484,7 @@ class LambdaResponse(BaseResponse): spec = self.json_body if "LayerName" not in spec: spec["LayerName"] = self.path.rsplit("/", 2)[-2] - layer_version = self.lambda_backend.publish_layer_version(spec) + layer_version = self.backend.publish_layer_version(spec) config = layer_version.get_layer_version() return 201, {}, json.dumps(config) @@ -501,7 +495,7 @@ class LambdaResponse(BaseResponse): description = params.get("Description", "") function_version = params.get("FunctionVersion") routing_config = params.get("RoutingConfig") - alias = self.lambda_backend.create_alias( + alias = self.backend.create_alias( name=alias_name, function_name=function_name, function_version=function_version, @@ -513,15 +507,13 @@ class LambdaResponse(BaseResponse): def _delete_alias(self): function_name = unquote(self.path.rsplit("/")[-3]) alias_name = unquote(self.path.rsplit("/", 2)[-1]) - self.lambda_backend.delete_alias(name=alias_name, function_name=function_name) + self.backend.delete_alias(name=alias_name, function_name=function_name) return 201, {}, "{}" def _get_alias(self): function_name = unquote(self.path.rsplit("/")[-3]) alias_name = unquote(self.path.rsplit("/", 2)[-1]) - alias = self.lambda_backend.get_alias( - name=alias_name, function_name=function_name - ) + alias = self.backend.get_alias(name=alias_name, function_name=function_name) return 201, {}, json.dumps(alias.to_json()) def _update_alias(self): @@ -531,7 +523,7 @@ class LambdaResponse(BaseResponse): description = params.get("Description") function_version = params.get("FunctionVersion") routing_config = params.get("RoutingConfig") - alias = self.lambda_backend.update_alias( + alias = self.backend.update_alias( name=alias_name, function_name=function_name, function_version=function_version, diff --git a/moto/backends.py b/moto/backends.py index 048b56851..ff4137487 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -26,8 +26,9 @@ def backends(): yield _import_backend(module_name, backends_name) -def unique_backends(): - for module_name, backends_name in sorted(set(BACKENDS.values())): +def service_backends(): + services = [(f.name, f.backend) for f in decorator_functions] + for module_name, backends_name in sorted(set(services)): yield _import_backend(module_name, backends_name) diff --git a/moto/batch/models.py b/moto/batch/models.py index 8afe18dc6..cafb5db1c 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -9,7 +9,7 @@ import threading import dateutil.parser from sys import platform -from moto.core import BaseBackend, BaseModel, CloudFormationModel, get_account_id +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.iam import iam_backends from moto.ec2 import ec2_backends from moto.ecs import ecs_backends @@ -60,6 +60,7 @@ class ComputeEnvironment(CloudFormationModel): state, compute_resources, service_role, + account_id, region_name, ): self.name = compute_environment_name @@ -68,7 +69,7 @@ class ComputeEnvironment(CloudFormationModel): self.compute_resources = compute_resources self.service_role = service_role self.arn = make_arn_for_compute_env( - get_account_id(), compute_environment_name, region_name + account_id, compute_environment_name, region_name ) self.instances = [] @@ -97,9 +98,9 @@ class ComputeEnvironment(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - backend = batch_backends[region_name] + backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] env = backend.create_compute_environment( @@ -122,7 +123,6 @@ class JobQueue(CloudFormationModel): state, environments, env_order_json, - region_name, backend, tags=None, ): @@ -137,15 +137,13 @@ class JobQueue(CloudFormationModel): :type environments: list of ComputeEnvironment :param env_order_json: Compute Environments JSON for use when describing :type env_order_json: list of dict - :param region_name: Region name - :type region_name: str """ self.name = name self.priority = priority self.state = state self.environments = environments self.env_order_json = env_order_json - self.arn = make_arn_for_job_queue(get_account_id(), name, region_name) + self.arn = make_arn_for_job_queue(backend.account_id, name, backend.region_name) self.status = "VALID" self.backend = backend @@ -182,9 +180,9 @@ class JobQueue(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - backend = batch_backends[region_name] + backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] # Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole @@ -212,7 +210,6 @@ class JobDefinition(CloudFormationModel): parameters, _type, container_properties, - region_name, tags=None, revision=0, retry_strategy=0, @@ -225,7 +222,7 @@ class JobDefinition(CloudFormationModel): self.retry_strategy = retry_strategy self.type = _type self.revision = revision - self._region = region_name + self._region = backend.region_name self.container_properties = container_properties self.arn = None self.status = "ACTIVE" @@ -257,7 +254,7 @@ class JobDefinition(CloudFormationModel): def _update_arn(self): self.revision += 1 self.arn = make_arn_for_task_def( - get_account_id(), self.name, self.revision, self._region + self.backend.account_id, self.name, self.revision, self._region ) def _get_resource_requirement(self, req_type, default=None): @@ -347,7 +344,6 @@ class JobDefinition(CloudFormationModel): parameters, _type, container_properties, - region_name=self._region, revision=self.revision, retry_strategy=retry_strategy, tags=tags, @@ -392,9 +388,9 @@ class JobDefinition(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - backend = batch_backends[region_name] + backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] res = backend.register_job_definition( def_name=resource_name, @@ -844,7 +840,7 @@ class BatchBackend(BaseBackend): :return: IAM Backend :rtype: moto.iam.models.IAMBackend """ - return iam_backends["global"] + return iam_backends[self.account_id]["global"] @property def ec2_backend(self): @@ -852,7 +848,7 @@ class BatchBackend(BaseBackend): :return: EC2 Backend :rtype: moto.ec2.models.EC2Backend """ - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] @property def ecs_backend(self): @@ -860,7 +856,7 @@ class BatchBackend(BaseBackend): :return: ECS Backend :rtype: moto.ecs.models.EC2ContainerServiceBackend """ - return ecs_backends[self.region_name] + return ecs_backends[self.account_id][self.region_name] @property def logs_backend(self): @@ -868,7 +864,7 @@ class BatchBackend(BaseBackend): :return: ECS Backend :rtype: moto.logs.models.LogsBackend """ - return logs_backends[self.region_name] + return logs_backends[self.account_id][self.region_name] def reset(self): for job in self._jobs.values(): @@ -1077,6 +1073,7 @@ class BatchBackend(BaseBackend): state, compute_resources, service_role, + account_id=self.account_id, region_name=self.region_name, ) self._compute_environments[new_comp_env.arn] = new_comp_env @@ -1344,7 +1341,6 @@ class BatchBackend(BaseBackend): state, env_objects, compute_env_order, - self.region_name, backend=self, tags=tags, ) @@ -1450,7 +1446,6 @@ class BatchBackend(BaseBackend): _type, container_properties, tags=tags, - region_name=self.region_name, retry_strategy=retry_strategy, timeout=timeout, backend=self, diff --git a/moto/batch/responses.py b/moto/batch/responses.py index bd7ec5e55..ee6c97c7f 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -6,6 +6,9 @@ import json class BatchResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="batch") + def _error(self, code, message): return json.dumps({"__type": code, "message": message}), dict(status=400) @@ -15,7 +18,7 @@ class BatchResponse(BaseResponse): :return: Batch Backend :rtype: moto.batch.models.BatchBackend """ - return batch_backends[self.region] + return batch_backends[self.current_account][self.region] @property def json(self): diff --git a/moto/batch_simple/models.py b/moto/batch_simple/models.py index 06ffdfd0a..e5afc85aa 100644 --- a/moto/batch_simple/models.py +++ b/moto/batch_simple/models.py @@ -12,7 +12,7 @@ class BatchSimpleBackend(BaseBackend): @property def backend(self): - return batch_backends[self.region_name] + return batch_backends[self.account_id][self.region_name] def __getattribute__(self, name): """ @@ -22,6 +22,7 @@ class BatchSimpleBackend(BaseBackend): """ if name in [ "backend", + "account_id", "region_name", "urls", "_url_module", diff --git a/moto/batch_simple/responses.py b/moto/batch_simple/responses.py index 789fa4453..f882b77c3 100644 --- a/moto/batch_simple/responses.py +++ b/moto/batch_simple/responses.py @@ -9,4 +9,4 @@ class BatchSimpleResponse(BatchResponse): :return: Batch Backend :rtype: moto.batch.models.BatchBackend """ - return batch_simple_backends[self.region] + return batch_simple_backends[self.current_account][self.region] diff --git a/moto/budgets/responses.py b/moto/budgets/responses.py index eb0dfbfa2..c60e92aaa 100644 --- a/moto/budgets/responses.py +++ b/moto/budgets/responses.py @@ -5,9 +5,12 @@ from .models import budgets_backends class BudgetsResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="budgets") + @property def backend(self): - return budgets_backends["global"] + return budgets_backends[self.current_account]["global"] def create_budget(self): account_id = self._get_param("AccountId") diff --git a/moto/ce/models.py b/moto/ce/models.py index 848e3a8b3..beee9c424 100644 --- a/moto/ce/models.py +++ b/moto/ce/models.py @@ -1,19 +1,21 @@ """CostExplorerBackend class with methods for supported APIs.""" from .exceptions import CostCategoryNotFound -from moto.core import ACCOUNT_ID, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from uuid import uuid4 class CostCategoryDefinition(BaseModel): - def __init__(self, name, rule_version, rules, default_value, split_charge_rules): + def __init__( + self, account_id, name, rule_version, rules, default_value, split_charge_rules + ): self.name = name self.rule_version = rule_version self.rules = rules self.default_value = default_value self.split_charge_rules = split_charge_rules - self.arn = f"arn:aws:ce::{ACCOUNT_ID}:costcategory/{str(uuid4())}" + self.arn = f"arn:aws:ce::{account_id}:costcategory/{str(uuid4())}" def update(self, rule_version, rules, default_value, split_charge_rules): self.rule_version = rule_version @@ -51,7 +53,12 @@ class CostExplorerBackend(BaseBackend): The EffectiveOn and ResourceTags-parameters are not yet implemented """ ccd = CostCategoryDefinition( - name, rule_version, rules, default_value, split_charge_rules + self.account_id, + name, + rule_version, + rules, + default_value, + split_charge_rules, ) self.cost_categories[ccd.arn] = ccd return ccd.arn, "" diff --git a/moto/ce/responses.py b/moto/ce/responses.py index 7f78c3ca3..5ba3b67cb 100644 --- a/moto/ce/responses.py +++ b/moto/ce/responses.py @@ -11,7 +11,7 @@ class CostExplorerResponse(BaseResponse): @property def ce_backend(self): """Return backend instance specific for this region.""" - return ce_backends["global"] + return ce_backends[self.current_account]["global"] def create_cost_category_definition(self): params = json.loads(self.body) diff --git a/moto/cloudformation/custom_model.py b/moto/cloudformation/custom_model.py index fe4ae5bc1..2b888b088 100644 --- a/moto/cloudformation/custom_model.py +++ b/moto/cloudformation/custom_model.py @@ -33,7 +33,7 @@ class CustomModel(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): logical_id = kwargs["LogicalId"] stack_id = kwargs["StackId"] @@ -41,7 +41,7 @@ class CustomModel(CloudFormationModel): properties = cloudformation_json["Properties"] service_token = properties["ServiceToken"] - backend = lambda_backends[region_name] + backend = lambda_backends[account_id][region_name] fn = backend.get_function(service_token) request_id = str(uuid4()) @@ -52,7 +52,7 @@ class CustomModel(CloudFormationModel): from moto.cloudformation import cloudformation_backends - stack = cloudformation_backends[region_name].get_stack(stack_id) + stack = cloudformation_backends[account_id][region_name].get_stack(stack_id) stack.add_custom_resource(custom_resource) # A request will be send to this URL to indicate success/failure diff --git a/moto/cloudformation/models.py b/moto/cloudformation/models.py index 387d5747c..53f769891 100644 --- a/moto/cloudformation/models.py +++ b/moto/cloudformation/models.py @@ -7,7 +7,7 @@ from collections import OrderedDict from yaml.parser import ParserError # pylint:disable=c-extension-no-member from yaml.scanner import ScannerError # pylint:disable=c-extension-no-member -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import ( iso_8601_datetime_with_milliseconds, iso_8601_datetime_without_milliseconds, @@ -31,6 +31,7 @@ class FakeStackSet(BaseModel): def __init__( self, stackset_id, + account_id, name, template, region="us-east-1", @@ -42,13 +43,14 @@ class FakeStackSet(BaseModel): execution_role="AWSCloudFormationStackSetExecutionRole", ): self.id = stackset_id - self.arn = generate_stackset_arn(stackset_id, region) + self.arn = generate_stackset_arn(stackset_id, region, account_id) self.name = name self.template = template self.description = description self.parameters = parameters self.tags = tags self.admin_role = admin_role + self.admin_role_arn = f"arn:aws:iam::{account_id}:role/{self.admin_role}" self.execution_role = execution_role self.status = status self.instances = FakeStackInstances(parameters, self.id, self.name) @@ -218,6 +220,7 @@ class FakeStack(BaseModel): name, template, parameters, + account_id, region_name, notification_arns=None, tags=None, @@ -226,6 +229,7 @@ class FakeStack(BaseModel): ): self.stack_id = stack_id self.name = name + self.account_id = account_id self.template = template if template != {}: self._parse_template() @@ -267,9 +271,10 @@ class FakeStack(BaseModel): self.name, self.parameters, self.tags, - self.region_name, - self.template_dict, - self.cross_stack_resources, + account_id=self.account_id, + region_name=self.region_name, + template=self.template_dict, + cross_stack_resources=self.cross_stack_resources, ) resource_map.load() return resource_map @@ -296,7 +301,7 @@ class FakeStack(BaseModel): resource_properties=resource_properties, ) - event.sendToSns(self.region_name, self.notification_arns) + event.sendToSns(self.account_id, self.region_name, self.notification_arns) self.events.append(event) def _add_resource_event( @@ -486,7 +491,7 @@ class FakeEvent(BaseModel): self.event_id = uuid.uuid4() self.client_request_token = client_request_token - def sendToSns(self, region, sns_topic_arns): + def sendToSns(self, account_id, region, sns_topic_arns): message = """StackId='{stack_id}' Timestamp='{timestamp}' EventId='{event_id}' @@ -502,7 +507,7 @@ ClientRequestToken='{client_request_token}'""".format( timestamp=iso_8601_datetime_with_milliseconds(self.timestamp), event_id=self.event_id, logical_resource_id=self.logical_resource_id, - account_id=get_account_id(), + account_id=account_id, resource_properties=self.resource_properties, resource_status=self.resource_status, resource_status_reason=self.resource_status_reason, @@ -512,7 +517,7 @@ ClientRequestToken='{client_request_token}'""".format( ) for sns_topic_arn in sns_topic_arns: - sns_backends[region].publish( + sns_backends[account_id][region].publish( message, subject="AWS CloudFormation Notification", arn=sns_topic_arn ) @@ -584,6 +589,7 @@ class CloudFormationBackend(BaseBackend): stackset_id = generate_stackset_id(name) new_stackset = FakeStackSet( stackset_id=stackset_id, + account_id=self.account_id, name=name, template=template, parameters=parameters, @@ -671,12 +677,13 @@ class CloudFormationBackend(BaseBackend): tags=None, role_arn=None, ): - stack_id = generate_stack_id(name, self.region_name) + stack_id = generate_stack_id(name, self.region_name, self.account_id) new_stack = FakeStack( stack_id=stack_id, name=name, template=template, parameters=parameters, + account_id=self.account_id, region_name=self.region_name, notification_arns=notification_arns, tags=tags, @@ -712,12 +719,13 @@ class CloudFormationBackend(BaseBackend): else: raise ValidationError(stack_name) else: - stack_id = generate_stack_id(stack_name, self.region_name) + stack_id = generate_stack_id(stack_name, self.region_name, self.account_id) stack = FakeStack( stack_id=stack_id, name=stack_name, template={}, parameters=parameters, + account_id=self.account_id, region_name=self.region_name, notification_arns=notification_arns, tags=tags, @@ -729,7 +737,9 @@ class CloudFormationBackend(BaseBackend): "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated" ) - change_set_id = generate_changeset_id(change_set_name, self.region_name) + change_set_id = generate_changeset_id( + change_set_name, self.region_name, self.account_id + ) new_change_set = FakeChangeSet( change_set_type=change_set_type, diff --git a/moto/cloudformation/parsing.py b/moto/cloudformation/parsing.py index 8302c4df1..b38aa7e1b 100644 --- a/moto/cloudformation/parsing.py +++ b/moto/cloudformation/parsing.py @@ -46,7 +46,7 @@ from moto.ssm import models # noqa # pylint: disable=all # End ugly list of imports -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from moto.s3.models import s3_backends from moto.s3.utils import bucket_and_name_from_url from moto.ssm import ssm_backends @@ -317,7 +317,9 @@ def parse_resource_and_generate_name(logical_id, resource_json, resources_map): return resource_class, resource_json, resource_name -def parse_and_create_resource(logical_id, resource_json, resources_map, region_name): +def parse_and_create_resource( + logical_id, resource_json, resources_map, account_id, region_name +): condition = resource_json.get("Condition") if condition and not resources_map.lazy_condition_map[condition]: # If this has a False condition, don't create the resource @@ -336,14 +338,16 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n "ResourceType": resource_type, } resource = resource_class.create_from_cloudformation_json( - resource_physical_name, resource_json, region_name, **kwargs + resource_physical_name, resource_json, account_id, region_name, **kwargs ) resource.type = resource_type resource.logical_resource_id = logical_id return resource -def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): +def parse_and_update_resource( + logical_id, resource_json, resources_map, account_id, region_name +): resource_tuple = parse_resource_and_generate_name( logical_id, resource_json, resources_map ) @@ -358,6 +362,7 @@ def parse_and_update_resource(logical_id, resource_json, resources_map, region_n original_resource=original_resource, new_resource_name=new_resource_name, cloudformation_json=resource_json, + account_id=account_id, region_name=region_name, ) new_resource.type = resource_json["Type"] @@ -367,14 +372,14 @@ def parse_and_update_resource(logical_id, resource_json, resources_map, region_n return None -def parse_and_delete_resource(resource_name, resource_json, region_name): +def parse_and_delete_resource(resource_name, resource_json, account_id, region_name): resource_type = resource_json["Type"] resource_class = resource_class_from_type(resource_type) if not hasattr( resource_class.delete_from_cloudformation_json, "__isabstractmethod__" ): resource_class.delete_from_cloudformation_json( - resource_name, resource_json, region_name + resource_name, resource_json, account_id, region_name ) @@ -439,11 +444,13 @@ class ResourceMap(collections_abc.Mapping): parameters, tags, region_name, + account_id, template, cross_stack_resources, ): self._template = template self._resource_json_map = template["Resources"] if template != {} else {} + self._account_id = account_id self._region_name = region_name self.input_parameters = parameters self.tags = copy.deepcopy(tags) @@ -453,7 +460,7 @@ class ResourceMap(collections_abc.Mapping): # Create the default resources self._parsed_resources = { - "AWS::AccountId": get_account_id(), + "AWS::AccountId": account_id, "AWS::Region": self._region_name, "AWS::StackId": stack_id, "AWS::StackName": stack_name, @@ -473,7 +480,11 @@ class ResourceMap(collections_abc.Mapping): if not resource_json: raise KeyError(resource_logical_id) new_resource = parse_and_create_resource( - resource_logical_id, resource_json, self, self._region_name + resource_logical_id, + resource_json, + self, + account_id=self._account_id, + region_name=self._region_name, ) if new_resource is not None: self._parsed_resources[resource_logical_id] = new_resource @@ -528,14 +539,18 @@ class ResourceMap(collections_abc.Mapping): if name == "AWS::Include": location = params["Location"] bucket_name, name = bucket_and_name_from_url(location) - key = s3_backends["global"].get_object(bucket_name, name) + key = s3_backends[self._account_id]["global"].get_object( + bucket_name, name + ) self._parsed_resources.update(json.loads(key.value)) def parse_ssm_parameter(self, value, value_type): # The Value in SSM parameters is the SSM parameter path # we need to use ssm_backend to retrieve the # actual value from parameter store - parameter = ssm_backends[self._region_name].get_parameter(value) + parameter = ssm_backends[self._account_id][self._region_name].get_parameter( + value + ) actual_value = parameter.value if value_type.find("List") > 0: return actual_value.split(",") @@ -646,9 +661,9 @@ class ResourceMap(collections_abc.Mapping): instance = self[resource] if isinstance(instance, TaggedEC2Resource): self.tags["aws:cloudformation:logical-id"] = resource - ec2_models.ec2_backends[self._region_name].create_tags( - [instance.physical_resource_id], self.tags - ) + ec2_models.ec2_backends[self._account_id][ + self._region_name + ].create_tags([instance.physical_resource_id], self.tags) if instance and not instance.is_created(): all_resources_ready = False return all_resources_ready @@ -716,7 +731,9 @@ class ResourceMap(collections_abc.Mapping): ].physical_resource_id else: resource_name = None - parse_and_delete_resource(resource_name, resource_json, self._region_name) + parse_and_delete_resource( + resource_name, resource_json, self._account_id, self._region_name + ) self._parsed_resources.pop(logical_name) self._template = template @@ -740,7 +757,11 @@ class ResourceMap(collections_abc.Mapping): resource_json = self._resource_json_map[logical_name] try: changed_resource = parse_and_update_resource( - logical_name, resource_json, self, self._region_name + logical_name, + resource_json, + self, + account_id=self._account_id, + region_name=self._region_name, ) except Exception as e: # skip over dependency violations, and try again in a @@ -765,7 +786,7 @@ class ResourceMap(collections_abc.Mapping): and parsed_resource is not None ): if parsed_resource and hasattr(parsed_resource, "delete"): - parsed_resource.delete(self._region_name) + parsed_resource.delete(self._account_id, self._region_name) else: if hasattr(parsed_resource, "physical_resource_id"): resource_name = parsed_resource.physical_resource_id @@ -777,7 +798,10 @@ class ResourceMap(collections_abc.Mapping): ] parse_and_delete_resource( - resource_name, resource_json, self._region_name + resource_name, + resource_json, + self._account_id, + self._region_name, ) self._parsed_resources.pop(parsed_resource.logical_resource_id) diff --git a/moto/cloudformation/responses.py b/moto/cloudformation/responses.py index 36da31b72..1bb19baf8 100644 --- a/moto/cloudformation/responses.py +++ b/moto/cloudformation/responses.py @@ -8,7 +8,6 @@ from moto.core.responses import BaseResponse from moto.core.utils import amzn_request_id from moto.s3.models import s3_backends from moto.s3.exceptions import S3ClientError -from moto.core import get_account_id from .models import cloudformation_backends from .exceptions import ValidationError, MissingParameterError from .utils import yaml_tag_constructor @@ -39,9 +38,12 @@ def get_template_summary_response_from_template(template_body): class CloudFormationResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="cloudformation") + @property def cloudformation_backend(self): - return cloudformation_backends[self.region] + return cloudformation_backends[self.current_account][self.region] @classmethod def cfnresponse(cls, *args, **kwargs): # pylint: disable=unused-argument @@ -68,7 +70,9 @@ class CloudFormationResponse(BaseResponse): bucket_name = template_url_parts.netloc.split(".")[0] key_name = template_url_parts.path.lstrip("/") - key = s3_backends["global"].get_object(bucket_name, key_name) + key = s3_backends[self.current_account]["global"].get_object( + bucket_name, key_name + ) return key.value.decode("utf-8") def _get_params_from_list(self, parameters_list): @@ -515,9 +519,7 @@ class CloudFormationResponse(BaseResponse): stackset = self.cloudformation_backend.get_stack_set(stackset_name) if not stackset.admin_role: - stackset.admin_role = "arn:aws:iam::{AccountId}:role/AWSCloudFormationStackSetAdministrationRole".format( - AccountId=get_account_id() - ) + stackset.admin_role = f"arn:aws:iam::{self.current_account}:role/AWSCloudFormationStackSetAdministrationRole" if not stackset.execution_role: stackset.execution_role = "AWSCloudFormationStackSetExecutionRole" @@ -1169,14 +1171,11 @@ STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE = """ """ -DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = ( - """ +DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """ {{ stackset.execution_role }} - arn:aws:iam::""" - + get_account_id() - + """:role/{{ stackset.admin_role }} + {{ stackset.admin_role_arn }} {{ stackset.id }} {{ operation.CreationTimestamp }} {{ operation.OperationId }} @@ -1193,19 +1192,15 @@ DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = ( """ -) -LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = ( - """ +LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """ {% for instance in operation.Instances %} {% for account, region in instance.items() %} - Function not found: arn:aws:lambda:us-west-2:""" - + get_account_id() - + """:function:AWSCloudFormationStackSetAccountGate + Function not found: arn:aws:lambda:us-west-2:{{ account }}:function:AWSCloudFormationStackSetAccountGate SKIPPED {{ region }} @@ -1221,7 +1216,6 @@ LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = ( """ -) # https://docs.aws.amazon.com/AWSCloudFormation/latest/APIReference/API_GetTemplateSummary.html # TODO:implement fields: ResourceIdentifierSummaries, Capabilities, CapabilitiesReason diff --git a/moto/cloudformation/utils.py b/moto/cloudformation/utils.py index 024cd78b0..e31273063 100644 --- a/moto/cloudformation/utils.py +++ b/moto/cloudformation/utils.py @@ -4,21 +4,15 @@ import yaml import os import string -from moto.core import get_account_id - -def generate_stack_id(stack_name, region="us-east-1", account=get_account_id()): +def generate_stack_id(stack_name, region, account): random_id = uuid.uuid4() - return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format( - region, account, stack_name, random_id - ) + return f"arn:aws:cloudformation:{region}:{account}:stack/{stack_name}/{random_id}" -def generate_changeset_id(changeset_name, region_name): +def generate_changeset_id(changeset_name, region_name, account_id): random_id = uuid.uuid4() - return "arn:aws:cloudformation:{0}:{1}:changeSet/{2}/{3}".format( - region_name, get_account_id(), changeset_name, random_id - ) + return f"arn:aws:cloudformation:{region_name}:{account_id}:changeSet/{changeset_name}/{random_id}" def generate_stackset_id(stackset_name): @@ -26,10 +20,8 @@ def generate_stackset_id(stackset_name): return "{}:{}".format(stackset_name, random_id) -def generate_stackset_arn(stackset_id, region_name): - return "arn:aws:cloudformation:{}:{}:stackset/{}".format( - region_name, get_account_id(), stackset_id - ) +def generate_stackset_arn(stackset_id, region_name, account_id): + return f"arn:aws:cloudformation:{region_name}:{account_id}:stackset/{stackset_id}" def random_suffix(): diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py index 2215e9323..07a1f6726 100644 --- a/moto/cloudfront/models.py +++ b/moto/cloudfront/models.py @@ -2,7 +2,7 @@ import random import string from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, iso_8601_datetime_with_milliseconds from moto.moto_api import state_manager from moto.moto_api._internal.managed_state_model import ManagedState @@ -181,7 +181,7 @@ class Distribution(BaseModel, ManagedState): ) return resource_id - def __init__(self, config): + def __init__(self, account_id, config): # Configured ManagedState super().__init__( "cloudfront::distribution", transitions=[("InProgress", "Deployed")] @@ -189,7 +189,7 @@ class Distribution(BaseModel, ManagedState): # Configure internal properties self.distribution_id = Distribution.random_id() self.arn = ( - f"arn:aws:cloudfront:{get_account_id()}:distribution/{self.distribution_id}" + f"arn:aws:cloudfront:{account_id}:distribution/{self.distribution_id}" ) self.distribution_config = DistributionConfig(config) self.active_trusted_signers = ActiveTrustedSigners() @@ -247,7 +247,7 @@ class CloudFrontBackend(BaseBackend): we're not persisting/returning the correct attributes for your use-case. """ - dist = Distribution(distribution_config) + dist = Distribution(self.account_id, distribution_config) caller_reference = dist.distribution_config.caller_reference existing_dist = self._distribution_with_caller_reference(caller_reference) if existing_dist: diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py index f5df5ff01..803bd2ec8 100644 --- a/moto/cloudfront/responses.py +++ b/moto/cloudfront/responses.py @@ -9,12 +9,15 @@ XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" class CloudFrontResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="cloudfront") + def _get_xml_body(self): return xmltodict.parse(self.body, dict_constructor=dict) @property def backend(self): - return cloudfront_backends["global"] + return cloudfront_backends[self.current_account]["global"] def distributions(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/cloudtrail/exceptions.py b/moto/cloudtrail/exceptions.py index 6fe2b7654..5c26a47e3 100644 --- a/moto/cloudtrail/exceptions.py +++ b/moto/cloudtrail/exceptions.py @@ -1,5 +1,4 @@ """Exceptions raised by the cloudtrail service.""" -from moto.core import get_account_id from moto.core.exceptions import JsonRESTError @@ -27,10 +26,10 @@ class InsufficientSnsTopicPolicyException(JsonRESTError): class TrailNotFoundException(JsonRESTError): code = 400 - def __init__(self, name): + def __init__(self, account_id, name): super().__init__( "TrailNotFoundException", - f"Unknown trail: {name} for the user: {get_account_id()}", + f"Unknown trail: {name} for the user: {account_id}", ) diff --git a/moto/cloudtrail/models.py b/moto/cloudtrail/models.py index 86db63ae1..aa604f6e5 100644 --- a/moto/cloudtrail/models.py +++ b/moto/cloudtrail/models.py @@ -2,7 +2,7 @@ import re import time from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_without_milliseconds, BackendDict from moto.utilities.tagging_service import TaggingService from .exceptions import ( @@ -74,6 +74,7 @@ class TrailStatus(object): class Trail(BaseModel): def __init__( self, + account_id, region_name, trail_name, bucket_name, @@ -87,6 +88,7 @@ class Trail(BaseModel): cw_role_arn, kms_key_id, ): + self.account_id = account_id self.region_name = region_name self.trail_name = trail_name self.bucket_name = bucket_name @@ -109,12 +111,12 @@ class Trail(BaseModel): @property def arn(self): - return f"arn:aws:cloudtrail:{self.region_name}:{get_account_id()}:trail/{self.trail_name}" + return f"arn:aws:cloudtrail:{self.region_name}:{self.account_id}:trail/{self.trail_name}" @property def topic_arn(self): if self.sns_topic_name: - return f"arn:aws:sns:{self.region_name}:{get_account_id()}:{self.sns_topic_name}" + return f"arn:aws:sns:{self.region_name}:{self.account_id}:{self.sns_topic_name}" return None def check_name(self): @@ -133,7 +135,7 @@ class Trail(BaseModel): from moto.s3.models import s3_backends try: - s3_backends["global"].get_bucket(self.bucket_name) + s3_backends[self.account_id]["global"].get_bucket(self.bucket_name) except Exception: raise S3BucketDoesNotExistException( f"S3 bucket {self.bucket_name} does not exist!" @@ -143,7 +145,7 @@ class Trail(BaseModel): if self.sns_topic_name: from moto.sns import sns_backends - sns_backend = sns_backends[self.region_name] + sns_backend = sns_backends[self.account_id][self.region_name] try: sns_backend.get_topic(self.topic_arn) except Exception: @@ -263,6 +265,7 @@ class CloudTrailBackend(BaseBackend): tags_list, ): trail = Trail( + self.account_id, self.region_name, name, bucket_name, @@ -288,7 +291,7 @@ class CloudTrailBackend(BaseBackend): for trail in self.trails.values(): if trail.arn == name_or_arn: return trail - raise TrailNotFoundException(name_or_arn) + raise TrailNotFoundException(account_id=self.account_id, name=name_or_arn) def get_trail_status(self, name): if len(name) < 3: @@ -304,9 +307,9 @@ class CloudTrailBackend(BaseBackend): if not trail_name: # This particular method returns the ARN as part of the error message arn = ( - f"arn:aws:cloudtrail:{self.region_name}:{get_account_id()}:trail/{name}" + f"arn:aws:cloudtrail:{self.region_name}:{self.account_id}:trail/{name}" ) - raise TrailNotFoundException(name=arn) + raise TrailNotFoundException(account_id=self.account_id, name=arn) trail = self.trails[trail_name] return trail.status diff --git a/moto/cloudtrail/responses.py b/moto/cloudtrail/responses.py index 9f6f0bb87..a383238dd 100644 --- a/moto/cloudtrail/responses.py +++ b/moto/cloudtrail/responses.py @@ -9,10 +9,13 @@ from .exceptions import InvalidParameterCombinationException class CloudTrailResponse(BaseResponse): """Handler for CloudTrail requests and responses.""" + def __init__(self): + super().__init__(service_name="cloudtrail") + @property def cloudtrail_backend(self): """Return backend instance specific for this region.""" - return cloudtrail_backends[self.region] + return cloudtrail_backends[self.current_account][self.region] def create_trail(self): name = self._get_param("Name") diff --git a/moto/cloudwatch/models.py b/moto/cloudwatch/models.py index 8a74db335..f8a39ef22 100644 --- a/moto/cloudwatch/models.py +++ b/moto/cloudwatch/models.py @@ -20,7 +20,6 @@ from .exceptions import ( from .utils import make_arn_for_dashboard, make_arn_for_alarm from dateutil import parser -from moto.core import get_account_id from ..utilities.tagging_service import TaggingService _EMPTY_LIST = tuple() @@ -103,6 +102,7 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False): class FakeAlarm(BaseModel): def __init__( self, + account_id, region_name, name, namespace, @@ -129,7 +129,7 @@ class FakeAlarm(BaseModel): ): self.region_name = region_name self.name = name - self.alarm_arn = make_arn_for_alarm(region_name, get_account_id(), name) + self.alarm_arn = make_arn_for_alarm(region_name, account_id, name) self.namespace = namespace self.metric_name = metric_name self.metric_data_queries = metric_data_queries @@ -238,9 +238,9 @@ class MetricDatum(BaseModel): class Dashboard(BaseModel): - def __init__(self, name, body): + def __init__(self, account_id, name, body): # Guaranteed to be unique for now as the name is also the key of a dictionary where they are stored - self.arn = make_arn_for_dashboard(get_account_id(), name) + self.arn = make_arn_for_dashboard(account_id, name) self.name = name self.body = body self.last_modified = datetime.now() @@ -327,7 +327,7 @@ class CloudWatchBackend(BaseBackend): providers = CloudWatchMetricProvider.__subclasses__() md = [] for provider in providers: - md.extend(provider.get_cloudwatch_metrics()) + md.extend(provider.get_cloudwatch_metrics(self.account_id)) return md def put_metric_alarm( @@ -370,6 +370,7 @@ class CloudWatchBackend(BaseBackend): ) alarm = FakeAlarm( + account_id=self.account_id, region_name=self.region_name, name=name, namespace=namespace, @@ -590,7 +591,7 @@ class CloudWatchBackend(BaseBackend): return self.metric_data + self.aws_metric_data def put_dashboard(self, name, body): - self.dashboards[name] = Dashboard(name, body) + self.dashboards[name] = Dashboard(self.account_id, name, body) def list_dashboards(self, prefix=""): for key, value in self.dashboards.items(): diff --git a/moto/cloudwatch/responses.py b/moto/cloudwatch/responses.py index 7e9094e97..3f4061a82 100644 --- a/moto/cloudwatch/responses.py +++ b/moto/cloudwatch/responses.py @@ -9,9 +9,12 @@ from .exceptions import InvalidParameterCombination class CloudWatchResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="cloudwatch") + @property def cloudwatch_backend(self): - return cloudwatch_backends[self.region] + return cloudwatch_backends[self.current_account][self.region] def _error(self, code, message, status=400): template = self.response_template(ERROR_RESPONSE_TEMPLATE) diff --git a/moto/codebuild/models.py b/moto/codebuild/models.py index ac96dd47e..306483a98 100644 --- a/moto/codebuild/models.py +++ b/moto/codebuild/models.py @@ -1,6 +1,5 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds, BackendDict -from moto.core import get_account_id from collections import defaultdict from random import randint from dateutil import parser @@ -9,14 +8,23 @@ import uuid class CodeBuildProjectMetadata(BaseModel): - def __init__(self, project_name, source_version, artifacts, build_id, service_role): + def __init__( + self, + account_id, + region_name, + project_name, + source_version, + artifacts, + build_id, + service_role, + ): current_date = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()) self.build_metadata = dict() self.build_metadata["id"] = build_id - self.build_metadata["arn"] = "arn:aws:codebuild:eu-west-2:{0}:build/{1}".format( - get_account_id(), build_id - ) + self.build_metadata[ + "arn" + ] = f"arn:aws:codebuild:{region_name}:{account_id}:build/{build_id}" self.build_metadata["buildNumber"] = randint(1, 100) self.build_metadata["startTime"] = current_date @@ -66,9 +74,7 @@ class CodeBuildProjectMetadata(BaseModel): self.build_metadata["logs"] = { "deepLink": "https://console.aws.amazon.com/cloudwatch/home?region=eu-west-2#logEvent:group=null;stream=null", - "cloudWatchLogsArn": "arn:aws:logs:eu-west-2:{0}:log-group:null:log-stream:null".format( - get_account_id() - ), + "cloudWatchLogsArn": f"arn:aws:logs:{region_name}:{account_id}:log-group:null:log-stream:null", "cloudWatchLogs": {"status": "ENABLED"}, "s3Logs": {"status": "DISABLED", "encryptionDisabled": False}, } @@ -79,12 +85,13 @@ class CodeBuildProjectMetadata(BaseModel): self.build_metadata["initiator"] = "rootme" self.build_metadata[ "encryptionKey" - ] = "arn:aws:kms:eu-west-2:{0}:alias/aws/s3".format(get_account_id()) + ] = f"arn:aws:kms:{region_name}:{account_id}:alias/aws/s3" class CodeBuild(BaseModel): def __init__( self, + account_id, region, project_name, project_source, @@ -97,16 +104,14 @@ class CodeBuild(BaseModel): self.project_metadata["name"] = project_name self.project_metadata["arn"] = "arn:aws:codebuild:{0}:{1}:project/{2}".format( - region, get_account_id(), self.project_metadata["name"] + region, account_id, self.project_metadata["name"] ) self.project_metadata[ "encryptionKey" - ] = "arn:aws:kms:{0}:{1}:alias/aws/s3".format(region, get_account_id()) + ] = f"arn:aws:kms:{region}:{account_id}:alias/aws/s3" self.project_metadata[ "serviceRole" - ] = "arn:aws:iam::{0}:role/service-role/{1}".format( - get_account_id(), serviceRole - ) + ] = f"arn:aws:iam::{account_id}:role/service-role/{serviceRole}" self.project_metadata["lastModifiedDate"] = current_date self.project_metadata["created"] = current_date self.project_metadata["badge"] = dict() @@ -138,6 +143,7 @@ class CodeBuildBackend(BaseBackend): self.service_role = service_role self.codebuild_projects[project_name] = CodeBuild( + self.account_id, self.region_name, project_name, project_source, @@ -166,7 +172,13 @@ class CodeBuildBackend(BaseBackend): # construct a new build self.build_metadata[project_name] = CodeBuildProjectMetadata( - project_name, source_version, artifact_override, build_id, self.service_role + self.account_id, + self.region_name, + project_name, + source_version, + artifact_override, + build_id, + self.service_role, ) self.build_history[project_name].append(build_id) diff --git a/moto/codebuild/responses.py b/moto/codebuild/responses.py index c3ce1b6ff..ad60c1fb5 100644 --- a/moto/codebuild/responses.py +++ b/moto/codebuild/responses.py @@ -5,7 +5,6 @@ from .exceptions import ( ResourceAlreadyExistsException, ResourceNotFoundException, ) -from moto.core import get_account_id import json import re @@ -29,11 +28,8 @@ def _validate_required_params_source(source): raise InvalidInputException("Project source location is required") -def _validate_required_params_service_role(service_role): - if ( - "arn:aws:iam::{0}:role/service-role/".format(get_account_id()) - not in service_role - ): +def _validate_required_params_service_role(account_id, service_role): + if f"arn:aws:iam::{account_id}:role/service-role/" not in service_role: raise InvalidInputException( "Invalid service role: Service role account ID does not match caller's account" ) @@ -99,7 +95,7 @@ def _validate_required_params_id(build_id, build_ids): class CodeBuildResponse(BaseResponse): @property def codebuild_backend(self): - return codebuild_backends[self.region] + return codebuild_backends[self.current_account][self.region] def list_builds_for_project(self): _validate_required_params_project_name(self._get_param("projectName")) @@ -110,7 +106,7 @@ class CodeBuildResponse(BaseResponse): ): raise ResourceNotFoundException( "The provided project arn:aws:codebuild:{0}:{1}:project/{2} does not exist".format( - self.region, get_account_id(), self._get_param("projectName") + self.region, self.current_account, self._get_param("projectName") ) ) @@ -122,7 +118,9 @@ class CodeBuildResponse(BaseResponse): def create_project(self): _validate_required_params_source(self._get_param("source")) - _validate_required_params_service_role(self._get_param("serviceRole")) + _validate_required_params_service_role( + self.current_account, self._get_param("serviceRole") + ) _validate_required_params_artifacts(self._get_param("artifacts")) _validate_required_params_environment(self._get_param("environment")) _validate_required_params_project_name(self._get_param("name")) @@ -130,7 +128,7 @@ class CodeBuildResponse(BaseResponse): if self._get_param("name") in self.codebuild_backend.codebuild_projects.keys(): raise ResourceAlreadyExistsException( "Project already exists: arn:aws:codebuild:{0}:{1}:project/{2}".format( - self.region, get_account_id(), self._get_param("name") + self.region, self.current_account, self._get_param("name") ) ) @@ -157,7 +155,7 @@ class CodeBuildResponse(BaseResponse): ): raise ResourceNotFoundException( "Project cannot be found: arn:aws:codebuild:{0}:{1}:project/{2}".format( - self.region, get_account_id(), self._get_param("projectName") + self.region, self.current_account, self._get_param("projectName") ) ) diff --git a/moto/codecommit/models.py b/moto/codecommit/models.py index cdb535733..4a662666d 100644 --- a/moto/codecommit/models.py +++ b/moto/codecommit/models.py @@ -1,13 +1,12 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds, BackendDict from datetime import datetime -from moto.core import get_account_id from .exceptions import RepositoryDoesNotExistException, RepositoryNameExistsException import uuid class CodeCommit(BaseModel): - def __init__(self, region, repository_description, repository_name): + def __init__(self, account_id, region, repository_description, repository_name): current_date = iso_8601_datetime_with_milliseconds(datetime.utcnow()) self.repository_metadata = dict() self.repository_metadata["repositoryName"] = repository_name @@ -25,10 +24,10 @@ class CodeCommit(BaseModel): self.repository_metadata["lastModifiedDate"] = current_date self.repository_metadata["repositoryDescription"] = repository_description self.repository_metadata["repositoryId"] = str(uuid.uuid4()) - self.repository_metadata["Arn"] = "arn:aws:codecommit:{0}:{1}:{2}".format( - region, get_account_id(), repository_name - ) - self.repository_metadata["accountId"] = get_account_id() + self.repository_metadata[ + "Arn" + ] = f"arn:aws:codecommit:{region}:{account_id}:{repository_name}" + self.repository_metadata["accountId"] = account_id class CodeCommitBackend(BaseBackend): @@ -49,7 +48,7 @@ class CodeCommitBackend(BaseBackend): raise RepositoryNameExistsException(repository_name) self.repositories[repository_name] = CodeCommit( - self.region_name, repository_description, repository_name + self.account_id, self.region_name, repository_description, repository_name ) return self.repositories[repository_name].repository_metadata diff --git a/moto/codecommit/responses.py b/moto/codecommit/responses.py index 5ccd77127..9537e8e45 100644 --- a/moto/codecommit/responses.py +++ b/moto/codecommit/responses.py @@ -17,9 +17,12 @@ def _is_repository_name_valid(repository_name): class CodeCommitResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="codecommit") + @property def codecommit_backend(self): - return codecommit_backends[self.region] + return codecommit_backends[self.current_account][self.region] def create_repository(self): if not _is_repository_name_valid(self._get_param("repositoryName")): diff --git a/moto/codepipeline/models.py b/moto/codepipeline/models.py index 72574e4b8..b1751aaca 100644 --- a/moto/codepipeline/models.py +++ b/moto/codepipeline/models.py @@ -14,20 +14,18 @@ from moto.codepipeline.exceptions import ( InvalidTagsException, TooManyTagsException, ) -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel class CodePipeline(BaseModel): - def __init__(self, region, pipeline): + def __init__(self, account_id, region, pipeline): # the version number for a new pipeline is always 1 pipeline["version"] = 1 self.pipeline = self.add_default_values(pipeline) self.tags = {} - self._arn = "arn:aws:codepipeline:{0}:{1}:{2}".format( - region, get_account_id(), pipeline["name"] - ) + self._arn = f"arn:aws:codepipeline:{region}:{account_id}:{pipeline['name']}" self._created = datetime.utcnow() self._updated = datetime.utcnow() @@ -80,14 +78,13 @@ class CodePipelineBackend(BaseBackend): @property def iam_backend(self): - return iam_backends["global"] + return iam_backends[self.account_id]["global"] def create_pipeline(self, pipeline, tags): - if pipeline["name"] in self.pipelines: + name = pipeline["name"] + if name in self.pipelines: raise InvalidStructureException( - "A pipeline with the name '{0}' already exists in account '{1}'".format( - pipeline["name"], get_account_id() - ) + f"A pipeline with the name '{name}' already exists in account '{self.account_id}'" ) try: @@ -112,7 +109,9 @@ class CodePipelineBackend(BaseBackend): "Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline" ) - self.pipelines[pipeline["name"]] = CodePipeline(self.region_name, pipeline) + self.pipelines[pipeline["name"]] = CodePipeline( + self.account_id, self.region_name, pipeline + ) if tags is not None: self.pipelines[pipeline["name"]].validate_tags(tags) @@ -129,9 +128,7 @@ class CodePipelineBackend(BaseBackend): if not codepipeline: raise PipelineNotFoundException( - "Account '{0}' does not have a pipeline with name '{1}'".format( - get_account_id(), name - ) + f"Account '{self.account_id}' does not have a pipeline with name '{name}'" ) return codepipeline.pipeline, codepipeline.metadata @@ -141,9 +138,7 @@ class CodePipelineBackend(BaseBackend): if not codepipeline: raise ResourceNotFoundException( - "The account with id '{0}' does not include a pipeline with the name '{1}'".format( - get_account_id(), pipeline["name"] - ) + f"The account with id '{self.account_id}' does not include a pipeline with the name '{pipeline['name']}'" ) # version number is auto incremented @@ -177,9 +172,7 @@ class CodePipelineBackend(BaseBackend): if not pipeline: raise ResourceNotFoundException( - "The account with id '{0}' does not include a pipeline with the name '{1}'".format( - get_account_id(), name - ) + f"The account with id '{self.account_id}' does not include a pipeline with the name '{name}'" ) tags = [{"key": key, "value": value} for key, value in pipeline.tags.items()] @@ -192,9 +185,7 @@ class CodePipelineBackend(BaseBackend): if not pipeline: raise ResourceNotFoundException( - "The account with id '{0}' does not include a pipeline with the name '{1}'".format( - get_account_id(), name - ) + f"The account with id '{self.account_id}' does not include a pipeline with the name '{name}'" ) pipeline.validate_tags(tags) @@ -208,9 +199,7 @@ class CodePipelineBackend(BaseBackend): if not pipeline: raise ResourceNotFoundException( - "The account with id '{0}' does not include a pipeline with the name '{1}'".format( - get_account_id(), name - ) + f"The account with id '{self.account_id}' does not include a pipeline with the name '{name}'" ) for key in tag_keys: diff --git a/moto/codepipeline/responses.py b/moto/codepipeline/responses.py index d18c67602..aaf678805 100644 --- a/moto/codepipeline/responses.py +++ b/moto/codepipeline/responses.py @@ -5,9 +5,12 @@ from .models import codepipeline_backends class CodePipelineResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="codepipeline") + @property def codepipeline_backend(self): - return codepipeline_backends[self.region] + return codepipeline_backends[self.current_account][self.region] def create_pipeline(self): pipeline, tags = self.codepipeline_backend.create_pipeline( diff --git a/moto/cognitoidentity/responses.py b/moto/cognitoidentity/responses.py index 6ac72ee8f..97f4d1d5e 100644 --- a/moto/cognitoidentity/responses.py +++ b/moto/cognitoidentity/responses.py @@ -4,9 +4,12 @@ from .utils import get_random_identity_id class CognitoIdentityResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="cognito-identity") + @property def backend(self): - return cognitoidentity_backends[self.region] + return cognitoidentity_backends[self.current_account][self.region] def create_identity_pool(self): identity_pool_name = self._get_param("IdentityPoolName") @@ -64,9 +67,7 @@ class CognitoIdentityResponse(BaseResponse): return self.backend.get_credentials_for_identity(self._get_param("IdentityId")) def get_open_id_token_for_developer_identity(self): - return cognitoidentity_backends[ - self.region - ].get_open_id_token_for_developer_identity( + return self.backend.get_open_id_token_for_developer_identity( self._get_param("IdentityId") or get_random_identity_id(self.region) ) diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index b021ba0ba..a134992e0 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -9,7 +9,6 @@ import random from jose import jws from collections import OrderedDict from moto.core import BaseBackend, BaseModel -from moto.core import get_account_id from moto.core.utils import BackendDict from .exceptions import ( GroupExistsException, @@ -371,16 +370,15 @@ class CognitoIdpUserPool(BaseModel): MAX_ID_LENGTH = 56 - def __init__(self, region, name, extended_config): + def __init__(self, account_id, region, name, extended_config): + self.account_id = account_id self.region = region user_pool_id = generate_id( get_cognito_idp_user_pool_id_strategy(), region, name, extended_config ) self.id = "{}_{}".format(self.region, user_pool_id)[: self.MAX_ID_LENGTH] - self.arn = "arn:aws:cognito-idp:{}:{}:userpool/{}".format( - self.region, get_account_id(), self.id - ) + self.arn = f"arn:aws:cognito-idp:{self.region}:{account_id}:userpool/{self.id}" self.name = name self.status = None @@ -445,7 +443,7 @@ class CognitoIdpUserPool(BaseModel): @property def backend(self): - return cognitoidp_backends[self.region] + return cognitoidp_backends[self.account_id][self.region] @property def domain(self): @@ -862,7 +860,9 @@ class CognitoIdpBackend(BaseBackend): # User pool def create_user_pool(self, name, extended_config): - user_pool = CognitoIdpUserPool(self.region_name, name, extended_config) + user_pool = CognitoIdpUserPool( + self.account_id, self.region_name, name, extended_config + ) self.user_pools[user_pool.id] = user_pool return user_pool @@ -1833,24 +1833,24 @@ class RegionAgnosticBackend: # This backend will cycle through all backends as a workaround def _find_backend_by_access_token(self, access_token): - account_specific_backends = cognitoidp_backends[get_account_id()] - for region, backend in account_specific_backends.items(): - if region == "global": - continue - for p in backend.user_pools.values(): - if access_token in p.access_tokens: - return backend - return account_specific_backends["us-east-1"] + for account_specific_backends in cognitoidp_backends.values(): + for region, backend in account_specific_backends.items(): + if region == "global": + continue + for p in backend.user_pools.values(): + if access_token in p.access_tokens: + return backend + return backend def _find_backend_for_clientid(self, client_id): - account_specific_backends = cognitoidp_backends[get_account_id()] - for region, backend in account_specific_backends.items(): - if region == "global": - continue - for p in backend.user_pools.values(): - if client_id in p.clients: - return backend - return account_specific_backends["us-east-1"] + for account_specific_backends in cognitoidp_backends.values(): + for region, backend in account_specific_backends.items(): + if region == "global": + continue + for p in backend.user_pools.values(): + if client_id in p.clients: + return backend + return backend def sign_up(self, client_id, username, password, attributes): backend = self._find_backend_for_clientid(client_id) @@ -1883,17 +1883,16 @@ cognitoidp_backends = BackendDict(CognitoIdpBackend, "cognito-idp") # Hack to help moto-server process requests on localhost, where the region isn't # specified in the host header. Some endpoints (change password, confirm forgot # password) have no authorization header from which to extract the region. -def find_region_by_value(key, value): - account_specific_backends = cognitoidp_backends[get_account_id()] - for region in account_specific_backends: - backend = cognitoidp_backends[region] - for user_pool in backend.user_pools.values(): - if key == "client_id" and value in user_pool.clients: - return region +def find_account_region_by_value(key, value): + for account_id, account_specific_backend in cognitoidp_backends.items(): + for region, backend in account_specific_backend.items(): + for user_pool in backend.user_pools.values(): + if key == "client_id" and value in user_pool.clients: + return account_id, region - if key == "access_token" and value in user_pool.access_tokens: - return region + if key == "access_token" and value in user_pool.access_tokens: + return account_id, region # If we can't find the `client_id` or `access_token`, we just pass # back a default backend region, which will raise the appropriate # error message (e.g. NotAuthorized or NotFound). - return list(account_specific_backends)[0] + return account_id, region diff --git a/moto/cognitoidp/responses.py b/moto/cognitoidp/responses.py index acf74253a..a748f3db4 100644 --- a/moto/cognitoidp/responses.py +++ b/moto/cognitoidp/responses.py @@ -5,7 +5,7 @@ import re from moto.core.responses import BaseResponse from .models import ( cognitoidp_backends, - find_region_by_value, + find_account_region_by_value, RegionAgnosticBackend, UserStatus, ) @@ -16,13 +16,16 @@ region_agnostic_backend = RegionAgnosticBackend() class CognitoIdpResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="cognito-idp") + @property def parameters(self): return json.loads(self.body) @property def backend(self): - return cognitoidp_backends[self.region] + return cognitoidp_backends[self.current_account][self.region] # User pool def create_user_pool(self): @@ -138,9 +141,7 @@ class CognitoIdpResponse(BaseResponse): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") - user_pool_clients, next_token = cognitoidp_backends[ - self.region - ].list_user_pool_clients( + user_pool_clients, next_token = self.backend.list_user_pool_clients( user_pool_id, max_results=max_results, next_token=next_token ) response = { @@ -189,9 +190,7 @@ class CognitoIdpResponse(BaseResponse): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") - identity_providers, next_token = cognitoidp_backends[ - self.region - ].list_identity_providers( + identity_providers, next_token = self.backend.list_identity_providers( user_pool_id, max_results=max_results, next_token=next_token ) response = { @@ -457,11 +456,10 @@ class CognitoIdpResponse(BaseResponse): def forgot_password(self): client_id = self._get_param("ClientId") username = self._get_param("Username") - region = find_region_by_value("client_id", client_id) - print(f"Region: {region}") - confirmation_code, response = cognitoidp_backends[region].forgot_password( - client_id, username - ) + account, region = find_account_region_by_value("client_id", client_id) + confirmation_code, response = cognitoidp_backends[account][ + region + ].forgot_password(client_id, username) self.response_headers[ "x-moto-forgot-password-confirmation-code" ] = confirmation_code @@ -476,8 +474,8 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") password = self._get_param("Password") confirmation_code = self._get_param("ConfirmationCode") - region = find_region_by_value("client_id", client_id) - cognitoidp_backends[region].confirm_forgot_password( + account, region = find_account_region_by_value("client_id", client_id) + cognitoidp_backends[account][region].confirm_forgot_password( client_id, username, password, confirmation_code ) return "" @@ -487,8 +485,8 @@ class CognitoIdpResponse(BaseResponse): access_token = self._get_param("AccessToken") previous_password = self._get_param("PreviousPassword") proposed_password = self._get_param("ProposedPassword") - region = find_region_by_value("access_token", access_token) - cognitoidp_backends[region].change_password( + account, region = find_account_region_by_value("access_token", access_token) + cognitoidp_backends[account][region].change_password( access_token, previous_password, proposed_password ) return "" diff --git a/moto/config/models.py b/moto/config/models.py index 50c3e3a93..b5732db89 100644 --- a/moto/config/models.py +++ b/moto/config/models.py @@ -50,7 +50,6 @@ from moto.config.exceptions import ( ) from moto.core import BaseBackend, BaseModel -from moto.core import get_account_id from moto.core.responses import AWSServiceSpec from moto.core.utils import BackendDict from moto.iam.config import role_config_query, policy_config_query @@ -354,13 +353,13 @@ class OrganizationAggregationSource(ConfigEmptyDictable): class ConfigAggregator(ConfigEmptyDictable): - def __init__(self, name, region, account_sources=None, org_source=None, tags=None): + def __init__( + self, name, account_id, region, account_sources=None, org_source=None, tags=None + ): super().__init__(capitalize_start=True, capitalize_arn=False) self.configuration_aggregator_name = name - self.configuration_aggregator_arn = "arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}".format( - region=region, id=get_account_id(), random=random_string() - ) + self.configuration_aggregator_arn = f"arn:aws:config:{region}:{account_id}:config-aggregator/config-aggregator-{random_string()}" self.account_aggregation_sources = account_sources self.organization_aggregation_source = org_source self.creation_time = datetime2int(datetime.utcnow()) @@ -389,7 +388,12 @@ class ConfigAggregator(ConfigEmptyDictable): class ConfigAggregationAuthorization(ConfigEmptyDictable): def __init__( - self, current_region, authorized_account_id, authorized_aws_region, tags=None + self, + account_id, + current_region, + authorized_account_id, + authorized_aws_region, + tags=None, ): super().__init__(capitalize_start=True, capitalize_arn=False) @@ -397,7 +401,7 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable): "arn:aws:config:{region}:{id}:aggregation-authorization/" "{auth_account}/{auth_region}".format( region=current_region, - id=get_account_id(), + id=account_id, auth_account=authorized_account_id, auth_region=authorized_aws_region, ) @@ -413,6 +417,7 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable): class OrganizationConformancePack(ConfigEmptyDictable): def __init__( self, + account_id, region, name, delivery_s3_bucket, @@ -430,11 +435,7 @@ class OrganizationConformancePack(ConfigEmptyDictable): self.delivery_s3_key_prefix = delivery_s3_key_prefix self.excluded_accounts = excluded_accounts or [] self.last_update_time = datetime2int(datetime.utcnow()) - self.organization_conformance_pack_arn = ( - "arn:aws:config:{0}:{1}:organization-conformance-pack/{2}".format( - region, get_account_id(), self._unique_pack_name - ) - ) + self.organization_conformance_pack_arn = f"arn:aws:config:{region}:{account_id}:organization-conformance-pack/{self._unique_pack_name}" self.organization_conformance_pack_name = name def update( @@ -602,7 +603,9 @@ class Source(ConfigEmptyDictable): OWNERS = {"AWS", "CUSTOM_LAMBDA"} - def __init__(self, region, owner, source_identifier, source_details=None): + def __init__( + self, account_id, region, owner, source_identifier, source_details=None + ): super().__init__(capitalize_start=True, capitalize_arn=False) if owner not in Source.OWNERS: raise ValidationException( @@ -644,7 +647,7 @@ class Source(ConfigEmptyDictable): from moto.awslambda import lambda_backends try: - lambda_backends[region].get_function(source_identifier) + lambda_backends[account_id][region].get_function(source_identifier) except Exception: raise InsufficientPermissionsException( f"The AWS Lambda function {source_identifier} cannot be " @@ -680,8 +683,9 @@ class ConfigRule(ConfigEmptyDictable): MAX_RULES = 150 RULE_STATES = {"ACTIVE", "DELETING", "DELETING_RESULTS", "EVALUATING"} - def __init__(self, region, config_rule, tags): + def __init__(self, account_id, region, config_rule, tags): super().__init__(capitalize_start=True, capitalize_arn=False) + self.account_id = account_id self.config_rule_name = config_rule.get("ConfigRuleName") if config_rule.get("ConfigRuleArn") or config_rule.get("ConfigRuleId"): raise InvalidParameterValueException( @@ -694,7 +698,9 @@ class ConfigRule(ConfigEmptyDictable): self.maximum_execution_frequency = None # keeps pylint happy self.modify_fields(region, config_rule, tags) self.config_rule_id = f"config-rule-{random_string():.6}" - self.config_rule_arn = f"arn:aws:config:{region}:{get_account_id()}:config-rule/{self.config_rule_id}" + self.config_rule_arn = ( + f"arn:aws:config:{region}:{account_id}:config-rule/{self.config_rule_id}" + ) def modify_fields(self, region, config_rule, tags): """Initialize or update ConfigRule fields.""" @@ -721,7 +727,7 @@ class ConfigRule(ConfigEmptyDictable): self.scope = Scope(**scope_dict) source_dict = convert_to_class_args(config_rule["Source"]) - self.source = Source(region, **source_dict) + self.source = Source(self.account_id, region, **source_dict) self.input_parameters = config_rule.get("InputParameters") self.input_parameters_dict = {} @@ -969,7 +975,8 @@ class ConfigBackend(BaseBackend): ): aggregator = ConfigAggregator( config_aggregator["ConfigurationAggregatorName"], - self.region_name, + account_id=self.account_id, + region=self.region_name, account_sources=account_sources, org_source=org_source, tags=tags, @@ -1049,7 +1056,11 @@ class ConfigBackend(BaseBackend): agg_auth = self.aggregation_authorizations.get(key) if not agg_auth: agg_auth = ConfigAggregationAuthorization( - self.region_name, authorized_account, authorized_region, tags=tags + self.account_id, + self.region_name, + authorized_account, + authorized_region, + tags=tags, ) self.aggregation_authorizations[ "{}/{}".format(authorized_account, authorized_region) @@ -1345,17 +1356,22 @@ class ConfigBackend(BaseBackend): backend_query_region = ( backend_region # Always provide the backend this request arrived from. ) - if RESOURCE_MAP[resource_type].backends.get("global"): + if RESOURCE_MAP[resource_type].backends[self.account_id].get("global"): backend_region = "global" # For non-aggregated queries, the we only care about the # backend_region. Need to verify that moto has implemented # the region for the given backend: - if RESOURCE_MAP[resource_type].backends.get(backend_region): + if ( + RESOURCE_MAP[resource_type] + .backends[self.account_id] + .get(backend_region) + ): # Fetch the resources for the backend's region: identifiers, new_token = RESOURCE_MAP[ resource_type ].list_config_service_resources( + self.account_id, resource_ids, resource_name, limit, @@ -1420,6 +1436,7 @@ class ConfigBackend(BaseBackend): identifiers, new_token = RESOURCE_MAP[ resource_type ].list_config_service_resources( + self.account_id, resource_id, resource_name, limit, @@ -1431,7 +1448,7 @@ class ConfigBackend(BaseBackend): resource_identifiers = [] for identifier in identifiers: item = { - "SourceAccountId": get_account_id(), + "SourceAccountId": self.account_id, "SourceRegion": identifier["region"], "ResourceType": identifier["type"], "ResourceId": identifier["id"], @@ -1468,21 +1485,25 @@ class ConfigBackend(BaseBackend): backend_query_region = ( backend_region # Always provide the backend this request arrived from. ) - if RESOURCE_MAP[resource_type].backends.get("global"): + if RESOURCE_MAP[resource_type].backends[self.account_id].get("global"): backend_region = "global" # If the backend region isn't implemented then we won't find the item: - if not RESOURCE_MAP[resource_type].backends.get(backend_region): + if ( + not RESOURCE_MAP[resource_type] + .backends[self.account_id] + .get(backend_region) + ): raise ResourceNotDiscoveredException(resource_type, resource_id) # Get the item: item = RESOURCE_MAP[resource_type].get_config_resource( - resource_id, backend_region=backend_query_region + self.account_id, resource_id, backend_region=backend_query_region ) if not item: raise ResourceNotDiscoveredException(resource_type, resource_id) - item["accountId"] = get_account_id() + item["accountId"] = self.account_id return {"configurationItems": [item]} @@ -1512,23 +1533,31 @@ class ConfigBackend(BaseBackend): backend_query_region = ( backend_region # Always provide the backend this request arrived from. ) - if RESOURCE_MAP[resource["resourceType"]].backends.get("global"): + if ( + RESOURCE_MAP[resource["resourceType"]] + .backends[self.account_id] + .get("global") + ): config_backend_region = "global" # If the backend region isn't implemented then we won't find the item: - if not RESOURCE_MAP[resource["resourceType"]].backends.get( - config_backend_region + if ( + not RESOURCE_MAP[resource["resourceType"]] + .backends[self.account_id] + .get(config_backend_region) ): continue # Get the item: item = RESOURCE_MAP[resource["resourceType"]].get_config_resource( - resource["resourceId"], backend_region=backend_query_region + self.account_id, + resource["resourceId"], + backend_region=backend_query_region, ) if not item: continue - item["accountId"] = get_account_id() + item["accountId"] = self.account_id results.append(item) @@ -1576,6 +1605,7 @@ class ConfigBackend(BaseBackend): # Get the item: item = RESOURCE_MAP[resource_type].get_config_resource( + self.account_id, resource_id, resource_name=resource_name, resource_region=resource_region, @@ -1584,7 +1614,7 @@ class ConfigBackend(BaseBackend): not_found.append(identifier) continue - item["accountId"] = get_account_id() + item["accountId"] = self.account_id # The 'tags' field is not included in aggregate results for some reason... item.pop("tags", None) @@ -1650,6 +1680,7 @@ class ConfigBackend(BaseBackend): ) else: pack = OrganizationConformancePack( + account_id=self.account_id, region=self.region_name, name=name, delivery_s3_bucket=delivery_s3_bucket, @@ -1723,7 +1754,7 @@ class ConfigBackend(BaseBackend): # actually here would be a list of all accounts in the organization statuses = [ { - "AccountId": get_account_id(), + "AccountId": self.account_id, "ConformancePackName": "OrgConformsPack-{0}".format( pack._unique_pack_name ), @@ -1877,7 +1908,7 @@ class ConfigBackend(BaseBackend): raise MaxNumberOfConfigRulesExceededException( rule_name, ConfigRule.MAX_RULES ) - rule = ConfigRule(self.region_name, config_rule, tags) + rule = ConfigRule(self.account_id, self.region_name, config_rule, tags) self.config_rules[rule_name] = rule return "" diff --git a/moto/config/responses.py b/moto/config/responses.py index 8c226b3d9..cbe6a97ee 100644 --- a/moto/config/responses.py +++ b/moto/config/responses.py @@ -4,9 +4,12 @@ from .models import config_backends class ConfigResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="config") + @property def config_backend(self): - return config_backends[self.region] + return config_backends[self.current_account][self.region] def put_configuration_recorder(self): self.config_backend.put_configuration_recorder( diff --git a/moto/core/__init__.py b/moto/core/__init__.py index 0a1dd5d29..6691a5b4e 100644 --- a/moto/core/__init__.py +++ b/moto/core/__init__.py @@ -1,4 +1,4 @@ -from .models import get_account_id, ACCOUNT_ID # noqa +from .models import DEFAULT_ACCOUNT_ID # noqa from .base_backend import BaseBackend # noqa from .common_models import BaseModel # noqa from .common_models import CloudFormationModel, CloudWatchMetricProvider # noqa diff --git a/moto/core/common_models.py b/moto/core/common_models.py index f8226ca9f..2785b6703 100644 --- a/moto/core/common_models.py +++ b/moto/core/common_models.py @@ -37,10 +37,10 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): # This must be implemented as a classmethod with parameters: - # cls, resource_name, cloudformation_json, region_name + # cls, resource_name, cloudformation_json, account_id, region_name # Extract the resource parameters from the cloudformation json # and return an instance of the resource class pass @@ -48,10 +48,15 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # This must be implemented as a classmethod with parameters: - # cls, original_resource, new_resource_name, cloudformation_json, region_name + # cls, original_resource, new_resource_name, cloudformation_json, account_id, region_name # Extract the resource parameters from the cloudformation json, # delete the old resource and return the new one. Optionally inspect # the change in parameters and no-op when nothing has changed. @@ -60,10 +65,10 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # This must be implemented as a classmethod with parameters: - # cls, resource_name, cloudformation_json, region_name + # cls, resource_name, cloudformation_json, account_id, region_name # Extract the resource parameters from the cloudformation json # and delete the resource. Do not include a return statement. pass @@ -83,6 +88,7 @@ class ConfigQueryModel: def list_config_service_resources( self, + account_id, resource_ids, resource_name, limit, @@ -114,6 +120,7 @@ class ConfigQueryModel: As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter from there. It may be valuable to make this a concatenation of the region and resource name. + :param account_id: The account number :param resource_ids: A list of resource IDs :param resource_name: The individual name of a resource :param limit: How many per page @@ -140,7 +147,12 @@ class ConfigQueryModel: raise NotImplementedError() def get_config_resource( - self, resource_id, resource_name=None, backend_region=None, resource_region=None + self, + account_id, + resource_id, + resource_name=None, + backend_region=None, + resource_region=None, ): """For AWS Config. This will query the backend for the specific resource type configuration. @@ -160,6 +172,7 @@ class ConfigQueryModel: from all resources in all regions for a given resource type*. ... + :param account_id: :param resource_id: :param resource_name: :param backend_region: @@ -172,5 +185,5 @@ class ConfigQueryModel: class CloudWatchMetricProvider(object): @staticmethod @abstractmethod - def get_cloudwatch_metrics(): + def get_cloudwatch_metrics(account_id): pass diff --git a/moto/core/models.py b/moto/core/models.py index c5f0602e1..041ef717a 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -23,18 +23,7 @@ from .custom_responses_mock import ( ) from .utils import convert_flask_to_responses_response -ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012") - - -def _get_default_account_id(): - return ACCOUNT_ID - - -account_id_resolver = _get_default_account_id - - -def get_account_id(): - return account_id_resolver() +DEFAULT_ACCOUNT_ID = "123456789012" class BaseMockAWS: @@ -42,23 +31,25 @@ class BaseMockAWS: mocks_active = False def __init__(self, backends): - from moto.instance_metadata import instance_metadata_backend + from moto.instance_metadata import instance_metadata_backends from moto.moto_api._internal.models import moto_api_backend self.backends = backends - self.backends_for_urls = {} - default_backends = { - "instance_metadata": instance_metadata_backend, - "moto_api": moto_api_backend, - } - if "us-east-1" in self.backends: + self.backends_for_urls = [] + default_account_id = DEFAULT_ACCOUNT_ID + default_backends = [ + instance_metadata_backends[default_account_id]["global"], + moto_api_backend, + ] + backend_default_account = self.backends[default_account_id] + if "us-east-1" in backend_default_account: # We only need to know the URL for a single region - they will be the same everywhere - self.backends_for_urls["us-east-1"] = self.backends["us-east-1"] - elif "global" in self.backends: + self.backends_for_urls.append(backend_default_account["us-east-1"]) + elif "global" in backend_default_account: # If us-east-1 is not available, it's probably a global service - self.backends_for_urls["global"] = self.backends["global"] - self.backends_for_urls.update(default_backends) + self.backends_for_urls.append(backend_default_account["global"]) + self.backends_for_urls.extend(default_backends) self.FAKE_KEYS = { "AWS_ACCESS_KEY_ID": "foobar_key", @@ -283,7 +274,7 @@ class BotocoreEventMockAWS(BaseMockAWS): def enable_patching(self, reset=True): # pylint: disable=unused-argument botocore_stubber.enabled = True for method in BOTOCORE_HTTP_METHODS: - for backend in self.backends_for_urls.values(): + for backend in self.backends_for_urls: for key, value in backend.urls.items(): pattern = re.compile(key) botocore_stubber.register_response(method, pattern, value) @@ -295,7 +286,7 @@ class BotocoreEventMockAWS(BaseMockAWS): for method in RESPONSES_METHODS: # for backend in default_backends.values(): - for backend in self.backends_for_urls.values(): + for backend in self.backends_for_urls: for key, value in backend.urls.items(): responses_mock.add( CallbackResponse( diff --git a/moto/core/responses.py b/moto/core/responses.py index 13ff213b6..590f1c817 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -3,6 +3,7 @@ from collections import defaultdict import datetime import json import logging +import os import re import requests @@ -137,7 +138,11 @@ class ActionAuthenticatorMixin(object): >= settings.INITIAL_NO_AUTH_ACTION_COUNT ): iam_request = iam_request_cls( - method=self.method, path=self.path, data=self.data, headers=self.headers + account_id=self.current_account, + method=self.method, + path=self.path, + data=self.data, + headers=self.headers, ) iam_request.check_signature() iam_request.check_action_permitted() @@ -215,6 +220,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) aws_service_spec = None + def __init__(self, service_name=None): + super().__init__() + self.service_name = service_name + @classmethod def dispatch(cls, *args, **kwargs): return cls()._dispatch(*args, **kwargs) @@ -295,6 +304,18 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.headers["host"] = urlparse(full_url).netloc self.response_headers = {"server": "amazon.com"} + # Register visit with IAM + from moto.iam.models import mark_account_as_visited + + self.access_key = self.get_access_key() + self.current_account = self.get_current_account() + mark_account_as_visited( + account_id=self.current_account, + access_key=self.access_key, + service=self.service_name, + region=self.region, + ) + def get_region_from_url(self, request, full_url): url_match = self.region_regex.search(full_url) user_agent_match = self.region_from_useragent_regex.search( @@ -313,7 +334,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): region = self.default_region return region - def get_current_user(self): + def get_access_key(self): """ Returns the access key id used in this request as the current user id """ @@ -323,10 +344,24 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return match.group(1) if self.querystring.get("AWSAccessKeyId"): - return self.querystring.get("AWSAccessKeyId") + return self.querystring.get("AWSAccessKeyId")[0] else: - # Should we raise an unauthorized exception instead? - return "111122223333" + return "AKIAEXAMPLE" + + def get_current_account(self): + # PRIO 1: Check if we have a Environment Variable set + if "MOTO_ACCOUNT_ID" in os.environ: + return os.environ["MOTO_ACCOUNT_ID"] + + # PRIO 2: Check if we have a specific request header that specifies the Account ID + if "x-moto-account-id" in self.headers: + return self.headers["x-moto-account-id"] + + # PRIO 3: Use the access key to get the Account ID + # PRIO 4: This method will return the default Account ID as a last resort + from moto.iam.models import get_account_id_from + + return get_account_id_from(self.get_access_key()) def _dispatch(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -372,10 +407,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # service response class should have 'SERVICE_NAME' class member, # if you want to get action from method and url - if not hasattr(self, "SERVICE_NAME"): - return None - service = self.SERVICE_NAME - conn = boto3.client(service, region_name=self.region) + conn = boto3.client(self.service_name, region_name=self.region) # make cache if it does not exist yet if not hasattr(self, "method_urls"): @@ -396,15 +428,15 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _get_action(self): action = self.querystring.get("Action", [""])[0] - if not action: # Some services use a header for the action - # Headers are case-insensitive. Probably a better way to do this. - match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target") - if match: - action = match.split(".")[-1] + if action: + return action + # Some services use a header for the action + # Headers are case-insensitive. Probably a better way to do this. + match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target") + if match: + return match.split(".")[-1] # get action from method and uri - if not action: - return self._get_action_from_method_and_request_uri(self.method, self.path) - return action + return self._get_action_from_method_and_request_uri(self.method, self.path) def call_action(self): headers = self.response_headers diff --git a/moto/core/utils.py b/moto/core/utils.py index e51e3119a..8083e96ca 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -1,4 +1,4 @@ -from functools import wraps +from functools import lru_cache, wraps import binascii import datetime @@ -11,6 +11,7 @@ from boto3 import Session from moto.settings import allow_unknown_region from threading import RLock from urllib.parse import urlparse +from uuid import uuid4 REQUEST_ID_LONG = string.digits + string.ascii_uppercase @@ -436,6 +437,20 @@ class AccountSpecificBackend(dict): sess.get_available_regions(service_name, partition_name="aws-cn") ) self.regions.extend(additional_regions or []) + self._id = str(uuid4()) + + def __hash__(self): + return hash(self._id) + + def __eq__(self, other): + return ( + other + and isinstance(other, AccountSpecificBackend) + and other._id == self._id + ) + + def __ne__(self, other): + return not self.__eq__(other) def reset(self): for region_specific_backend in self.values(): @@ -444,6 +459,7 @@ class AccountSpecificBackend(dict): def __contains__(self, region): return region in self.regions or region in self.keys() + @lru_cache() def __getitem__(self, region_name): if region_name in self.keys(): return super().__getitem__(region_name) @@ -466,14 +482,6 @@ class BackendDict(dict): Format: [account_id: str]: AccountSpecificBackend [account_id: str][region: str] = BaseBackend - - Full multi-account support is not yet available. We will always return account_id 123456789012, regardless of the input. - - To not break existing usage patterns, the following data access pattern is also supported: - [region: str] = BaseBackend - - This will automatically resolve to: - [default_account_id][region: str] = BaseBackend """ def __init__( @@ -483,46 +491,25 @@ class BackendDict(dict): self.service_name = service_name self._use_boto3_regions = use_boto3_regions self._additional_regions = additional_regions + self._id = str(uuid4()) - def __contains__(self, account_id_or_region): - """ - Possible data access patterns: - backend_dict[account_id][region_name] - backend_dict[region_name] - backend_dict[unknown_region] + def __hash__(self): + # Required for the LRUcache to work. + # service_name is enough to determine uniqueness - other properties are dependent + return hash(self._id) - The latter two will be phased out in the future, and we can remove this method. - """ - if re.match(r"[0-9]+", account_id_or_region): - self._create_account_specific_backend("123456789012") - return True - else: - region = account_id_or_region - self._create_account_specific_backend("123456789012") - return region in self["123456789012"] + def __eq__(self, other): + return other and isinstance(other, BackendDict) and other._id == self._id - def get(self, account_id_or_region, if_none=None): - if self.__contains__(account_id_or_region): - return self.__getitem__(account_id_or_region) - return if_none + def __ne__(self, other): + return not self.__eq__(other) - def __getitem__(self, account_id_or_region): - """ - Possible data access patterns: - backend_dict[account_id][region_name] - backend_dict[region_name] - backend_dict[unknown_region] + @lru_cache() + def __getitem__(self, account_id) -> AccountSpecificBackend: + self._create_account_specific_backend(account_id) + return super().__getitem__(account_id) - The latter two will be phased out in the future. - """ - if re.match(r"[0-9]+", account_id_or_region): - self._create_account_specific_backend("123456789012") - return super().__getitem__("123456789012") - else: - region_name = account_id_or_region - return self["123456789012"][region_name] - - def _create_account_specific_backend(self, account_id): + def _create_account_specific_backend(self, account_id) -> None: with backend_lock: if account_id not in self.keys(): self[account_id] = AccountSpecificBackend( diff --git a/moto/databrew/responses.py b/moto/databrew/responses.py index cc5a57885..400ff3716 100644 --- a/moto/databrew/responses.py +++ b/moto/databrew/responses.py @@ -7,12 +7,13 @@ from .models import databrew_backends class DataBrewResponse(BaseResponse): - SERVICE_NAME = "databrew" + def __init__(self): + super().__init__(service_name="databrew") @property def databrew_backend(self): """Return backend instance specific for this region.""" - return databrew_backends[self.region] + return databrew_backends[self.current_account][self.region] # region Recipes @property diff --git a/moto/datapipeline/models.py b/moto/datapipeline/models.py index cd611d2e1..078ea861c 100644 --- a/moto/datapipeline/models.py +++ b/moto/datapipeline/models.py @@ -83,9 +83,9 @@ class Pipeline(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - datapipeline_backend = datapipeline_backends[region_name] + datapipeline_backend = datapipeline_backends[account_id][region_name] properties = cloudformation_json["Properties"] cloudformation_unique_id = "cf-" + resource_name diff --git a/moto/datapipeline/responses.py b/moto/datapipeline/responses.py index 60cc294b8..f9fa3f27c 100644 --- a/moto/datapipeline/responses.py +++ b/moto/datapipeline/responses.py @@ -5,9 +5,12 @@ from .models import datapipeline_backends class DataPipelineResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="datapipeline") + @property def datapipeline_backend(self): - return datapipeline_backends[self.region] + return datapipeline_backends[self.current_account][self.region] def create_pipeline(self): name = self._get_param("name") diff --git a/moto/datasync/responses.py b/moto/datasync/responses.py index 8f04055bb..efabb805f 100644 --- a/moto/datasync/responses.py +++ b/moto/datasync/responses.py @@ -6,9 +6,12 @@ from .models import datasync_backends class DataSyncResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="datasync") + @property def datasync_backend(self): - return datasync_backends[self.region] + return datasync_backends[self.current_account][self.region] def list_locations(self): locations = list() diff --git a/moto/dax/models.py b/moto/dax/models.py index fd4d1a746..e293a1993 100644 --- a/moto/dax/models.py +++ b/moto/dax/models.py @@ -1,5 +1,5 @@ """DAXBackend class with methods for supported APIs.""" -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, get_random_hex, unix_time from moto.moto_api import state_manager from moto.moto_api._internal.managed_state_model import ManagedState @@ -68,6 +68,7 @@ class DaxEndpoint: class DaxCluster(BaseModel, ManagedState): def __init__( self, + account_id, region, name, description, @@ -85,7 +86,7 @@ class DaxCluster(BaseModel, ManagedState): # Set internal properties self.name = name self.description = description - self.arn = f"arn:aws:dax:{region}:{get_account_id()}:cache/{self.name}" + self.arn = f"arn:aws:dax:{region}:{account_id}:cache/{self.name}" self.node_type = node_type self.replication_factor = replication_factor self.cluster_hex = get_random_hex(6) @@ -187,6 +188,7 @@ class DAXBackend(BaseBackend): AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName """ cluster = DaxCluster( + account_id=self.account_id, region=self.region_name, name=cluster_name, description=description, diff --git a/moto/dax/responses.py b/moto/dax/responses.py index 3ebd66026..b2016b273 100644 --- a/moto/dax/responses.py +++ b/moto/dax/responses.py @@ -7,9 +7,12 @@ from .models import dax_backends class DAXResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="dax") + @property def dax_backend(self): - return dax_backends[self.region] + return dax_backends[self.current_account][self.region] def create_cluster(self): params = json.loads(self.body) diff --git a/moto/dms/models.py b/moto/dms/models.py index d50a0aae5..a5f0958f1 100644 --- a/moto/dms/models.py +++ b/moto/dms/models.py @@ -1,7 +1,7 @@ import json from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from .exceptions import ( @@ -46,6 +46,7 @@ class DatabaseMigrationServiceBackend(BaseBackend): migration_type=migration_type, table_mappings=table_mappings, replication_task_settings=replication_task_settings, + account_id=self.account_id, region_name=self.region_name, ) @@ -106,6 +107,7 @@ class FakeReplicationTask(BaseModel): target_endpoint_arn, table_mappings, replication_task_settings, + account_id, region_name, ): self.id = replication_task_identifier @@ -117,18 +119,13 @@ class FakeReplicationTask(BaseModel): self.table_mappings = table_mappings self.replication_task_settings = replication_task_settings + self.arn = f"arn:aws:dms:{region_name}:{account_id}:task:{self.id}" self.status = "creating" self.creation_date = datetime.utcnow() self.start_date = None self.stop_date = None - @property - def arn(self): - return "arn:aws:dms:{region}:{account_id}:task:{task_id}".format( - region=self.region, account_id=get_account_id(), task_id=self.id - ) - def to_dict(self): start_date = self.start_date.isoformat() if self.start_date else None stop_date = self.stop_date.isoformat() if self.stop_date else None diff --git a/moto/dms/responses.py b/moto/dms/responses.py index cf671f4e1..c51741c2a 100644 --- a/moto/dms/responses.py +++ b/moto/dms/responses.py @@ -4,11 +4,12 @@ import json class DatabaseMigrationServiceResponse(BaseResponse): - SERVICE_NAME = "dms" + def __init__(self): + super().__init__(service_name="dms") @property def dms_backend(self): - return dms_backends[self.region] + return dms_backends[self.current_account][self.region] def create_replication_task(self): replication_task_identifier = self._get_param("ReplicationTaskIdentifier") diff --git a/moto/ds/models.py b/moto/ds/models.py index de42344ae..c4eb1a274 100644 --- a/moto/ds/models.py +++ b/moto/ds/models.py @@ -46,6 +46,7 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes def __init__( self, + account_id, region, name, password, @@ -57,6 +58,7 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes description=None, edition=None, ): # pylint: disable=too-many-arguments + self.account_id = account_id self.region = region self.name = name self.password = password @@ -101,7 +103,9 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes def create_security_group(self, vpc_id): """Create security group for the network interface.""" - security_group_info = ec2_backends[self.region].create_security_group( + security_group_info = ec2_backends[self.account_id][ + self.region + ].create_security_group( name=f"{self.directory_id}_controllers", description=( f"AWS created security group for {self.directory_id} " @@ -113,14 +117,18 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes def delete_security_group(self): """Delete the given security group.""" - ec2_backends[self.region].delete_security_group(group_id=self.security_group_id) + ec2_backends[self.account_id][self.region].delete_security_group( + group_id=self.security_group_id + ) def create_eni(self, security_group_id, subnet_ids): """Return ENI ids and primary addresses created for each subnet.""" eni_ids = [] subnet_ips = [] for subnet_id in subnet_ids: - eni_info = ec2_backends[self.region].create_network_interface( + eni_info = ec2_backends[self.account_id][ + self.region + ].create_network_interface( subnet=subnet_id, private_ip_address=None, group_ids=[security_group_id], @@ -133,7 +141,7 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes def delete_eni(self): """Delete ENI for each subnet and the security group.""" for eni_id in self.eni_ids: - ec2_backends[self.region].delete_network_interface(eni_id) + ec2_backends[self.account_id][self.region].delete_network_interface(eni_id) def update_alias(self, alias): """Change default alias to given alias.""" @@ -192,8 +200,7 @@ class DirectoryServiceBackend(BaseBackend): service_region, zones, "ds" ) - @staticmethod - def _verify_subnets(region, vpc_settings): + def _verify_subnets(self, region, vpc_settings): """Verify subnets are valid, else raise an exception. If settings are valid, add AvailabilityZones to vpc_settings. @@ -207,7 +214,7 @@ class DirectoryServiceBackend(BaseBackend): # Subnet IDs are checked before the VPC ID. The Subnet IDs must # be valid and in different availability zones. try: - subnets = ec2_backends[region].get_all_subnets( + subnets = ec2_backends[self.account_id][region].get_all_subnets( subnet_ids=vpc_settings["SubnetIds"] ) except InvalidSubnetIdError as exc: @@ -223,7 +230,7 @@ class DirectoryServiceBackend(BaseBackend): "different Availability Zones." ) - vpcs = ec2_backends[region].describe_vpcs() + vpcs = ec2_backends[self.account_id][region].describe_vpcs() if vpc_settings["VpcId"] not in [x.id for x in vpcs]: raise ClientException("Invalid VPC ID.") vpc_settings["AvailabilityZones"] = regions @@ -274,6 +281,7 @@ class DirectoryServiceBackend(BaseBackend): raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( + self.account_id, region, name, password, @@ -319,6 +327,7 @@ class DirectoryServiceBackend(BaseBackend): raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( + self.account_id, region, name, password, @@ -400,6 +409,7 @@ class DirectoryServiceBackend(BaseBackend): raise DirectoryLimitExceededException("Tag Limit is exceeding") directory = Directory( + self.account_id, region, name, password, diff --git a/moto/ds/responses.py b/moto/ds/responses.py index 9d852a965..0516a4d97 100644 --- a/moto/ds/responses.py +++ b/moto/ds/responses.py @@ -10,10 +10,13 @@ from moto.ds.models import ds_backends class DirectoryServiceResponse(BaseResponse): """Handler for DirectoryService requests and responses.""" + def __init__(self): + super().__init__(service_name="ds") + @property def ds_backend(self): """Return backend instance specific for this region.""" - return ds_backends[self.region] + return ds_backends[self.current_account][self.region] def connect_directory(self): """Create an AD Connector to connect to a self-managed directory.""" diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index a5f49b66e..38cfabc4e 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -7,7 +7,6 @@ import re import uuid from collections import OrderedDict -from moto.core import get_account_id from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import unix_time, unix_time_millis, BackendDict from moto.core.exceptions import JsonRESTError @@ -252,7 +251,8 @@ class StreamRecord(BaseModel): class StreamShard(BaseModel): - def __init__(self, table): + def __init__(self, account_id, table): + self.account_id = account_id self.table = table self.id = "shardId-00000001541626099285-f35f62ef" self.starting_sequence_number = 1100000000017454423009 @@ -285,7 +285,7 @@ class StreamShard(BaseModel): len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:")) ] - result = lambda_backends[region].send_dynamodb_items( + result = lambda_backends[self.account_id][region].send_dynamodb_items( arn, self.items, esm.event_source_arn ) @@ -398,6 +398,7 @@ class Table(CloudFormationModel): def __init__( self, table_name, + account_id, region, schema=None, attr=None, @@ -410,6 +411,7 @@ class Table(CloudFormationModel): tags=None, ): self.name = table_name + self.account_id = account_id self.attr = attr self.schema = schema self.range_key_attr = None @@ -467,16 +469,16 @@ class Table(CloudFormationModel): self.sse_specification = sse_specification if sse_specification and "KMSMasterKeyId" not in self.sse_specification: self.sse_specification["KMSMasterKeyId"] = self._get_default_encryption_key( - region + account_id, region ) - def _get_default_encryption_key(self, region): + def _get_default_encryption_key(self, account_id, region): from moto.kms import kms_backends # https://aws.amazon.com/kms/features/#AWS_Service_Integration # An AWS managed CMK is created automatically when you first create # an encrypted resource using an AWS service integrated with KMS. - kms = kms_backends[region] + kms = kms_backends[account_id][region] ddb_alias = "alias/aws/dynamodb" if not kms.alias_exists(ddb_alias): key = kms.create_key( @@ -485,7 +487,6 @@ class Table(CloudFormationModel): key_spec="SYMMETRIC_DEFAULT", description="Default master key that protects my DynamoDB table storage", tags=None, - region=region, ) kms.add_alias(key.id, ddb_alias) ebs_key = kms.describe_key(ddb_alias) @@ -532,7 +533,7 @@ class Table(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] params = {} @@ -550,27 +551,29 @@ class Table(CloudFormationModel): if "StreamSpecification" in properties: params["streams"] = properties["StreamSpecification"] - table = dynamodb_backends[region_name].create_table( + table = dynamodb_backends[account_id][region_name].create_table( name=resource_name, **params ) return table @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - table = dynamodb_backends[region_name].delete_table(name=resource_name) + table = dynamodb_backends[account_id][region_name].delete_table( + name=resource_name + ) return table def _generate_arn(self, name): - return f"arn:aws:dynamodb:us-east-1:{get_account_id()}:table/{name}" + return f"arn:aws:dynamodb:us-east-1:{self.account_id}:table/{name}" def set_stream_specification(self, streams): self.stream_specification = streams if streams and (streams.get("StreamEnabled") or streams.get("StreamViewType")): self.stream_specification["StreamEnabled"] = True self.latest_stream_label = datetime.datetime.utcnow().isoformat() - self.stream_shard = StreamShard(self) + self.stream_shard = StreamShard(self.account_id, self) else: self.stream_specification = {"StreamEnabled": False} @@ -1042,14 +1045,14 @@ class Table(CloudFormationModel): return results, last_evaluated_key - def delete(self, region_name): - dynamodb_backends[region_name].delete_table(self.name) + def delete(self, account_id, region_name): + dynamodb_backends[account_id][region_name].delete_table(self.name) class RestoredTable(Table): - def __init__(self, name, region, backup): + def __init__(self, name, account_id, region, backup): params = self._parse_params_from_backup(backup) - super().__init__(name, region=region, **params) + super().__init__(name, account_id=account_id, region=region, **params) self.indexes = copy.deepcopy(backup.table.indexes) self.global_indexes = copy.deepcopy(backup.table.global_indexes) self.items = copy.deepcopy(backup.table.items) @@ -1079,9 +1082,9 @@ class RestoredTable(Table): class RestoredPITTable(Table): - def __init__(self, name, region, source): + def __init__(self, name, account_id, region, source): params = self._parse_params_from_table(source) - super().__init__(name, region=region, **params) + super().__init__(name, account_id=account_id, region=region, **params) self.indexes = copy.deepcopy(source.indexes) self.global_indexes = copy.deepcopy(source.global_indexes) self.items = copy.deepcopy(source.items) @@ -1129,7 +1132,7 @@ class Backup(object): def arn(self): return "arn:aws:dynamodb:{region}:{account}:table/{table_name}/backup/{identifier}".format( region=self.backend.region_name, - account=get_account_id(), + account=self.backend.account_id, table_name=self.table.name, identifier=self.identifier, ) @@ -1197,7 +1200,9 @@ class DynamoDBBackend(BaseBackend): def create_table(self, name, **params): if name in self.tables: raise ResourceInUseException - table = Table(name, region=self.region_name, **params) + table = Table( + name, account_id=self.account_id, region=self.region_name, **params + ) self.tables[name] = table return table @@ -1818,7 +1823,10 @@ class DynamoDBBackend(BaseBackend): if target_table_name in self.tables: raise TableAlreadyExistsException(target_table_name) new_table = RestoredTable( - target_table_name, region=self.region_name, backup=backup + target_table_name, + account_id=self.account_id, + region=self.region_name, + backup=backup, ) self.tables[target_table_name] = new_table return new_table @@ -1836,7 +1844,10 @@ class DynamoDBBackend(BaseBackend): if target_table_name in self.tables: raise TableAlreadyExistsException(target_table_name) new_table = RestoredPITTable( - target_table_name, region=self.region_name, source=source + target_table_name, + account_id=self.account_id, + region=self.region_name, + source=source, ) self.tables[target_table_name] = new_table return new_table diff --git a/moto/dynamodb/parsing/validators.py b/moto/dynamodb/parsing/validators.py index 3969665b0..b4f3dcfd1 100644 --- a/moto/dynamodb/parsing/validators.py +++ b/moto/dynamodb/parsing/validators.py @@ -115,8 +115,6 @@ class ExpressionPathResolver(object): raise NotImplementedError( "Path resolution for {t}".format(t=type(child)) ) - if not isinstance(target, DynamoType): - print(target) return DDBTypedValue(target) def resolve_expression_path_nodes_to_dynamo_type( diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 93fa9b51a..b0fa105b8 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -117,6 +117,9 @@ def check_projection_expression(expression): class DynamoHandler(BaseResponse): + def __init__(self): + super().__init__(service_name="dynamodb") + def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -134,7 +137,7 @@ class DynamoHandler(BaseResponse): :return: DynamoDB2 Backend :rtype: moto.dynamodb2.models.DynamoDBBackend """ - return dynamodb_backends[self.region] + return dynamodb_backends[self.current_account][self.region] @amz_crc32 @amzn_request_id diff --git a/moto/dynamodb_v20111205/models.py b/moto/dynamodb_v20111205/models.py index a2a811a72..c8ea16da1 100644 --- a/moto/dynamodb_v20111205/models.py +++ b/moto/dynamodb_v20111205/models.py @@ -5,7 +5,6 @@ import json from collections import OrderedDict from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import unix_time, BackendDict -from moto.core import get_account_id from .comparisons import get_comparison_func @@ -90,6 +89,7 @@ class Item(BaseModel): class Table(CloudFormationModel): def __init__( self, + account_id, name, hash_key_attr, hash_key_type, @@ -98,6 +98,7 @@ class Table(CloudFormationModel): read_capacity=None, write_capacity=None, ): + self.account_id = account_id self.name = name self.hash_key_attr = hash_key_attr self.hash_key_type = hash_key_type @@ -151,7 +152,7 @@ class Table(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] key_attr = [ @@ -165,6 +166,7 @@ class Table(CloudFormationModel): if i["AttributeName"] == key_attr ][0] spec = { + "account_id": account_id, "name": properties["TableName"], "hash_key_attr": key_attr, "hash_key_type": key_type, @@ -306,9 +308,7 @@ class Table(CloudFormationModel): if attribute_name == "StreamArn": region = "us-east-1" time = "2000-01-01T00:00:00.000" - return "arn:aws:dynamodb:{0}:{1}:table/{2}/stream/{3}".format( - region, get_account_id(), self.name, time - ) + return f"arn:aws:dynamodb:{region}:{self.account_id}:table/{self.name}/stream/{time}" raise UnformattedGetAttTemplateException() @@ -318,7 +318,7 @@ class DynamoDBBackend(BaseBackend): self.tables = OrderedDict() def create_table(self, name, **params): - table = Table(name, **params) + table = Table(self.account_id, name, **params) self.tables[name] = table return table diff --git a/moto/dynamodb_v20111205/responses.py b/moto/dynamodb_v20111205/responses.py index 1e2cfdc4d..9c7b82732 100644 --- a/moto/dynamodb_v20111205/responses.py +++ b/moto/dynamodb_v20111205/responses.py @@ -6,6 +6,9 @@ from .models import dynamodb_backends, dynamo_json_dump class DynamoHandler(BaseResponse): + def __init__(self): + super().__init__(service_name="dynamodb") + def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -38,7 +41,7 @@ class DynamoHandler(BaseResponse): @property def backend(self): - return dynamodb_backends["global"] + return dynamodb_backends[self.current_account]["global"] def list_tables(self): body = self.body diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 0e82cd5af..04f580ebe 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -70,7 +70,7 @@ class DynamoDBStreamsBackend(BaseBackend): @property def dynamodb(self): - return dynamodb_backends[self.region_name] + return dynamodb_backends[self.account_id][self.region_name] def _get_table_from_arn(self, arn): table_name = arn.split(":", 6)[5].split("/")[1] diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index b0707ff44..ff8a78650 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -4,9 +4,12 @@ from .models import dynamodbstreams_backends class DynamoDBStreamsHandler(BaseResponse): + def __init__(self): + super().__init__(service_name="dynamodb-streams") + @property def backend(self): - return dynamodbstreams_backends[self.region] + return dynamodbstreams_backends[self.current_account][self.region] def describe_stream(self): arn = self._get_param("StreamArn") diff --git a/moto/ebs/models.py b/moto/ebs/models.py index 69a5c32a8..df342ef5f 100644 --- a/moto/ebs/models.py +++ b/moto/ebs/models.py @@ -1,6 +1,6 @@ """EBSBackend class with methods for supported APIs.""" -from moto.core import ACCOUNT_ID, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, unix_time from moto.ec2 import ec2_backends from moto.ec2.models.elastic_block_store import Snapshot @@ -17,7 +17,8 @@ class Block(BaseModel): class EBSSnapshot(BaseModel): - def __init__(self, snapshot: Snapshot): + def __init__(self, account_id, snapshot: Snapshot): + self.account_id = account_id self.snapshot_id = snapshot.id self.status = "pending" self.start_time = unix_time() @@ -39,7 +40,7 @@ class EBSSnapshot(BaseModel): def to_json(self): return { "SnapshotId": self.snapshot_id, - "OwnerId": ACCOUNT_ID, + "OwnerId": self.account_id, "Status": self.status, "StartTime": self.start_time, "VolumeSize": self.volume_size, @@ -58,7 +59,7 @@ class EBSBackend(BaseBackend): @property def ec2_backend(self): - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] def start_snapshot(self, volume_size, tags, description): zone_name = f"{self.region_name}a" @@ -69,7 +70,7 @@ class EBSBackend(BaseBackend): if tags: tags = {tag["Key"]: tag["Value"] for tag in tags} snapshot.add_tags(tags) - ebs_snapshot = EBSSnapshot(snapshot=snapshot) + ebs_snapshot = EBSSnapshot(account_id=self.account_id, snapshot=snapshot) self.snapshots[ebs_snapshot.snapshot_id] = ebs_snapshot return ebs_snapshot diff --git a/moto/ebs/responses.py b/moto/ebs/responses.py index 5b92ff2f5..4f301a415 100644 --- a/moto/ebs/responses.py +++ b/moto/ebs/responses.py @@ -8,10 +8,13 @@ from .models import ebs_backends class EBSResponse(BaseResponse): """Handler for EBS requests and responses.""" + def __init__(self): + super().__init__(service_name="ebs") + @property def ebs_backend(self): """Return backend instance specific for this region.""" - return ebs_backends[self.region] + return ebs_backends[self.current_account][self.region] def snapshots(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/ec2/__init__.py b/moto/ec2/__init__.py index 1025a0d60..7ac36e67b 100644 --- a/moto/ec2/__init__.py +++ b/moto/ec2/__init__.py @@ -1,5 +1,4 @@ from .models import ec2_backends from ..core.models import base_decorator -ec2_backend = ec2_backends["us-east-1"] mock_ec2 = base_decorator(ec2_backends) diff --git a/moto/ec2/models/__init__.py b/moto/ec2/models/__init__.py index 37e480ed9..addaadc02 100644 --- a/moto/ec2/models/__init__.py +++ b/moto/ec2/models/__init__.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.core import BaseBackend from moto.core.utils import BackendDict from ..exceptions import ( @@ -55,8 +54,6 @@ from ..utils import ( get_prefix, ) -OWNER_ID = get_account_id() - def validate_resource_ids(resource_ids): if not resource_ids: @@ -72,15 +69,15 @@ class SettingsBackend: self.ebs_encryption_by_default = False def disable_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] ec2_backend.ebs_encryption_by_default = False def enable_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] ec2_backend.ebs_encryption_by_default = True def get_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] return ec2_backend.ebs_encryption_by_default diff --git a/moto/ec2/models/amis.py b/moto/ec2/models/amis.py index 66e7806ae..b262ef018 100644 --- a/moto/ec2/models/amis.py +++ b/moto/ec2/models/amis.py @@ -1,7 +1,6 @@ import json import re from os import environ -from moto.core import get_account_id from moto.utilities.utils import load_resource from ..exceptions import ( InvalidAMIIdError, @@ -34,7 +33,7 @@ class Ami(TaggedEC2Resource): source_ami=None, name=None, description=None, - owner_id=get_account_id(), + owner_id=None, owner_alias=None, public=False, virtualization_type=None, @@ -57,7 +56,7 @@ class Ami(TaggedEC2Resource): self.name = name self.image_type = image_type self.image_location = image_location - self.owner_id = owner_id + self.owner_id = owner_id or ec2_backend.account_id self.owner_alias = owner_alias self.description = description self.virtualization_type = virtualization_type @@ -68,9 +67,7 @@ class Ami(TaggedEC2Resource): self.root_device_name = root_device_name self.root_device_type = root_device_type self.sriov = sriov - self.creation_date = ( - utc_date_and_time() if creation_date is None else creation_date - ) + self.creation_date = creation_date or utc_date_and_time() if instance: self.instance = instance @@ -107,7 +104,7 @@ class Ami(TaggedEC2Resource): snapshot_description or "Auto-created snapshot for AMI %s" % self.id ) self.ebs_snapshot = self.ec2_backend.create_snapshot( - volume.id, snapshot_description, owner_id, from_ami=ami_id + volume.id, snapshot_description, self.owner_id, from_ami=ami_id ) self.ec2_backend.delete_volume(volume.id) @@ -185,7 +182,7 @@ class AmiBackend: source_ami=None, name=name, description=description, - owner_id=get_account_id(), + owner_id=None, snapshot_description=f"Created by CreateImage({instance_id}) for {ami_id}", ) for tag in tags: @@ -196,7 +193,7 @@ class AmiBackend: def copy_image(self, source_image_id, source_region, name=None, description=None): from ..models import ec2_backends - source_ami = ec2_backends[source_region].describe_images( + source_ami = ec2_backends[self.account_id][source_region].describe_images( ami_ids=[source_image_id] )[0] ami_id = random_ami_id() @@ -245,7 +242,7 @@ class AmiBackend: # support filtering by Owners=['self'] if "self" in owners: owners = list( - map(lambda o: get_account_id() if o == "self" else o, owners) + map(lambda o: self.account_id if o == "self" else o, owners) ) images = [ ami diff --git a/moto/ec2/models/carrier_gateways.py b/moto/ec2/models/carrier_gateways.py index eecb1a10c..51ef97a6b 100644 --- a/moto/ec2/models/carrier_gateways.py +++ b/moto/ec2/models/carrier_gateways.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.utilities.utils import filter_resources from .core import TaggedEC2Resource @@ -20,7 +19,7 @@ class CarrierGateway(TaggedEC2Resource): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id class CarrierGatewayBackend: diff --git a/moto/ec2/models/elastic_block_store.py b/moto/ec2/models/elastic_block_store.py index 08f784ce5..c45702f10 100644 --- a/moto/ec2/models/elastic_block_store.py +++ b/moto/ec2/models/elastic_block_store.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from moto.packages.boto.ec2.blockdevicemapping import BlockDeviceType from ..exceptions import ( InvalidAMIAttributeItemValueError, @@ -38,7 +38,7 @@ class VolumeAttachment(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -47,7 +47,7 @@ class VolumeAttachment(CloudFormationModel): instance_id = properties["InstanceId"] volume_id = properties["VolumeId"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] attachment = ec2_backend.attach_volume( volume_id=volume_id, instance_id=instance_id, @@ -90,13 +90,13 @@ class Volume(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] volume = ec2_backend.create_volume( size=properties.get("Size"), zone_name=properties.get("AvailabilityZone") ) @@ -150,7 +150,7 @@ class Snapshot(TaggedEC2Resource): volume, description, encrypted=False, - owner_id=get_account_id(), + owner_id=None, from_ami=None, ): self.id = snapshot_id @@ -162,7 +162,7 @@ class Snapshot(TaggedEC2Resource): self.ec2_backend = ec2_backend self.status = "completed" self.encrypted = encrypted - self.owner_id = owner_id + self.owner_id = owner_id or ec2_backend.account_id self.from_ami = from_ami def get_filter_value(self, filter_name): @@ -339,9 +339,9 @@ class EBSBackend: def copy_snapshot(self, source_snapshot_id, source_region, description=None): from ..models import ec2_backends - source_snapshot = ec2_backends[source_region].describe_snapshots( - snapshot_ids=[source_snapshot_id] - )[0] + source_snapshot = ec2_backends[self.account_id][ + source_region + ].describe_snapshots(snapshot_ids=[source_snapshot_id])[0] snapshot_id = random_snapshot_id() snapshot = Snapshot( self, @@ -405,7 +405,7 @@ class EBSBackend: # an encrypted resource using an AWS service integrated with KMS. from moto.kms import kms_backends - kms = kms_backends[self.region_name] + kms = kms_backends[self.account_id][self.region_name] ebs_alias = "alias/aws/ebs" if not kms.alias_exists(ebs_alias): key = kms.create_key( @@ -414,7 +414,6 @@ class EBSBackend: key_spec="SYMMETRIC_DEFAULT", description="Default master key that protects my EBS volumes when no other key is defined", tags=None, - region=self.region_name, ) kms.add_alias(key.id, ebs_alias) ebs_key = kms.describe_key(ebs_alias) diff --git a/moto/ec2/models/elastic_ip_addresses.py b/moto/ec2/models/elastic_ip_addresses.py index 85dacfa42..e1dbc7f2e 100644 --- a/moto/ec2/models/elastic_ip_addresses.py +++ b/moto/ec2/models/elastic_ip_addresses.py @@ -42,11 +42,11 @@ class ElasticAddress(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] properties = cloudformation_json.get("Properties") instance_id = None diff --git a/moto/ec2/models/elastic_network_interfaces.py b/moto/ec2/models/elastic_network_interfaces.py index 6cd96fdfe..846bf696a 100644 --- a/moto/ec2/models/elastic_network_interfaces.py +++ b/moto/ec2/models/elastic_network_interfaces.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from ..exceptions import InvalidNetworkAttachmentIdError, InvalidNetworkInterfaceIdError from .core import TaggedEC2Resource from .security_groups import SecurityGroup @@ -134,7 +134,7 @@ class NetworkInterface(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @property def association(self): @@ -160,7 +160,7 @@ class NetworkInterface(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -168,7 +168,7 @@ class NetworkInterface(TaggedEC2Resource, CloudFormationModel): security_group_ids = properties.get("SecurityGroups", []) - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] subnet_id = properties.get("SubnetId") if subnet_id: subnet = ec2_backend.get_subnet(subnet_id) diff --git a/moto/ec2/models/flow_logs.py b/moto/ec2/models/flow_logs.py index 38491780f..26ea09614 100644 --- a/moto/ec2/models/flow_logs.py +++ b/moto/ec2/models/flow_logs.py @@ -58,7 +58,7 @@ class FlowLogs(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -74,7 +74,7 @@ class FlowLogs(TaggedEC2Resource, CloudFormationModel): log_format = properties.get("LogFormat") max_aggregation_interval = properties.get("MaxAggregationInterval") - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] flow_log, _ = ec2_backend.create_flow_logs( resource_type, resource_id, @@ -219,7 +219,7 @@ class FlowLogsBackend: arn = log_destination.split(":", 5)[5] try: - s3_backends["global"].get_bucket(arn) + s3_backends[self.account_id]["global"].get_bucket(arn) except MissingBucket: # Instead of creating FlowLog report # the unsuccessful status for the @@ -242,7 +242,9 @@ class FlowLogsBackend: try: # Need something easy to check the group exists. # The list_tags_log_group seems to do the trick. - logs_backends[self.region_name].list_tags_log_group(log_group_name) + logs_backends[self.account_id][ + self.region_name + ].list_tags_log_group(log_group_name) except ResourceNotFoundException: deliver_logs_status = "FAILED" deliver_logs_error_message = "Access error" diff --git a/moto/ec2/models/iam_instance_profile.py b/moto/ec2/models/iam_instance_profile.py index 33ead5c40..2f8ea69fb 100644 --- a/moto/ec2/models/iam_instance_profile.py +++ b/moto/ec2/models/iam_instance_profile.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.core import CloudFormationModel from ..exceptions import ( IncorrectStateIamProfileAssociationError, @@ -10,8 +9,6 @@ from ..utils import ( filter_iam_instance_profiles, ) -OWNER_ID = get_account_id() - class IamInstanceProfileAssociation(CloudFormationModel): def __init__(self, ec2_backend, association_id, instance, iam_instance_profile): @@ -32,7 +29,7 @@ class IamInstanceProfileAssociationBackend: iam_association_id = random_iam_instance_profile_association_id() instance_profile = filter_iam_instance_profiles( - iam_instance_profile_arn, iam_instance_profile_name + self.account_id, iam_instance_profile_arn, iam_instance_profile_name ) if instance_id in self.iam_instance_profile_associations.keys(): @@ -101,7 +98,7 @@ class IamInstanceProfileAssociationBackend: iam_instance_profile_arn=None, ): instance_profile = filter_iam_instance_profiles( - iam_instance_profile_arn, iam_instance_profile_name + self.account_id, iam_instance_profile_arn, iam_instance_profile_name ) iam_instance_profile_association = None diff --git a/moto/ec2/models/instances.py b/moto/ec2/models/instances.py index c66fec62a..5f438e5d9 100644 --- a/moto/ec2/models/instances.py +++ b/moto/ec2/models/instances.py @@ -4,7 +4,6 @@ from collections import OrderedDict from datetime import datetime from moto import settings -from moto.core import get_account_id from moto.core import CloudFormationModel from moto.core.utils import camelcase_to_underscores from moto.ec2.models.fleets import Fleet @@ -70,7 +69,7 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel): super().__init__() self.ec2_backend = ec2_backend self.id = random_instance_id() - self.owner_id = get_account_id() + self.owner_id = ec2_backend.account_id self.lifecycle = kwargs.get("lifecycle") nics = kwargs.get("nics", {}) @@ -265,13 +264,13 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] security_group_ids = properties.get("SecurityGroups", []) group_names = [ ec2_backend.get_security_group_from_id(group_id).name @@ -307,11 +306,11 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] all_instances = ec2_backend.all_instances() # the resource_name for instances is the stack name, logical id, and random suffix separated @@ -326,7 +325,7 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel): tag["key"] == "aws:cloudformation:logical-id" and tag["value"] == logical_id ): - instance.delete(region_name) + instance.delete(account_id, region_name) @property def physical_resource_id(self): @@ -360,7 +359,7 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel): def is_running(self): return self._state.name == "running" - def delete(self, region): # pylint: disable=unused-argument + def delete(self, account_id, region): # pylint: disable=unused-argument self.terminate() def terminate(self): diff --git a/moto/ec2/models/internet_gateways.py b/moto/ec2/models/internet_gateways.py index ebabf9374..9f3d4ea2e 100644 --- a/moto/ec2/models/internet_gateways.py +++ b/moto/ec2/models/internet_gateways.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from .core import TaggedEC2Resource from ..exceptions import ( @@ -80,7 +80,7 @@ class InternetGateway(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @staticmethod def cloudformation_name_type(): @@ -93,11 +93,11 @@ class InternetGateway(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] return ec2_backend.create_internet_gateway() @property diff --git a/moto/ec2/models/managed_prefixes.py b/moto/ec2/models/managed_prefixes.py index 7084cc4ef..247f4b42b 100644 --- a/moto/ec2/models/managed_prefixes.py +++ b/moto/ec2/models/managed_prefixes.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.utilities.utils import filter_resources from .core import TaggedEC2Resource from ..utils import random_managed_prefix_list_id, describe_tag_filter @@ -37,9 +36,7 @@ class ManagedPrefixList(TaggedEC2Resource): @property def owner_id(self): - return ( - get_account_id() if not self.resource_owner_id else self.resource_owner_id - ) + return self.resource_owner_id or self.ec2_backend.account_id class ManagedPrefixListBackend: diff --git a/moto/ec2/models/nat_gateways.py b/moto/ec2/models/nat_gateways.py index ecfe3abbf..593422faf 100644 --- a/moto/ec2/models/nat_gateways.py +++ b/moto/ec2/models/nat_gateways.py @@ -69,11 +69,11 @@ class NatGateway(CloudFormationModel, TaggedEC2Resource): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] nat_gateway = ec2_backend.create_nat_gateway( cloudformation_json["Properties"]["SubnetId"], cloudformation_json["Properties"]["AllocationId"], diff --git a/moto/ec2/models/network_acls.py b/moto/ec2/models/network_acls.py index f060be55d..25acf1b39 100644 --- a/moto/ec2/models/network_acls.py +++ b/moto/ec2/models/network_acls.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from ..exceptions import ( InvalidNetworkAclIdError, InvalidRouteTableIdError, @@ -12,9 +11,6 @@ from ..utils import ( ) -OWNER_ID = get_account_id() - - class NetworkAclBackend: def __init__(self): self.network_acls = {} @@ -210,12 +206,12 @@ class NetworkAclAssociation(object): class NetworkAcl(TaggedEC2Resource): def __init__( - self, ec2_backend, network_acl_id, vpc_id, default=False, owner_id=OWNER_ID + self, ec2_backend, network_acl_id, vpc_id, default=False, owner_id=None ): self.ec2_backend = ec2_backend self.id = network_acl_id self.vpc_id = vpc_id - self.owner_id = owner_id + self.owner_id = owner_id or ec2_backend.account_id self.network_acl_entries = [] self.associations = {} self.default = "true" if default is True else "false" diff --git a/moto/ec2/models/route_tables.py b/moto/ec2/models/route_tables.py index 947383724..c2f7bc751 100644 --- a/moto/ec2/models/route_tables.py +++ b/moto/ec2/models/route_tables.py @@ -1,6 +1,6 @@ import ipaddress -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from .core import TaggedEC2Resource from ..exceptions import ( DependencyViolationError, @@ -32,7 +32,7 @@ class RouteTable(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @staticmethod def cloudformation_name_type(): @@ -45,14 +45,14 @@ class RouteTable(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] vpc_id = properties["VpcId"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] route_table = ec2_backend.create_route_table(vpc_id=vpc_id) return route_table @@ -255,7 +255,7 @@ class Route(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -270,7 +270,7 @@ class Route(CloudFormationModel): pcx_id = properties.get("VpcPeeringConnectionId") route_table_id = properties["RouteTableId"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] route_table = ec2_backend.create_route( route_table_id=route_table_id, destination_cidr_block=properties.get("DestinationCidrBlock"), diff --git a/moto/ec2/models/security_groups.py b/moto/ec2/models/security_groups.py index 0e4e22475..1109522b5 100644 --- a/moto/ec2/models/security_groups.py +++ b/moto/ec2/models/security_groups.py @@ -3,7 +3,7 @@ import itertools import json from collections import defaultdict -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from moto.core.utils import aws_api_matches from ..exceptions import ( DependencyViolationError, @@ -30,6 +30,7 @@ from ..utils import ( class SecurityRule(object): def __init__( self, + account_id, ip_protocol, from_port, to_port, @@ -37,6 +38,7 @@ class SecurityRule(object): source_groups, prefix_list_ids=None, ): + self.account_id = account_id self.id = random_security_group_rule_id() self.ip_protocol = str(ip_protocol) self.ip_ranges = ip_ranges or [] @@ -69,7 +71,7 @@ class SecurityRule(object): @property def owner_id(self): - return get_account_id() + return self.account_id def __eq__(self, other): if self.ip_protocol != other.ip_protocol: @@ -126,7 +128,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): self.egress_rules = [] self.enis = {} self.vpc_id = vpc_id - self.owner_id = get_account_id() + self.owner_id = ec2_backend.account_id self.add_tags(tags or {}) self.is_default = is_default or False @@ -135,11 +137,15 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): vpc = self.ec2_backend.vpcs.get(vpc_id) if vpc: self.egress_rules.append( - SecurityRule("-1", None, None, [{"CidrIp": "0.0.0.0/0"}], []) + SecurityRule( + self.owner_id, "-1", None, None, [{"CidrIp": "0.0.0.0/0"}], [] + ) ) if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0: self.egress_rules.append( - SecurityRule("-1", None, None, [{"CidrIpv6": "::/0"}], []) + SecurityRule( + self.owner_id, "-1", None, None, [{"CidrIpv6": "::/0"}], [] + ) ) # each filter as a simple function in a mapping @@ -181,13 +187,13 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] vpc_id = properties.get("VpcId") security_group = ec2_backend.create_security_group( name=resource_name, @@ -223,35 +229,44 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls._delete_security_group_given_vpc_id( - original_resource.name, original_resource.vpc_id, region_name + original_resource.name, original_resource.vpc_id, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): properties = cloudformation_json["Properties"] vpc_id = properties.get("VpcId") - cls._delete_security_group_given_vpc_id(resource_name, vpc_id, region_name) + cls._delete_security_group_given_vpc_id( + resource_name, vpc_id, account_id, region_name + ) @classmethod - def _delete_security_group_given_vpc_id(cls, resource_name, vpc_id, region_name): + def _delete_security_group_given_vpc_id( + cls, resource_name, vpc_id, account_id, region_name + ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] security_group = ec2_backend.get_security_group_by_name_or_id( resource_name, vpc_id ) if security_group: - security_group.delete(region_name) + security_group.delete(account_id, region_name) - def delete(self, region_name): # pylint: disable=unused-argument + def delete(self, account_id, region_name): # pylint: disable=unused-argument """Not exposed as part of the ELB API - used for CloudFormation.""" self.ec2_backend.delete_security_group(group_id=self.id) @@ -603,7 +618,13 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) if security_rule in group.ingress_rules: @@ -660,7 +681,13 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) # To match drift property of the security rules. @@ -748,7 +775,13 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) if security_rule in group.egress_rules: @@ -820,7 +853,13 @@ class SecurityGroupBackend: ip_ranges.remove(item) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) # To match drift property of the security rules. @@ -901,7 +940,13 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) for rule in group.ingress_rules: if ( @@ -951,7 +996,13 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + self.account_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + _source_groups, + prefix_list_ids, ) for rule in group.egress_rules: if ( @@ -1008,7 +1059,7 @@ class SecurityGroupBackend: _source_groups = [] for item in source_groups or []: if "OwnerId" not in item: - item["OwnerId"] = get_account_id() + item["OwnerId"] = self.account_id # for VPCs if "GroupId" in item: if not self.get_security_group_by_name_or_id( @@ -1061,13 +1112,13 @@ class SecurityGroupIngress(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] group_name = properties.get("GroupName") group_id = properties.get("GroupId") ip_protocol = properties.get("IpProtocol") diff --git a/moto/ec2/models/spot_requests.py b/moto/ec2/models/spot_requests.py index ba57684c4..1b2c9f8d9 100644 --- a/moto/ec2/models/spot_requests.py +++ b/moto/ec2/models/spot_requests.py @@ -306,12 +306,12 @@ class SpotFleetRequest(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"]["SpotFleetRequestConfigData"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] spot_price = properties.get("SpotPrice") target_capacity = properties["TargetCapacity"] diff --git a/moto/ec2/models/subnets.py b/moto/ec2/models/subnets.py index db976a8cb..e8b3bbcc0 100644 --- a/moto/ec2/models/subnets.py +++ b/moto/ec2/models/subnets.py @@ -2,7 +2,6 @@ import ipaddress import itertools from collections import defaultdict -from moto.core import get_account_id from moto.core import CloudFormationModel from ..exceptions import ( GenericInvalidParameterValueError, @@ -64,7 +63,7 @@ class Subnet(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @staticmethod def cloudformation_name_type(): @@ -77,7 +76,7 @@ class Subnet(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -86,7 +85,7 @@ class Subnet(TaggedEC2Resource, CloudFormationModel): vpc_id = properties["VpcId"] cidr_block = properties["CidrBlock"] availability_zone = properties.get("AvailabilityZone") - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] subnet = ec2_backend.create_subnet( vpc_id=vpc_id, cidr_block=cidr_block, availability_zone=availability_zone ) @@ -415,7 +414,7 @@ class SubnetRouteTableAssociation(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -424,7 +423,7 @@ class SubnetRouteTableAssociation(CloudFormationModel): route_table_id = properties["RouteTableId"] subnet_id = properties["SubnetId"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] subnet_association = ec2_backend.create_subnet_association( route_table_id=route_table_id, subnet_id=subnet_id ) diff --git a/moto/ec2/models/transit_gateway.py b/moto/ec2/models/transit_gateway.py index a82e3591b..e4d301f5a 100644 --- a/moto/ec2/models/transit_gateway.py +++ b/moto/ec2/models/transit_gateway.py @@ -1,5 +1,5 @@ from datetime import datetime -from moto.core import get_account_id, CloudFormationModel +from moto.core import CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.utilities.utils import filter_resources, merge_multiple_dicts from .core import TaggedEC2Resource @@ -38,7 +38,7 @@ class TransitGateway(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @staticmethod def cloudformation_name_type(): @@ -51,11 +51,11 @@ class TransitGateway(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] properties = cloudformation_json["Properties"] description = properties["Description"] options = dict(properties) diff --git a/moto/ec2/models/transit_gateway_attachments.py b/moto/ec2/models/transit_gateway_attachments.py index db0f90445..4b58d1725 100644 --- a/moto/ec2/models/transit_gateway_attachments.py +++ b/moto/ec2/models/transit_gateway_attachments.py @@ -1,5 +1,4 @@ from datetime import datetime -from moto.core import get_account_id from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.utilities.utils import merge_multiple_dicts, filter_resources from .core import TaggedEC2Resource @@ -24,20 +23,14 @@ class TransitGatewayAttachment(TaggedEC2Resource): self.add_tags(tags or {}) self._created_at = datetime.utcnow() - self.owner_id = self.resource_owner_id + self.resource_owner_id = backend.account_id + self.transit_gateway_owner_id = backend.account_id + self.owner_id = backend.account_id @property def create_time(self): return iso_8601_datetime_with_milliseconds(self._created_at) - @property - def resource_owner_id(self): - return get_account_id() - - @property - def transit_gateway_owner_id(self): - return get_account_id() - class TransitGatewayVpcAttachment(TransitGatewayAttachment): DEFAULT_OPTIONS = { @@ -93,10 +86,6 @@ class TransitGatewayPeeringAttachment(TransitGatewayAttachment): } self.status = PeeringConnectionStatus() - @property - def resource_owner_id(self): - return get_account_id() - class TransitGatewayAttachmentBackend: def __init__(self): diff --git a/moto/ec2/models/vpc_peering_connections.py b/moto/ec2/models/vpc_peering_connections.py index c37d5196f..fbbc6a9a1 100644 --- a/moto/ec2/models/vpc_peering_connections.py +++ b/moto/ec2/models/vpc_peering_connections.py @@ -65,13 +65,13 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] vpc = ec2_backend.get_vpc(properties["VpcId"]) peer_vpc = ec2_backend.get_vpc(properties["PeerVpcId"]) diff --git a/moto/ec2/models/vpc_service_configuration.py b/moto/ec2/models/vpc_service_configuration.py index 6f3f5e649..e679d7760 100644 --- a/moto/ec2/models/vpc_service_configuration.py +++ b/moto/ec2/models/vpc_service_configuration.py @@ -44,7 +44,7 @@ class VPCServiceConfigurationBackend: def elbv2_backend(self): from moto.elbv2.models import elbv2_backends - return elbv2_backends[self.region_name] + return elbv2_backends[self.account_id][self.region_name] def get_vpc_endpoint_service(self, resource_id): return self.configurations.get(resource_id) diff --git a/moto/ec2/models/vpcs.py b/moto/ec2/models/vpcs.py index 239eb9446..e6b7083e9 100644 --- a/moto/ec2/models/vpcs.py +++ b/moto/ec2/models/vpcs.py @@ -4,7 +4,6 @@ import weakref from collections import defaultdict from operator import itemgetter -from moto.core import get_account_id from moto.core import CloudFormationModel from .core import TaggedEC2Resource from ..exceptions import ( @@ -85,7 +84,7 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @property def physical_resource_id(self): @@ -101,7 +100,7 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends @@ -116,7 +115,7 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): route_table_ids = properties.get("RouteTableIds") security_group_ids = properties.get("SecurityGroupIds") - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] vpc_endpoint = ec2_backend.create_vpc_endpoint( vpc_id=vpc_id, service_name=service_name, @@ -167,7 +166,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel): @property def owner_id(self): - return get_account_id() + return self.ec2_backend.account_id @staticmethod def cloudformation_name_type(): @@ -180,13 +179,13 @@ class VPC(TaggedEC2Resource, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] vpc = ec2_backend.create_vpc( cidr_block=properties["CidrBlock"], instance_tenancy=properties.get("InstanceTenancy", "default"), @@ -629,7 +628,7 @@ class VPCBackend: return generic_filter(filters, vpc_end_points) @staticmethod - def _collect_default_endpoint_services(region): + def _collect_default_endpoint_services(account_id, region): """Return list of default services using list of backends.""" if DEFAULT_VPC_ENDPOINT_SERVICES: return DEFAULT_VPC_ENDPOINT_SERVICES @@ -643,7 +642,8 @@ class VPCBackend: from moto import backends # pylint: disable=import-outside-toplevel - for _backends in backends.unique_backends(): + for _backends in backends.service_backends(): + _backends = _backends[account_id] if region in _backends: service = _backends[region].default_vpc_endpoint_service(region, zones) if service: @@ -757,7 +757,9 @@ class VPCBackend: The DryRun parameter is ignored. """ - default_services = self._collect_default_endpoint_services(region) + default_services = self._collect_default_endpoint_services( + self.account_id, region + ) for service_name in service_names: if service_name not in [x["ServiceName"] for x in default_services]: raise InvalidServiceName(service_name) diff --git a/moto/ec2/models/vpn_gateway.py b/moto/ec2/models/vpn_gateway.py index 00eef817b..322a5a434 100644 --- a/moto/ec2/models/vpn_gateway.py +++ b/moto/ec2/models/vpn_gateway.py @@ -22,13 +22,13 @@ class VPCGatewayAttachment(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] vpn_gateway_id = properties.get("VpnGatewayId", None) internet_gateway_id = properties.get("InternetGatewayId", None) if vpn_gateway_id: @@ -78,14 +78,14 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): from ..models import ec2_backends properties = cloudformation_json["Properties"] _type = properties["Type"] asn = properties.get("AmazonSideAsn", None) - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] return ec2_backend.create_vpn_gateway(gateway_type=_type, amazon_side_asn=asn) diff --git a/moto/ec2/responses/__init__.py b/moto/ec2/responses/__init__.py index 548d97188..bdee5ce03 100644 --- a/moto/ec2/responses/__init__.py +++ b/moto/ec2/responses/__init__.py @@ -88,11 +88,14 @@ class EC2Response( IamInstanceProfiles, CarrierGateway, ): + def __init__(self): + super().__init__(service_name="ec2") + @property def ec2_backend(self): from moto.ec2.models import ec2_backends - return ec2_backends[self.region] + return ec2_backends[self.current_account][self.region] @property def should_autoescape(self): diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 635276df7..4c5b9b0b5 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -4,7 +4,6 @@ from moto.ec2.exceptions import ( InvalidParameterCombination, InvalidRequest, ) -from moto.core import get_account_id from copy import deepcopy @@ -36,7 +35,11 @@ class InstanceResponse(EC2BaseResponse): next_token = reservations_resp[-1].id template = self.response_template(EC2_DESCRIBE_INSTANCES) return ( - template.render(reservations=reservations_resp, next_token=next_token) + template.render( + account_id=self.current_account, + reservations=reservations_resp, + next_token=next_token, + ) .replace("True", "true") .replace("False", "false") ) @@ -85,7 +88,9 @@ class InstanceResponse(EC2BaseResponse): ) template = self.response_template(EC2_RUN_INSTANCES) - return template.render(reservation=new_reservation) + return template.render( + account_id=self.current_account, reservation=new_reservation + ) def terminate_instances(self): instance_ids = self._get_multi_param("InstanceId") @@ -94,8 +99,12 @@ class InstanceResponse(EC2BaseResponse): from moto.autoscaling import autoscaling_backends from moto.elbv2 import elbv2_backends - autoscaling_backends[self.region].notify_terminate_instances(instance_ids) - elbv2_backends[self.region].notify_terminate_instances(instance_ids) + autoscaling_backends[self.current_account][ + self.region + ].notify_terminate_instances(instance_ids) + elbv2_backends[self.current_account][ + self.region + ].notify_terminate_instances(instance_ids) template = self.response_template(EC2_TERMINATE_INSTANCES) return template.render(instances=instances) @@ -380,13 +389,10 @@ BLOCK_DEVICE_MAPPING_TEMPLATE = { }, } -EC2_RUN_INSTANCES = ( - """ +EC2_RUN_INSTANCES = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ reservation.id }} - """ - + get_account_id() - + """ + {{ account_id }} sg-245f6a01 @@ -473,9 +479,7 @@ EC2_RUN_INSTANCES = ( {{ nic.subnet.vpc_id }} {% endif %} Primary network interface - """ - + get_account_id() - + """ + {{ account_id }} in-use 1b:2b:3c:4d:5e:6f {{ nic.private_ip_address }} @@ -498,9 +502,7 @@ EC2_RUN_INSTANCES = ( {% if nic.public_ip %} {{ nic.public_ip }} - """ - + get_account_id() - + """ + {{ account_id }} {% endif %} @@ -510,9 +512,7 @@ EC2_RUN_INSTANCES = ( {% if nic.public_ip %} {{ nic.public_ip }} - """ - + get_account_id() - + """ + {{ account_id }} {% endif %} @@ -524,18 +524,14 @@ EC2_RUN_INSTANCES = ( {% endfor %} """ -) -EC2_DESCRIBE_INSTANCES = ( - """ +EC2_DESCRIBE_INSTANCES = """ fdcdcab1-ae5c-489e-9c33-4637c5dda355 {% for reservation in reservations %} {{ reservation.id }} - """ - + get_account_id() - + """ + {{ account_id }} {% for group in reservation.dynamic_group_list %} @@ -633,9 +629,7 @@ EC2_DESCRIBE_INSTANCES = ( {% endfor %} {{ instance.virtualization_type }} - ABCDE""" - + get_account_id() - + """3 + ABCDE{{ account_id }}3 {% if instance.get_tags() %} {% for tag in instance.get_tags() %} @@ -658,9 +652,7 @@ EC2_DESCRIBE_INSTANCES = ( {{ nic.subnet.vpc_id }} {% endif %} Primary network interface - """ - + get_account_id() - + """ + {{ account_id }} in-use 1b:2b:3c:4d:5e:6f {{ nic.private_ip_address }} @@ -687,9 +679,7 @@ EC2_DESCRIBE_INSTANCES = ( {% if nic.public_ip %} {{ nic.public_ip }} - """ - + get_account_id() - + """ + {{ account_id }} {% endif %} @@ -699,9 +689,7 @@ EC2_DESCRIBE_INSTANCES = ( {% if nic.public_ip %} {{ nic.public_ip }} - """ - + get_account_id() - + """ + {{ account_id }} {% endif %} @@ -719,7 +707,6 @@ EC2_DESCRIBE_INSTANCES = ( {{ next_token }} {% endif %} """ -) EC2_TERMINATE_INSTANCES = """ diff --git a/moto/ec2/responses/launch_templates.py b/moto/ec2/responses/launch_templates.py index 0029c0d62..a8d18766e 100644 --- a/moto/ec2/responses/launch_templates.py +++ b/moto/ec2/responses/launch_templates.py @@ -1,5 +1,4 @@ import uuid -from moto.ec2.models import OWNER_ID from moto.ec2.exceptions import FilterNotImplementedError from ._base_response import EC2BaseResponse @@ -128,9 +127,7 @@ class LaunchTemplates(EC2BaseResponse): "launchTemplate", { "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( - OWNER_ID=OWNER_ID - ), + "createdBy": f"arn:aws:iam::{self.current_account}:root", "defaultVersionNumber": template.default_version_number, "latestVersionNumber": version.number, "launchTemplateId": template.id, @@ -162,9 +159,7 @@ class LaunchTemplates(EC2BaseResponse): "launchTemplateVersion", { "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( - OWNER_ID=OWNER_ID - ), + "createdBy": f"arn:aws:iam::{self.current_account}:root", "defaultVersion": template.is_default(version), "launchTemplateData": version.data, "launchTemplateId": template.id, @@ -253,9 +248,7 @@ class LaunchTemplates(EC2BaseResponse): "item", { "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( - OWNER_ID=OWNER_ID - ), + "createdBy": f"arn:aws:iam::{self.current_account}:root", "defaultVersion": True, "launchTemplateData": version.data, "launchTemplateId": template.id, @@ -291,9 +284,7 @@ class LaunchTemplates(EC2BaseResponse): "item", { "createTime": template.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( - OWNER_ID=OWNER_ID - ), + "createdBy": f"arn:aws:iam::{self.current_account}:root", "defaultVersionNumber": template.default_version_number, "latestVersionNumber": template.latest_version_number, "launchTemplateId": template.id, diff --git a/moto/ec2/responses/vpc_peering_connections.py b/moto/ec2/responses/vpc_peering_connections.py index 773a2b8ca..965cc6548 100644 --- a/moto/ec2/responses/vpc_peering_connections.py +++ b/moto/ec2/responses/vpc_peering_connections.py @@ -1,5 +1,4 @@ from moto.core.responses import BaseResponse -from moto.core import get_account_id class VPCPeeringConnections(BaseResponse): @@ -15,11 +14,13 @@ class VPCPeeringConnections(BaseResponse): else: from moto.ec2.models import ec2_backends - peer_vpc = ec2_backends[peer_region].get_vpc(self._get_param("PeerVpcId")) + peer_vpc = ec2_backends[self.current_account][peer_region].get_vpc( + self._get_param("PeerVpcId") + ) vpc = self.ec2_backend.get_vpc(self._get_param("VpcId")) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc, tags) template = self.response_template(CREATE_VPC_PEERING_CONNECTION_RESPONSE) - return template.render(vpc_pcx=vpc_pcx) + return template.render(account_id=self.current_account, vpc_pcx=vpc_pcx) def delete_vpc_peering_connection(self): vpc_pcx_id = self._get_param("VpcPeeringConnectionId") @@ -33,13 +34,13 @@ class VPCPeeringConnections(BaseResponse): vpc_peering_ids=ids ) template = self.response_template(DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) - return template.render(vpc_pcxs=vpc_pcxs) + return template.render(account_id=self.current_account, vpc_pcxs=vpc_pcxs) def accept_vpc_peering_connection(self): vpc_pcx_id = self._get_param("VpcPeeringConnectionId") vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id) template = self.response_template(ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) - return template.render(vpc_pcx=vpc_pcx) + return template.render(account_id=self.current_account, vpc_pcx=vpc_pcx) def reject_vpc_peering_connection(self): vpc_pcx_id = self._get_param("VpcPeeringConnectionId") @@ -66,16 +67,13 @@ class VPCPeeringConnections(BaseResponse): # we are assuming that the owner id for accepter and requester vpc are same # as we are checking for the vpc exsistance -CREATE_VPC_PEERING_CONNECTION_RESPONSE = ( - """ +CREATE_VPC_PEERING_CONNECTION_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE {{ vpc_pcx.id }} - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.vpc.id }} {{ vpc_pcx.vpc.cidr_block }} @@ -85,9 +83,7 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = ( - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.peer_vpc.id }} {{ vpc_pcx.accepter_options.AllowEgressFromLocalClassicLinkToRemoteVpc or '' }} @@ -111,10 +107,8 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = ( """ -) -DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = ( - """ +DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -122,9 +116,7 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = ( {{ vpc_pcx.id }} - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.vpc.id }} {{ vpc_pcx.vpc.cidr_block }} {{ vpc_pcx.vpc.ec2_backend.region_name }} @@ -135,9 +127,7 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = ( - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.peer_vpc.id }} {{ vpc_pcx.peer_vpc.cidr_block }} {{ vpc_pcx.peer_vpc.ec2_backend.region_name }} @@ -164,7 +154,6 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = ( """ -) DELETE_VPC_PEERING_CONNECTION_RESPONSE = """ @@ -173,24 +162,19 @@ DELETE_VPC_PEERING_CONNECTION_RESPONSE = """ """ -ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = ( - """ +ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE {{ vpc_pcx.id }} - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.vpc.id }} {{ vpc_pcx.vpc.cidr_block }} {{ vpc_pcx.vpc.ec2_backend.region_name }} - """ - + get_account_id() - + """ + {{ account_id }} {{ vpc_pcx.peer_vpc.id }} {{ vpc_pcx.peer_vpc.cidr_block }} @@ -215,7 +199,6 @@ ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = ( """ -) REJECT_VPC_PEERING_CONNECTION_RESPONSE = """ diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 809c1c3b6..0d38527e0 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.core.utils import camelcase_to_underscores from moto.ec2.utils import add_tag_specification from ._base_response import EC2BaseResponse @@ -237,7 +236,7 @@ class VPCs(EC2BaseResponse): ) template = self.response_template(DESCRIBE_VPC_ENDPOINT_RESPONSE) return template.render( - vpc_end_points=vpc_end_points, account_id=get_account_id() + vpc_end_points=vpc_end_points, account_id=self.current_account ) def delete_vpc_endpoints(self): diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index 367c183dd..4978ee47e 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -9,7 +9,6 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa -from moto.core import get_account_id from moto.iam import iam_backends from moto.utilities.utils import md5_hash @@ -335,9 +334,7 @@ def get_object_value(obj, attr): keys = attr.split(".") val = obj for key in keys: - if key == "owner_id": - return get_account_id() - elif hasattr(val, key): + if hasattr(val, key): val = getattr(val, key) elif isinstance(val, dict): val = val[key] @@ -346,6 +343,8 @@ def get_object_value(obj, attr): item_val = get_object_value(item, key) if item_val: return item_val + elif key == "owner_id" and hasattr(val, "account_id"): + val = getattr(val, "account_id") else: return None return val @@ -687,19 +686,21 @@ def filter_iam_instance_profile_associations(iam_instance_associations, filter_d return result -def filter_iam_instance_profiles(iam_instance_profile_arn, iam_instance_profile_name): +def filter_iam_instance_profiles( + account_id, iam_instance_profile_arn, iam_instance_profile_name +): instance_profile = None instance_profile_by_name = None instance_profile_by_arn = None if iam_instance_profile_name: - instance_profile_by_name = iam_backends["global"].get_instance_profile( - iam_instance_profile_name - ) + instance_profile_by_name = iam_backends[account_id][ + "global" + ].get_instance_profile(iam_instance_profile_name) instance_profile = instance_profile_by_name if iam_instance_profile_arn: - instance_profile_by_arn = iam_backends["global"].get_instance_profile_by_arn( - iam_instance_profile_arn - ) + instance_profile_by_arn = iam_backends[account_id][ + "global" + ].get_instance_profile_by_arn(iam_instance_profile_arn) instance_profile = instance_profile_by_arn # We would prefer instance profile that we found by arn if iam_instance_profile_arn and iam_instance_profile_name: diff --git a/moto/ec2instanceconnect/responses.py b/moto/ec2instanceconnect/responses.py index 9fce11aa2..9e76d5510 100644 --- a/moto/ec2instanceconnect/responses.py +++ b/moto/ec2instanceconnect/responses.py @@ -3,9 +3,12 @@ from .models import ec2instanceconnect_backends class Ec2InstanceConnectResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="ec2-instanceconnect") + @property def ec2instanceconnect_backend(self): - return ec2instanceconnect_backends[self.region] + return ec2instanceconnect_backends[self.current_account][self.region] def send_ssh_public_key(self): return self.ec2instanceconnect_backend.send_ssh_public_key() diff --git a/moto/ecr/models.py b/moto/ecr/models.py index c74c51819..ec7678074 100644 --- a/moto/ecr/models.py +++ b/moto/ecr/models.py @@ -9,7 +9,7 @@ from typing import Dict, List from botocore.exceptions import ParamValidationError -from moto.core import BaseBackend, BaseModel, CloudFormationModel, get_account_id +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import iso_8601_datetime_without_milliseconds, BackendDict from moto.ecr.exceptions import ( ImageNotFoundException, @@ -29,7 +29,6 @@ from moto.iam.exceptions import MalformedPolicyDocument from moto.iam.policy_validation import IAMPolicyDocumentValidator from moto.utilities.tagging_service import TaggingService -DEFAULT_REGISTRY_ID = get_account_id() ECR_REPOSITORY_ARN_PATTERN = "^arn:(?P[^:]+):ecr:(?P[^:]+):(?P[^:]+):repository/(?P.*)$" EcrRepositoryArn = namedtuple( @@ -64,6 +63,7 @@ class BaseObject(BaseModel): class Repository(BaseObject, CloudFormationModel): def __init__( self, + account_id, region_name, repository_name, registry_id, @@ -71,8 +71,9 @@ class Repository(BaseObject, CloudFormationModel): image_scan_config, image_tag_mutablility, ): + self.account_id = account_id self.region_name = region_name - self.registry_id = registry_id or DEFAULT_REGISTRY_ID + self.registry_id = registry_id or account_id self.arn = ( f"arn:aws:ecr:{region_name}:{self.registry_id}:repository/{repository_name}" ) @@ -96,7 +97,7 @@ class Repository(BaseObject, CloudFormationModel): if encryption_config == {"encryptionType": "KMS"}: encryption_config[ "kmsKey" - ] = f"arn:aws:kms:{self.region_name}:{get_account_id()}:key/{uuid.uuid4()}" + ] = f"arn:aws:kms:{self.region_name}:{self.account_id}:key/{uuid.uuid4()}" return encryption_config def _get_image(self, image_tag, image_digest): @@ -148,8 +149,8 @@ class Repository(BaseObject, CloudFormationModel): if image_tag_mutability: self.image_tag_mutability = image_tag_mutability - def delete(self, region_name): - ecr_backend = ecr_backends[region_name] + def delete(self, account_id, region_name): + ecr_backend = ecr_backends[account_id][region_name] ecr_backend.delete_repository(self.name) @classmethod @@ -177,9 +178,9 @@ class Repository(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - ecr_backend = ecr_backends[region_name] + ecr_backend = ecr_backends[account_id][region_name] properties = cloudformation_json["Properties"] encryption_config = properties.get("EncryptionConfiguration") @@ -200,9 +201,14 @@ class Repository(BaseObject, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): - ecr_backend = ecr_backends[region_name] + ecr_backend = ecr_backends[account_id][region_name] properties = cloudformation_json["Properties"] encryption_configuration = properties.get( "EncryptionConfiguration", {"encryptionType": "AES256"} @@ -223,22 +229,22 @@ class Repository(BaseObject, CloudFormationModel): return original_resource else: - original_resource.delete(region_name) + original_resource.delete(account_id, region_name) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) class Image(BaseObject): def __init__( - self, tag, manifest, repository, digest=None, registry_id=DEFAULT_REGISTRY_ID + self, account_id, tag, manifest, repository, digest=None, registry_id=None ): self.image_tag = tag self.image_tags = [tag] if tag is not None else [] self.image_manifest = manifest self.image_size_in_bytes = 50 * 1024 * 1024 self.repository = repository - self.registry_id = registry_id + self.registry_id = registry_id or account_id self.image_digest = digest self.image_pushed_at = str(datetime.now(timezone.utc).isoformat()) self.last_scan = None @@ -359,7 +365,7 @@ class ECRBackend(BaseBackend): def _get_repository(self, name, registry_id=None) -> Repository: repo = self.repositories.get(name) - reg_id = registry_id or DEFAULT_REGISTRY_ID + reg_id = registry_id or self.account_id if not repo or repo.registry_id != reg_id: raise RepositoryNotFoundException(name, reg_id) @@ -383,7 +389,7 @@ class ECRBackend(BaseBackend): for repository_name in repository_names: if repository_name not in self.repositories: raise RepositoryNotFoundException( - repository_name, registry_id or DEFAULT_REGISTRY_ID + repository_name, registry_id or self.account_id ) repositories = [] @@ -410,9 +416,10 @@ class ECRBackend(BaseBackend): tags, ): if self.repositories.get(repository_name): - raise RepositoryAlreadyExistsException(repository_name, DEFAULT_REGISTRY_ID) + raise RepositoryAlreadyExistsException(repository_name, self.account_id) repository = Repository( + account_id=self.account_id, region_name=self.region_name, repository_name=repository_name, registry_id=registry_id, @@ -430,7 +437,7 @@ class ECRBackend(BaseBackend): if repo.images and not force: raise RepositoryNotEmptyException( - repository_name, registry_id or DEFAULT_REGISTRY_ID + repository_name, registry_id or self.account_id ) self.tagger.delete_all_tags_for_resource(repo.arn) @@ -452,7 +459,7 @@ class ECRBackend(BaseBackend): if not found: raise RepositoryNotFoundException( - repository_name, registry_id or DEFAULT_REGISTRY_ID + repository_name, registry_id or self.account_id ) images = [] @@ -492,7 +499,7 @@ class ECRBackend(BaseBackend): ) if not existing_images: # this image is not in ECR yet - image = Image(image_tag, image_manifest, repository_name) + image = Image(self.account_id, image_tag, image_manifest, repository_name) repository.images.append(image) return image else: @@ -508,7 +515,7 @@ class ECRBackend(BaseBackend): repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( - repository_name, registry_id or DEFAULT_REGISTRY_ID + repository_name, registry_id or self.account_id ) if not image_ids: @@ -546,7 +553,7 @@ class ECRBackend(BaseBackend): repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( - repository_name, registry_id or DEFAULT_REGISTRY_ID + repository_name, registry_id or self.account_id ) if not image_ids: @@ -822,28 +829,28 @@ class ECRBackend(BaseBackend): self.registry_policy = policy_text return { - "registryId": get_account_id(), + "registryId": self.account_id, "policyText": policy_text, } def get_registry_policy(self): if not self.registry_policy: - raise RegistryPolicyNotFoundException(get_account_id()) + raise RegistryPolicyNotFoundException(self.account_id) return { - "registryId": get_account_id(), + "registryId": self.account_id, "policyText": self.registry_policy, } def delete_registry_policy(self): policy = self.registry_policy if not policy: - raise RegistryPolicyNotFoundException(get_account_id()) + raise RegistryPolicyNotFoundException(self.account_id) self.registry_policy = None return { - "registryId": get_account_id(), + "registryId": self.account_id, "policyText": policy, } @@ -931,7 +938,7 @@ class ECRBackend(BaseBackend): for dest in rules[0]["destinations"]: if ( dest["region"] == self.region_name - and dest["registryId"] == DEFAULT_REGISTRY_ID + and dest["registryId"] == self.account_id ): raise InvalidParameterException( "Invalid parameter at 'replicationConfiguration' failed to satisfy constraint: " @@ -944,7 +951,7 @@ class ECRBackend(BaseBackend): def describe_registry(self): return { - "registryId": DEFAULT_REGISTRY_ID, + "registryId": self.account_id, "replicationConfiguration": self.replication_config, } diff --git a/moto/ecr/responses.py b/moto/ecr/responses.py index 8cfec7424..948d88130 100644 --- a/moto/ecr/responses.py +++ b/moto/ecr/responses.py @@ -4,13 +4,16 @@ from datetime import datetime import time from moto.core.responses import BaseResponse -from .models import ecr_backends, DEFAULT_REGISTRY_ID +from .models import ecr_backends class ECRResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="ecr") + @property def ecr_backend(self): - return ecr_backends[self.region] + return ecr_backends[self.current_account][self.region] @property def request_params(self): @@ -131,7 +134,7 @@ class ECRResponse(BaseResponse): def get_authorization_token(self): registry_ids = self._get_param("registryIds") if not registry_ids: - registry_ids = [DEFAULT_REGISTRY_ID] + registry_ids = [self.current_account] auth_data = [] for registry_id in registry_ids: password = "{}-auth-token".format(registry_id) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 6c906966f..1b53cd14b 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -7,7 +7,7 @@ from random import random, randint import pytz from moto import settings -from moto.core import BaseBackend, BaseModel, CloudFormationModel, get_account_id +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import JsonRESTError from moto.core.utils import ( unix_time, @@ -61,11 +61,9 @@ class AccountSetting(BaseObject): class Cluster(BaseObject, CloudFormationModel): - def __init__(self, cluster_name, region_name, cluster_settings=None): + def __init__(self, cluster_name, account_id, region_name, cluster_settings=None): self.active_services_count = 0 - self.arn = "arn:aws:ecs:{0}:{1}:cluster/{2}".format( - region_name, get_account_id(), cluster_name - ) + self.arn = f"arn:aws:ecs:{region_name}:{account_id}:cluster/{cluster_name}" self.name = cluster_name self.pending_tasks_count = 0 self.registered_container_instances_count = 0 @@ -97,9 +95,9 @@ class Cluster(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] return ecs_backend.create_cluster( # ClusterName is optional in CloudFormation, thus create a random # name if necessary @@ -108,10 +106,15 @@ class Cluster(BaseObject, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): if original_resource.name != new_resource_name: - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] ecs_backend.delete_cluster(original_resource.arn) return ecs_backend.create_cluster( # ClusterName is optional in CloudFormation, thus create a @@ -140,6 +143,7 @@ class TaskDefinition(BaseObject, CloudFormationModel): family, revision, container_definitions, + account_id, region_name, network_mode=None, volumes=None, @@ -153,9 +157,7 @@ class TaskDefinition(BaseObject, CloudFormationModel): ): self.family = family self.revision = revision - self.arn = "arn:aws:ecs:{0}:{1}:task-definition/{2}:{3}".format( - region_name, get_account_id(), family, revision - ) + self.arn = f"arn:aws:ecs:{region_name}:{account_id}:task-definition/{family}:{revision}" default_container_definition = { "cpu": 0, @@ -236,7 +238,7 @@ class TaskDefinition(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] @@ -248,14 +250,19 @@ class TaskDefinition(BaseObject, CloudFormationModel): ) volumes = remap_nested_keys(properties.get("Volumes", []), pascal_to_camelcase) - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] return ecs_backend.register_task_definition( family=family, container_definitions=container_definitions, volumes=volumes ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] family = properties.get( @@ -270,7 +277,7 @@ class TaskDefinition(BaseObject, CloudFormationModel): ): # currently TaskRoleArn isn't stored at TaskDefinition # instances - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] ecs_backend.deregister_task_definition(original_resource.arn) return ecs_backend.register_task_definition( family=family, @@ -310,15 +317,14 @@ class Task(BaseObject): self.stopped_reason = "" self.resource_requirements = resource_requirements self.region_name = cluster.region_name + self._account_id = backend.account_id self._backend = backend @property def task_arn(self): if self._backend.enable_long_arn_for_name(name="taskLongArnFormat"): - return f"arn:aws:ecs:{self.region_name}:{get_account_id()}:task/{self.cluster_name}/{self.id}" - return "arn:aws:ecs:{0}:{1}:task/{2}".format( - self.region_name, get_account_id(), self.id - ) + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.cluster_name}/{self.id}" + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.id}" @property def response_object(self): @@ -328,9 +334,9 @@ class Task(BaseObject): class CapacityProvider(BaseObject): - def __init__(self, region_name, name, asg_details, tags): + def __init__(self, account_id, region_name, name, asg_details, tags): self._id = str(uuid.uuid4()) - self.capacity_provider_arn = f"arn:aws:ecs:{region_name}:{get_account_id()}:capacity_provider/{name}/{self._id}" + self.capacity_provider_arn = f"arn:aws:ecs:{region_name}:{account_id}:capacity_provider/{name}/{self._id}" self.name = name self.status = "ACTIVE" self.auto_scaling_group_provider = asg_details @@ -338,11 +344,9 @@ class CapacityProvider(BaseObject): class CapacityProviderFailure(BaseObject): - def __init__(self, reason, name, region_name): + def __init__(self, reason, name, account_id, region_name): self.reason = reason - self.arn = "arn:aws:ecs:{0}:{1}:capacity_provider/{2}".format( - region_name, get_account_id(), name - ) + self.arn = f"arn:aws:ecs:{region_name}:{account_id}:capacity_provider/{name}" @property def response_object(self): @@ -405,15 +409,14 @@ class Service(BaseObject, CloudFormationModel): self.tags = tags if tags is not None else [] self.pending_count = 0 self.region_name = cluster.region_name + self._account_id = backend.account_id self._backend = backend @property def arn(self): if self._backend.enable_long_arn_for_name(name="serviceLongArnFormat"): - return f"arn:aws:ecs:{self.region_name}:{get_account_id()}:service/{self.cluster_name}/{self.name}" - return "arn:aws:ecs:{0}:{1}:service/{2}".format( - self.region_name, get_account_id(), self.name - ) + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:service/{self.cluster_name}/{self.name}" + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:service/{self.name}" @property def physical_resource_id(self): @@ -457,7 +460,7 @@ class Service(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] if isinstance(properties["Cluster"], Cluster): @@ -472,14 +475,19 @@ class Service(BaseObject, CloudFormationModel): # TODO: LoadBalancers # TODO: Role - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] return ecs_backend.create_service( cluster, resource_name, desired_count, task_definition_str=task_definition ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] if isinstance(properties["Cluster"], Cluster): @@ -492,9 +500,12 @@ class Service(BaseObject, CloudFormationModel): task_definition = properties["TaskDefinition"] desired_count = properties.get("DesiredCount", None) - ecs_backend = ecs_backends[region_name] + ecs_backend = ecs_backends[account_id][region_name] service_name = original_resource.name - if original_resource.cluster_arn != Cluster(cluster_name, region_name).arn: + if ( + original_resource.cluster_arn + != Cluster(cluster_name, account_id, region_name).arn + ): # TODO: LoadBalancers # TODO: Role ecs_backend.delete_service(cluster_name, service_name) @@ -522,7 +533,7 @@ class Service(BaseObject, CloudFormationModel): class ContainerInstance(BaseObject): - def __init__(self, ec2_instance_id, region_name, cluster_name, backend): + def __init__(self, ec2_instance_id, account_id, region_name, cluster_name, backend): self.ec2_instance_id = ec2_instance_id self.agent_connected = True self.status = "ACTIVE" @@ -597,7 +608,7 @@ class ContainerInstance(BaseObject): "agentHash": "4023248", "dockerVersion": "DockerVersion: 1.5.0", } - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] ec2_instance = ec2_backend.get_instance(ec2_instance_id) self.attributes = { "ecs.ami-id": ec2_instance.image_id, @@ -611,6 +622,7 @@ class ContainerInstance(BaseObject): self.region_name = region_name self.id = str(uuid.uuid4()) self.cluster_name = cluster_name + self._account_id = backend.account_id self._backend = backend @property @@ -618,8 +630,8 @@ class ContainerInstance(BaseObject): if self._backend.enable_long_arn_for_name( name="containerInstanceLongArnFormat" ): - return f"arn:aws:ecs:{self.region_name}:{get_account_id()}:container-instance/{self.cluster_name}/{self.id}" - return f"arn:aws:ecs:{self.region_name}:{get_account_id()}:container-instance/{self.id}" + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:container-instance/{self.cluster_name}/{self.id}" + return f"arn:aws:ecs:{self.region_name}:{self._account_id}:container-instance/{self.id}" @property def response_object(self): @@ -643,11 +655,9 @@ class ContainerInstance(BaseObject): class ClusterFailure(BaseObject): - def __init__(self, reason, cluster_name, region_name): + def __init__(self, reason, cluster_name, account_id, region_name): self.reason = reason - self.arn = "arn:aws:ecs:{0}:{1}:cluster/{2}".format( - region_name, get_account_id(), cluster_name - ) + self.arn = f"arn:aws:ecs:{region_name}:{account_id}:cluster/{cluster_name}" @property def response_object(self): @@ -658,11 +668,9 @@ class ClusterFailure(BaseObject): class ContainerInstanceFailure(BaseObject): - def __init__(self, reason, container_instance_id, region_name): + def __init__(self, reason, container_instance_id, account_id, region_name): self.reason = reason - self.arn = "arn:aws:ecs:{0}:{1}:container-instance/{2}".format( - region_name, get_account_id(), container_instance_id - ) + self.arn = f"arn:aws:ecs:{region_name}:{account_id}:container-instance/{container_instance_id}" @property def response_object(self): @@ -678,6 +686,7 @@ class TaskSet(BaseObject): service, cluster, task_definition, + account_id, region_name, external_id=None, network_configuration=None, @@ -715,9 +724,7 @@ class TaskSet(BaseObject): cluster_name = self.cluster.split("/")[-1] service_name = self.service.split("/")[-1] - self.task_set_arn = "arn:aws:ecs:{0}:{1}:task-set/{2}/{3}/{4}".format( - region_name, get_account_id(), cluster_name, service_name, self.id - ) + self.task_set_arn = f"arn:aws:ecs:{region_name}:{account_id}:task-set/{cluster_name}/{service_name}/{self.id}" @property def response_object(self): @@ -780,7 +787,9 @@ class EC2ContainerServiceBackend(BaseBackend): return cluster def create_capacity_provider(self, name, asg_details, tags): - capacity_provider = CapacityProvider(self.region_name, name, asg_details, tags) + capacity_provider = CapacityProvider( + self.account_id, self.region_name, name, asg_details, tags + ) self.capacity_providers[name] = capacity_provider if tags: self.tagger.tag_resource(capacity_provider.capacity_provider_arn, tags) @@ -807,7 +816,9 @@ class EC2ContainerServiceBackend(BaseBackend): """ The following parameters are not yet implemented: configuration, capacityProviders, defaultCapacityProviderStrategy """ - cluster = Cluster(cluster_name, self.region_name, cluster_settings) + cluster = Cluster( + cluster_name, self.account_id, self.region_name, cluster_settings + ) self.clusters[cluster_name] = cluster if tags: self.tagger.tag_resource(cluster.arn, tags) @@ -830,7 +841,9 @@ class EC2ContainerServiceBackend(BaseBackend): providers.append(provider) else: failures.append( - CapacityProviderFailure("MISSING", name, self.region_name) + CapacityProviderFailure( + "MISSING", name, self.account_id, self.region_name + ) ) return providers, failures @@ -861,7 +874,9 @@ class EC2ContainerServiceBackend(BaseBackend): list_clusters.append(self.clusters[cluster_name].response_object) else: failures.append( - ClusterFailure("MISSING", cluster_name, self.region_name) + ClusterFailure( + "MISSING", cluster_name, self.account_id, self.region_name + ) ) if "TAGS" in (include or []): @@ -902,6 +917,7 @@ class EC2ContainerServiceBackend(BaseBackend): family, revision, container_definitions, + self.account_id, self.region_name, volumes=volumes, network_mode=network_mode, @@ -1304,7 +1320,7 @@ class EC2ContainerServiceBackend(BaseBackend): result.append(self.services[cluster_service_pair]) else: missing_arn = ( - f"arn:aws:ecs:{self.region_name}:{get_account_id()}:service/{name}" + f"arn:aws:ecs:{self.region_name}:{self.account_id}:service/{name}" ) failures.append({"arn": missing_arn, "reason": "MISSING"}) @@ -1348,7 +1364,11 @@ class EC2ContainerServiceBackend(BaseBackend): if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) container_instance = ContainerInstance( - ec2_instance_id, self.region_name, cluster_name, backend=self + ec2_instance_id, + self.account_id, + self.region_name, + cluster_name, + backend=self, ) if not self.container_instances.get(cluster_name): self.container_instances[cluster_name] = {} @@ -1386,7 +1406,10 @@ class EC2ContainerServiceBackend(BaseBackend): else: failures.append( ContainerInstanceFailure( - "MISSING", container_instance_id, self.region_name + "MISSING", + container_instance_id, + self.account_id, + self.region_name, ) ) @@ -1417,7 +1440,10 @@ class EC2ContainerServiceBackend(BaseBackend): else: failures.append( ContainerInstanceFailure( - "MISSING", container_instance_id, self.region_name + "MISSING", + container_instance_id, + self.account_id, + self.region_name, ) ) @@ -1722,6 +1748,7 @@ class EC2ContainerServiceBackend(BaseBackend): service, cluster_str, task_definition, + self.account_id, self.region_name, external_id=external_id, network_configuration=network_configuration, diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index 66d5c7a28..b835e7e21 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -5,6 +5,9 @@ from .models import ecs_backends class EC2ContainerServiceResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="ecs") + @property def ecs_backend(self): """ @@ -13,7 +16,7 @@ class EC2ContainerServiceResponse(BaseResponse): :return: ECS Backend object :rtype: moto.ecs.models.EC2ContainerServiceBackend """ - return ecs_backends[self.region] + return ecs_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/efs/models.py b/moto/efs/models.py index 77b9489a0..3fc04ff51 100644 --- a/moto/efs/models.py +++ b/moto/efs/models.py @@ -8,7 +8,7 @@ import json import time from copy import deepcopy -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import ( camelcase_to_underscores, get_random_hex, @@ -34,9 +34,9 @@ from moto.utilities.tagging_service import TaggingService from moto.utilities.utils import md5_hash -def _lookup_az_id(az_name): +def _lookup_az_id(account_id, az_name): """Find the Availability zone ID given the AZ name.""" - ec2 = ec2_backends[az_name[:-1]] + ec2 = ec2_backends[account_id][az_name[:-1]] for zone in ec2.describe_availability_zones(): if zone.name == az_name: return zone.zone_id @@ -45,6 +45,7 @@ def _lookup_az_id(az_name): class AccessPoint(BaseModel): def __init__( self, + account_id, region_name, client_token, file_system_id, @@ -54,15 +55,12 @@ class AccessPoint(BaseModel): context, ): self.access_point_id = get_random_hex(8) - self.access_point_arn = "arn:aws:elasticfilesystem:{region}:{user_id}:access-point/fsap-{file_system_id}".format( - region=region_name, - user_id=get_account_id(), - file_system_id=self.access_point_id, - ) + self.access_point_arn = f"arn:aws:elasticfilesystem:{region_name}:{account_id}:access-point/fsap-{self.access_point_id}" self.client_token = client_token self.file_system_id = file_system_id self.name = name self.posix_user = posix_user + self.account_id = account_id if not root_directory: root_directory = {"Path": "/"} @@ -81,7 +79,7 @@ class AccessPoint(BaseModel): "FileSystemId": self.file_system_id, "PosixUser": self.posix_user, "RootDirectory": self.root_directory, - "OwnerId": get_account_id(), + "OwnerId": self.account_id, "LifeCycleState": "available", } @@ -91,6 +89,7 @@ class FileSystem(CloudFormationModel): def __init__( self, + account_id, region_name, creation_token, file_system_id, @@ -120,7 +119,9 @@ class FileSystem(CloudFormationModel): self.availability_zone_name = availability_zone_name self.availability_zone_id = None if self.availability_zone_name: - self.availability_zone_id = _lookup_az_id(self.availability_zone_name) + self.availability_zone_id = _lookup_az_id( + account_id, self.availability_zone_name + ) self._backup = backup self.lifecycle_policies = lifecycle_policies or [] self.file_system_policy = file_system_policy @@ -129,13 +130,9 @@ class FileSystem(CloudFormationModel): # Generate AWS-assigned parameters self.file_system_id = file_system_id - self.file_system_arn = "arn:aws:elasticfilesystem:{region}:{user_id}:file-system/{file_system_id}".format( - region=region_name, - user_id=get_account_id(), - file_system_id=self.file_system_id, - ) + self.file_system_arn = f"arn:aws:elasticfilesystem:{region_name}:{account_id}:file-system/{self.file_system_id}" self.creation_time = time.time() - self.owner_id = get_account_id() + self.owner_id = account_id # Initialize some state parameters self.life_cycle_state = "available" @@ -220,7 +217,7 @@ class FileSystem(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-efs-filesystem.html props = deepcopy(cloudformation_json["Properties"]) @@ -240,11 +237,18 @@ class FileSystem(CloudFormationModel): "supported by AWS Cloudformation." ) - return efs_backends[region_name].create_file_system(resource_name, **props) + return efs_backends[account_id][region_name].create_file_system( + resource_name, **props + ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): raise NotImplementedError( "Update of EFS File System via cloudformation is not yet implemented." @@ -252,15 +256,15 @@ class FileSystem(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - return efs_backends[region_name].delete_file_system(resource_name) + return efs_backends[account_id][region_name].delete_file_system(resource_name) class MountTarget(CloudFormationModel): """A model for an EFS Mount Target.""" - def __init__(self, file_system, subnet, ip_address, security_groups): + def __init__(self, account_id, file_system, subnet, ip_address, security_groups): # Set the simple given parameters. self.file_system_id = file_system.file_system_id self._file_system = file_system @@ -292,7 +296,7 @@ class MountTarget(CloudFormationModel): self.ip_address = ip_address # Init non-user-assigned values. - self.owner_id = get_account_id() + self.owner_id = account_id self.mount_target_id = "fsmt-{}".format(get_random_hex()) self.life_cycle_state = "available" self.network_interface_id = None @@ -332,16 +336,21 @@ class MountTarget(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-efs-mounttarget.html props = deepcopy(cloudformation_json["Properties"]) props = {camelcase_to_underscores(k): v for k, v in props.items()} - return efs_backends[region_name].create_mount_target(**props) + return efs_backends[account_id][region_name].create_mount_target(**props) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): raise NotImplementedError( "Updates of EFS Mount Target via cloudformation are not yet implemented." @@ -349,9 +358,9 @@ class MountTarget(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - return efs_backends[region_name].delete_mount_target(resource_name) + return efs_backends[account_id][region_name].delete_mount_target(resource_name) class EFSBackend(BaseBackend): @@ -384,7 +393,7 @@ class EFSBackend(BaseBackend): @property def ec2_backend(self): - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] def create_file_system( self, @@ -415,6 +424,7 @@ class EFSBackend(BaseBackend): while fsid in self.file_systems_by_id: fsid = make_id() self.file_systems_by_id[fsid] = FileSystem( + self.account_id, self.region_name, creation_token, fsid, @@ -495,7 +505,9 @@ class EFSBackend(BaseBackend): raise SecurityGroupNotFound(sg_id) # Create the new mount target - mount_target = MountTarget(file_system, subnet, ip_address, security_groups) + mount_target = MountTarget( + self.account_id, file_system, subnet, ip_address, security_groups + ) # Establish the network interface. network_interface = self.ec2_backend.create_network_interface( @@ -625,6 +637,7 @@ class EFSBackend(BaseBackend): ): name = next((tag["Value"] for tag in tags if tag["Key"] == "Name"), None) access_point = AccessPoint( + self.account_id, self.region_name, client_token, file_system_id, diff --git a/moto/efs/responses.py b/moto/efs/responses.py index 4fb350772..156bfed35 100644 --- a/moto/efs/responses.py +++ b/moto/efs/responses.py @@ -6,11 +6,12 @@ from .models import efs_backends class EFSResponse(BaseResponse): - SERVICE_NAME = "efs" + def __init__(self): + super().__init__(service_name="efs") @property def efs_backend(self): - return efs_backends[self.region] + return efs_backends[self.current_account][self.region] def create_file_system(self): creation_token = self._get_param("CreationToken") diff --git a/moto/eks/models.py b/moto/eks/models.py index f277344d5..9b58aab70 100644 --- a/moto/eks/models.py +++ b/moto/eks/models.py @@ -1,7 +1,7 @@ from datetime import datetime from uuid import uuid4 -from moto.core import get_account_id, BaseBackend +from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds, BackendDict from ..utilities.utils import random_string @@ -14,19 +14,9 @@ from .exceptions import ( from .utils import get_partition, validate_role_arn # String Templates -CLUSTER_ARN_TEMPLATE = ( - "arn:{partition}:eks:{region}:" + str(get_account_id()) + ":cluster/{name}" -) -FARGATE_PROFILE_ARN_TEMPLATE = ( - "arn:{partition}:eks:{region}:" - + str(get_account_id()) - + ":fargateprofile/{cluster_name}/{fargate_profile_name}/{uuid}" -) -NODEGROUP_ARN_TEMPLATE = ( - "arn:{partition}:eks:{region}:" - + str(get_account_id()) - + ":nodegroup/{cluster_name}/{nodegroup_name}/{uuid}" -) +CLUSTER_ARN_TEMPLATE = "arn:{partition}:eks:{region}:{account_id}:cluster/{name}" +FARGATE_PROFILE_ARN_TEMPLATE = "arn:{partition}:eks:{region}:{account_id}:fargateprofile/{cluster_name}/{fargate_profile_name}/{uuid}" +NODEGROUP_ARN_TEMPLATE = "arn:{partition}:eks:{region}:{account_id}:nodegroup/{cluster_name}/{nodegroup_name}/{uuid}" ISSUER_TEMPLATE = "https://oidc.eks.{region}.amazonaws.com/id/" + random_string(10) ENDPOINT_TEMPLATE = ( "https://" @@ -103,6 +93,7 @@ class Cluster: name, role_arn, resources_vpc_config, + account_id, region_name, aws_partition, version=None, @@ -124,7 +115,10 @@ class Cluster: self.fargate_profile_count = 0 self.arn = CLUSTER_ARN_TEMPLATE.format( - partition=aws_partition, region=region_name, name=name + partition=aws_partition, + account_id=account_id, + region=region_name, + name=name, ) self.certificateAuthority = {"data": random_string(1400)} self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) @@ -175,6 +169,7 @@ class FargateProfile: fargate_profile_name, pod_execution_role_arn, selectors, + account_id, region_name, aws_partition, client_request_token=None, @@ -190,6 +185,7 @@ class FargateProfile: self.uuid = str(uuid4()) self.fargate_profile_arn = FARGATE_PROFILE_ARN_TEMPLATE.format( partition=aws_partition, + account_id=account_id, region=region_name, cluster_name=cluster_name, fargate_profile_name=fargate_profile_name, @@ -224,6 +220,7 @@ class ManagedNodegroup: node_role, nodegroup_name, subnets, + account_id, region_name, aws_partition, scaling_config=None, @@ -250,6 +247,7 @@ class ManagedNodegroup: self.uuid = str(uuid4()) self.arn = NODEGROUP_ARN_TEMPLATE.format( partition=aws_partition, + account_id=account_id, region=region_name, cluster_name=cluster_name, nodegroup_name=nodegroup_name, @@ -349,6 +347,7 @@ class EKSBackend(BaseBackend): client_request_token=client_request_token, tags=tags, encryption_config=encryption_config, + account_id=self.account_id, region_name=self.region_name, aws_partition=self.partition, ) @@ -401,6 +400,7 @@ class EKSBackend(BaseBackend): selectors=selectors, subnets=subnets, tags=tags, + account_id=self.account_id, region_name=self.region_name, aws_partition=self.partition, ) @@ -477,6 +477,7 @@ class EKSBackend(BaseBackend): capacity_type=capacity_type, version=version, release_version=release_version, + account_id=self.account_id, region_name=self.region_name, aws_partition=self.partition, ) diff --git a/moto/eks/responses.py b/moto/eks/responses.py index 21f7e0672..c0deb4845 100644 --- a/moto/eks/responses.py +++ b/moto/eks/responses.py @@ -9,11 +9,12 @@ DEFAULT_NEXT_TOKEN = "" class EKSResponse(BaseResponse): - SERVICE_NAME = "eks" + def __init__(self): + super().__init__(service_name="eks") @property def eks_backend(self): - return eks_backends[self.region] + return eks_backends[self.current_account][self.region] def create_cluster(self): name = self._get_param("name") diff --git a/moto/elasticache/models.py b/moto/elasticache/models.py index e1dade082..e1fdd0809 100644 --- a/moto/elasticache/models.py +++ b/moto/elasticache/models.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from .exceptions import UserAlreadyExists, UserNotFound @@ -7,6 +7,7 @@ from .exceptions import UserAlreadyExists, UserNotFound class User(BaseModel): def __init__( self, + account_id, region, user_id, user_name, @@ -25,10 +26,7 @@ class User(BaseModel): self.minimum_engine_version = "6.0" self.usergroupids = [] self.region = region - - @property - def arn(self): - return f"arn:aws:elasticache:{self.region}:{get_account_id()}:user:{self.id}" + self.arn = f"arn:aws:elasticache:{self.region}:{account_id}:user:{self.id}" class ElastiCacheBackend(BaseBackend): @@ -38,6 +36,7 @@ class ElastiCacheBackend(BaseBackend): super().__init__(region_name, account_id) self.users = dict() self.users["default"] = User( + account_id=self.account_id, region=self.region_name, user_id="default", user_name="default", @@ -52,6 +51,7 @@ class ElastiCacheBackend(BaseBackend): if user_id in self.users: raise UserAlreadyExists user = User( + account_id=self.account_id, region=self.region_name, user_id=user_id, user_name=user_name, diff --git a/moto/elasticache/responses.py b/moto/elasticache/responses.py index b63e029e8..087ad6814 100644 --- a/moto/elasticache/responses.py +++ b/moto/elasticache/responses.py @@ -6,10 +6,13 @@ from .models import elasticache_backends class ElastiCacheResponse(BaseResponse): """Handler for ElastiCache requests and responses.""" + def __init__(self): + super().__init__(service_name="elasticache") + @property def elasticache_backend(self): """Return backend instance specific for this region.""" - return elasticache_backends[self.region] + return elasticache_backends[self.current_account][self.region] def create_user(self): params = self._get_params() diff --git a/moto/elasticbeanstalk/models.py b/moto/elasticbeanstalk/models.py index 82388ad4d..ac4fe163d 100644 --- a/moto/elasticbeanstalk/models.py +++ b/moto/elasticbeanstalk/models.py @@ -1,6 +1,6 @@ import weakref -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from .exceptions import InvalidParameterValueError, ResourceNotFoundException from .utils import make_arn @@ -22,7 +22,9 @@ class FakeEnvironment(BaseModel): @property def environment_arn(self): resource_path = "%s/%s" % (self.application_name, self.environment_name) - return make_arn(self.region, get_account_id(), "environment", resource_path) + return make_arn( + self.region, self.application.account_id, "environment", resource_path + ) @property def platform_arn(self): @@ -38,6 +40,11 @@ class FakeApplication(BaseModel): self.backend = weakref.proxy(backend) # weakref to break cycles self.application_name = application_name self.environments = dict() + self.account_id = self.backend.account_id + self.region = self.backend.region_name + self.arn = make_arn( + self.region, self.account_id, "application", self.application_name + ) def create_environment(self, environment_name, solution_stack_name, tags): if environment_name in self.environments: @@ -53,16 +60,6 @@ class FakeApplication(BaseModel): return env - @property - def region(self): - return self.backend.region_name - - @property - def arn(self): - return make_arn( - self.region, get_account_id(), "application", self.application_name - ) - class EBBackend(BaseBackend): def __init__(self, region_name, account_id): diff --git a/moto/elasticbeanstalk/responses.py b/moto/elasticbeanstalk/responses.py index adc20497b..74d345a6c 100644 --- a/moto/elasticbeanstalk/responses.py +++ b/moto/elasticbeanstalk/responses.py @@ -5,12 +5,15 @@ from .exceptions import InvalidParameterValueError class EBResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="elasticbeanstalk") + @property def backend(self): """ :rtype: EBBackend """ - return eb_backends[self.region] + return eb_backends[self.current_account][self.region] def create_application(self): app = self.backend.create_application( diff --git a/moto/elastictranscoder/models.py b/moto/elastictranscoder/models.py index 18ffb5c67..01ab2d740 100644 --- a/moto/elastictranscoder/models.py +++ b/moto/elastictranscoder/models.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict import random import string @@ -7,6 +7,7 @@ import string class Pipeline(BaseModel): def __init__( self, + account_id, region, name, input_bucket, @@ -19,9 +20,7 @@ class Pipeline(BaseModel): b = "".join(random.choice(string.ascii_lowercase) for _ in range(6)) self.id = "{}-{}".format(a, b) self.name = name - self.arn = "arn:aws:elastictranscoder:{}:{}:pipeline/{}".format( - region, get_account_id(), self.id - ) + self.arn = f"arn:aws:elastictranscoder:{region}:{account_id}:pipeline/{self.id}" self.status = "Active" self.input_bucket = input_bucket self.output_bucket = output_bucket or content_config["Bucket"] @@ -80,6 +79,7 @@ class ElasticTranscoderBackend(BaseBackend): AWSKMSKeyArn, Notifications """ pipeline = Pipeline( + self.account_id, self.region_name, name, input_bucket, diff --git a/moto/elastictranscoder/responses.py b/moto/elastictranscoder/responses.py index 0a9dea0d2..2dcdd867f 100644 --- a/moto/elastictranscoder/responses.py +++ b/moto/elastictranscoder/responses.py @@ -1,4 +1,3 @@ -from moto.core import get_account_id from moto.core.responses import BaseResponse from .models import elastictranscoder_backends import json @@ -6,11 +5,12 @@ import re class ElasticTranscoderResponse(BaseResponse): - SERVICE_NAME = "elastictranscoder" + def __init__(self): + super().__init__(service_name="elastictranscoder") @property def elastictranscoder_backend(self): - return elastictranscoder_backends[self.region] + return elastictranscoder_backends[self.current_account][self.region] def pipelines(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -83,9 +83,7 @@ class ElasticTranscoderResponse(BaseResponse): self.elastictranscoder_backend.read_pipeline(pipeline_id) return None except KeyError: - msg = "The specified pipeline was not found: account={}, pipelineId={}.".format( - get_account_id(), pipeline_id - ) + msg = f"The specified pipeline was not found: account={self.current_account}, pipelineId={pipeline_id}." return ( 404, {"status": 404, "x-amzn-ErrorType": "ResourceNotFoundException"}, diff --git a/moto/elb/models.py b/moto/elb/models.py index 4290d7567..f5487860d 100644 --- a/moto/elb/models.py +++ b/moto/elb/models.py @@ -138,11 +138,11 @@ class FakeLoadBalancer(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - elb_backend = elb_backends[region_name] + elb_backend = elb_backends[account_id][region_name] new_elb = elb_backend.create_load_balancer( name=properties.get("LoadBalancerName", resource_name), zones=properties.get("AvailabilityZones", []), @@ -186,20 +186,25 @@ class FakeLoadBalancer(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - elb_backend = elb_backends[region_name] + elb_backend = elb_backends[account_id][region_name] try: elb_backend.delete_load_balancer(resource_name) except KeyError: @@ -264,9 +269,9 @@ class FakeLoadBalancer(CloudFormationModel): if key in self.tags: del self.tags[key] - def delete(self, region): + def delete(self, account_id, region): """Not exposed as part of the ELB API - used for CloudFormation.""" - elb_backends[region].delete_load_balancer(self.name) + elb_backends[account_id][region].delete_load_balancer(self.name) class ELBBackend(BaseBackend): @@ -284,7 +289,7 @@ class ELBBackend(BaseBackend): security_groups=None, ): vpc_id = None - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] if subnets: subnet = ec2_backend.get_subnet(subnets[0]) vpc_id = subnet.vpc_id @@ -379,7 +384,7 @@ class ELBBackend(BaseBackend): def describe_instance_health(self, lb_name, instances): provided_ids = [i["InstanceId"] for i in instances] registered_ids = self.get_load_balancer(lb_name).instance_ids - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] if len(provided_ids) == 0: provided_ids = registered_ids instances = [] @@ -423,7 +428,7 @@ class ELBBackend(BaseBackend): self, load_balancer_name, security_group_ids ): load_balancer = self.load_balancers.get(load_balancer_name) - ec2_backend = ec2_backends[self.region_name] + ec2_backend = ec2_backends[self.account_id][self.region_name] for security_group_id in security_group_ids: if ec2_backend.get_security_group_from_id(security_group_id) is None: raise InvalidSecurityGroupError() @@ -572,7 +577,7 @@ class ELBBackend(BaseBackend): def _register_certificate(self, ssl_certificate_id, dns_name): from moto.acm.models import acm_backends, AWSResourceNotFoundException - acm_backend = acm_backends[self.region_name] + acm_backend = acm_backends[self.account_id][self.region_name] try: acm_backend.set_certificate_in_use_by(ssl_certificate_id, dns_name) except AWSResourceNotFoundException: diff --git a/moto/elb/responses.py b/moto/elb/responses.py index 3c2389e2c..38a1bc68f 100644 --- a/moto/elb/responses.py +++ b/moto/elb/responses.py @@ -1,13 +1,15 @@ -from moto.core import get_account_id from moto.core.responses import BaseResponse from .models import elb_backends from .exceptions import DuplicateTagKeysError, LoadBalancerNotFoundError class ELBResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="elb") + @property def elb_backend(self): - return elb_backends[self.region] + return elb_backends[self.current_account][self.region] def create_load_balancer(self): load_balancer_name = self._get_param("LoadBalancerName") @@ -59,7 +61,7 @@ class ELBResponse(BaseResponse): template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE) return template.render( - ACCOUNT_ID=get_account_id(), + ACCOUNT_ID=self.current_account, load_balancers=load_balancers_resp, marker=next_marker, ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 7434e9887..2b30c3fa8 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -4,7 +4,7 @@ from jinja2 import Template from botocore.exceptions import ParamValidationError from collections import OrderedDict from moto.core.exceptions import RESTError -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import ( iso_8601_datetime_with_milliseconds, get_random_hex, @@ -172,11 +172,11 @@ class FakeTargetGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] vpc_id = properties.get("VpcId") protocol = properties.get("Protocol") @@ -268,11 +268,11 @@ class FakeListener(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] load_balancer_arn = properties.get("LoadBalancerArn") protocol = properties.get("Protocol") port = properties.get("Port") @@ -288,11 +288,16 @@ class FakeListener(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] protocol = properties.get("Protocol") port = properties.get("Port") ssl_policy = properties.get("SslPolicy") @@ -330,10 +335,10 @@ class FakeListenerRule(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] listener_arn = properties.get("ListenerArn") priority = properties.get("Priority") conditions = properties.get("Conditions") @@ -346,12 +351,17 @@ class FakeListenerRule(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] conditions = properties.get("Conditions") actions = elbv2_backend.convert_and_validate_action_properties(properties) @@ -563,9 +573,9 @@ class FakeLoadBalancer(CloudFormationModel): if self.state == "provisioning": self.state = "active" - def delete(self, region): + def delete(self, account_id, region): """Not exposed as part of the ELB API - used for CloudFormation.""" - elbv2_backends[region].delete_load_balancer(self.arn) + elbv2_backends[account_id][region].delete_load_balancer(self.arn) @staticmethod def cloudformation_name_type(): @@ -578,11 +588,11 @@ class FakeLoadBalancer(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - elbv2_backend = elbv2_backends[region_name] + elbv2_backend = elbv2_backends[account_id][region_name] security_groups = properties.get("SecurityGroups") subnet_ids = properties.get("Subnets") @@ -657,7 +667,7 @@ class ELBv2Backend(BaseBackend): :return: EC2 Backend :rtype: moto.ec2.models.EC2Backend """ - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] def create_load_balancer( self, @@ -689,7 +699,7 @@ class ELBv2Backend(BaseBackend): vpc_id = subnets[0].vpc_id arn = make_arn_for_load_balancer( - account_id=get_account_id(), name=name, region_name=self.region_name + account_id=self.account_id, name=name, region_name=self.region_name ) dns_name = "%s-1.%s.elb.amazonaws.com" % (name, self.region_name) @@ -1017,7 +1027,7 @@ Member must satisfy regular expression pattern: {}".format( ) arn = make_arn_for_target_group( - account_id=get_account_id(), name=name, region_name=self.region_name + account_id=self.account_id, name=name, region_name=self.region_name ) tags = kwargs.pop("tags", None) target_group = FakeTargetGroup(name, arn, **kwargs) @@ -1541,7 +1551,7 @@ Member must satisfy regular expression pattern: {}".format( from moto.acm.models import AWSResourceNotFoundException try: - acm_backend = acm_backends[self.region_name] + acm_backend = acm_backends[self.account_id][self.region_name] acm_backend.get_certificate(certificate_arn) return True except AWSResourceNotFoundException: @@ -1549,7 +1559,9 @@ Member must satisfy regular expression pattern: {}".format( from moto.iam import iam_backends - cert = iam_backends["global"].get_certificate_by_arn(certificate_arn) + cert = iam_backends[self.account_id]["global"].get_certificate_by_arn( + certificate_arn + ) if cert is not None: return True diff --git a/moto/elbv2/responses.py b/moto/elbv2/responses.py index 28ef9b8d5..c03b23395 100644 --- a/moto/elbv2/responses.py +++ b/moto/elbv2/responses.py @@ -135,9 +135,12 @@ SSL_POLICIES = [ class ELBV2Response(BaseResponse): + def __init__(self): + super().__init__(service_name="elbv2") + @property def elbv2_backend(self): - return elbv2_backends[self.region] + return elbv2_backends[self.current_account][self.region] @amzn_request_id def create_load_balancer(self): diff --git a/moto/emr/models.py b/moto/emr/models.py index 3411f573c..1503bc941 100644 --- a/moto/emr/models.py +++ b/moto/emr/models.py @@ -5,7 +5,7 @@ import warnings import pytz from dateutil.parser import parse as dtparse -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.emr.exceptions import ( InvalidRequestException, @@ -280,7 +280,7 @@ class FakeCluster(BaseModel): @property def arn(self): return "arn:aws:elasticmapreduce:{0}:{1}:cluster/{2}".format( - self.emr_backend.region_name, get_account_id(), self.id + self.emr_backend.region_name, self.emr_backend.account_id, self.id ) @property @@ -411,7 +411,7 @@ class ElasticMapReduceBackend(BaseBackend): """ from moto.ec2 import ec2_backends - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] def add_applications(self, cluster_id, applications): cluster = self.describe_cluster(cluster_id) diff --git a/moto/emr/responses.py b/moto/emr/responses.py index c98e96750..d64f57519 100644 --- a/moto/emr/responses.py +++ b/moto/emr/responses.py @@ -55,6 +55,9 @@ class ElasticMapReduceResponse(BaseResponse): aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json") + def __init__(self): + super().__init__(service_name="emr") + def get_region_from_url(self, request, full_url): parsed = urlparse(full_url) for regex in self.region_regex: @@ -65,7 +68,7 @@ class ElasticMapReduceResponse(BaseResponse): @property def backend(self): - return emr_backends[self.region] + return emr_backends[self.current_account][self.region] @generate_boto3_response("AddInstanceGroups") def add_instance_groups(self): diff --git a/moto/emrcontainers/models.py b/moto/emrcontainers/models.py index e76646eb5..bdf9d9ec0 100644 --- a/moto/emrcontainers/models.py +++ b/moto/emrcontainers/models.py @@ -2,7 +2,7 @@ import re from datetime import datetime -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_without_milliseconds, BackendDict from .utils import random_cluster_id, random_job_id, get_partition, paginated_list @@ -10,17 +10,9 @@ from .exceptions import ResourceNotFoundException from ..config.exceptions import ValidationException -VIRTUAL_CLUSTER_ARN_TEMPLATE = ( - "arn:{partition}:emr-containers:{region}:" - + str(get_account_id()) - + ":/virtualclusters/{virtual_cluster_id}" -) +VIRTUAL_CLUSTER_ARN_TEMPLATE = "arn:{partition}:emr-containers:{region}:{account_id}:/virtualclusters/{virtual_cluster_id}" -JOB_ARN_TEMPLATE = ( - "arn:{partition}:emr-containers:{region}:" - + str(get_account_id()) - + ":/virtualclusters/{virtual_cluster_id}/jobruns/{job_id}" -) +JOB_ARN_TEMPLATE = "arn:{partition}:emr-containers:{region}:{account_id}:/virtualclusters/{virtual_cluster_id}/jobruns/{job_id}" # Defaults used for creating a Virtual cluster VIRTUAL_CLUSTER_STATUS = "RUNNING" @@ -33,6 +25,7 @@ class FakeCluster(BaseModel): name, container_provider, client_token, + account_id, region_name, aws_partition, tags=None, @@ -43,7 +36,10 @@ class FakeCluster(BaseModel): self.name = name self.client_token = client_token self.arn = VIRTUAL_CLUSTER_ARN_TEMPLATE.format( - partition=aws_partition, region=region_name, virtual_cluster_id=self.id + partition=aws_partition, + region=region_name, + account_id=account_id, + virtual_cluster_id=self.id, ) self.state = VIRTUAL_CLUSTER_STATUS self.container_provider = container_provider @@ -87,6 +83,7 @@ class FakeJob(BaseModel): release_label, job_driver, configuration_overrides, + account_id, region_name, aws_partition, tags, @@ -97,6 +94,7 @@ class FakeJob(BaseModel): self.arn = JOB_ARN_TEMPLATE.format( partition=aws_partition, region=region_name, + account_id=account_id, virtual_cluster_id=self.virtual_cluster_id, job_id=self.id, ) @@ -183,6 +181,7 @@ class EMRContainersBackend(BaseBackend): container_provider=container_provider, client_token=client_token, tags=tags, + account_id=self.account_id, region_name=self.region_name, aws_partition=self.partition, ) @@ -288,6 +287,7 @@ class EMRContainersBackend(BaseBackend): job_driver=job_driver, configuration_overrides=configuration_overrides, tags=tags, + account_id=self.account_id, region_name=self.region_name, aws_partition=self.partition, ) diff --git a/moto/emrcontainers/responses.py b/moto/emrcontainers/responses.py index 084de2fce..601ab1295 100644 --- a/moto/emrcontainers/responses.py +++ b/moto/emrcontainers/responses.py @@ -12,12 +12,13 @@ DEFAULT_CONTAINER_PROVIDER_TYPE = "EKS" class EMRContainersResponse(BaseResponse): """Handler for EMRContainers requests and responses.""" - SERVICE_NAME = "emr-containers" + def __init__(self): + super().__init__(service_name="emr-containers") @property def emrcontainers_backend(self): """Return backend instance specific for this region.""" - return emrcontainers_backends[self.region] + return emrcontainers_backends[self.current_account][self.region] def create_virtual_cluster(self): name = self._get_param("name") diff --git a/moto/emrserverless/models.py b/moto/emrserverless/models.py index f1729b2cc..4eed63442 100644 --- a/moto/emrserverless/models.py +++ b/moto/emrserverless/models.py @@ -3,7 +3,7 @@ import re from datetime import datetime import inspect -from moto.core import ACCOUNT_ID, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, iso_8601_datetime_without_milliseconds from .utils import ( default_auto_start_configuration, @@ -16,17 +16,7 @@ from .utils import ( from .exceptions import ResourceNotFoundException, ValidationException -APPLICATION_ARN_TEMPLATE = ( - "arn:{partition}:emr-containers:{region}:" - + str(ACCOUNT_ID) - + ":/applications/{application_id}" -) - -JOB_ARN_TEMPLATE = ( - "arn:{partition}:emr-containers:{region}:" - + str(ACCOUNT_ID) - + ":/applications/{application_id}/jobruns/{job_id}" -) +APPLICATION_ARN_TEMPLATE = "arn:{partition}:emr-containers:{region}:{account_id}:/applications/{application_id}" # Defaults used for creating an EMR Serverless application APPLICATION_STATUS = "STARTED" @@ -40,6 +30,7 @@ class FakeApplication(BaseModel): release_label, application_type, client_token, + account_id, region_name, initial_capacity, maximum_capacity, @@ -67,7 +58,10 @@ class FakeApplication(BaseModel): # Service-generated-parameters self.id = random_appplication_id() self.arn = APPLICATION_ARN_TEMPLATE.format( - partition="aws", region=region_name, application_id=self.id + partition="aws", + region=region_name, + account_id=account_id, + application_id=self.id, ) self.state = APPLICATION_STATUS self.state_details = "" @@ -166,6 +160,7 @@ class EMRServerlessBackend(BaseBackend): name=name, release_label=release_label, application_type=application_type, + account_id=self.account_id, region_name=self.region_name, client_token=client_token, initial_capacity=initial_capacity, diff --git a/moto/emrserverless/responses.py b/moto/emrserverless/responses.py index 7377b65a7..af8edad5f 100644 --- a/moto/emrserverless/responses.py +++ b/moto/emrserverless/responses.py @@ -33,12 +33,13 @@ These are the available methos: class EMRServerlessResponse(BaseResponse): """Handler for EMRServerless requests and responses.""" - SERVICE_NAME = "emr-serverless" + def __init__(self): + super().__init__("emr-serverless") @property def emrserverless_backend(self): """Return backend instance specific for this region.""" - return emrserverless_backends[self.region] + return emrserverless_backends[self.current_account][self.region] def create_application(self): name = self._get_param("name") diff --git a/moto/es/responses.py b/moto/es/responses.py index 7aaa11ae4..e61e3e77b 100644 --- a/moto/es/responses.py +++ b/moto/es/responses.py @@ -9,10 +9,13 @@ from .models import es_backends class ElasticsearchServiceResponse(BaseResponse): """Handler for ElasticsearchService requests and responses.""" + def __init__(self): + super().__init__(service_name="elasticsearch") + @property def es_backend(self): """Return backend instance specific for this region.""" - return es_backends[self.region] + return es_backends[self.current_account][self.region] @classmethod def list_domains(cls, request, full_url, headers): diff --git a/moto/events/models.py b/moto/events/models.py index 4a9a7a990..8fade1e07 100644 --- a/moto/events/models.py +++ b/moto/events/models.py @@ -12,7 +12,7 @@ from operator import lt, le, eq, ge, gt from collections import OrderedDict from moto.core.exceptions import JsonRESTError -from moto.core import get_account_id, BaseBackend, CloudFormationModel, BaseModel +from moto.core import BaseBackend, CloudFormationModel, BaseModel from moto.core.utils import ( unix_time, unix_time_millis, @@ -43,6 +43,7 @@ class Rule(CloudFormationModel): def __init__( self, name, + account_id, region_name, description, event_pattern, @@ -54,6 +55,7 @@ class Rule(CloudFormationModel): targets=None, ): self.name = name + self.account_id = account_id self.region_name = region_name self.description = description self.event_pattern = EventPattern.load(event_pattern) @@ -62,7 +64,7 @@ class Rule(CloudFormationModel): self.event_bus_name = event_bus_name self.state = state or "ENABLED" self.managed_by = managed_by # can only be set by AWS services - self.created_by = get_account_id() + self.created_by = account_id self.targets = targets or [] @property @@ -76,7 +78,7 @@ class Rule(CloudFormationModel): return ( "arn:aws:events:{region}:{account_id}:rule/{event_bus_name}{name}".format( region=self.region_name, - account_id=get_account_id(), + account_id=self.account_id, event_bus_name=event_bus_name, name=self.name, ) @@ -100,8 +102,8 @@ class Rule(CloudFormationModel): def disable(self): self.state = "DISABLED" - def delete(self, region_name): - event_backend = events_backends[region_name] + def delete(self, account_id, region_name): + event_backend = events_backends[account_id][region_name] event_backend.delete_rule(name=self.name) def put_targets(self, targets): @@ -189,14 +191,15 @@ class Rule(CloudFormationModel): } ] - logs_backends[self.region_name].create_log_stream(name, log_stream_name) - logs_backends[self.region_name].put_log_events( - name, log_stream_name, log_events - ) + log_backend = logs_backends[self.account_id][self.region_name] + log_backend.create_log_stream(name, log_stream_name) + log_backend.put_log_events(name, log_stream_name, log_events) def _send_to_events_archive(self, resource_id, event): archive_name, archive_uuid = resource_id.split(":") - archive = events_backends[self.region_name].archives.get(archive_name) + archive = events_backends[self.account_id][self.region_name].archives.get( + archive_name + ) if archive.uuid == archive_uuid: archive.events.append(event) @@ -209,7 +212,9 @@ class Rule(CloudFormationModel): ) if group_id: - queue_attr = sqs_backends[self.region_name].get_queue_attributes( + queue_attr = sqs_backends[self.account_id][ + self.region_name + ].get_queue_attributes( queue_name=resource_id, attribute_names=["ContentBasedDeduplication"] ) if queue_attr["ContentBasedDeduplication"] == "false": @@ -219,7 +224,7 @@ class Rule(CloudFormationModel): ) return - sqs_backends[self.region_name].send_message( + sqs_backends[self.account_id][self.region_name].send_message( queue_name=resource_id, message_body=json.dumps(event_copy), group_id=group_id, @@ -248,7 +253,7 @@ class Rule(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] properties.setdefault("EventBusName", "default") @@ -266,7 +271,7 @@ class Rule(CloudFormationModel): event_bus_name = properties.get("EventBusName") tags = properties.get("Tags") - backend = events_backends[region_name] + backend = events_backends[account_id][region_name] return backend.put_rule( event_name, scheduled_expression=scheduled_expression, @@ -280,18 +285,23 @@ class Rule(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): - original_resource.delete(region_name) + original_resource.delete(account_id, region_name) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - event_backend = events_backends[region_name] + event_backend = events_backends[account_id][region_name] event_backend.delete_rule(resource_name) def describe(self): @@ -314,19 +324,15 @@ class Rule(CloudFormationModel): class EventBus(CloudFormationModel): - def __init__(self, region_name, name, tags=None): + def __init__(self, account_id, region_name, name, tags=None): + self.account_id = account_id self.region = region_name self.name = name + self.arn = f"arn:aws:events:{self.region}:{account_id}:event-bus/{name}" self.tags = tags or [] self._statements = {} - @property - def arn(self): - return "arn:aws:events:{region}:{account_id}:event-bus/{name}".format( - region=self.region, account_id=get_account_id(), name=self.name - ) - @property def policy(self): if self._statements: @@ -340,8 +346,8 @@ class EventBus(CloudFormationModel): def has_permissions(self): return len(self._statements) > 0 - def delete(self, region_name): - event_backend = events_backends[region_name] + def delete(self, account_id, region_name): + event_backend = events_backends[account_id][region_name] event_backend.delete_event_bus(name=self.name) @classmethod @@ -371,10 +377,10 @@ class EventBus(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - event_backend = events_backends[region_name] + event_backend = events_backends[account_id][region_name] event_name = resource_name event_source_name = properties.get("EventSourceName") return event_backend.create_event_bus( @@ -383,18 +389,23 @@ class EventBus(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): - original_resource.delete(region_name) + original_resource.delete(account_id, region_name) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - event_backend = events_backends[region_name] + event_backend = events_backends[account_id][region_name] event_bus_name = resource_name event_backend.delete_event_bus(event_bus_name) @@ -491,7 +502,14 @@ class Archive(CloudFormationModel): ] def __init__( - self, region_name, name, source_arn, description, event_pattern, retention + self, + account_id, + region_name, + name, + source_arn, + description, + event_pattern, + retention, ): self.region = region_name self.name = name @@ -500,6 +518,7 @@ class Archive(CloudFormationModel): self.event_pattern = EventPattern.load(event_pattern) self.retention = retention if retention else 0 + self.arn = f"arn:aws:events:{region_name}:{account_id}:archive/{name}" self.creation_time = unix_time(datetime.utcnow()) self.state = "ENABLED" self.uuid = str(uuid4()) @@ -507,12 +526,6 @@ class Archive(CloudFormationModel): self.events = [] self.event_bus_name = source_arn.split("/")[-1] - @property - def arn(self): - return "arn:aws:events:{region}:{account_id}:archive/{name}".format( - region=self.region, account_id=get_account_id(), name=self.name - ) - def describe_short(self): return { "ArchiveName": self.name, @@ -542,8 +555,8 @@ class Archive(CloudFormationModel): if retention: self.retention = retention - def delete(self, region_name): - event_backend = events_backends[region_name] + def delete(self, account_id, region_name): + event_backend = events_backends[account_id][region_name] event_backend.archives.pop(self.name) @classmethod @@ -571,10 +584,10 @@ class Archive(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - event_backend = events_backends[region_name] + event_backend = events_backends[account_id][region_name] source_arn = properties.get("SourceArn") description = properties.get("Description") @@ -587,7 +600,12 @@ class Archive(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): if new_resource_name == original_resource.name: properties = cloudformation_json["Properties"] @@ -600,9 +618,9 @@ class Archive(CloudFormationModel): return original_resource else: - original_resource.delete(region_name) + original_resource.delete(account_id, region_name) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @@ -620,6 +638,7 @@ class ReplayState(Enum): class Replay(BaseModel): def __init__( self, + account_id, region_name, name, description, @@ -628,6 +647,7 @@ class Replay(BaseModel): end_time, destination, ): + self.account_id = account_id self.region = region_name self.name = name self.description = description @@ -636,16 +656,11 @@ class Replay(BaseModel): self.event_end_time = end_time self.destination = destination + self.arn = f"arn:aws:events:{region_name}:{account_id}:replay/{name}" self.state = ReplayState.STARTING self.start_time = unix_time(datetime.utcnow()) self.end_time = None - @property - def arn(self): - return "arn:aws:events:{region}:{account_id}:replay/{name}".format( - region=self.region, account_id=get_account_id(), name=self.name - ) - def describe_short(self): return { "ReplayName": self.name, @@ -672,7 +687,8 @@ class Replay(BaseModel): event_bus_name = self.destination["Arn"].split("/")[-1] for event in archive.events: - for rule in events_backends[self.region].rules.values(): + event_backend = events_backends[self.account_id][self.region] + for rule in event_backend.rules.values(): rule.send_to_targets( event_bus_name, dict(event, **{"id": str(uuid4()), "replay-name": self.name}), @@ -684,7 +700,13 @@ class Replay(BaseModel): class Connection(BaseModel): def __init__( - self, name, region_name, description, authorization_type, auth_parameters + self, + name, + account_id, + region_name, + description, + authorization_type, + auth_parameters, ): self.uuid = uuid4() self.name = name @@ -695,11 +717,7 @@ class Connection(BaseModel): self.creation_time = unix_time(datetime.utcnow()) self.state = "AUTHORIZED" - @property - def arn(self): - return "arn:aws:events:{0}:{1}:connection/{2}/{3}".format( - self.region, get_account_id(), self.name, self.uuid - ) + self.arn = f"arn:aws:events:{region_name}:{account_id}:connection/{self.name}/{self.uuid}" def describe_short(self): """ @@ -758,6 +776,7 @@ class Destination(BaseModel): def __init__( self, name, + account_id, region_name, description, connection_arn, @@ -775,12 +794,7 @@ class Destination(BaseModel): self.creation_time = unix_time(datetime.utcnow()) self.http_method = http_method self.state = "ACTIVE" - - @property - def arn(self): - return "arn:aws:events:{0}:{1}:api-destination/{2}/{3}".format( - self.region, get_account_id(), self.name, self.uuid - ) + self.arn = f"arn:aws:events:{region_name}:{account_id}:api-destination/{name}/{self.uuid}" def describe(self): """ @@ -959,7 +973,9 @@ class EventsBackend(BaseBackend): ) def _add_default_event_bus(self): - self.event_buses["default"] = EventBus(self.region_name, "default") + self.event_buses["default"] = EventBus( + self.account_id, self.region_name, "default" + ) def _gen_next_token(self, index): token = os.urandom(128).encode("base64") @@ -1037,6 +1053,7 @@ class EventsBackend(BaseBackend): targets = existing_rule.targets if existing_rule else list() rule = Rule( name, + self.account_id, self.region_name, description, event_pattern, @@ -1231,7 +1248,7 @@ class EventsBackend(BaseBackend): "id": event_id, "detail-type": event["DetailType"], "source": event["Source"], - "account": get_account_id(), + "account": self.account_id, "time": event.get("Time", unix_time(datetime.utcnow())), "region": self.region_name, "resources": event.get("Resources", []), @@ -1371,7 +1388,7 @@ class EventsBackend(BaseBackend): "Event source {} does not exist.".format(event_source_name), ) - event_bus = EventBus(self.region_name, name, tags=tags) + event_bus = EventBus(self.account_id, self.region_name, name, tags=tags) self.event_buses[name] = event_bus if tags: self.tagger.tag_resource(event_bus.arn, tags) @@ -1445,7 +1462,13 @@ class EventsBackend(BaseBackend): ) archive = Archive( - self.region_name, name, source_arn, description, event_pattern, retention + self.account_id, + self.region_name, + name, + source_arn, + description, + event_pattern, + retention, ) rule_event_pattern = json.loads(event_pattern or "{}") @@ -1543,7 +1566,7 @@ class EventsBackend(BaseBackend): if not archive: raise ResourceNotFoundException("Archive {} does not exist.".format(name)) - archive.delete(self.region_name) + archive.delete(self.account_id, self.region_name) def start_replay( self, name, description, source_arn, start_time, end_time, destination @@ -1584,6 +1607,7 @@ class EventsBackend(BaseBackend): ) replay = Replay( + self.account_id, self.region_name, name, description, @@ -1660,7 +1684,12 @@ class EventsBackend(BaseBackend): def create_connection(self, name, description, authorization_type, auth_parameters): connection = Connection( - name, self.region_name, description, authorization_type, auth_parameters + name, + self.account_id, + self.region_name, + description, + authorization_type, + auth_parameters, ) self.connections[name] = connection return connection @@ -1748,6 +1777,7 @@ class EventsBackend(BaseBackend): """ destination = Destination( name=name, + account_id=self.account_id, region_name=self.region_name, description=description, connection_arn=connection_arn, diff --git a/moto/events/notifications.py b/moto/events/notifications.py index b8a05e236..3f380b416 100644 --- a/moto/events/notifications.py +++ b/moto/events/notifications.py @@ -1,5 +1,4 @@ import json -from moto.core import get_account_id _EVENT_S3_OBJECT_CREATED = { @@ -36,29 +35,29 @@ def _send_safe_notification(source, event_name, region, resources, detail): if event is None: return - account = events_backends[get_account_id()] - for backend in account.values(): - applicable_targets = [] - for rule in backend.rules.values(): - if rule.state != "ENABLED": - continue - pattern = rule.event_pattern.get_pattern() - if source in pattern.get("source", []): - if event_name in pattern.get("detail", {}).get("eventName", []): - applicable_targets.extend(rule.targets) + for account_id, account in events_backends.items(): + for backend in account.values(): + applicable_targets = [] + for rule in backend.rules.values(): + if rule.state != "ENABLED": + continue + pattern = rule.event_pattern.get_pattern() + if source in pattern.get("source", []): + if event_name in pattern.get("detail", {}).get("eventName", []): + applicable_targets.extend(rule.targets) - for target in applicable_targets: - if target.get("Arn", "").startswith("arn:aws:lambda"): - _invoke_lambda(target.get("Arn"), event=event) + for target in applicable_targets: + if target.get("Arn", "").startswith("arn:aws:lambda"): + _invoke_lambda(account_id, target.get("Arn"), event=event) -def _invoke_lambda(fn_arn, event): +def _invoke_lambda(account_id, fn_arn, event): from moto.awslambda import lambda_backends lmbda_region = fn_arn.split(":")[3] body = json.dumps(event) - lambda_backends[lmbda_region].invoke( + lambda_backends[account_id][lmbda_region].invoke( function_name=fn_arn, qualifier=None, body=body, diff --git a/moto/events/responses.py b/moto/events/responses.py index f79fbe8da..d6cc88a93 100644 --- a/moto/events/responses.py +++ b/moto/events/responses.py @@ -5,6 +5,9 @@ from moto.events import events_backends class EventsHandler(BaseResponse): + def __init__(self): + super().__init__(service_name="events") + @property def events_backend(self): """ @@ -13,7 +16,7 @@ class EventsHandler(BaseResponse): :return: Events Backend object :rtype: moto.events.models.EventsBackend """ - return events_backends[self.region] + return events_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/firehose/models.py b/moto/firehose/models.py index a8d43f3ac..75e28961e 100644 --- a/moto/firehose/models.py +++ b/moto/firehose/models.py @@ -27,7 +27,6 @@ import warnings import requests from moto.core import BaseBackend, BaseModel -from moto.core import get_account_id from moto.core.utils import BackendDict from moto.firehose.exceptions import ( ConcurrentModificationException, @@ -117,6 +116,7 @@ class DeliveryStream( def __init__( self, + account_id, region, delivery_stream_name, delivery_stream_type, @@ -151,7 +151,7 @@ class DeliveryStream( del self.destinations[0][destination_name]["S3Configuration"] self.delivery_stream_status = "ACTIVE" - self.delivery_stream_arn = f"arn:aws:firehose:{region}:{get_account_id()}:deliverystream/{delivery_stream_name}" + self.delivery_stream_arn = f"arn:aws:firehose:{region}:{account_id}:deliverystream/{delivery_stream_name}" self.create_timestamp = datetime.now(timezone.utc).isoformat() self.version_id = "1" # Used to track updates of destination configs @@ -197,7 +197,7 @@ class FirehoseBackend(BaseBackend): if delivery_stream_name in self.delivery_streams: raise ResourceInUseException( - f"Firehose {delivery_stream_name} under accountId {get_account_id()} " + f"Firehose {delivery_stream_name} under accountId {self.account_id} " f"already exists" ) @@ -243,6 +243,7 @@ class FirehoseBackend(BaseBackend): # by delivery stream name. This instance will update the state and # create the ARN. delivery_stream = DeliveryStream( + self.account_id, region, delivery_stream_name, delivery_stream_type, @@ -266,7 +267,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " + f"Firehose {delivery_stream_name} under account {self.account_id} " f"not found." ) @@ -286,7 +287,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " + f"Firehose {delivery_stream_name} under account {self.account_id} " f"not found." ) @@ -370,8 +371,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " - f"not found." + f"Firehose {delivery_stream_name} under account {self.account_id} not found." ) tags = self.tagger.list_tags_for_resource(delivery_stream.delivery_stream_arn)[ @@ -447,7 +447,9 @@ class FirehoseBackend(BaseBackend): batched_data = b"".join([b64decode(r["Data"]) for r in records]) try: - s3_backends["global"].put_object(bucket_name, object_path, batched_data) + s3_backends[self.account_id]["global"].put_object( + bucket_name, object_path, batched_data + ) except Exception as exc: # This could be better ... raise RuntimeError( @@ -460,7 +462,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " + f"Firehose {delivery_stream_name} under account {self.account_id} " f"not found." ) @@ -506,7 +508,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " + f"Firehose {delivery_stream_name} under account {self.account_id} " f"not found." ) @@ -528,8 +530,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under account {get_account_id()} " - f"not found." + f"Firehose {delivery_stream_name} under account {self.account_id} not found." ) # If a tag key doesn't exist for the stream, boto3 ignores it. @@ -558,8 +559,7 @@ class FirehoseBackend(BaseBackend): delivery_stream = self.delivery_streams.get(delivery_stream_name) if not delivery_stream: raise ResourceNotFoundException( - f"Firehose {delivery_stream_name} under accountId " - f"{get_account_id()} not found." + f"Firehose {delivery_stream_name} under accountId {self.account_id} not found." ) if destination_name == "Splunk": @@ -647,7 +647,7 @@ class FirehoseBackend(BaseBackend): "logGroup": log_group_name, "logStream": log_stream_name, "messageType": "DATA_MESSAGE", - "owner": get_account_id(), + "owner": self.account_id, "subscriptionFilters": [filter_name], } diff --git a/moto/firehose/responses.py b/moto/firehose/responses.py index 5b77836ab..1613fca4f 100644 --- a/moto/firehose/responses.py +++ b/moto/firehose/responses.py @@ -8,10 +8,13 @@ from .models import firehose_backends class FirehoseResponse(BaseResponse): """Handler for Firehose requests and responses.""" + def __init__(self): + super().__init__(service_name="firehose") + @property def firehose_backend(self): """Return backend instance specific to this region.""" - return firehose_backends[self.region] + return firehose_backends[self.current_account][self.region] def create_delivery_stream(self): """Prepare arguments and respond to CreateDeliveryStream request.""" diff --git a/moto/forecast/models.py b/moto/forecast/models.py index 2982c3976..e74b68e5f 100644 --- a/moto/forecast/models.py +++ b/moto/forecast/models.py @@ -1,7 +1,7 @@ import re from datetime import datetime -from moto.core import get_account_id, BaseBackend +from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds, BackendDict from .exceptions import ( InvalidInputException, @@ -25,19 +25,18 @@ class DatasetGroup: ] def __init__( - self, region_name, dataset_arns, dataset_group_name, domain, tags=None + self, + account_id, + region_name, + dataset_arns, + dataset_group_name, + domain, + tags=None, ): self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) self.modified_date = self.creation_date - self.arn = ( - "arn:aws:forecast:" - + region_name - + ":" - + str(get_account_id()) - + ":dataset-group/" - + dataset_group_name - ) + self.arn = f"arn:aws:forecast:{region_name}:{account_id}:dataset-group/{dataset_group_name}" self.dataset_arns = dataset_arns if dataset_arns else [] self.dataset_group_name = dataset_group_name self.domain = domain @@ -106,6 +105,7 @@ class ForecastBackend(BaseBackend): def create_dataset_group(self, dataset_group_name, domain, dataset_arns, tags): dataset_group = DatasetGroup( + account_id=self.account_id, region_name=self.region_name, dataset_group_name=dataset_group_name, domain=domain, diff --git a/moto/forecast/responses.py b/moto/forecast/responses.py index d5c277b9f..bee513d9a 100644 --- a/moto/forecast/responses.py +++ b/moto/forecast/responses.py @@ -6,9 +6,12 @@ from .models import forecast_backends class ForecastResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="forecast") + @property def forecast_backend(self): - return forecast_backends[self.region] + return forecast_backends[self.current_account][self.region] @amzn_request_id def create_dataset_group(self): diff --git a/moto/glacier/models.py b/moto/glacier/models.py index d5b138958..4067b6282 100644 --- a/moto/glacier/models.py +++ b/moto/glacier/models.py @@ -2,7 +2,7 @@ import hashlib import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.utilities.utils import md5_hash @@ -90,18 +90,13 @@ class InventoryJob(Job): class Vault(BaseModel): - def __init__(self, vault_name, region): + def __init__(self, vault_name, account_id, region): self.st = datetime.datetime.now() self.vault_name = vault_name self.region = region self.archives = {} self.jobs = {} - - @property - def arn(self): - return "arn:aws:glacier:{0}:{1}:vaults/{2}".format( - self.region, get_account_id(), self.vault_name - ) + self.arn = f"arn:aws:glacier:{region}:{account_id}:vaults/{vault_name}" def to_dict(self): archives_size = 0 @@ -196,7 +191,7 @@ class GlacierBackend(BaseBackend): return self.vaults[vault_name] def create_vault(self, vault_name): - self.vaults[vault_name] = Vault(vault_name, self.region_name) + self.vaults[vault_name] = Vault(vault_name, self.account_id, self.region_name) def list_vaults(self): return self.vaults.values() diff --git a/moto/glacier/responses.py b/moto/glacier/responses.py index 334e737bd..0219fdea9 100644 --- a/moto/glacier/responses.py +++ b/moto/glacier/responses.py @@ -6,9 +6,12 @@ from .utils import vault_from_glacier_url class GlacierResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="glacier") + @property def glacier_backend(self): - return glacier_backends[self.region] + return glacier_backends[self.current_account][self.region] def all_vault_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/glue/models.py b/moto/glue/models.py index 3948e1c33..30559565e 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -4,7 +4,6 @@ from datetime import datetime from uuid import uuid4 from moto.core import BaseBackend, BaseModel -from moto.core.models import get_account_id from moto.core.utils import BackendDict from moto.glue.exceptions import ( CrawlerRunningException, @@ -296,14 +295,14 @@ class GlueBackend(BaseBackend): def create_registry(self, registry_name, description=None, tags=None): # If registry name id default-registry, create default-registry if registry_name == DEFAULT_REGISTRY_NAME: - registry = FakeRegistry(registry_name, description, tags) + registry = FakeRegistry(self.account_id, registry_name, description, tags) self.registries[registry_name] = registry return registry # Validate Registry Parameters validate_registry_params(self.registries, registry_name, description, tags) - registry = FakeRegistry(registry_name, description, tags) + registry = FakeRegistry(self.account_id, registry_name, description, tags) self.registries[registry_name] = registry return registry.as_dict() @@ -344,10 +343,15 @@ class GlueBackend(BaseBackend): # Create Schema schema_version = FakeSchemaVersion( - registry_name, schema_name, schema_definition, version_number=1 + self.account_id, + registry_name, + schema_name, + schema_definition, + version_number=1, ) schema_version_id = schema_version.get_schema_version_id() schema = FakeSchema( + self.account_id, registry_name, schema_name, data_format, @@ -410,7 +414,11 @@ class GlueBackend(BaseBackend): self.num_schema_versions += 1 schema_version = FakeSchemaVersion( - registry_name, schema_name, schema_definition, version_number + self.account_id, + registry_name, + schema_name, + schema_definition, + version_number, ) self.registries[registry_name].schemas[schema_name].schema_versions[ schema_version.schema_version_id @@ -730,7 +738,7 @@ class FakeCrawler(BaseModel): self.version = 1 self.crawl_elapsed_time = 0 self.last_crawl_info = None - self.arn = f"arn:aws:glue:us-east-1:{get_account_id()}:crawler/{self.name}" + self.arn = f"arn:aws:glue:us-east-1:{backend.account_id}:crawler/{self.name}" self.backend = backend self.backend.tag_resource(self.arn, tags) @@ -853,7 +861,7 @@ class FakeJob: self.worker_type = worker_type self.created_on = datetime.utcnow() self.last_modified_on = datetime.utcnow() - self.arn = f"arn:aws:glue:us-east-1:{get_account_id()}:job/{self.name}" + self.arn = f"arn:aws:glue:us-east-1:{backend.account_id}:job/{self.name}" self.backend = backend self.backend.tag_resource(self.arn, tags) @@ -951,16 +959,14 @@ class FakeJobRun: class FakeRegistry(BaseModel): - def __init__(self, registry_name, description=None, tags=None): + def __init__(self, account_id, registry_name, description=None, tags=None): self.name = registry_name self.description = description self.tags = tags self.created_time = datetime.utcnow() self.updated_time = datetime.utcnow() self.status = "AVAILABLE" - self.registry_arn = ( - f"arn:aws:glue:us-east-1:{get_account_id()}:registry/{self.name}" - ) + self.registry_arn = f"arn:aws:glue:us-east-1:{account_id}:registry/{self.name}" self.schemas = OrderedDict() def as_dict(self): @@ -975,6 +981,7 @@ class FakeRegistry(BaseModel): class FakeSchema(BaseModel): def __init__( self, + account_id, registry_name, schema_name, data_format, @@ -985,10 +992,10 @@ class FakeSchema(BaseModel): ): self.registry_name = registry_name self.registry_arn = ( - f"arn:aws:glue:us-east-1:{get_account_id()}:registry/{self.registry_name}" + f"arn:aws:glue:us-east-1:{account_id}:registry/{self.registry_name}" ) self.schema_name = schema_name - self.schema_arn = f"arn:aws:glue:us-east-1:{get_account_id()}:schema/{self.registry_name}/{self.schema_name}" + self.schema_arn = f"arn:aws:glue:us-east-1:{account_id}:schema/{self.registry_name}/{self.schema_name}" self.description = description self.data_format = data_format self.compatibility = compatibility @@ -1032,10 +1039,12 @@ class FakeSchema(BaseModel): class FakeSchemaVersion(BaseModel): - def __init__(self, registry_name, schema_name, schema_definition, version_number): + def __init__( + self, account_id, registry_name, schema_name, schema_definition, version_number + ): self.registry_name = registry_name self.schema_name = schema_name - self.schema_arn = f"arn:aws:glue:us-east-1:{get_account_id()}:schema/{self.registry_name}/{self.schema_name}" + self.schema_arn = f"arn:aws:glue:us-east-1:{account_id}:schema/{self.registry_name}/{self.schema_name}" self.schema_definition = schema_definition self.schema_version_status = AVAILABLE_STATUS self.version_number = version_number @@ -1078,4 +1087,3 @@ class FakeSchemaVersion(BaseModel): glue_backends = BackendDict( GlueBackend, "glue", use_boto3_regions=False, additional_regions=["global"] ) -glue_backend = glue_backends["global"] diff --git a/moto/glue/responses.py b/moto/glue/responses.py index 6d96358b0..3fa5e8f66 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -6,13 +6,16 @@ from .exceptions import ( PartitionNotFoundException, TableNotFoundException, ) -from .models import glue_backend +from .models import glue_backends class GlueResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="glue") + @property def glue_backend(self): - return glue_backend + return glue_backends[self.current_account]["global"] @property def parameters(self): diff --git a/moto/greengrass/models.py b/moto/greengrass/models.py index 788cc850d..f54892aa0 100644 --- a/moto/greengrass/models.py +++ b/moto/greengrass/models.py @@ -4,7 +4,7 @@ from collections import OrderedDict from datetime import datetime import re -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, iso_8601_datetime_with_milliseconds from .exceptions import ( GreengrassClientError, @@ -18,11 +18,11 @@ from .exceptions import ( class FakeCoreDefinition(BaseModel): - def __init__(self, region_name, name): + def __init__(self, account_id, region_name, name): self.region_name = region_name self.name = name self.id = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/cores/{self.id}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.id}" self.created_at_datetime = datetime.utcnow() self.latest_version = "" self.latest_version_arn = "" @@ -44,12 +44,12 @@ class FakeCoreDefinition(BaseModel): class FakeCoreDefinitionVersion(BaseModel): - def __init__(self, region_name, core_definition_id, definition): + def __init__(self, account_id, region_name, core_definition_id, definition): self.region_name = region_name self.core_definition_id = core_definition_id self.definition = definition self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/cores/{self.core_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() def to_dict(self, include_detail=False): @@ -69,10 +69,10 @@ class FakeCoreDefinitionVersion(BaseModel): class FakeDeviceDefinition(BaseModel): - def __init__(self, region_name, name, initial_version): + def __init__(self, account_id, region_name, name, initial_version): self.region_name = region_name self.id = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/devices/{self.id}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.id}" self.created_at_datetime = datetime.utcnow() self.update_at_datetime = datetime.utcnow() self.latest_version = "" @@ -99,12 +99,12 @@ class FakeDeviceDefinition(BaseModel): class FakeDeviceDefinitionVersion(BaseModel): - def __init__(self, region_name, device_definition_id, devices): + def __init__(self, account_id, region_name, device_definition_id, devices): self.region_name = region_name self.device_definition_id = device_definition_id self.devices = devices self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/devices/{self.device_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() def to_dict(self, include_detail=False): @@ -124,10 +124,10 @@ class FakeDeviceDefinitionVersion(BaseModel): class FakeResourceDefinition(BaseModel): - def __init__(self, region_name, name, initial_version): + def __init__(self, account_id, region_name, name, initial_version): self.region_name = region_name self.id = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/resources/{self.id}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.id}" self.created_at_datetime = datetime.utcnow() self.update_at_datetime = datetime.utcnow() self.latest_version = "" @@ -152,12 +152,12 @@ class FakeResourceDefinition(BaseModel): class FakeResourceDefinitionVersion(BaseModel): - def __init__(self, region_name, resource_definition_id, resources): + def __init__(self, account_id, region_name, resource_definition_id, resources): self.region_name = region_name self.resource_definition_id = resource_definition_id self.resources = resources self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{region_name}:{get_account_id()}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{region_name}:{account_id}:greengrass/definition/resources/{self.resource_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() def to_dict(self): @@ -173,10 +173,10 @@ class FakeResourceDefinitionVersion(BaseModel): class FakeFunctionDefinition(BaseModel): - def __init__(self, region_name, name, initial_version): + def __init__(self, account_id, region_name, name, initial_version): self.region_name = region_name self.id = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/definition/functions/{self.id}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.id}" self.created_at_datetime = datetime.utcnow() self.update_at_datetime = datetime.utcnow() self.latest_version = "" @@ -203,13 +203,15 @@ class FakeFunctionDefinition(BaseModel): class FakeFunctionDefinitionVersion(BaseModel): - def __init__(self, region_name, function_definition_id, functions, default_config): + def __init__( + self, account_id, region_name, function_definition_id, functions, default_config + ): self.region_name = region_name self.function_definition_id = function_definition_id self.functions = functions self.default_config = default_config self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/functions/{self.function_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() def to_dict(self): @@ -225,10 +227,10 @@ class FakeFunctionDefinitionVersion(BaseModel): class FakeSubscriptionDefinition(BaseModel): - def __init__(self, region_name, name, initial_version): + def __init__(self, account_id, region_name, name, initial_version): self.region_name = region_name self.id = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/definition/subscriptions/{self.id}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.id}" self.created_at_datetime = datetime.utcnow() self.update_at_datetime = datetime.utcnow() self.latest_version = "" @@ -253,12 +255,14 @@ class FakeSubscriptionDefinition(BaseModel): class FakeSubscriptionDefinitionVersion(BaseModel): - def __init__(self, region_name, subscription_definition_id, subscriptions): + def __init__( + self, account_id, region_name, subscription_definition_id, subscriptions + ): self.region_name = region_name self.subscription_definition_id = subscription_definition_id self.subscriptions = subscriptions self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/definition/subscriptions/{self.subscription_definition_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() def to_dict(self): @@ -274,11 +278,11 @@ class FakeSubscriptionDefinitionVersion(BaseModel): class FakeGroup(BaseModel): - def __init__(self, region_name, name): + def __init__(self, account_id, region_name, name): self.region_name = region_name self.group_id = str(uuid.uuid4()) self.name = name - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/groups/{self.group_id}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/groups/{self.group_id}" self.created_at_datetime = datetime.utcnow() self.last_updated_datetime = datetime.utcnow() self.latest_version = "" @@ -304,6 +308,7 @@ class FakeGroup(BaseModel): class FakeGroupVersion(BaseModel): def __init__( self, + account_id, region_name, group_id, core_definition_version_arn, @@ -315,7 +320,7 @@ class FakeGroupVersion(BaseModel): self.region_name = region_name self.group_id = group_id self.version = str(uuid.uuid4()) - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:greengrass/groups/{self.group_id}/versions/{self.version}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:greengrass/groups/{self.group_id}/versions/{self.version}" self.created_at_datetime = datetime.utcnow() self.core_definition_version_arn = core_definition_version_arn self.device_definition_version_arn = device_definition_version_arn @@ -365,7 +370,7 @@ class FakeGroupVersion(BaseModel): class FakeDeployment(BaseModel): - def __init__(self, region_name, group_id, group_arn, deployment_type): + def __init__(self, account_id, region_name, group_id, group_arn, deployment_type): self.region_name = region_name self.id = str(uuid.uuid4()) self.group_id = group_id @@ -374,7 +379,7 @@ class FakeDeployment(BaseModel): self.update_at_datetime = datetime.utcnow() self.deployment_status = "InProgress" self.deployment_type = deployment_type - self.arn = f"arn:aws:greengrass:{self.region_name}:{get_account_id()}:/greengrass/groups/{self.group_id}/deployments/{self.id}" + self.arn = f"arn:aws:greengrass:{self.region_name}:{account_id}:/greengrass/groups/{self.group_id}/deployments/{self.id}" def to_dict(self, include_detail=False): obj = {"DeploymentId": self.id, "DeploymentArn": self.arn} @@ -437,7 +442,7 @@ class GreengrassBackend(BaseBackend): def create_core_definition(self, name, initial_version): - core_definition = FakeCoreDefinition(self.region_name, name) + core_definition = FakeCoreDefinition(self.account_id, self.region_name, name) self.core_definitions[core_definition.id] = core_definition self.create_core_definition_version( core_definition.id, initial_version["Cores"] @@ -473,7 +478,7 @@ class GreengrassBackend(BaseBackend): definition = {"Cores": cores} core_def_ver = FakeCoreDefinitionVersion( - self.region_name, core_definition_id, definition + self.account_id, self.region_name, core_definition_id, definition ) core_def_vers = self.core_definition_versions.get( core_def_ver.core_definition_id, {} @@ -512,7 +517,9 @@ class GreengrassBackend(BaseBackend): ] def create_device_definition(self, name, initial_version): - device_def = FakeDeviceDefinition(self.region_name, name, initial_version) + device_def = FakeDeviceDefinition( + self.account_id, self.region_name, name, initial_version + ) self.device_definitions[device_def.id] = device_def init_ver = device_def.initial_version init_device_def = init_ver.get("Devices", {}) @@ -529,7 +536,7 @@ class GreengrassBackend(BaseBackend): raise IdNotFoundException("That devices definition does not exist.") device_ver = FakeDeviceDefinitionVersion( - self.region_name, device_definition_id, devices + self.account_id, self.region_name, device_definition_id, devices ) device_vers = self.device_definition_versions.get( device_ver.device_definition_id, {} @@ -597,7 +604,9 @@ class GreengrassBackend(BaseBackend): resources = initial_version.get("Resources", []) GreengrassBackend._validate_resources(resources) - resource_def = FakeResourceDefinition(self.region_name, name, initial_version) + resource_def = FakeResourceDefinition( + self.account_id, self.region_name, name, initial_version + ) self.resource_definitions[resource_def.id] = resource_def init_ver = resource_def.initial_version resources = init_ver.get("Resources", {}) @@ -636,7 +645,7 @@ class GreengrassBackend(BaseBackend): GreengrassBackend._validate_resources(resources) resource_def_ver = FakeResourceDefinitionVersion( - self.region_name, resource_definition_id, resources + self.account_id, self.region_name, resource_definition_id, resources ) resources_ver = self.resource_definition_versions.get( @@ -711,7 +720,9 @@ class GreengrassBackend(BaseBackend): ) def create_function_definition(self, name, initial_version): - func_def = FakeFunctionDefinition(self.region_name, name, initial_version) + func_def = FakeFunctionDefinition( + self.account_id, self.region_name, name, initial_version + ) self.function_definitions[func_def.id] = func_def init_ver = func_def.initial_version init_func_def = init_ver.get("Functions", {}) @@ -753,7 +764,11 @@ class GreengrassBackend(BaseBackend): raise IdNotFoundException("That lambdas does not exist.") func_ver = FakeFunctionDefinitionVersion( - self.region_name, function_definition_id, functions, default_config + self.account_id, + self.region_name, + function_definition_id, + functions, + default_config, ) func_vers = self.function_definition_versions.get( func_ver.function_definition_id, {} @@ -854,7 +869,9 @@ class GreengrassBackend(BaseBackend): initial_version["Subscriptions"] ) - sub_def = FakeSubscriptionDefinition(self.region_name, name, initial_version) + sub_def = FakeSubscriptionDefinition( + self.account_id, self.region_name, name, initial_version + ) self.subscription_definitions[sub_def.id] = sub_def init_ver = sub_def.initial_version subscriptions = init_ver.get("Subscriptions", {}) @@ -903,7 +920,7 @@ class GreengrassBackend(BaseBackend): raise IdNotFoundException("That subscriptions does not exist.") sub_def_ver = FakeSubscriptionDefinitionVersion( - self.region_name, subscription_definition_id, subscriptions + self.account_id, self.region_name, subscription_definition_id, subscriptions ) sub_vers = self.subscription_definition_versions.get( @@ -939,7 +956,7 @@ class GreengrassBackend(BaseBackend): ] def create_group(self, name, initial_version): - group = FakeGroup(self.region_name, name) + group = FakeGroup(self.account_id, self.region_name, name) self.groups[group.group_id] = group definitions = initial_version or {} @@ -1013,6 +1030,7 @@ class GreengrassBackend(BaseBackend): ) group_ver = FakeGroupVersion( + self.account_id, self.region_name, group_id=group_id, core_definition_version_arn=core_definition_version_arn, @@ -1172,7 +1190,11 @@ class GreengrassBackend(BaseBackend): raise MissingCoreException(json.dumps(err)) group_version_arn = self.group_versions[group_id][group_version_id].arn deployment = FakeDeployment( - self.region_name, group_id, group_version_arn, deployment_type + self.account_id, + self.region_name, + group_id, + group_version_arn, + deployment_type, ) self.deployments[deployment.id] = deployment return deployment @@ -1221,7 +1243,7 @@ class GreengrassBackend(BaseBackend): group = self.groups[group_id] deployment = FakeDeployment( - self.region_name, group_id, group.arn, deployment_type + self.account_id, self.region_name, group_id, group.arn, deployment_type ) self.deployments[deployment.id] = deployment return deployment diff --git a/moto/greengrass/responses.py b/moto/greengrass/responses.py index 6c2c1d808..d36ad8516 100644 --- a/moto/greengrass/responses.py +++ b/moto/greengrass/responses.py @@ -7,11 +7,12 @@ from .models import greengrass_backends class GreengrassResponse(BaseResponse): - SERVICE_NAME = "greengrass" + def __init__(self): + super().__init__(service_name="greengrass") @property def greengrass_backend(self): - return greengrass_backends[self.region] + return greengrass_backends[self.current_account][self.region] def core_definitions(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/guardduty/models.py b/moto/guardduty/models.py index ad12b2098..fcdf53dd3 100644 --- a/moto/guardduty/models.py +++ b/moto/guardduty/models.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, get_random_hex from datetime import datetime @@ -21,6 +21,7 @@ class GuardDutyBackend(BaseBackend): finding_publishing_frequency = "SIX_HOURS" detector = Detector( + account_id=self.account_id, created_at=datetime.now(), finding_publish_freq=finding_publishing_frequency, enabled=enable, @@ -121,6 +122,7 @@ class Filter(BaseModel): class Detector(BaseModel): def __init__( self, + account_id, created_at, finding_publish_freq, enabled, @@ -130,7 +132,7 @@ class Detector(BaseModel): self.id = get_random_hex(length=32) self.created_at = created_at self.finding_publish_freq = finding_publish_freq - self.service_role = f"arn:aws:iam::{get_account_id()}:role/aws-service-role/guardduty.amazonaws.com/AWSServiceRoleForAmazonGuardDuty" + self.service_role = f"arn:aws:iam::{account_id}:role/aws-service-role/guardduty.amazonaws.com/AWSServiceRoleForAmazonGuardDuty" self.enabled = enabled self.updated_at = created_at self.datasources = datasources or {} diff --git a/moto/guardduty/responses.py b/moto/guardduty/responses.py index 104ed745d..50f39e5de 100644 --- a/moto/guardduty/responses.py +++ b/moto/guardduty/responses.py @@ -6,11 +6,12 @@ from urllib.parse import unquote class GuardDutyResponse(BaseResponse): - SERVICE_NAME = "guardduty" + def __init__(self): + super().__init__(service_name="guardduty") @property def guardduty_backend(self): - return guardduty_backends[self.region] + return guardduty_backends[self.current_account][self.region] def filter(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/iam/access_control.py b/moto/iam/access_control.py index 5c2d24bdf..3f7f89026 100644 --- a/moto/iam/access_control.py +++ b/moto/iam/access_control.py @@ -22,7 +22,6 @@ from botocore.auth import SigV4Auth, S3SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials -from moto.core import get_account_id from moto.core.exceptions import ( SignatureDoesNotMatchError, AccessDeniedError, @@ -45,20 +44,22 @@ from .models import iam_backends, Policy log = logging.getLogger(__name__) -def create_access_key(access_key_id, headers): +def create_access_key(account_id, access_key_id, headers): if access_key_id.startswith("AKIA") or "X-Amz-Security-Token" not in headers: - return IAMUserAccessKey(access_key_id, headers) + return IAMUserAccessKey(account_id, access_key_id, headers) else: - return AssumedRoleAccessKey(access_key_id, headers) + return AssumedRoleAccessKey(account_id, access_key_id, headers) -class IAMUserAccessKey(object): +class IAMUserAccessKey: @property def backend(self): - return iam_backends["global"] + return iam_backends[self.account_id]["global"] - def __init__(self, access_key_id, headers): + def __init__(self, account_id, access_key_id, headers): + self.account_id = account_id iam_users = self.backend.list_users("/", None, None) + for iam_user in iam_users: for access_key in iam_user.access_keys: if access_key.access_key_id == access_key_id: @@ -73,7 +74,7 @@ class IAMUserAccessKey(object): @property def arn(self): return "arn:aws:iam::{account_id}:user/{iam_user_name}".format( - account_id=get_account_id(), iam_user_name=self._owner_user_name + account_id=self.account_id, iam_user_name=self._owner_user_name ) def create_credentials(self): @@ -116,10 +117,11 @@ class IAMUserAccessKey(object): class AssumedRoleAccessKey(object): @property def backend(self): - return iam_backends["global"] + return iam_backends[self.account_id]["global"] - def __init__(self, access_key_id, headers): - for assumed_role in sts_backends["global"].assumed_roles: + def __init__(self, account_id, access_key_id, headers): + self.account_id = account_id + for assumed_role in sts_backends[account_id]["global"].assumed_roles: if assumed_role.access_key_id == access_key_id: self._access_key_id = access_key_id self._secret_access_key = assumed_role.secret_access_key @@ -135,7 +137,7 @@ class AssumedRoleAccessKey(object): def arn(self): return ( "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=get_account_id(), + account_id=self.account_id, role_name=self._owner_role_name, session_name=self._session_name, ) @@ -171,7 +173,7 @@ class CreateAccessKeyFailure(Exception): class IAMRequestBase(object, metaclass=ABCMeta): - def __init__(self, method, path, data, headers): + def __init__(self, account_id, method, path, data, headers): log.debug( "Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( class_name=self.__class__.__name__, @@ -181,6 +183,7 @@ class IAMRequestBase(object, metaclass=ABCMeta): headers=headers, ) ) + self.account_id = account_id self._method = method self._path = path self._data = data @@ -202,7 +205,9 @@ class IAMRequestBase(object, metaclass=ABCMeta): ) try: self._access_key = create_access_key( - access_key_id=credential_data[0], headers=headers + account_id=self.account_id, + access_key_id=credential_data[0], + headers=headers, ) except CreateAccessKeyFailure as e: self._raise_invalid_access_key(e.reason) diff --git a/moto/iam/config.py b/moto/iam/config.py index 357907db9..b3386eba5 100644 --- a/moto/iam/config.py +++ b/moto/iam/config.py @@ -8,6 +8,7 @@ from moto.iam import iam_backends class RoleConfigQuery(ConfigQueryModel): def list_config_service_resources( self, + account_id, resource_ids, resource_name, limit, @@ -22,7 +23,7 @@ class RoleConfigQuery(ConfigQueryModel): # Stored in moto backend with the AWS-assigned random string like "AROA0BSVNSZKXVHS00SBJ" # Grab roles from backend; need the full values since names and id's are different - role_list = list(self.backends["global"].roles.values()) + role_list = list(self.backends[account_id]["global"].roles.values()) if not role_list: return [], None @@ -126,10 +127,15 @@ class RoleConfigQuery(ConfigQueryModel): ) def get_config_resource( - self, resource_id, resource_name=None, backend_region=None, resource_region=None + self, + account_id, + resource_id, + resource_name=None, + backend_region=None, + resource_region=None, ): - role = self.backends["global"].roles.get(resource_id, {}) + role = self.backends[account_id]["global"].roles.get(resource_id, {}) if not role: return @@ -154,6 +160,7 @@ class RoleConfigQuery(ConfigQueryModel): class PolicyConfigQuery(ConfigQueryModel): def list_config_service_resources( self, + account_id, resource_ids, resource_name, limit, @@ -167,7 +174,9 @@ class PolicyConfigQuery(ConfigQueryModel): # The resource name is a user-assigned string like "my-development-policy" # Stored in moto backend with the arn like "arn:aws:iam::123456789012:policy/my-development-policy" - policy_list = list(self.backends["global"].managed_policies.values()) + policy_list = list( + self.backends[account_id]["global"].managed_policies.values() + ) # We don't want to include AWS Managed Policies. This technically needs to # respect the configuration recorder's 'includeGlobalResourceTypes' setting, @@ -286,13 +295,18 @@ class PolicyConfigQuery(ConfigQueryModel): ) def get_config_resource( - self, resource_id, resource_name=None, backend_region=None, resource_region=None + self, + account_id, + resource_id, + resource_name=None, + backend_region=None, + resource_region=None, ): # policies are listed in the backend as arns, but we have to accept the PolicyID as the resource_id # we'll make a really crude search for it policy = None - for arn in self.backends["global"].managed_policies.keys(): - policy_candidate = self.backends["global"].managed_policies[arn] + for arn in self.backends[account_id]["global"].managed_policies.keys(): + policy_candidate = self.backends[account_id]["global"].managed_policies[arn] if policy_candidate.id == resource_id: policy = policy_candidate break diff --git a/moto/iam/models.py b/moto/iam/models.py index 43f8d1cbb..9e56aaf07 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -13,9 +13,10 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from jinja2 import Template +from typing import Mapping from urllib import parse from moto.core.exceptions import RESTError -from moto.core import BaseBackend, BaseModel, get_account_id, CloudFormationModel +from moto.core import DEFAULT_ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import ( iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds, @@ -62,6 +63,24 @@ SERVICE_NAME_CONVERSION = { } +def get_account_id_from(access_key): + for account_id, account in iam_backends.items(): + if access_key in account["global"].access_keys: + return account_id + return DEFAULT_ACCOUNT_ID + + +def mark_account_as_visited(account_id, access_key, service, region): + account = iam_backends[account_id] + if access_key in account["global"].access_keys: + account["global"].access_keys[access_key].last_used = AccessKeyLastUsed( + timestamp=datetime.utcnow(), service=service, region=region + ) + else: + # User provided access credentials unknown to us + pass + + LIMIT_KEYS_PER_USER = 2 @@ -80,10 +99,8 @@ class MFADevice(object): class VirtualMfaDevice(object): - def __init__(self, device_name): - self.serial_number = "arn:aws:iam::{0}:mfa{1}".format( - get_account_id(), device_name - ) + def __init__(self, account_id, device_name): + self.serial_number = f"arn:aws:iam::{account_id}:mfa{device_name}" random_base32_string = "".join( random.choice(string.ascii_uppercase + "234567") for _ in range(64) @@ -114,6 +131,7 @@ class Policy(CloudFormationModel): def __init__( self, name, + account_id, default_version_id=None, description=None, document=None, @@ -123,7 +141,7 @@ class Policy(CloudFormationModel): tags=None, ): self.name = name - + self.account_id = account_id self.attachment_count = 0 self.description = description or "" self.id = random_policy_id() @@ -166,20 +184,24 @@ class Policy(CloudFormationModel): class SAMLProvider(BaseModel): - def __init__(self, name, saml_metadata_document=None): + def __init__(self, account_id, name, saml_metadata_document=None): + self.account_id = account_id self.name = name self.saml_metadata_document = saml_metadata_document @property def arn(self): - return "arn:aws:iam::{0}:saml-provider/{1}".format(get_account_id(), self.name) + return f"arn:aws:iam::{self.account_id}:saml-provider/{self.name}" class OpenIDConnectProvider(BaseModel): - def __init__(self, url, thumbprint_list, client_id_list=None, tags=None): + def __init__( + self, account_id, url, thumbprint_list, client_id_list=None, tags=None + ): self._errors = [] self._validate(url, thumbprint_list, client_id_list) + self.account_id = account_id parsed_url = parse.urlparse(url) self.url = parsed_url.netloc + parsed_url.path self.thumbprint_list = thumbprint_list @@ -189,7 +211,7 @@ class OpenIDConnectProvider(BaseModel): @property def arn(self): - return "arn:aws:iam::{0}:oidc-provider/{1}".format(get_account_id(), self.url) + return f"arn:aws:iam::{self.account_id}:oidc-provider/{self.url}" @property def created_iso_8601(self): @@ -282,6 +304,10 @@ class PolicyVersion(object): class ManagedPolicy(Policy, CloudFormationModel): """Managed policy.""" + @property + def backend(self): + return iam_backends[self.account_id]["global"] + is_attachable = True def attach_to(self, obj): @@ -295,7 +321,7 @@ class ManagedPolicy(Policy, CloudFormationModel): @property def arn(self): return "arn:aws:iam::{0}:policy{1}{2}".format( - get_account_id(), self.path, self.name + self.account_id, self.path, self.name ) def to_config_dict(self): @@ -306,7 +332,7 @@ class ManagedPolicy(Policy, CloudFormationModel): "configurationStateId": str( int(time.mktime(self.create_date.timetuple())) ), # PY2 and 3 compatible - "arn": "arn:aws:iam::{}:policy/{}".format(get_account_id(), self.name), + "arn": "arn:aws:iam::{}:policy/{}".format(self.account_id, self.name), "resourceType": "AWS::IAM::Policy", "resourceId": self.id, "resourceName": self.name, @@ -317,7 +343,7 @@ class ManagedPolicy(Policy, CloudFormationModel): "configuration": { "policyName": self.name, "policyId": self.id, - "arn": "arn:aws:iam::{}:policy/{}".format(get_account_id(), self.name), + "arn": "arn:aws:iam::{}:policy/{}".format(self.account_id, self.name), "path": self.path, "defaultVersionId": self.default_version_id, "attachmentCount": self.attachment_count, @@ -357,7 +383,7 @@ class ManagedPolicy(Policy, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json.get("Properties", {}) policy_document = json.dumps(properties.get("PolicyDocument")) @@ -369,7 +395,7 @@ class ManagedPolicy(Policy, CloudFormationModel): role_names = properties.get("Roles", []) tags = properties.get("Tags", {}) - policy = iam_backends["global"].create_policy( + policy = iam_backends[account_id]["global"].create_policy( description=description, path=path, policy_document=policy_document, @@ -377,15 +403,15 @@ class ManagedPolicy(Policy, CloudFormationModel): tags=tags, ) for group_name in group_names: - iam_backends["global"].attach_group_policy( + iam_backends[account_id]["global"].attach_group_policy( group_name=group_name, policy_arn=policy.arn ) for user_name in user_names: - iam_backends["global"].attach_user_policy( + iam_backends[account_id]["global"].attach_user_policy( user_name=user_name, policy_arn=policy.arn ) for role_name in role_names: - iam_backends["global"].attach_role_policy( + iam_backends[account_id]["global"].attach_role_policy( role_name=role_name, policy_arn=policy.arn ) return policy @@ -399,9 +425,10 @@ class AWSManagedPolicy(ManagedPolicy): """AWS-managed policy.""" @classmethod - def from_data(cls, name, data): + def from_data(cls, name, account_id, data): return cls( name, + account_id=account_id, default_version_id=data.get("DefaultVersionId"), path=data.get("Path"), document=json.dumps(data.get("Document")), @@ -418,16 +445,6 @@ class AWSManagedPolicy(ManagedPolicy): return "arn:aws:iam::aws:policy{0}{1}".format(self.path, self.name) -# AWS defines some of its own managed policies and we periodically -# import them via `make aws_managed_policies` -# FIXME: Takes about 40ms at import time -aws_managed_policies_data_parsed = json.loads(aws_managed_policies_data) -aws_managed_policies = [ - AWSManagedPolicy.from_data(name, d) - for name, d in aws_managed_policies_data_parsed.items() -] - - class InlinePolicy(CloudFormationModel): # Represents an Inline Policy created by CloudFormation def __init__( @@ -468,7 +485,7 @@ class InlinePolicy(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json.get("Properties", {}) policy_document = properties.get("PolicyDocument") @@ -477,7 +494,7 @@ class InlinePolicy(CloudFormationModel): role_names = properties.get("Roles") group_names = properties.get("Groups") - return iam_backends["global"].create_inline_policy( + return iam_backends[account_id]["global"].create_inline_policy( resource_name, policy_name, policy_document, @@ -488,7 +505,12 @@ class InlinePolicy(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] @@ -497,11 +519,14 @@ class InlinePolicy(CloudFormationModel): if resource_name_property not in properties: properties[resource_name_property] = new_resource_name new_resource = cls.create_from_cloudformation_json( - properties[resource_name_property], cloudformation_json, region_name + properties[resource_name_property], + cloudformation_json, + account_id, + region_name, ) properties[resource_name_property] = original_resource.name cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return new_resource @@ -513,7 +538,7 @@ class InlinePolicy(CloudFormationModel): role_names = properties.get("Roles") group_names = properties.get("Groups") - return iam_backends["global"].update_inline_policy( + return iam_backends[account_id]["global"].update_inline_policy( original_resource.name, policy_name, policy_document, @@ -524,9 +549,9 @@ class InlinePolicy(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - iam_backends["global"].delete_inline_policy(resource_name) + iam_backends[account_id]["global"].delete_inline_policy(resource_name) @staticmethod def is_replacement_update(properties): @@ -574,6 +599,7 @@ class InlinePolicy(CloudFormationModel): class Role(CloudFormationModel): def __init__( self, + account_id, role_id, name, assume_role_policy_document, @@ -584,6 +610,7 @@ class Role(CloudFormationModel): max_session_duration, linked_service=None, ): + self.account_id = account_id self.id = role_id self.name = name self.assume_role_policy_document = assume_role_policy_document @@ -619,12 +646,13 @@ class Role(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] role_name = properties.get("RoleName", resource_name) - role = iam_backends["global"].create_role( + iam_backend = iam_backends[account_id]["global"] + role = iam_backend.create_role( role_name=role_name, assume_role_policy_document=properties["AssumeRolePolicyDocument"], path=properties.get("Path", "/"), @@ -644,24 +672,23 @@ class Role(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - for profile in iam_backends["global"].instance_profiles.values(): + backend = iam_backends[account_id]["global"] + for profile in backend.instance_profiles.values(): profile.delete_role(role_name=resource_name) - for role in iam_backends["global"].roles.values(): + for role in backend.roles.values(): if role.name == resource_name: for arn in role.policies.keys(): role.delete_policy(arn) - iam_backends["global"].delete_role(resource_name) + backend.delete_role(resource_name) @property def arn(self): if self._linked_service: - return f"arn:aws:iam::{get_account_id()}:role/aws-service-role/{self._linked_service}/{self.name}" - return "arn:aws:iam::{0}:role{1}{2}".format( - get_account_id(), self.path, self.name - ) + return f"arn:aws:iam::{self.account_id}:role/aws-service-role/{self._linked_service}/{self.name}" + return f"arn:aws:iam::{self.account_id}:role{self.path}{self.name}" def to_config_dict(self): _managed_policies = [] @@ -669,7 +696,9 @@ class Role(CloudFormationModel): _managed_policies.append( { "policyArn": key, - "policyName": iam_backends["global"].managed_policies[key].name, + "policyName": iam_backends[self.account_id]["global"] + .managed_policies[key] + .name, } ) @@ -680,7 +709,9 @@ class Role(CloudFormationModel): ) _instance_profiles = [] - for key, instance_profile in iam_backends["global"].instance_profiles.items(): + for key, instance_profile in iam_backends[self.account_id][ + "global" + ].instance_profiles.items(): for _ in instance_profile.roles: _instance_profiles.append(instance_profile.to_embedded_config_dict()) break @@ -692,7 +723,7 @@ class Role(CloudFormationModel): "configurationStateId": str( int(time.mktime(self.create_date.timetuple())) ), # PY2 and 3 compatible - "arn": "arn:aws:iam::{}:role/{}".format(get_account_id(), self.name), + "arn": f"arn:aws:iam::{self.account_id}:role/{self.name}", "resourceType": "AWS::IAM::Role", "resourceId": self.name, "resourceName": self.name, @@ -706,7 +737,7 @@ class Role(CloudFormationModel): "path": self.path, "roleName": self.name, "roleId": self.id, - "arn": "arn:aws:iam::{}:role/{}".format(get_account_id(), self.name), + "arn": f"arn:aws:iam::{self.account_id}:role/{self.name}", "assumeRolePolicyDocument": parse.quote( self.assume_role_policy_document ) @@ -809,8 +840,9 @@ class Role(CloudFormationModel): class InstanceProfile(CloudFormationModel): - def __init__(self, instance_profile_id, name, path, roles, tags=None): + def __init__(self, account_id, instance_profile_id, name, path, roles, tags=None): self.id = instance_profile_id + self.account_id = account_id self.name = name self.path = path or "/" self.roles = roles if roles else [] @@ -832,12 +864,12 @@ class InstanceProfile(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] role_names = properties["Roles"] - return iam_backends["global"].create_instance_profile( + return iam_backends[account_id]["global"].create_instance_profile( name=resource_name, path=properties.get("Path", "/"), role_names=role_names, @@ -845,18 +877,16 @@ class InstanceProfile(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - iam_backends["global"].delete_instance_profile(resource_name) + iam_backends[account_id]["global"].delete_instance_profile(resource_name) def delete_role(self, role_name): self.roles = [role for role in self.roles if role.name != role_name] @property def arn(self): - return "arn:aws:iam::{0}:instance-profile{1}{2}".format( - get_account_id(), self.path, self.name - ) + return f"arn:aws:iam::{self.account_id}:instance-profile{self.path}{self.name}" @property def physical_resource_id(self): @@ -883,9 +913,7 @@ class InstanceProfile(CloudFormationModel): "path": role.path, "roleName": role.name, "roleId": role.id, - "arn": "arn:aws:iam::{}:role/{}".format( - get_account_id(), role.name - ), + "arn": f"arn:aws:iam::{self.account_id}:role/{role.name}", "createDate": str(role.create_date), "assumeRolePolicyDocument": parse.quote( role.assume_role_policy_document @@ -907,16 +935,17 @@ class InstanceProfile(CloudFormationModel): "path": self.path, "instanceProfileName": self.name, "instanceProfileId": self.id, - "arn": "arn:aws:iam::{}:instance-profile/{}".format( - get_account_id(), self.name - ), + "arn": f"arn:aws:iam::{self.account_id}:instance-profile/{role.name}", "createDate": str(self.create_date), "roles": roles, } class Certificate(BaseModel): - def __init__(self, cert_name, cert_body, private_key, cert_chain=None, path=None): + def __init__( + self, account_id, cert_name, cert_body, private_key, cert_chain=None, path=None + ): + self.account_id = account_id self.cert_name = cert_name if cert_body: cert_body = cert_body.rstrip() @@ -931,9 +960,7 @@ class Certificate(BaseModel): @property def arn(self): - return "arn:aws:iam::{0}:server-certificate{1}{2}".format( - get_account_id(), self.path, self.cert_name - ) + return f"arn:aws:iam::{self.account_id}:server-certificate{self.path}{self.cert_name}" class SigningCertificate(BaseModel): @@ -949,23 +976,30 @@ class SigningCertificate(BaseModel): return iso_8601_datetime_without_milliseconds(self.upload_date) +class AccessKeyLastUsed: + def __init__(self, timestamp, service, region): + self._timestamp = timestamp + self.service = service + self.region = region + + @property + def timestamp(self): + return iso_8601_datetime_without_milliseconds(self._timestamp) + + class AccessKey(CloudFormationModel): - def __init__(self, user_name, status="Active"): + def __init__(self, user_name, prefix, status="Active"): self.user_name = user_name - self.access_key_id = "AKIA" + random_access_key() + self.access_key_id = prefix + random_access_key() self.secret_access_key = random_alphanumeric(40) self.status = status self.create_date = datetime.utcnow() - self.last_used = None + self.last_used: AccessKeyLastUsed = None @property def created_iso_8601(self): return iso_8601_datetime_without_milliseconds(self.create_date) - @property - def last_used_iso_8601(self): - return iso_8601_datetime_without_milliseconds(self.last_used) - @classmethod def has_cfn_attr(cls, attr): return attr in ["SecretAccessKey"] @@ -987,41 +1021,51 @@ class AccessKey(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json.get("Properties", {}) user_name = properties.get("UserName") status = properties.get("Status", "Active") - return iam_backends["global"].create_access_key(user_name, status=status) + return iam_backends[account_id]["global"].create_access_key( + user_name, status=status + ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] if cls.is_replacement_update(properties): new_resource = cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) cls.delete_from_cloudformation_json( - original_resource.physical_resource_id, cloudformation_json, region_name + original_resource.physical_resource_id, + cloudformation_json, + account_id, + region_name, ) return new_resource else: # No Interruption properties = cloudformation_json.get("Properties", {}) status = properties.get("Status") - return iam_backends["global"].update_access_key( + return iam_backends[account_id]["global"].update_access_key( original_resource.user_name, original_resource.access_key_id, status ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - iam_backends["global"].delete_access_key_by_name(resource_name) + iam_backends[account_id]["global"].delete_access_key_by_name(resource_name) @staticmethod def is_replacement_update(properties): @@ -1053,7 +1097,8 @@ class SshPublicKey(BaseModel): class Group(BaseModel): - def __init__(self, name, path="/"): + def __init__(self, account_id, name, path="/"): + self.account_id = account_id self.name = name self.id = random_resource_id() self.path = path @@ -1081,12 +1126,10 @@ class Group(BaseModel): @property def arn(self): if self.path == "/": - return "arn:aws:iam::{0}:group/{1}".format(get_account_id(), self.name) + return f"arn:aws:iam::{self.account_id}:group/{self.name}" else: - return "arn:aws:iam::{0}:group/{1}/{2}".format( - get_account_id(), self.path, self.name - ) + return f"arn:aws:iam::{self.account_id}:group/{self.path}/{self.name}" def get_policy(self, policy_name): try: @@ -1114,7 +1157,8 @@ class Group(BaseModel): class User(CloudFormationModel): - def __init__(self, name, path=None): + def __init__(self, account_id, name, path=None): + self.account_id = account_id self.name = name self.id = random_resource_id() self.path = path if path else "/" @@ -1122,7 +1166,7 @@ class User(CloudFormationModel): self.mfa_devices = {} self.policies = {} self.managed_policies = {} - self.access_keys = [] + self.access_keys: Mapping[str, AccessKey] = [] self.ssh_public_keys = [] self.password = None self.password_last_used = None @@ -1131,9 +1175,7 @@ class User(CloudFormationModel): @property def arn(self): - return "arn:aws:iam::{0}:user{1}{2}".format( - get_account_id(), self.path, self.name - ) + return f"arn:aws:iam::{self.account_id}:user{self.path}{self.name}" @property def created_iso_8601(self): @@ -1164,8 +1206,8 @@ class User(CloudFormationModel): del self.policies[policy_name] - def create_access_key(self, status="Active"): - access_key = AccessKey(self.name, status) + def create_access_key(self, prefix, status="Active") -> AccessKey: + access_key = AccessKey(self.name, prefix=prefix, status=status) self.access_keys.append(access_key) return access_key @@ -1328,16 +1370,21 @@ class User(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json.get("Properties", {}) path = properties.get("Path") - user, _ = iam_backends["global"].create_user(resource_name, path) + user, _ = iam_backends[account_id]["global"].create_user(resource_name, path) return user @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] @@ -1346,11 +1393,14 @@ class User(CloudFormationModel): if resource_name_property not in properties: properties[resource_name_property] = new_resource_name new_resource = cls.create_from_cloudformation_json( - properties[resource_name_property], cloudformation_json, region_name + properties[resource_name_property], + cloudformation_json, + account_id, + region_name, ) properties[resource_name_property] = original_resource.name cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return new_resource @@ -1361,9 +1411,9 @@ class User(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - iam_backends["global"].delete_user(resource_name) + iam_backends[account_id]["global"].delete_user(resource_name) @staticmethod def is_replacement_update(properties): @@ -1596,13 +1646,15 @@ def filter_items_with_path_prefix(path_prefix, items): class IAMBackend(BaseBackend): - def __init__(self, region_name, account_id=None): + def __init__(self, region_name, account_id=None, aws_policies=None): + super().__init__(region_name=region_name, account_id=account_id) self.instance_profiles = {} self.roles = {} self.certificates = {} self.groups = {} self.users = {} self.credential_report = None + self.aws_managed_policies = aws_policies or self._init_aws_policies() self.managed_policies = self._init_managed_policies() self.account_aliases = [] self.saml_providers = {} @@ -1615,10 +1667,27 @@ class IAMBackend(BaseBackend): self.access_keys = {} self.tagger = TaggingService() - super().__init__(region_name=region_name, account_id=account_id) + + def _init_aws_policies(self): + # AWS defines some of its own managed policies and we periodically + # import them via `make aws_managed_policies` + aws_managed_policies_data_parsed = json.loads(aws_managed_policies_data) + return [ + AWSManagedPolicy.from_data(name, self.account_id, d) + for name, d in aws_managed_policies_data_parsed.items() + ] def _init_managed_policies(self): - return dict((p.arn, p) for p in aws_managed_policies) + return dict((p.arn, p) for p in self.aws_managed_policies) + + def reset(self): + region_name = self.region_name + account_id = self.account_id + # Do not reset these policies, as they take a long time to load + aws_policies = self.aws_managed_policies + self._reset_model_refs() + self.__dict__ = {} + self.__init__(region_name, account_id, aws_policies) def attach_role_policy(self, policy_arn, role_name): arns = dict((p.arn, p) for p in self.managed_policies.values()) @@ -1702,6 +1771,7 @@ class IAMBackend(BaseBackend): clean_tags = self._tag_verification(tags) policy = ManagedPolicy( policy_name, + account_id=self.account_id, description=description, document=policy_document, path=path, @@ -1817,6 +1887,7 @@ class IAMBackend(BaseBackend): clean_tags = self._tag_verification(tags) role = Role( + self.account_id, role_id, role_name, assume_role_policy_document, @@ -2078,8 +2149,10 @@ class IAMBackend(BaseBackend): instance_profile_id = random_resource_id() - roles = [iam_backends["global"].get_role(role_name) for role_name in role_names] - instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags) + roles = [self.get_role(role_name) for role_name in role_names] + instance_profile = InstanceProfile( + self.account_id, instance_profile_id, name, path, roles, tags + ) self.instance_profiles[name] = instance_profile return instance_profile @@ -2141,7 +2214,9 @@ class IAMBackend(BaseBackend): self, cert_name, cert_body, private_key, cert_chain=None, path=None ): certificate_id = random_resource_id() - cert = Certificate(cert_name, cert_body, private_key, cert_chain, path) + cert = Certificate( + self.account_id, cert_name, cert_body, private_key, cert_chain, path + ) self.certificates[certificate_id] = cert return cert @@ -2178,7 +2253,7 @@ class IAMBackend(BaseBackend): if group_name in self.groups: raise IAMConflictException("Group {0} already exists".format(group_name)) - group = Group(group_name, path) + group = Group(self.account_id, group_name, path) self.groups[group_name] = group return group @@ -2262,12 +2337,12 @@ class IAMBackend(BaseBackend): "EntityAlreadyExists", "User {0} already exists".format(user_name) ) - user = User(user_name, path) + user = User(self.account_id, user_name, path) self.tagger.tag_resource(user.arn, tags or []) self.users[user_name] = user return user, self.tagger.list_tags_for_resource(user.arn) - def get_user(self, name): + def get_user(self, name) -> User: user = self.users.get(name) if not user: @@ -2448,15 +2523,21 @@ class IAMBackend(BaseBackend): policy = self.get_policy(policy_arn) del self.managed_policies[policy.arn] - def create_access_key(self, user_name=None, status="Active"): - user = self.get_user(user_name) + def create_access_key(self, user_name=None, prefix="AKIA", status="Active"): keys = self.list_access_keys(user_name) if len(keys) >= LIMIT_KEYS_PER_USER: raise IAMLimitExceededException( f"Cannot exceed quota for AccessKeysPerUser: {LIMIT_KEYS_PER_USER}" ) + user = self.get_user(user_name) + key = user.create_access_key(prefix=prefix, status=status) + self.access_keys[key.physical_resource_id] = key + return key + + def create_temp_access_key(self): + # Temporary access keys such as the ones returned by STS when assuming a role temporarily + key = AccessKey(user_name=None, prefix="ASIA") - key = user.create_access_key(status) self.access_keys[key.physical_resource_id] = key return key @@ -2468,7 +2549,7 @@ class IAMBackend(BaseBackend): access_keys_list = self.get_all_access_keys_for_all_users() for key in access_keys_list: if key.access_key_id == access_key_id: - return {"user_name": key.user_name, "last_used": key.last_used_iso_8601} + return {"user_name": key.user_name, "last_used": key.last_used} raise IAMNotFoundException( f"The Access Key with id {access_key_id} cannot be found" @@ -2476,8 +2557,9 @@ class IAMBackend(BaseBackend): def get_all_access_keys_for_all_users(self): access_keys_list = [] - for user_name in self.users: - access_keys_list += self.list_access_keys(user_name) + for account in iam_backends.values(): + for user_name in account["global"].users: + access_keys_list += account["global"].list_access_keys(user_name) return access_keys_list def list_access_keys(self, user_name): @@ -2592,7 +2674,7 @@ class IAMBackend(BaseBackend): "Member must have length less than or equal to 512" ) - device = VirtualMfaDevice(path + device_name) + device = VirtualMfaDevice(self.account_id, path + device_name) if device.serial_number in self.virtual_mfa_devices: raise EntityAlreadyExists( @@ -2678,7 +2760,7 @@ class IAMBackend(BaseBackend): def get_account_authorization_details(self, policy_filter): policies = self.managed_policies.values() - local_policies = set(policies) - set(aws_managed_policies) + local_policies = set(policies) - set(self.aws_managed_policies) returned_policies = [] if len(policy_filter) == 0: @@ -2691,7 +2773,7 @@ class IAMBackend(BaseBackend): } if "AWSManagedPolicy" in policy_filter: - returned_policies = aws_managed_policies + returned_policies = self.aws_managed_policies if "LocalManagedPolicy" in policy_filter: returned_policies = returned_policies + list(local_policies) @@ -2704,7 +2786,7 @@ class IAMBackend(BaseBackend): } def create_saml_provider(self, name, saml_metadata_document): - saml_provider = SAMLProvider(name, saml_metadata_document) + saml_provider = SAMLProvider(self.account_id, name, saml_metadata_document) self.saml_providers[name] = saml_provider return saml_provider @@ -2747,7 +2829,7 @@ class IAMBackend(BaseBackend): ): clean_tags = self._tag_verification(tags) open_id_provider = OpenIDConnectProvider( - url, thumbprint_list, client_id_list, clean_tags + self.account_id, url, thumbprint_list, client_id_list, clean_tags ) if open_id_provider.arn in self.open_id_providers: @@ -2833,9 +2915,7 @@ class IAMBackend(BaseBackend): def get_account_password_policy(self): if not self.account_password_policy: raise NoSuchEntity( - "The Password Policy with domain name {} cannot be found.".format( - get_account_id() - ) + f"The Password Policy with domain name {self.account_id} cannot be found." ) return self.account_password_policy @@ -2963,6 +3043,6 @@ class IAMBackend(BaseBackend): return True -iam_backends = BackendDict( +iam_backends: Mapping[str, Mapping[str, IAMBackend]] = BackendDict( IAMBackend, "iam", use_boto3_regions=False, additional_regions=["global"] ) diff --git a/moto/iam/responses.py b/moto/iam/responses.py index d9ab5e06a..e0b6a707e 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -4,9 +4,12 @@ from .models import iam_backends, User class IamResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="iam") + @property def backend(self): - return iam_backends["global"] + return iam_backends[self.current_account]["global"] def attach_role_policy(self): policy_arn = self._get_param("PolicyArn") @@ -524,10 +527,10 @@ class IamResponse(BaseResponse): def get_user(self): user_name = self._get_param("UserName") if not user_name: - access_key_id = self.get_current_user() + access_key_id = self.get_access_key() user = self.backend.get_user_from_access_key_id(access_key_id) if user is None: - user = User("default_user") + user = User(self.current_account, "default_user") else: user = self.backend.get_user(user_name) tags = self.backend.tagger.list_tags_for_resource(user.arn).get("Tags", []) @@ -640,7 +643,7 @@ class IamResponse(BaseResponse): def create_access_key(self): user_name = self._get_param("UserName") if not user_name: - access_key_id = self.get_current_user() + access_key_id = self.get_access_key() access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] @@ -672,7 +675,7 @@ class IamResponse(BaseResponse): def list_access_keys(self): user_name = self._get_param("UserName") if not user_name: - access_key_id = self.get_current_user() + access_key_id = self.get_access_key() access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] @@ -1960,10 +1963,13 @@ GET_ACCESS_KEY_LAST_USED_TEMPLATE = """ {{ user_name }} {% if last_used %} - {{ last_used }} - {% endif %} + {{ last_used.timestamp }} + {{ last_used.service }} + {{ last_used.region }} + {% else %} N/A N/A + {% endif %} diff --git a/moto/instance_metadata/__init__.py b/moto/instance_metadata/__init__.py index 4f02388a4..767847804 100644 --- a/moto/instance_metadata/__init__.py +++ b/moto/instance_metadata/__init__.py @@ -1,3 +1 @@ -from .models import instance_metadata_backend - -instance_metadata_backends = {"global": instance_metadata_backend} +from .models import instance_metadata_backends # noqa diff --git a/moto/instance_metadata/models.py b/moto/instance_metadata/models.py index c8ba5115a..ee22e2fff 100644 --- a/moto/instance_metadata/models.py +++ b/moto/instance_metadata/models.py @@ -1,8 +1,14 @@ from moto.core import BaseBackend +from moto.core.utils import BackendDict class InstanceMetadataBackend(BaseBackend): pass -instance_metadata_backend = InstanceMetadataBackend(region_name="global") +instance_metadata_backends = BackendDict( + InstanceMetadataBackend, + "instance_metadata", + use_boto3_regions=False, + additional_regions=["global"], +) diff --git a/moto/instance_metadata/responses.py b/moto/instance_metadata/responses.py index f96675581..dbd4d3103 100644 --- a/moto/instance_metadata/responses.py +++ b/moto/instance_metadata/responses.py @@ -6,6 +6,12 @@ from moto.core.responses import BaseResponse class InstanceMetadataResponse(BaseResponse): + def __init__(self): + super().__init__(service_name=None) + + def backends(self): + pass + def metadata_response( self, request, full_url, headers ): # pylint: disable=unused-argument diff --git a/moto/iot/models.py b/moto/iot/models.py index 1c79656e1..fd01b31bb 100644 --- a/moto/iot/models.py +++ b/moto/iot/models.py @@ -14,7 +14,7 @@ from datetime import datetime, timedelta from .utils import PAGINATION_MODEL -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.utilities.utils import random_string from moto.utilities.paginator import paginate @@ -32,12 +32,12 @@ from .exceptions import ( class FakeThing(BaseModel): - def __init__(self, thing_name, thing_type, attributes, region_name): + def __init__(self, thing_name, thing_type, attributes, account_id, region_name): self.region_name = region_name self.thing_name = thing_name self.thing_type = thing_type self.attributes = attributes - self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:thing/{thing_name}" + self.arn = f"arn:aws:iot:{region_name}:{account_id}:thing/{thing_name}" self.version = 1 # TODO: we need to handle "version"? @@ -149,17 +149,17 @@ class FakeThingGroup(BaseModel): class FakeCertificate(BaseModel): - def __init__(self, certificate_pem, status, region_name, ca_certificate_id=None): + def __init__( + self, certificate_pem, status, account_id, region_name, ca_certificate_id=None + ): m = hashlib.sha256() m.update(certificate_pem.encode("utf-8")) self.certificate_id = m.hexdigest() - self.arn = ( - f"arn:aws:iot:{region_name}:{get_account_id()}:cert/{self.certificate_id}" - ) + self.arn = f"arn:aws:iot:{region_name}:{account_id}:cert/{self.certificate_id}" self.certificate_pem = certificate_pem self.status = status - self.owner = get_account_id() + self.owner = account_id self.transfer_data = {} self.creation_date = time.time() self.last_modified_date = self.creation_date @@ -199,10 +199,13 @@ class FakeCertificate(BaseModel): class FakeCaCertificate(FakeCertificate): - def __init__(self, ca_certificate, status, region_name, registration_config): + def __init__( + self, ca_certificate, status, account_id, region_name, registration_config + ): super().__init__( certificate_pem=ca_certificate, status=status, + account_id=account_id, region_name=region_name, ca_certificate_id=None, ) @@ -210,12 +213,14 @@ class FakeCaCertificate(FakeCertificate): class FakePolicy(BaseModel): - def __init__(self, name, document, region_name, default_version_id="1"): + def __init__(self, name, document, account_id, region_name, default_version_id="1"): self.name = name self.document = document - self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:policy/{name}" + self.arn = f"arn:aws:iot:{region_name}:{account_id}:policy/{name}" self.default_version_id = default_version_id - self.versions = [FakePolicyVersion(self.name, document, True, region_name)] + self.versions = [ + FakePolicyVersion(self.name, document, True, account_id, region_name) + ] self._max_version_id = self.versions[0]._version_id def to_get_dict(self): @@ -239,9 +244,11 @@ class FakePolicy(BaseModel): class FakePolicyVersion(object): - def __init__(self, policy_name, document, is_default, region_name, version_id=1): + def __init__( + self, policy_name, document, is_default, account_id, region_name, version_id=1 + ): self.name = policy_name - self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:policy/{policy_name}" + self.arn = f"arn:aws:iot:{region_name}:{account_id}:policy/{policy_name}" self.document = document or {} self.is_default = is_default self._version_id = version_id @@ -663,7 +670,9 @@ class IoTBackend(BaseBackend): attributes = {} else: attributes = attribute_payload["attributes"] - thing = FakeThing(thing_name, thing_type, attributes, self.region_name) + thing = FakeThing( + thing_name, thing_type, attributes, self.account_id, self.region_name + ) self.things[thing.arn] = thing return thing.thing_name, thing.arn @@ -843,7 +852,9 @@ class IoTBackend(BaseBackend): } certificate_pem = self._random_string() status = "ACTIVE" if set_as_active else "INACTIVE" - certificate = FakeCertificate(certificate_pem, status, self.region_name) + certificate = FakeCertificate( + certificate_pem, status, self.account_id, self.region_name + ) self.certificates[certificate.certificate_id] = certificate return certificate, key_pair @@ -937,6 +948,7 @@ class IoTBackend(BaseBackend): certificate = FakeCaCertificate( ca_certificate=ca_certificate, status="ACTIVE" if set_as_active else "INACTIVE", + account_id=self.account_id, region_name=self.region_name, registration_config=registration_config, ) @@ -957,6 +969,7 @@ class IoTBackend(BaseBackend): certificate = FakeCertificate( certificate_pem, "ACTIVE" if set_as_active else status, + self.account_id, self.region_name, ca_certificate_id, ) @@ -968,7 +981,9 @@ class IoTBackend(BaseBackend): return certificate def register_certificate_without_ca(self, certificate_pem, status): - certificate = FakeCertificate(certificate_pem, status, self.region_name) + certificate = FakeCertificate( + certificate_pem, status, self.account_id, self.region_name + ) self.__raise_if_certificate_already_exists( certificate.certificate_id, certificate_arn=certificate.arn ) @@ -999,7 +1014,9 @@ class IoTBackend(BaseBackend): current_policy.name, current_policy.arn, ) - policy = FakePolicy(policy_name, policy_document, self.region_name) + policy = FakePolicy( + policy_name, policy_document, self.account_id, self.region_name + ) self.policies[policy.name] = policy return policy @@ -1065,6 +1082,7 @@ class IoTBackend(BaseBackend): policy_name, policy_document, set_as_default, + self.account_id, self.region_name, version_id=policy._max_version_id, ) @@ -1133,7 +1151,7 @@ class IoTBackend(BaseBackend): ) from moto.cognitoidentity import cognitoidentity_backends - cognito = cognitoidentity_backends[self.region_name] + cognito = cognitoidentity_backends[self.account_id][self.region_name] identities = [] for identity_pool in cognito.identity_pools: pool_identities = cognito.pools_identities.get(identity_pool, None) diff --git a/moto/iot/responses.py b/moto/iot/responses.py index 6223c20de..d9cf2b08d 100644 --- a/moto/iot/responses.py +++ b/moto/iot/responses.py @@ -6,11 +6,12 @@ from .models import iot_backends class IoTResponse(BaseResponse): - SERVICE_NAME = "iot" + def __init__(self): + super().__init__(service_name="iot") @property def iot_backend(self): - return iot_backends[self.region] + return iot_backends[self.current_account][self.region] def create_certificate_from_csr(self): certificate_signing_request = self._get_param("certificateSigningRequest") diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index 02cc11d34..5164db52a 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -143,6 +143,10 @@ class IoTDataPlaneBackend(BaseBackend): super().__init__(region_name, account_id) self.published_payloads = list() + @property + def iot_backend(self): + return iot_backends[self.account_id][self.region_name] + def update_thing_shadow(self, thing_name, payload): """ spec of payload: @@ -150,7 +154,7 @@ class IoTDataPlaneBackend(BaseBackend): - state node must be an Object - State contains an invalid node: 'foo' """ - thing = iot_backends[self.region_name].describe_thing(thing_name) + thing = self.iot_backend.describe_thing(thing_name) # validate try: @@ -173,14 +177,14 @@ class IoTDataPlaneBackend(BaseBackend): return thing.thing_shadow def get_thing_shadow(self, thing_name): - thing = iot_backends[self.region_name].describe_thing(thing_name) + thing = self.iot_backend.describe_thing(thing_name) if thing.thing_shadow is None or thing.thing_shadow.deleted: raise ResourceNotFoundException() return thing.thing_shadow def delete_thing_shadow(self, thing_name): - thing = iot_backends[self.region_name].describe_thing(thing_name) + thing = self.iot_backend.describe_thing(thing_name) if thing.thing_shadow is None: raise ResourceNotFoundException() payload = None diff --git a/moto/iotdata/responses.py b/moto/iotdata/responses.py index 02b8d6ea8..6aa157327 100644 --- a/moto/iotdata/responses.py +++ b/moto/iotdata/responses.py @@ -5,11 +5,12 @@ from urllib.parse import unquote class IoTDataPlaneResponse(BaseResponse): - SERVICE_NAME = "iot-data" + def __init__(self): + super().__init__(service_name="iot-data") @property def iotdata_backend(self): - return iotdata_backends[self.region] + return iotdata_backends[self.current_account][self.region] def update_thing_shadow(self): thing_name = self._get_param("thingName") diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index 9dbf1f0a2..519dcb35e 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -1,6 +1,5 @@ import json from werkzeug.exceptions import BadRequest -from moto.core import get_account_id class ResourceNotFoundError(BadRequest): @@ -20,24 +19,20 @@ class ResourceInUseError(BadRequest): class StreamNotFoundError(ResourceNotFoundError): - def __init__(self, stream_name): - super().__init__( - "Stream {0} under account {1} not found.".format( - stream_name, get_account_id() - ) - ) + def __init__(self, stream_name, account_id): + super().__init__(f"Stream {stream_name} under account {account_id} not found.") class ShardNotFoundError(ResourceNotFoundError): - def __init__(self, shard_id, stream): + def __init__(self, shard_id, stream, account_id): super().__init__( - f"Could not find shard {shard_id} in stream {stream} under account {get_account_id()}." + f"Could not find shard {shard_id} in stream {stream} under account {account_id}." ) class ConsumerNotFound(ResourceNotFoundError): - def __init__(self, consumer): - super().__init__(f"Consumer {consumer}, account {get_account_id()} not found.") + def __init__(self, consumer, account_id): + super().__init__(f"Consumer {consumer}, account {account_id} not found.") class InvalidArgumentError(BadRequest): diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index a6f495519..f53f10279 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -7,7 +7,6 @@ from operator import attrgetter from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import unix_time, BackendDict -from moto.core import get_account_id from moto.utilities.paginator import paginate from moto.utilities.utils import md5_hash from .exceptions import ( @@ -31,12 +30,12 @@ from .utils import ( class Consumer(BaseModel): - def __init__(self, consumer_name, region_name, stream_arn): + def __init__(self, consumer_name, account_id, region_name, stream_arn): self.consumer_name = consumer_name self.created = unix_time() self.stream_arn = stream_arn stream_name = stream_arn.split("/")[-1] - self.consumer_arn = f"arn:aws:kinesis:{region_name}:{get_account_id()}:stream/{stream_name}/consumer/{consumer_name}" + self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}" def to_json(self, include_stream_arn=False): resp = { @@ -164,13 +163,16 @@ class Shard(BaseModel): class Stream(CloudFormationModel): - def __init__(self, stream_name, shard_count, retention_period_hours, region_name): + def __init__( + self, stream_name, shard_count, retention_period_hours, account_id, region_name + ): self.stream_name = stream_name self.creation_datetime = datetime.datetime.now().strftime( "%Y-%m-%dT%H:%M:%S.%f000" ) self.region = region_name - self.account_number = get_account_id() + self.account_id = account_id + self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}" self.shards = {} self.tags = {} self.status = "ACTIVE" @@ -211,12 +213,12 @@ class Stream(CloudFormationModel): pass else: raise InvalidArgumentError( - message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {get_account_id()} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." + message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." ) if not shard.is_open: raise InvalidArgumentError( - message=f"Shard {shard.shard_id} in stream {self.stream_name} under account {get_account_id()} has already been merged or split, and thus is not eligible for merging or splitting." + message=f"Shard {shard.shard_id} in stream {self.stream_name} under account {self.account_id} has already been merged or split, and thus is not eligible for merging or splitting." ) last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[ @@ -333,19 +335,11 @@ class Stream(CloudFormationModel): self.shard_count = target_shard_count - @property - def arn(self): - return "arn:aws:kinesis:{region}:{account_number}:stream/{stream_name}".format( - region=self.region, - account_number=self.account_number, - stream_name=self.stream_name, - ) - def get_shard(self, shard_id): if shard_id in self.shards: return self.shards[shard_id] else: - raise ShardNotFoundError(shard_id, stream="") + raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id) def get_shard_for_key(self, partition_key, explicit_hash_key): if not isinstance(partition_key, str): @@ -415,7 +409,7 @@ class Stream(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json.get("Properties", {}) shard_count = properties.get("ShardCount", 1) @@ -425,7 +419,7 @@ class Stream(CloudFormationModel): for tag_item in properties.get("Tags", []) } - backend = kinesis_backends[region_name] + backend = kinesis_backends[account_id][region_name] stream = backend.create_stream( resource_name, shard_count, retention_period_hours ) @@ -435,7 +429,12 @@ class Stream(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] @@ -468,9 +467,9 @@ class Stream(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - backend = kinesis_backends[region_name] + backend = kinesis_backends[account_id][region_name] backend.delete_stream(resource_name) @staticmethod @@ -515,7 +514,11 @@ class KinesisBackend(BaseBackend): if stream_name in self.streams: raise ResourceInUseError(stream_name) stream = Stream( - stream_name, shard_count, retention_period_hours, self.region_name + stream_name, + shard_count, + retention_period_hours, + self.account_id, + self.region_name, ) self.streams[stream_name] = stream return stream @@ -524,7 +527,7 @@ class KinesisBackend(BaseBackend): if stream_name in self.streams: return self.streams[stream_name] else: - raise StreamNotFoundError(stream_name) + raise StreamNotFoundError(stream_name, self.account_id) def describe_stream_summary(self, stream_name): return self.describe_stream(stream_name) @@ -535,7 +538,7 @@ class KinesisBackend(BaseBackend): def delete_stream(self, stream_name): if stream_name in self.streams: return self.streams.pop(stream_name) - raise StreamNotFoundError(stream_name) + raise StreamNotFoundError(stream_name, self.account_id) def get_shard_iterator( self, @@ -551,7 +554,7 @@ class KinesisBackend(BaseBackend): shard = stream.get_shard(shard_id) except ShardNotFoundError: raise ResourceNotFoundError( - message=f"Shard {shard_id} in stream {stream_name} under account {get_account_id()} does not exist" + message=f"Shard {shard_id} in stream {stream_name} under account {self.account_id} does not exist" ) shard_iterator = compose_new_shard_iterator( @@ -625,7 +628,9 @@ class KinesisBackend(BaseBackend): ) if shard_to_split not in stream.shards: - raise ShardNotFoundError(shard_id=shard_to_split, stream=stream_name) + raise ShardNotFoundError( + shard_id=shard_to_split, stream=stream_name, account_id=self.account_id + ) if not re.match(r"0|([1-9]\d{0,38})", new_starting_hash_key): raise ValidationException( @@ -640,10 +645,14 @@ class KinesisBackend(BaseBackend): stream = self.describe_stream(stream_name) if shard_to_merge not in stream.shards: - raise ShardNotFoundError(shard_to_merge, stream=stream_name) + raise ShardNotFoundError( + shard_to_merge, stream=stream_name, account_id=self.account_id + ) if adjacent_shard_to_merge not in stream.shards: - raise ShardNotFoundError(adjacent_shard_to_merge, stream=stream_name) + raise ShardNotFoundError( + adjacent_shard_to_merge, stream=stream_name, account_id=self.account_id + ) stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) @@ -749,7 +758,9 @@ class KinesisBackend(BaseBackend): return stream.consumers def register_stream_consumer(self, stream_arn, consumer_name): - consumer = Consumer(consumer_name, self.region_name, stream_arn) + consumer = Consumer( + consumer_name, self.account_id, self.region_name, stream_arn + ) stream = self._find_stream_by_arn(stream_arn) stream.consumers.append(consumer) return consumer @@ -765,7 +776,9 @@ class KinesisBackend(BaseBackend): consumer = stream.get_consumer_by_arn(consumer_arn) if consumer: return consumer - raise ConsumerNotFound(consumer=consumer_name or consumer_arn) + raise ConsumerNotFound( + consumer=consumer_name or consumer_arn, account_id=self.account_id + ) def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn): if stream_arn: diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index f8e321409..d27475094 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -5,13 +5,16 @@ from .models import kinesis_backends class KinesisResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="kinesis") + @property def parameters(self): return json.loads(self.body) @property def kinesis_backend(self): - return kinesis_backends[self.region] + return kinesis_backends[self.current_account][self.region] def create_stream(self): stream_name = self.parameters.get("StreamName") diff --git a/moto/kinesisvideo/models.py b/moto/kinesisvideo/models.py index 179a59fd6..ba2fc3779 100644 --- a/moto/kinesisvideo/models.py +++ b/moto/kinesisvideo/models.py @@ -4,12 +4,12 @@ from .exceptions import ResourceNotFoundException, ResourceInUseException import random import string from moto.core.utils import get_random_hex, BackendDict -from moto.core import get_account_id class Stream(BaseModel): def __init__( self, + account_id, region_name, device_name, stream_name, @@ -28,9 +28,7 @@ class Stream(BaseModel): self.status = "ACTIVE" self.version = self._get_random_string() self.creation_time = datetime.utcnow() - stream_arn = "arn:aws:kinesisvideo:{}:{}:stream/{}/1598784211076".format( - self.region_name, get_account_id(), self.stream_name - ) + stream_arn = f"arn:aws:kinesisvideo:{region_name}:{account_id}:stream/{stream_name}/1598784211076" self.data_endpoint_number = get_random_hex() self.arn = stream_arn @@ -79,6 +77,7 @@ class KinesisVideoBackend(BaseBackend): "The stream {} already exists.".format(stream_name) ) stream = Stream( + self.account_id, self.region_name, device_name, stream_name, diff --git a/moto/kinesisvideo/responses.py b/moto/kinesisvideo/responses.py index 41d954627..98a596817 100644 --- a/moto/kinesisvideo/responses.py +++ b/moto/kinesisvideo/responses.py @@ -4,11 +4,12 @@ import json class KinesisVideoResponse(BaseResponse): - SERVICE_NAME = "kinesisvideo" + def __init__(self): + super().__init__(service_name="kinesisvideo") @property def kinesisvideo_backend(self): - return kinesisvideo_backends[self.region] + return kinesisvideo_backends[self.current_account][self.region] def create_stream(self): device_name = self._get_param("DeviceName") diff --git a/moto/kinesisvideoarchivedmedia/models.py b/moto/kinesisvideoarchivedmedia/models.py index 6cdd469ef..fdf9d5a4e 100644 --- a/moto/kinesisvideoarchivedmedia/models.py +++ b/moto/kinesisvideoarchivedmedia/models.py @@ -5,10 +5,12 @@ from moto.sts.utils import random_session_token class KinesisVideoArchivedMediaBackend(BaseBackend): + @property + def backend(self): + return kinesisvideo_backends[self.account_id][self.region_name] + def _get_streaming_url(self, stream_name, stream_arn, api_name): - stream = kinesisvideo_backends[self.region_name]._get_stream( - stream_name, stream_arn - ) + stream = self.backend._get_stream(stream_name, stream_arn) data_endpoint = stream.get_data_endpoint(api_name) session_token = random_session_token() api_to_relative_path = { @@ -32,7 +34,7 @@ class KinesisVideoArchivedMediaBackend(BaseBackend): return url def get_clip(self, stream_name, stream_arn): - kinesisvideo_backends[self.region_name]._get_stream(stream_name, stream_arn) + self.backend._get_stream(stream_name, stream_arn) content_type = "video/mp4" # Fixed content_type as it depends on input stream payload = b"sample-mp4-video" return content_type, payload diff --git a/moto/kinesisvideoarchivedmedia/responses.py b/moto/kinesisvideoarchivedmedia/responses.py index caa10ca54..e86824b46 100644 --- a/moto/kinesisvideoarchivedmedia/responses.py +++ b/moto/kinesisvideoarchivedmedia/responses.py @@ -4,11 +4,12 @@ import json class KinesisVideoArchivedMediaResponse(BaseResponse): - SERVICE_NAME = "kinesis-video-archived-media" + def __init__(self): + super().__init__(service_name="kinesis-video-archived-media") @property def kinesisvideoarchivedmedia_backend(self): - return kinesisvideoarchivedmedia_backends[self.region] + return kinesisvideoarchivedmedia_backends[self.current_account][self.region] def get_hls_streaming_session_url(self): stream_name = self._get_param("StreamName") diff --git a/moto/kms/models.py b/moto/kms/models.py index 2e0008d24..4bec69f64 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -8,7 +8,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import get_random_hex, unix_time, BackendDict from moto.utilities.tagging_service import TaggingService from moto.core.exceptions import JsonRESTError @@ -57,10 +57,18 @@ class Grant(BaseModel): class Key(CloudFormationModel): def __init__( - self, policy, key_usage, key_spec, description, region, multi_region=False + self, + policy, + key_usage, + key_spec, + description, + account_id, + region, + multi_region=False, ): self.id = generate_key_id(multi_region) self.creation_date = unix_time() + self.account_id = account_id self.policy = policy or self.generate_default_policy() self.key_usage = key_usage self.key_state = "Enabled" @@ -68,7 +76,6 @@ class Key(CloudFormationModel): self.enabled = True self.region = region self.multi_region = multi_region - self.account_id = get_account_id() self.key_rotation_status = False self.deletion_date = None self.key_material = generate_master_key() @@ -76,6 +83,7 @@ class Key(CloudFormationModel): self.origin = "AWS_KMS" self.key_manager = "CUSTOMER" self.key_spec = key_spec or "SYMMETRIC_DEFAULT" + self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}" self.grants = dict() @@ -126,7 +134,7 @@ class Key(CloudFormationModel): { "Sid": "Enable IAM User Permissions", "Effect": "Allow", - "Principal": {"AWS": f"arn:aws:iam::{get_account_id()}:root"}, + "Principal": {"AWS": f"arn:aws:iam::{self.account_id}:root"}, "Action": "kms:*", "Resource": "*", } @@ -138,12 +146,6 @@ class Key(CloudFormationModel): def physical_resource_id(self): return self.id - @property - def arn(self): - return "arn:aws:kms:{0}:{1}:key/{2}".format( - self.region, self.account_id, self.id - ) - @property def encryption_algorithms(self): if self.key_usage == "SIGN_VERIFY": @@ -197,8 +199,8 @@ class Key(CloudFormationModel): key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date) return key_dict - def delete(self, region_name): - kms_backends[region_name].delete_key(self.id) + def delete(self, account_id, region_name): + kms_backends[account_id][region_name].delete_key(self.id) @staticmethod def cloudformation_name_type(): @@ -211,9 +213,9 @@ class Key(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - kms_backend = kms_backends[region_name] + kms_backend = kms_backends[account_id][region_name] properties = cloudformation_json["Properties"] key = kms_backend.create_key( @@ -222,7 +224,6 @@ class Key(CloudFormationModel): key_spec="SYMMETRIC_DEFAULT", description=properties["Description"], tags=properties.get("Tags", []), - region=region_name, ) key.key_rotation_status = properties["EnableKeyRotation"] key.enabled = properties["Enabled"] @@ -264,15 +265,22 @@ class KmsBackend(BaseBackend): "SYMMETRIC_DEFAULT", "Default key", None, - self.region_name, ) self.add_alias(key.id, alias_name) return key.id def create_key( - self, policy, key_usage, key_spec, description, tags, region, multi_region=False + self, policy, key_usage, key_spec, description, tags, multi_region=False ): - key = Key(policy, key_usage, key_spec, description, region, multi_region) + key = Key( + policy, + key_usage, + key_spec, + description, + self.account_id, + self.region_name, + multi_region, + ) self.keys[key.id] = key if tags is not None and len(tags) > 0: self.tag_resource(key.id, tags) @@ -291,7 +299,7 @@ class KmsBackend(BaseBackend): # Since we only update top level properties, copy() should suffice. replica_key = copy(self.keys[key_id]) replica_key.region = replica_region - to_region_backend = kms_backends[replica_region] + to_region_backend = kms_backends[self.account_id][replica_region] to_region_backend.keys[replica_key.id] = replica_key def update_key_description(self, key_id, description): diff --git a/moto/kms/responses.py b/moto/kms/responses.py index e4c16fef2..4784cb794 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -4,7 +4,6 @@ import os import re import warnings -from moto.core import get_account_id from moto.core.responses import BaseResponse from moto.kms.utils import RESERVED_ALIASES from .models import kms_backends @@ -17,6 +16,9 @@ from .exceptions import ( class KmsResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="kms") + @property def parameters(self): params = json.loads(self.body) @@ -29,7 +31,7 @@ class KmsResponse(BaseResponse): @property def kms_backend(self): - return kms_backends[self.region] + return kms_backends[self.current_account][self.region] def _display_arn(self, key_id): if key_id.startswith("arn:"): @@ -40,9 +42,7 @@ class KmsResponse(BaseResponse): else: id_type = "key/" - return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format( - region=self.region, account=get_account_id(), id_type=id_type, key_id=key_id - ) + return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}" def _validate_cmk_id(self, key_id): """Determine whether a CMK ID exists. @@ -120,7 +120,7 @@ class KmsResponse(BaseResponse): multi_region = self.parameters.get("MultiRegion") key = self.kms_backend.create_key( - policy, key_usage, key_spec, description, tags, self.region, multi_region + policy, key_usage, key_spec, description, tags, multi_region ) return json.dumps(key.to_dict()) @@ -235,7 +235,7 @@ class KmsResponse(BaseResponse): "An alias with the name arn:aws:kms:{region}:{account_id}:{alias_name} " "already exists".format( region=self.region, - account_id=get_account_id(), + account_id=self.current_account, alias_name=alias_name, ) ) @@ -270,11 +270,7 @@ class KmsResponse(BaseResponse): # TODO: add creation date and last updated in response_aliases response_aliases.append( { - "AliasArn": "arn:aws:kms:{region}:{account_id}:{alias_name}".format( - region=region, - account_id=get_account_id(), - alias_name=alias_name, - ), + "AliasArn": f"arn:aws:kms:{region}:{self.current_account}:{alias_name}", "AliasName": alias_name, "TargetKeyId": target_key_id, } @@ -286,11 +282,7 @@ class KmsResponse(BaseResponse): if not exsisting: response_aliases.append( { - "AliasArn": "arn:aws:kms:{region}:{account_id}:{reserved_alias}".format( - region=region, - account_id=get_account_id(), - reserved_alias=reserved_alias, - ), + "AliasArn": f"arn:aws:kms:{region}:{self.current_account}:{reserved_alias}", "AliasName": reserved_alias, } ) diff --git a/moto/logs/models.py b/moto/logs/models.py index 4cf5cc33e..7447fd804 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime, timedelta -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core import CloudFormationModel from moto.core.utils import unix_time_millis, BackendDict from moto.utilities.paginator import paginate @@ -58,9 +58,10 @@ class LogEvent(BaseModel): class LogStream(BaseModel): _log_ids = 0 - def __init__(self, region, log_group, name): + def __init__(self, account_id, region, log_group, name): + self.account_id = account_id self.region = region - self.arn = f"arn:aws:logs:{region}:{get_account_id()}:log-group:{log_group}:log-stream:{name}" + self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{log_group}:log-stream:{name}" self.creation_time = int(unix_time_millis()) self.first_event_timestamp = None self.last_event_timestamp = None @@ -134,7 +135,7 @@ class LogStream(BaseModel): if service == "lambda": from moto.awslambda import lambda_backends # due to circular dependency - lambda_backends[self.region].send_log_event( + lambda_backends[self.account_id][self.region].send_log_event( self.destination_arn, self.filter_name, log_group_name, @@ -142,11 +143,9 @@ class LogStream(BaseModel): formatted_log_events, ) elif service == "firehose": - from moto.firehose import ( # pylint: disable=import-outside-toplevel - firehose_backends, - ) + from moto.firehose import firehose_backends - firehose_backends[self.region].send_log_event( + firehose_backends[self.account_id][self.region].send_log_event( self.destination_arn, self.filter_name, log_group_name, @@ -258,10 +257,11 @@ class LogStream(BaseModel): class LogGroup(CloudFormationModel): - def __init__(self, region, name, tags, **kwargs): + def __init__(self, account_id, region, name, tags, **kwargs): self.name = name + self.account_id = account_id self.region = region - self.arn = f"arn:aws:logs:{region}:{get_account_id()}:log-group:{name}" + self.arn = f"arn:aws:logs:{region}:{account_id}:log-group:{name}" self.creation_time = int(unix_time_millis()) self.tags = tags self.streams = dict() # {name: LogStream} @@ -286,18 +286,18 @@ class LogGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] tags = properties.get("Tags", {}) - return logs_backends[region_name].create_log_group( + return logs_backends[account_id][region_name].create_log_group( resource_name, tags, **properties ) def create_log_stream(self, log_stream_name): if log_stream_name in self.streams: raise ResourceAlreadyExistsException() - stream = LogStream(self.region, self.name, log_stream_name) + stream = LogStream(self.account_id, self.region, self.name, log_stream_name) filters = self.describe_subscription_filters() if filters: @@ -561,38 +561,42 @@ class LogResourcePolicy(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] policy_name = properties["PolicyName"] policy_document = properties["PolicyDocument"] - return logs_backends[region_name].put_resource_policy( + return logs_backends[account_id][region_name].put_resource_policy( policy_name, policy_document ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] policy_name = properties["PolicyName"] policy_document = properties["PolicyDocument"] - updated = logs_backends[region_name].put_resource_policy( - policy_name, policy_document - ) + backend = logs_backends[account_id][region_name] + updated = backend.put_resource_policy(policy_name, policy_document) # TODO: move `update by replacement logic` to cloudformation. this is required for implementing rollbacks if original_resource.policy_name != policy_name: - logs_backends[region_name].delete_resource_policy( - original_resource.policy_name - ) + backend.delete_resource_policy(original_resource.policy_name) return updated @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - return logs_backends[region_name].delete_resource_policy(resource_name) + return logs_backends[account_id][region_name].delete_resource_policy( + resource_name + ) class LogsBackend(BaseBackend): @@ -620,14 +624,16 @@ class LogsBackend(BaseBackend): value=log_group_name, ) self.groups[log_group_name] = LogGroup( - self.region_name, log_group_name, tags, **kwargs + self.account_id, self.region_name, log_group_name, tags, **kwargs ) return self.groups[log_group_name] def ensure_log_group(self, log_group_name, tags): if log_group_name in self.groups: return - self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags) + self.groups[log_group_name] = LogGroup( + self.account_id, self.region_name, log_group_name, tags + ) def delete_log_group(self, log_group_name): if log_group_name not in self.groups: @@ -883,12 +889,12 @@ class LogsBackend(BaseBackend): service = destination_arn.split(":")[2] if service == "lambda": - from moto.awslambda import ( # pylint: disable=import-outside-toplevel - lambda_backends, - ) + from moto.awslambda import lambda_backends try: - lambda_backends[self.region_name].get_function(destination_arn) + lambda_backends[self.account_id][self.region_name].get_function( + destination_arn + ) # no specific permission check implemented except Exception: raise InvalidParameterException( @@ -897,13 +903,11 @@ class LogsBackend(BaseBackend): "function." ) elif service == "firehose": - from moto.firehose import ( # pylint: disable=import-outside-toplevel - firehose_backends, - ) + from moto.firehose import firehose_backends - firehose = firehose_backends[self.region_name].lookup_name_from_arn( - destination_arn - ) + firehose = firehose_backends[self.account_id][ + self.region_name + ].lookup_name_from_arn(destination_arn) if not firehose: raise InvalidParameterException( "Could not deliver test message to specified Firehose " @@ -940,7 +944,7 @@ class LogsBackend(BaseBackend): return query_id def create_export_task(self, log_group_name, destination): - s3_backends["global"].get_bucket(destination) + s3_backends[self.account_id]["global"].get_bucket(destination) if log_group_name not in self.groups: raise ResourceNotFoundException() task_id = uuid.uuid4() diff --git a/moto/logs/responses.py b/moto/logs/responses.py index fe8900c8c..92a01e72f 100644 --- a/moto/logs/responses.py +++ b/moto/logs/responses.py @@ -33,9 +33,12 @@ def validate_param( class LogsResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="logs") + @property def logs_backend(self): - return logs_backends[self.region] + return logs_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/managedblockchain/responses.py b/moto/managedblockchain/responses.py index 3c0f28908..e1ed7f724 100644 --- a/moto/managedblockchain/responses.py +++ b/moto/managedblockchain/responses.py @@ -14,9 +14,12 @@ from .utils import ( class ManagedBlockchainResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="managedblockchain") + @property def backend(self): - return managedblockchain_backends[self.region] + return managedblockchain_backends[self.current_account][self.region] @exception_handler def network_response(self, request, full_url, headers): diff --git a/moto/mediaconnect/models.py b/moto/mediaconnect/models.py index 625f61a89..574fdfd3b 100644 --- a/moto/mediaconnect/models.py +++ b/moto/mediaconnect/models.py @@ -1,7 +1,7 @@ from collections import OrderedDict from uuid import uuid4 -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.mediaconnect.exceptions import NotFoundException @@ -80,7 +80,7 @@ class MediaConnectBackend(BaseBackend): def _add_source_details(self, source, flow_id, ingest_ip="127.0.0.1"): if source: source["sourceArn"] = ( - f"arn:aws:mediaconnect:{self.region_name}:{get_account_id()}:source" + f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source" f":{flow_id}:{source['name']}" ) if not source.get("entitlementArn"): @@ -91,7 +91,7 @@ class MediaConnectBackend(BaseBackend): flow.description = "A Moto test flow" flow.egress_ip = "127.0.0.1" - flow.flow_arn = f"arn:aws:mediaconnect:{self.region_name}:{get_account_id()}:flow:{flow_id}:{flow.name}" + flow.flow_arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:flow:{flow_id}:{flow.name}" for index, _source in enumerate(flow.sources): self._add_source_details(_source, flow_id, f"127.0.0.{index}") @@ -250,7 +250,7 @@ class MediaConnectBackend(BaseBackend): for source in sources: source_id = uuid4().hex name = source["name"] - arn = f"arn:aws:mediaconnect:{self.region_name}:{get_account_id()}:source:{source_id}:{name}" + arn = f"arn:aws:mediaconnect:{self.region_name}:{self.account_id}:source:{source_id}:{name}" source["sourceArn"] = arn flow.sources = sources return flow_arn, sources diff --git a/moto/mediaconnect/responses.py b/moto/mediaconnect/responses.py index c741126b3..99b4fb034 100644 --- a/moto/mediaconnect/responses.py +++ b/moto/mediaconnect/responses.py @@ -7,11 +7,12 @@ from urllib.parse import unquote class MediaConnectResponse(BaseResponse): - SERVICE_NAME = "mediaconnect" + def __init__(self): + super().__init__(service_name="mediaconnect") @property def mediaconnect_backend(self): - return mediaconnect_backends[self.region] + return mediaconnect_backends[self.current_account][self.region] def create_flow(self): availability_zone = self._get_param("availabilityZone") diff --git a/moto/medialive/responses.py b/moto/medialive/responses.py index 9d75eb4b8..f3a0832fe 100644 --- a/moto/medialive/responses.py +++ b/moto/medialive/responses.py @@ -4,11 +4,12 @@ import json class MediaLiveResponse(BaseResponse): - SERVICE_NAME = "medialive" + def __init__(self): + super().__init__(service_name="medialive") @property def medialive_backend(self): - return medialive_backends[self.region] + return medialive_backends[self.current_account][self.region] def create_channel(self): cdi_input_specification = self._get_param("cdiInputSpecification") diff --git a/moto/mediapackage/responses.py b/moto/mediapackage/responses.py index 27384dc39..f3d85d7e7 100644 --- a/moto/mediapackage/responses.py +++ b/moto/mediapackage/responses.py @@ -4,11 +4,12 @@ import json class MediaPackageResponse(BaseResponse): - SERVICE_NAME = "mediapackage" + def __init__(self): + super().__init__(service_name="mediapackage") @property def mediapackage_backend(self): - return mediapackage_backends[self.region] + return mediapackage_backends[self.current_account][self.region] def create_channel(self): description = self._get_param("description") diff --git a/moto/mediastore/responses.py b/moto/mediastore/responses.py index 5d55c179f..ecb90f779 100644 --- a/moto/mediastore/responses.py +++ b/moto/mediastore/responses.py @@ -5,11 +5,12 @@ from .models import mediastore_backends class MediaStoreResponse(BaseResponse): - SERVICE_NAME = "mediastore" + def __init__(self): + super().__init__(service_name="mediastore") @property def mediastore_backend(self): - return mediastore_backends[self.region] + return mediastore_backends[self.current_account][self.region] def create_container(self): name = self._get_param("ContainerName") diff --git a/moto/mediastoredata/responses.py b/moto/mediastoredata/responses.py index 2803e0913..8e3251a17 100644 --- a/moto/mediastoredata/responses.py +++ b/moto/mediastoredata/responses.py @@ -5,11 +5,12 @@ from .models import mediastoredata_backends class MediaStoreDataResponse(BaseResponse): - SERVICE_NAME = "mediastore-data" + def __init__(self): + super().__init__(service_name="mediastore-data") @property def mediastoredata_backend(self): - return mediastoredata_backends[self.region] + return mediastoredata_backends[self.current_account][self.region] def get_object(self): path = self._get_param("Path") diff --git a/moto/moto_server/werkzeug_app.py b/moto/moto_server/werkzeug_app.py index 6983c970c..8d9c3bc77 100644 --- a/moto/moto_server/werkzeug_app.py +++ b/moto/moto_server/werkzeug_app.py @@ -8,7 +8,8 @@ from flask_cors import CORS import moto.backends as backends import moto.backend_index as backend_index -from moto.core.utils import convert_to_flask_response +from moto.core import DEFAULT_ACCOUNT_ID +from moto.core.utils import convert_to_flask_response, BackendDict from .utilities import AWSTestHelper, RegexConverter @@ -259,8 +260,13 @@ def create_backend_app(service): backend_app.url_map.converters["regex"] = RegexConverter backend_dict = backends.get_backend(service) - if "us-east-1" in backend_dict: - backend = backend_dict["us-east-1"] + # Get an instance of this backend. + # We'll only use this backend to resolve the URL's, so the exact region/account_id is irrelevant + if isinstance(backend_dict, BackendDict): + if "us-east-1" in backend_dict[DEFAULT_ACCOUNT_ID]: + backend = backend_dict[DEFAULT_ACCOUNT_ID]["us-east-1"] + else: + backend = backend_dict[DEFAULT_ACCOUNT_ID]["global"] else: backend = backend_dict["global"] diff --git a/moto/mq/models.py b/moto/mq/models.py index 0ffcc2735..82626f12a 100644 --- a/moto/mq/models.py +++ b/moto/mq/models.py @@ -1,7 +1,7 @@ import base64 import xmltodict -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, get_random_hex, unix_time from moto.utilities.tagging_service import TaggingService @@ -57,9 +57,9 @@ class ConfigurationRevision(BaseModel): class Configuration(BaseModel): - def __init__(self, region, name, engine_type, engine_version): + def __init__(self, account_id, region, name, engine_type, engine_version): self.id = f"c-{get_random_hex(6)}" - self.arn = f"arn:aws:mq:{region}:{get_account_id()}:configuration:{self.id}" + self.arn = f"arn:aws:mq:{region}:{account_id}:configuration:{self.id}" self.created = unix_time() self.name = name @@ -140,6 +140,7 @@ class Broker(BaseModel): def __init__( self, name, + account_id, region, authentication_strategy, auto_minor_version_upgrade, @@ -160,7 +161,7 @@ class Broker(BaseModel): ): self.name = name self.id = get_random_hex(6) - self.arn = f"arn:aws:mq:{region}:{get_account_id()}:broker:{self.id}" + self.arn = f"arn:aws:mq:{region}:{account_id}:broker:{self.id}" self.state = "RUNNING" self.created = unix_time() @@ -379,6 +380,7 @@ class MQBackend(BaseBackend): ): broker = Broker( name=broker_name, + account_id=self.account_id, region=self.region_name, authentication_strategy=authentication_strategy, auto_minor_version_upgrade=auto_minor_version_upgrade, @@ -444,6 +446,7 @@ class MQBackend(BaseBackend): if engine_type.upper() != "ACTIVEMQ": raise UnknownEngineType(engine_type) config = Configuration( + account_id=self.account_id, region=self.region_name, name=name, engine_type=engine_type, diff --git a/moto/mq/responses.py b/moto/mq/responses.py index 4cb18d60b..bee19d039 100644 --- a/moto/mq/responses.py +++ b/moto/mq/responses.py @@ -9,10 +9,13 @@ from .models import mq_backends class MQResponse(BaseResponse): """Handler for MQ requests and responses.""" + def __init__(self): + super().__init__(service_name="mq") + @property def mq_backend(self): """Return backend instance specific for this region.""" - return mq_backends[self.region] + return mq_backends[self.current_account][self.region] def broker(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/opsworks/models.py b/moto/opsworks/models.py index 5bbc9590b..8a22f2b72 100644 --- a/moto/opsworks/models.py +++ b/moto/opsworks/models.py @@ -1,6 +1,5 @@ from moto.core import BaseBackend, BaseModel from moto.ec2 import ec2_backends -from moto.core import get_account_id from moto.core.utils import BackendDict import uuid import datetime @@ -316,6 +315,7 @@ class Stack(BaseModel): def __init__( self, name, + account_id, region, service_role_arn, default_instance_profile_arn, @@ -372,7 +372,7 @@ class Stack(BaseModel): self.id = "{0}".format(uuid.uuid4()) self.layers = [] self.apps = [] - self.account_number = get_account_id() + self.account_number = account_id self.created_at = datetime.datetime.utcnow() def __eq__(self, other): @@ -502,10 +502,10 @@ class OpsWorksBackend(BaseBackend): self.layers = {} self.apps = {} self.instances = {} - self.ec2_backend = ec2_backends[region_name] + self.ec2_backend = ec2_backends[account_id][region_name] def create_stack(self, **kwargs): - stack = Stack(**kwargs) + stack = Stack(account_id=self.account_id, **kwargs) self.stacks[stack.id] = stack return stack diff --git a/moto/opsworks/responses.py b/moto/opsworks/responses.py index fb8db2598..fe8c1b5be 100644 --- a/moto/opsworks/responses.py +++ b/moto/opsworks/responses.py @@ -5,13 +5,16 @@ from .models import opsworks_backends class OpsWorksResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="opsworks") + @property def parameters(self): return json.loads(self.body) @property def opsworks_backend(self): - return opsworks_backends[self.region] + return opsworks_backends[self.current_account][self.region] def create_stack(self): kwargs = dict( diff --git a/moto/organizations/models.py b/moto/organizations/models.py index 7aa9c09f6..777124d5e 100644 --- a/moto/organizations/models.py +++ b/moto/organizations/models.py @@ -2,7 +2,7 @@ import datetime import re import json -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.core.utils import unix_time, BackendDict from moto.organizations import utils @@ -25,11 +25,11 @@ from .utils import PAGINATION_MODEL class FakeOrganization(BaseModel): - def __init__(self, feature_set): + def __init__(self, account_id, feature_set): self.id = utils.make_random_org_id() self.root_id = utils.make_random_root_id() self.feature_set = feature_set - self.master_account_id = utils.MASTER_ACCOUNT_ID + self.master_account_id = account_id self.master_account_email = utils.MASTER_ACCOUNT_EMAIL self.available_policy_types = [ # This policy is available, but not applied @@ -355,7 +355,7 @@ class OrganizationsBackend(BaseBackend): return root def create_organization(self, **kwargs): - self.org = FakeOrganization(kwargs["FeatureSet"]) + self.org = FakeOrganization(self.account_id, kwargs["FeatureSet"]) root_ou = FakeRoot(self.org) self.ou.append(root_ou) master_account = FakeAccount( @@ -775,7 +775,7 @@ class OrganizationsBackend(BaseBackend): def register_delegated_administrator(self, **kwargs): account_id = kwargs["AccountId"] - if account_id == get_account_id(): + if account_id == self.account_id: raise ConstraintViolationException( "You cannot register master account/yourself as delegated administrator for your organization." ) @@ -834,7 +834,7 @@ class OrganizationsBackend(BaseBackend): account_id = kwargs["AccountId"] service = kwargs["ServicePrincipal"] - if account_id == get_account_id(): + if account_id == self.account_id: raise ConstraintViolationException( "You cannot register master account/yourself as delegated administrator for your organization." ) @@ -920,4 +920,3 @@ organizations_backends = BackendDict( use_boto3_regions=False, additional_regions=["global"], ) -organizations_backend = organizations_backends["global"] diff --git a/moto/organizations/responses.py b/moto/organizations/responses.py index 4c0b519ad..f30e2b4bf 100644 --- a/moto/organizations/responses.py +++ b/moto/organizations/responses.py @@ -1,13 +1,16 @@ import json from moto.core.responses import BaseResponse -from .models import organizations_backend +from .models import organizations_backends class OrganizationsResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="organizations") + @property def organizations_backend(self): - return organizations_backend + return organizations_backends[self.current_account]["global"] @property def request_params(self): diff --git a/moto/organizations/utils.py b/moto/organizations/utils.py index 3e9117ced..f328889cb 100644 --- a/moto/organizations/utils.py +++ b/moto/organizations/utils.py @@ -1,10 +1,8 @@ import random import re import string -from moto.core import get_account_id -MASTER_ACCOUNT_ID = get_account_id() MASTER_ACCOUNT_EMAIL = "master@example.com" DEFAULT_POLICY_ID = "p-FullAWSAccess" ORGANIZATION_ARN_FORMAT = "arn:aws:organizations::{0}:organization/{1}" diff --git a/moto/pinpoint/models.py b/moto/pinpoint/models.py index 1149f6ff8..a2e82d004 100644 --- a/moto/pinpoint/models.py +++ b/moto/pinpoint/models.py @@ -1,5 +1,5 @@ from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, unix_time from moto.utilities.tagging_service import TaggingService from uuid import uuid4 @@ -8,9 +8,11 @@ from .exceptions import ApplicationNotFound, EventStreamNotFound class App(BaseModel): - def __init__(self, name): + def __init__(self, account_id, name): self.application_id = str(uuid4()).replace("-", "") - self.arn = f"arn:aws:mobiletargeting:us-east-1:{get_account_id()}:apps/{self.application_id}" + self.arn = ( + f"arn:aws:mobiletargeting:us-east-1:{account_id}:apps/{self.application_id}" + ) self.name = name self.created = unix_time() self.settings = AppSettings() @@ -90,7 +92,7 @@ class PinpointBackend(BaseBackend): self.tagger = TaggingService() def create_app(self, name, tags): - app = App(name) + app = App(self.account_id, name) self.apps[app.application_id] = app tags = self.tagger.convert_dict_to_tags_input(tags) self.tagger.tag_resource(app.arn, tags) diff --git a/moto/pinpoint/responses.py b/moto/pinpoint/responses.py index 3d6b1aa3d..54d4b01e8 100644 --- a/moto/pinpoint/responses.py +++ b/moto/pinpoint/responses.py @@ -9,10 +9,13 @@ from .models import pinpoint_backends class PinpointResponse(BaseResponse): """Handler for Pinpoint requests and responses.""" + def __init__(self): + super().__init__(service_name="pinpoint") + @property def pinpoint_backend(self): """Return backend instance specific for this region.""" - return pinpoint_backends[self.region] + return pinpoint_backends[self.current_account][self.region] def app(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/polly/models.py b/moto/polly/models.py index 1e4ad64ba..20455812a 100644 --- a/moto/polly/models.py +++ b/moto/polly/models.py @@ -7,11 +7,9 @@ from moto.core.utils import BackendDict from .resources import VOICE_DATA from .utils import make_arn_for_lexicon -from moto.core import get_account_id - class Lexicon(BaseModel): - def __init__(self, name, content, region_name): + def __init__(self, name, content, account_id, region_name): self.name = name self.content = content self.size = 0 @@ -19,7 +17,7 @@ class Lexicon(BaseModel): self.last_modified = None self.language_code = None self.lexemes_count = 0 - self.arn = make_arn_for_lexicon(get_account_id(), name, region_name) + self.arn = make_arn_for_lexicon(account_id, name, region_name) self.update() @@ -107,7 +105,9 @@ class PollyBackend(BaseBackend): # but keeps the ARN self._lexicons.update(content) else: - lexicon = Lexicon(name, content, region_name=self.region_name) + lexicon = Lexicon( + name, content, self.account_id, region_name=self.region_name + ) self._lexicons[name] = lexicon diff --git a/moto/polly/responses.py b/moto/polly/responses.py index 39ecc7385..f0f7d8b77 100644 --- a/moto/polly/responses.py +++ b/moto/polly/responses.py @@ -11,9 +11,12 @@ LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$") class PollyResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="polly") + @property def polly_backend(self): - return polly_backends[self.region] + return polly_backends[self.current_account][self.region] @property def json(self): diff --git a/moto/quicksight/models.py b/moto/quicksight/models.py index c915afe21..03ba06918 100644 --- a/moto/quicksight/models.py +++ b/moto/quicksight/models.py @@ -1,6 +1,6 @@ """QuickSightBackend class with methods for supported APIs.""" -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from .exceptions import ResourceNotFoundException @@ -10,23 +10,24 @@ def _create_id(aws_account_id, namespace, _id): class QuicksightDataSet(BaseModel): - def __init__(self, region, _id, name): - self.arn = f"arn:aws:quicksight:{region}:{get_account_id()}:data-set/{_id}" + def __init__(self, account_id, region, _id, name): + self.arn = f"arn:aws:quicksight:{region}:{account_id}:data-set/{_id}" self._id = _id self.name = name self.region = region + self.account_id = account_id def to_json(self): return { "Arn": self.arn, "DataSetId": self._id, - "IngestionArn": f"arn:aws:quicksight:{self.region}:{get_account_id()}:ingestion/tbd", + "IngestionArn": f"arn:aws:quicksight:{self.region}:{self.account_id}:ingestion/tbd", } class QuicksightIngestion(BaseModel): - def __init__(self, region, data_set_id, ingestion_id): - self.arn = f"arn:aws:quicksight:{region}:{get_account_id()}:data-set/{data_set_id}/ingestions/{ingestion_id}" + def __init__(self, account_id, region, data_set_id, ingestion_id): + self.arn = f"arn:aws:quicksight:{region}:{account_id}:data-set/{data_set_id}/ingestions/{ingestion_id}" self.ingestion_id = ingestion_id def to_json(self): @@ -38,10 +39,12 @@ class QuicksightIngestion(BaseModel): class QuicksightMembership(BaseModel): - def __init__(self, region, group, user): + def __init__(self, account_id, region, group, user): self.group = group self.user = user - self.arn = f"arn:aws:quicksight:{region}:{get_account_id()}:group/default/{group}/{user}" + self.arn = ( + f"arn:aws:quicksight:{region}:{account_id}:group/default/{group}/{user}" + ) def to_json(self): return {"Arn": self.arn, "MemberName": self.user} @@ -50,7 +53,7 @@ class QuicksightMembership(BaseModel): class QuicksightGroup(BaseModel): def __init__(self, region, group_name, description, aws_account_id, namespace): self.arn = ( - f"arn:aws:quicksight:{region}:{get_account_id()}:group/default/{group_name}" + f"arn:aws:quicksight:{region}:{aws_account_id}:group/default/{group_name}" ) self.group_name = group_name self.description = description @@ -61,7 +64,9 @@ class QuicksightGroup(BaseModel): self.members = dict() def add_member(self, user_name): - membership = QuicksightMembership(self.region, self.group_name, user_name) + membership = QuicksightMembership( + self.aws_account_id, self.region, self.group_name, user_name + ) self.members[user_name] = membership return membership @@ -85,10 +90,8 @@ class QuicksightGroup(BaseModel): class QuicksightUser(BaseModel): - def __init__(self, region, email, identity_type, username, user_role): - self.arn = ( - f"arn:aws:quicksight:{region}:{get_account_id()}:user/default/{username}" - ) + def __init__(self, account_id, region, email, identity_type, username, user_role): + self.arn = f"arn:aws:quicksight:{region}:{account_id}:user/default/{username}" self.email = email self.identity_type = identity_type self.username = username @@ -115,7 +118,9 @@ class QuickSightBackend(BaseBackend): self.users = dict() def create_data_set(self, data_set_id, name): - return QuicksightDataSet(self.region_name, data_set_id, name=name) + return QuicksightDataSet( + self.account_id, self.region_name, data_set_id, name=name + ) def create_group(self, group_name, description, aws_account_id, namespace): group = QuicksightGroup( @@ -134,7 +139,9 @@ class QuickSightBackend(BaseBackend): return group.add_member(user_name) def create_ingestion(self, data_set_id, ingestion_id): - return QuicksightIngestion(self.region_name, data_set_id, ingestion_id) + return QuicksightIngestion( + self.account_id, self.region_name, data_set_id, ingestion_id + ) def delete_group(self, aws_account_id, namespace, group_name): _id = _create_id(aws_account_id, namespace, group_name) @@ -197,6 +204,7 @@ class QuickSightBackend(BaseBackend): IamArn, SessionName, CustomsPermissionsName, ExternalLoginFederationProviderType, CustomFederationProviderUrl, ExternalLoginId """ user = QuicksightUser( + account_id=self.account_id, region=self.region_name, email=email, identity_type=identity_type, diff --git a/moto/quicksight/responses.py b/moto/quicksight/responses.py index 91cc0bd31..4ef3efb69 100644 --- a/moto/quicksight/responses.py +++ b/moto/quicksight/responses.py @@ -8,10 +8,13 @@ from .models import quicksight_backends class QuickSightResponse(BaseResponse): """Handler for QuickSight requests and responses.""" + def __init__(self): + super().__init__(service_name="quicksight") + @property def quicksight_backend(self): """Return backend instance specific for this region.""" - return quicksight_backends[self.region] + return quicksight_backends[self.current_account][self.region] def dataset(self, request, full_url, headers): self.setup_class(request, full_url, headers) diff --git a/moto/ram/models.py b/moto/ram/models.py index 6d589be7e..eefc7f6ec 100644 --- a/moto/ram/models.py +++ b/moto/ram/models.py @@ -4,7 +4,7 @@ from datetime import datetime import random from uuid import uuid4 -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time, BackendDict from moto.organizations import organizations_backends from moto.ram.exceptions import ( @@ -38,25 +38,24 @@ class ResourceShare(BaseModel): "transit-gateway", # Amazon EC2 transit gateway ] - def __init__(self, region, **kwargs): + def __init__(self, account_id, region, **kwargs): + self.account_id = account_id self.region = region self.allow_external_principals = kwargs.get("allowExternalPrincipals", True) - self.arn = "arn:aws:ram:{0}:{1}:resource-share/{2}".format( - self.region, get_account_id(), uuid4() - ) + self.arn = f"arn:aws:ram:{self.region}:{account_id}:resource-share/{uuid4()}" self.creation_time = datetime.utcnow() self.feature_set = "STANDARD" self.last_updated_time = datetime.utcnow() self.name = kwargs["name"] - self.owning_account_id = get_account_id() + self.owning_account_id = account_id self.principals = [] self.resource_arns = [] self.status = "ACTIVE" @property def organizations_backend(self): - return organizations_backends["global"] + return organizations_backends[self.account_id]["global"] def add_principals(self, principals): for principal in principals: @@ -161,10 +160,10 @@ class ResourceAccessManagerBackend(BaseBackend): @property def organizations_backend(self): - return organizations_backends["global"] + return organizations_backends[self.account_id]["global"] def create_resource_share(self, **kwargs): - resource = ResourceShare(self.region_name, **kwargs) + resource = ResourceShare(self.account_id, self.region_name, **kwargs) resource.add_principals(kwargs.get("principals", [])) resource.add_resources(kwargs.get("resourceArns", [])) diff --git a/moto/ram/responses.py b/moto/ram/responses.py index fbc0d351d..7687e8090 100644 --- a/moto/ram/responses.py +++ b/moto/ram/responses.py @@ -4,11 +4,12 @@ import json class ResourceAccessManagerResponse(BaseResponse): - SERVICE_NAME = "ram" + def __init__(self): + super().__init__(service_name="ram") @property def ram_backend(self): - return ram_backends[self.region] + return ram_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/rds/models.py b/moto/rds/models.py index 691cf49d7..8e7202cf0 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -8,7 +8,7 @@ from collections import defaultdict from jinja2 import Template from re import compile as re_compile from collections import OrderedDict -from moto.core import BaseBackend, BaseModel, CloudFormationModel, get_account_id +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds, BackendDict from moto.ec2.models import ec2_backends @@ -52,6 +52,7 @@ class Cluster: self.engine_mode = kwargs.get("engine_mode") or "provisioned" self.iops = kwargs.get("iops") self.status = "active" + self.account_id = kwargs.get("account_id") self.region_name = kwargs.get("region") self.cluster_create_time = iso_8601_datetime_with_milliseconds( datetime.datetime.now() @@ -116,9 +117,7 @@ class Cluster: @property def db_cluster_arn(self): - return "arn:aws:rds:{0}:{1}:cluster:{2}".format( - self.region_name, get_account_id(), self.db_cluster_identifier - ) + return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}" def to_xml(self): template = Template( @@ -265,9 +264,7 @@ class ClusterSnapshot(BaseModel): @property def snapshot_arn(self): - return "arn:aws:rds:{0}:{1}:cluster-snapshot:{2}".format( - self.cluster.region_name, get_account_id(), self.snapshot_id - ) + return f"arn:aws:rds:{self.cluster.region_name}:{self.cluster.account_id}:cluster-snapshot:{self.snapshot_id}" def to_xml(self): template = Template( @@ -341,6 +338,7 @@ class Database(CloudFormationModel): self.status = "available" self.is_replica = False self.replicas = [] + self.account_id = kwargs.get("account_id") self.region_name = kwargs.get("region") self.engine = kwargs.get("engine") self.engine_version = kwargs.get("engine_version", None) @@ -392,7 +390,7 @@ class Database(CloudFormationModel): self.multi_az = kwargs.get("multi_az") self.db_subnet_group_name = kwargs.get("db_subnet_group_name") if self.db_subnet_group_name: - self.db_subnet_group = rds_backends[ + self.db_subnet_group = rds_backends[self.account_id][ self.region_name ].describe_subnet_groups(self.db_subnet_group_name)[0] else: @@ -407,7 +405,7 @@ class Database(CloudFormationModel): self.db_parameter_group_name and not self.is_default_parameter_group(self.db_parameter_group_name) and self.db_parameter_group_name - not in rds_backends[self.region_name].db_parameter_groups + not in rds_backends[self.account_id][self.region_name].db_parameter_groups ): raise DBParameterGroupNotFoundError(self.db_parameter_group_name) @@ -420,7 +418,7 @@ class Database(CloudFormationModel): if ( self.option_group_name and self.option_group_name - not in rds_backends[self.region_name].option_groups + not in rds_backends[self.account_id][self.region_name].option_groups ): raise OptionGroupNotFoundFaultError(self.option_group_name) self.default_option_groups = { @@ -443,9 +441,7 @@ class Database(CloudFormationModel): @property def db_instance_arn(self): - return "arn:aws:rds:{0}:{1}:db:{2}".format( - self.region_name, get_account_id(), self.db_instance_identifier - ) + return f"arn:aws:rds:{self.region_name}:{self.account_id}:db:{self.db_instance_identifier}" @property def physical_resource_id(self): @@ -462,6 +458,7 @@ class Database(CloudFormationModel): description = "Default parameter group for {0}".format(db_family) return [ DBParameterGroup( + account_id=self.account_id, name=db_parameter_group_name, family=db_family, description=description, @@ -470,17 +467,11 @@ class Database(CloudFormationModel): ) ] else: - if ( - self.db_parameter_group_name - not in rds_backends[self.region_name].db_parameter_groups - ): + backend = rds_backends[self.account_id][self.region_name] + if self.db_parameter_group_name not in backend.db_parameter_groups: raise DBParameterGroupNotFoundError(self.db_parameter_group_name) - return [ - rds_backends[self.region_name].db_parameter_groups[ - self.db_parameter_group_name - ] - ] + return [backend.db_parameter_groups[self.db_parameter_group_name]] def is_default_parameter_group(self, param_group_name): return param_group_name.startswith("default.%s" % self.engine.lower()) @@ -710,7 +701,7 @@ class Database(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] @@ -740,6 +731,7 @@ class Database(CloudFormationModel): "port": properties.get("Port", 3306), "publicly_accessible": properties.get("PubliclyAccessible"), "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"), + "account_id": account_id, "region": region_name, "security_groups": security_groups, "storage_encrypted": properties.get("StorageEncrypted"), @@ -748,7 +740,7 @@ class Database(CloudFormationModel): "vpc_security_group_ids": properties.get("VpcSecurityGroupIds", []), } - rds_backend = rds_backends[region_name] + rds_backend = rds_backends[account_id][region_name] source_db_identifier = properties.get("SourceDBInstanceIdentifier") if source_db_identifier: # Replica @@ -846,8 +838,8 @@ class Database(CloudFormationModel): def remove_tags(self, tag_keys): self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, region_name): - backend = rds_backends[region_name] + def delete(self, account_id, region_name): + backend = rds_backends[account_id][region_name] backend.delete_db_instance(self.db_instance_identifier) @@ -873,9 +865,7 @@ class DatabaseSnapshot(BaseModel): @property def snapshot_arn(self): - return "arn:aws:rds:{0}:{1}:snapshot:{2}".format( - self.database.region_name, get_account_id(), self.snapshot_id - ) + return f"arn:aws:rds:{self.database.region_name}:{self.database.account_id}:snapshot:{self.snapshot_id}" def to_xml(self): template = Template( @@ -983,15 +973,13 @@ class EventSubscription(BaseModel): self.tags = kwargs.get("tags", True) self.region_name = "" - self.customer_aws_id = copy.copy(get_account_id()) + self.customer_aws_id = kwargs["account_id"] self.status = "active" self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) @property def es_arn(self): - return "arn:aws:rds:{0}:{1}:es:{2}".format( - self.region_name, get_account_id(), self.subscription_name - ) + return f"arn:aws:rds:{self.region_name}:{self.customer_aws_id}:es:{self.subscription_name}" def to_xml(self): template = Template( @@ -1039,14 +1027,14 @@ class EventSubscription(BaseModel): class SecurityGroup(CloudFormationModel): - def __init__(self, group_name, description, tags): + def __init__(self, account_id, group_name, description, tags): self.group_name = group_name self.description = description self.status = "authorized" self.ip_ranges = [] self.ec2_security_groups = [] self.tags = tags - self.owner_id = get_account_id() + self.owner_id = account_id self.vpc_id = None def to_xml(self): @@ -1112,7 +1100,7 @@ class SecurityGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] group_name = resource_name.lower() @@ -1120,8 +1108,8 @@ class SecurityGroup(CloudFormationModel): security_group_ingress_rules = properties.get("DBSecurityGroupIngress", []) tags = properties.get("Tags") - ec2_backend = ec2_backends[region_name] - rds_backend = rds_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] + rds_backend = rds_backends[account_id][region_name] security_group = rds_backend.create_db_security_group( group_name, description, tags ) @@ -1149,8 +1137,8 @@ class SecurityGroup(CloudFormationModel): def remove_tags(self, tag_keys): self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, region_name): - backend = rds_backends[region_name] + def delete(self, account_id, region_name): + backend = rds_backends[account_id][region_name] backend.delete_security_group(self.group_name) @@ -1220,7 +1208,7 @@ class SubnetGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] @@ -1228,9 +1216,9 @@ class SubnetGroup(CloudFormationModel): subnet_ids = properties["SubnetIds"] tags = properties.get("Tags") - ec2_backend = ec2_backends[region_name] + ec2_backend = ec2_backends[account_id][region_name] subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids] - rds_backend = rds_backends[region_name] + rds_backend = rds_backends[account_id][region_name] subnet_group = rds_backend.create_subnet_group( resource_name, description, subnets, tags ) @@ -1248,8 +1236,8 @@ class SubnetGroup(CloudFormationModel): def remove_tags(self, tag_keys): self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, region_name): - backend = rds_backends[region_name] + def delete(self, account_id, region_name): + backend = rds_backends[account_id][region_name] backend.delete_subnet_group(self.subnet_name) @@ -1440,7 +1428,7 @@ class RDSBackend(BaseBackend): if self.arn_regex.match(db_id): arn_breakdown = db_id.split(":") region = arn_breakdown[3] - backend = rds_backends[region] + backend = rds_backends[self.account_id][region] db_name = arn_breakdown[-1] else: backend = self @@ -1466,7 +1454,7 @@ class RDSBackend(BaseBackend): raise DBInstanceNotFoundError(db_instance_identifier) def create_db_security_group(self, group_name, description, tags): - security_group = SecurityGroup(group_name, description, tags) + security_group = SecurityGroup(self.account_id, group_name, description, tags) self.security_groups[group_name] = security_group return security_group @@ -1701,6 +1689,7 @@ class RDSBackend(BaseBackend): "The parameter DBParameterGroupName must be provided and must not be blank.", ) db_parameter_group_kwargs["region"] = self.region_name + db_parameter_group_kwargs["account_id"] = self.account_id db_parameter_group = DBParameterGroup(**db_parameter_group_kwargs) self.db_parameter_groups[db_parameter_group_id] = db_parameter_group return db_parameter_group @@ -1748,6 +1737,7 @@ class RDSBackend(BaseBackend): def create_db_cluster(self, kwargs): cluster_id = kwargs["db_cluster_identifier"] + kwargs["account_id"] = self.account_id cluster = Cluster(**kwargs) self.clusters[cluster_id] = cluster initial_state = copy.deepcopy(cluster) # Return status=creating @@ -1924,6 +1914,7 @@ class RDSBackend(BaseBackend): if subscription_name in self.event_subscriptions: raise SubscriptionAlreadyExistError(subscription_name) + kwargs["account_id"] = self.account_id subscription = EventSubscription(kwargs) self.event_subscriptions[subscription_name] = subscription @@ -2143,18 +2134,14 @@ class OptionGroup(object): self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] -def make_rds_arn(region, name): - return "arn:aws:rds:{0}:{1}:pg:{2}".format(region, get_account_id(), name) - - class DBParameterGroup(CloudFormationModel): - def __init__(self, name, description, family, tags, region): + def __init__(self, account_id, name, description, family, tags, region): self.name = name self.description = description self.family = family self.tags = tags self.parameters = defaultdict(dict) - self.arn = make_rds_arn(region, name) + self.arn = f"arn:aws:rds:{region}:{account_id}:pg:{name}" def to_xml(self): template = Template( @@ -2184,8 +2171,8 @@ class DBParameterGroup(CloudFormationModel): parameter = self.parameters[new_parameter["ParameterName"]] parameter.update(new_parameter) - def delete(self, region_name): - backend = rds_backends[region_name] + def delete(self, account_id, region_name): + backend = rds_backends[account_id][region_name] backend.delete_db_parameter_group(self.name) @staticmethod @@ -2199,7 +2186,7 @@ class DBParameterGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] @@ -2217,7 +2204,7 @@ class DBParameterGroup(CloudFormationModel): {"ParameterName": db_parameter, "ParameterValue": db_parameter_value} ) - rds_backend = rds_backends[region_name] + rds_backend = rds_backends[account_id][region_name] db_parameter_group = rds_backend.create_db_parameter_group( db_parameter_group_kwargs ) diff --git a/moto/rds/responses.py b/moto/rds/responses.py index 77f01bb51..862f34444 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -7,9 +7,12 @@ from .exceptions import DBParameterGroupNotFoundError class RDSResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="rds") + @property def backend(self): - return rds_backends[self.region] + return rds_backends[self.current_account][self.region] def _get_db_kwargs(self): args = { @@ -44,6 +47,7 @@ class RDSResponse(BaseResponse): # PreferredBackupWindow # PreferredMaintenanceWindow "publicly_accessible": self._get_param("PubliclyAccessible"), + "account_id": self.current_account, "region": self.region, "security_groups": self._get_multi_param( "DBSecurityGroups.DBSecurityGroupName" @@ -344,7 +348,8 @@ class RDSResponse(BaseResponse): subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) subnets = [ - ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ec2_backends[self.current_account][self.region].get_subnet(subnet_id) + for subnet_id in subnet_ids ] subnet_group = self.backend.create_subnet_group( subnet_name, description, subnets, tags @@ -363,7 +368,8 @@ class RDSResponse(BaseResponse): description = self._get_param("DBSubnetGroupDescription") subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") subnets = [ - ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ec2_backends[self.current_account][self.region].get_subnet(subnet_id) + for subnet_id in subnet_ids ] subnet_group = self.backend.modify_db_subnet_group( subnet_name, description, subnets diff --git a/moto/redshift/models.py b/moto/redshift/models.py index e29e8a98a..98f48389e 100644 --- a/moto/redshift/models.py +++ b/moto/redshift/models.py @@ -29,14 +29,12 @@ from .exceptions import ( ) -from moto.core import get_account_id - - class TaggableResourceMixin(object): resource_type = None - def __init__(self, region_name, tags): + def __init__(self, account_id, region_name, tags): + self.account_id = account_id self.region = region_name self.tags = tags or [] @@ -46,12 +44,7 @@ class TaggableResourceMixin(object): @property def arn(self): - return "arn:aws:redshift:{region}:{account_id}:{resource_type}:{resource_id}".format( - region=self.region, - account_id=get_account_id(), - resource_type=self.resource_type, - resource_id=self.resource_id, - ) + return f"arn:aws:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}" def create_tags(self, tags): new_keys = [tag_set["Key"] for tag_set in tags] @@ -97,7 +90,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel): restored_from_snapshot=False, kms_key_id=None, ): - super().__init__(region_name, tags) + super().__init__(redshift_backend.account_id, region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier self.create_time = iso_8601_datetime_with_milliseconds( @@ -171,9 +164,9 @@ class Cluster(TaggableResourceMixin, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - redshift_backend = redshift_backends[region_name] + redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] if "ClusterSubnetGroupName" in properties: @@ -359,7 +352,7 @@ class SubnetGroup(TaggableResourceMixin, CloudFormationModel): region_name, tags=None, ): - super().__init__(region_name, tags) + super().__init__(ec2_backend.account_id, region_name, tags) self.ec2_backend = ec2_backend self.cluster_subnet_group_name = cluster_subnet_group_name self.description = description @@ -378,9 +371,9 @@ class SubnetGroup(TaggableResourceMixin, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - redshift_backend = redshift_backends[region_name] + redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] subnet_group = redshift_backend.create_cluster_subnet_group( @@ -426,9 +419,14 @@ class SecurityGroup(TaggableResourceMixin, BaseModel): resource_type = "securitygroup" def __init__( - self, cluster_security_group_name, description, region_name, tags=None + self, + cluster_security_group_name, + description, + account_id, + region_name, + tags=None, ): - super().__init__(region_name, tags) + super().__init__(account_id, region_name, tags) self.cluster_security_group_name = cluster_security_group_name self.description = description self.ingress_rules = [] @@ -456,10 +454,11 @@ class ParameterGroup(TaggableResourceMixin, CloudFormationModel): cluster_parameter_group_name, group_family, description, + account_id, region_name, tags=None, ): - super().__init__(region_name, tags) + super().__init__(account_id, region_name, tags) self.cluster_parameter_group_name = cluster_parameter_group_name self.group_family = group_family self.description = description @@ -475,9 +474,9 @@ class ParameterGroup(TaggableResourceMixin, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - redshift_backend = redshift_backends[region_name] + redshift_backend = redshift_backends[account_id][region_name] properties = cloudformation_json["Properties"] parameter_group = redshift_backend.create_cluster_parameter_group( @@ -509,12 +508,13 @@ class Snapshot(TaggableResourceMixin, BaseModel): self, cluster, snapshot_identifier, + account_id, region_name, tags=None, iam_roles_arn=None, snapshot_type="manual", ): - super().__init__(region_name, tags) + super().__init__(account_id, region_name, tags) self.cluster = copy.copy(cluster) self.snapshot_identifier = snapshot_identifier self.snapshot_type = snapshot_type @@ -559,7 +559,7 @@ class RedshiftBackend(BaseBackend): self.subnet_groups = {} self.security_groups = { "Default": SecurityGroup( - "Default", "Default Redshift Security Group", self.region_name + "Default", "Default Redshift Security Group", account_id, region_name ) } self.parameter_groups = { @@ -567,10 +567,11 @@ class RedshiftBackend(BaseBackend): "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", + self.account_id, self.region_name, ) } - self.ec2_backend = ec2_backends[self.region_name] + self.ec2_backend = ec2_backends[self.account_id][self.region_name] self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { "cluster": self.clusters, @@ -776,10 +777,14 @@ class RedshiftBackend(BaseBackend): raise ClusterSubnetGroupNotFoundError(subnet_identifier) def create_cluster_security_group( - self, cluster_security_group_name, description, region_name, tags=None + self, cluster_security_group_name, description, tags=None ): security_group = SecurityGroup( - cluster_security_group_name, description, region_name, tags + cluster_security_group_name, + description, + self.account_id, + self.region_name, + tags, ) self.security_groups[cluster_security_group_name] = security_group return security_group @@ -817,7 +822,12 @@ class RedshiftBackend(BaseBackend): tags=None, ): parameter_group = ParameterGroup( - cluster_parameter_group_name, group_family, description, region_name, tags + cluster_parameter_group_name, + group_family, + description, + self.account_id, + region_name, + tags, ) self.parameter_groups[cluster_parameter_group_name] = parameter_group @@ -851,7 +861,12 @@ class RedshiftBackend(BaseBackend): if self.snapshots.get(snapshot_identifier) is not None: raise ClusterSnapshotAlreadyExistsError(snapshot_identifier) snapshot = Snapshot( - cluster, snapshot_identifier, region_name, tags, snapshot_type=snapshot_type + cluster, + snapshot_identifier, + self.account_id, + region_name, + tags, + snapshot_type=snapshot_type, ) self.snapshots[snapshot_identifier] = snapshot return snapshot diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index f3422c034..a9934682a 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -45,9 +45,12 @@ def itemize(data): class RedshiftResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="redshift") + @property def redshift_backend(self): - return redshift_backends[self.region] + return redshift_backends[self.current_account][self.region] def get_response(self, response): if self.request_json: @@ -390,7 +393,6 @@ class RedshiftResponse(BaseResponse): security_group = self.redshift_backend.create_cluster_security_group( cluster_security_group_name=cluster_security_group_name, description=description, - region_name=self.region, tags=tags, ) diff --git a/moto/redshiftdata/responses.py b/moto/redshiftdata/responses.py index d73e7c205..9674bb315 100644 --- a/moto/redshiftdata/responses.py +++ b/moto/redshiftdata/responses.py @@ -4,9 +4,12 @@ from .models import redshiftdata_backends class RedshiftDataAPIServiceResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="redshift-data") + @property def redshiftdata_backend(self): - return redshiftdata_backends[self.region] + return redshiftdata_backends[self.current_account][self.region] def cancel_statement(self): statement_id = self._get_param("Id") diff --git a/moto/rekognition/responses.py b/moto/rekognition/responses.py index fcb767a5e..d82e73b40 100644 --- a/moto/rekognition/responses.py +++ b/moto/rekognition/responses.py @@ -8,10 +8,13 @@ from .models import rekognition_backends class RekognitionResponse(BaseResponse): """Handler for Rekognition requests and responses.""" + def __init__(self): + super().__init__(service_name="rekognition") + @property def rekognition_backend(self): """Return backend instance specific for this region.""" - return rekognition_backends[self.region] + return rekognition_backends[self.current_account][self.region] def get_face_search(self): ( diff --git a/moto/resourcegroups/models.py b/moto/resourcegroups/models.py index 72fb2d16b..e17272538 100644 --- a/moto/resourcegroups/models.py +++ b/moto/resourcegroups/models.py @@ -3,14 +3,20 @@ from builtins import str import json import re -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from .exceptions import BadRequestException class FakeResourceGroup(BaseModel): def __init__( - self, name, resource_query, description=None, tags=None, configuration=None + self, + account_id, + name, + resource_query, + description=None, + tags=None, + configuration=None, ): self.errors = [] description = description or "" @@ -24,9 +30,7 @@ class FakeResourceGroup(BaseModel): if self._validate_tags(value=tags): self._tags = tags self._raise_errors() - self.arn = "arn:aws:resource-groups:us-west-1:{AccountId}:{name}".format( - name=name, AccountId=get_account_id() - ) + self.arn = f"arn:aws:resource-groups:us-west-1:{account_id}:{name}" self.configuration = configuration @staticmethod @@ -305,6 +309,7 @@ class ResourceGroupsBackend(BaseBackend): ): tags = tags or {} group = FakeResourceGroup( + account_id=self.account_id, name=name, resource_query=resource_query, description=description, diff --git a/moto/resourcegroups/responses.py b/moto/resourcegroups/responses.py index 19042e2fd..9a076f867 100644 --- a/moto/resourcegroups/responses.py +++ b/moto/resourcegroups/responses.py @@ -7,11 +7,12 @@ from .models import resourcegroups_backends class ResourceGroupsResponse(BaseResponse): - SERVICE_NAME = "resource-groups" + def __init__(self): + super().__init__(service_name="resource-groups") @property def resourcegroups_backend(self): - return resourcegroups_backends[self.region] + return resourcegroups_backends[self.current_account][self.region] def create_group(self): name = self._get_param("Name") diff --git a/moto/resourcegroupstaggingapi/models.py b/moto/resourcegroupstaggingapi/models.py index 145bd9a1c..369858033 100644 --- a/moto/resourcegroupstaggingapi/models.py +++ b/moto/resourcegroupstaggingapi/models.py @@ -1,6 +1,5 @@ import uuid -from moto.core import get_account_id from moto.core import BaseBackend from moto.core.exceptions import RESTError from moto.core.utils import BackendDict @@ -37,77 +36,77 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): """ :rtype: moto.s3.models.S3Backend """ - return s3_backends["global"] + return s3_backends[self.account_id]["global"] @property def ec2_backend(self): """ :rtype: moto.ec2.models.EC2Backend """ - return ec2_backends[self.region_name] + return ec2_backends[self.account_id][self.region_name] @property def elb_backend(self): """ :rtype: moto.elb.models.ELBBackend """ - return elb_backends[self.region_name] + return elb_backends[self.account_id][self.region_name] @property def elbv2_backend(self): """ :rtype: moto.elbv2.models.ELBv2Backend """ - return elbv2_backends[self.region_name] + return elbv2_backends[self.account_id][self.region_name] @property def kinesis_backend(self): """ :rtype: moto.kinesis.models.KinesisBackend """ - return kinesis_backends[self.region_name] + return kinesis_backends[self.account_id][self.region_name] @property def kms_backend(self): """ :rtype: moto.kms.models.KmsBackend """ - return kms_backends[self.region_name] + return kms_backends[self.account_id][self.region_name] @property def rds_backend(self): """ :rtype: moto.rds.models.RDSBackend """ - return rds_backends[self.region_name] + return rds_backends[self.account_id][self.region_name] @property def glacier_backend(self): """ :rtype: moto.glacier.models.GlacierBackend """ - return glacier_backends[self.region_name] + return glacier_backends[self.account_id][self.region_name] @property def emr_backend(self): """ :rtype: moto.emr.models.ElasticMapReduceBackend """ - return emr_backends[self.region_name] + return emr_backends[self.account_id][self.region_name] @property def redshift_backend(self): """ :rtype: moto.redshift.models.RedshiftBackend """ - return redshift_backends[self.region_name] + return redshift_backends[self.account_id][self.region_name] @property def lambda_backend(self): """ :rtype: moto.awslambda.models.LambdaBackend """ - return lambda_backends[self.region_name] + return lambda_backends[self.account_id][self.region_name] def _get_resources_generator(self, tag_filters=None, resource_type_filters=None): # Look at @@ -414,9 +413,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): ): # Skip if no tags, or invalid filter continue yield { - "ResourceARN": "arn:aws:ec2:{0}:{1}:vpc/{2}".format( - self.region_name, get_account_id(), vpc.id - ), + "ResourceARN": f"arn:aws:ec2:{self.region_name}:{self.account_id}:vpc/{vpc.id}", "Tags": tags, } # VPC Customer Gateway diff --git a/moto/resourcegroupstaggingapi/responses.py b/moto/resourcegroupstaggingapi/responses.py index a153e2bdf..3ecea5f6e 100644 --- a/moto/resourcegroupstaggingapi/responses.py +++ b/moto/resourcegroupstaggingapi/responses.py @@ -4,7 +4,8 @@ import json class ResourceGroupsTaggingAPIResponse(BaseResponse): - SERVICE_NAME = "resourcegroupstaggingapi" + def __init__(self): + super().__init__(service_name="resourcegroupstaggingapi") @property def backend(self): @@ -13,7 +14,7 @@ class ResourceGroupsTaggingAPIResponse(BaseResponse): :returns: Resource tagging api backend :rtype: moto.resourcegroupstaggingapi.models.ResourceGroupsTaggingAPIBackend """ - return resourcegroupstaggingapi_backends[self.region] + return resourcegroupstaggingapi_backends[self.current_account][self.region] def get_resources(self): pagination_token = self._get_param("PaginationToken") diff --git a/moto/route53/models.py b/moto/route53/models.py index db713cbe9..ceb771bbe 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -21,7 +21,7 @@ from moto.route53.exceptions import ( PublicZoneVPCAssociation, QueryLoggingConfigAlreadyExists, ) -from moto.core import BaseBackend, BaseModel, CloudFormationModel, get_account_id +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import BackendDict from moto.utilities.paginator import paginate from .utils import PAGINATION_MODEL @@ -96,7 +96,7 @@ class HealthCheck(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"]["HealthCheckConfig"] health_check_args = { @@ -109,7 +109,8 @@ class HealthCheck(CloudFormationModel): "request_interval": properties.get("RequestInterval"), "failure_threshold": properties.get("FailureThreshold"), } - health_check = route53_backend.create_health_check( + backend = route53_backends[account_id]["global"] + health_check = backend.create_health_check( caller_reference=resource_name, health_check_args=health_check_args ) return health_check @@ -193,42 +194,49 @@ class RecordSet(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") + backend = route53_backends[account_id]["global"] if zone_name: - hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) + hosted_zone = backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) + hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"]) record_set = hosted_zone.add_rrset(properties) return record_set @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # this will break if you changed the zone the record is in, # unfortunately properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") + backend = route53_backends[account_id]["global"] if zone_name: - hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) + hosted_zone = backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) + hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"]) try: hosted_zone.delete_rrset({"Name": resource_name}) @@ -239,11 +247,12 @@ class RecordSet(CloudFormationModel): def physical_resource_id(self): return self.name - def delete(self, *args, **kwargs): # pylint: disable=unused-argument - """Not exposed as part of the Route 53 API - used for CloudFormation. args are ignored""" - hosted_zone = route53_backend.get_hosted_zone_by_name(self.hosted_zone_name) + def delete(self, account_id, region): # pylint: disable=unused-argument + """Not exposed as part of the Route 53 API - used for CloudFormation""" + backend = route53_backends[account_id][region] + hosted_zone = backend.get_hosted_zone_by_name(self.hosted_zone_name) if not hosted_zone: - hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id) + hosted_zone = backend.get_hosted_zone(self.hosted_zone_id) hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_}) @@ -352,9 +361,9 @@ class FakeZone(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - hosted_zone = route53_backend.create_hosted_zone( + hosted_zone = route53_backends[account_id]["global"].create_hosted_zone( resource_name, private_zone=False ) return hosted_zone @@ -380,15 +389,16 @@ class RecordSetGroup(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") + backend = route53_backends[account_id]["global"] if zone_name: - hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) + hosted_zone = backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) + hosted_zone = backend.get_hosted_zone(properties["HostedZoneId"]) record_sets = properties["RecordSets"] for record_set in record_sets: hosted_zone.add_rrset(record_set) @@ -599,7 +609,7 @@ class Route53Backend(BaseBackend): { "HostedZoneId": zone.id, "Name": zone.name, - "Owner": {"OwningAccount": get_account_id()}, + "Owner": {"OwningAccount": self.account_id}, } ) @@ -723,7 +733,7 @@ class Route53Backend(BaseBackend): from moto.logs import logs_backends # pylint: disable=import-outside-toplevel - response = logs_backends[region].describe_log_groups() + response = logs_backends[self.account_id][region].describe_log_groups() log_groups = response[0] if response else [] for entry in log_groups: if log_group_arn == entry["arn"]: @@ -804,4 +814,3 @@ class Route53Backend(BaseBackend): route53_backends = BackendDict( Route53Backend, "route53", use_boto3_regions=False, additional_regions=["global"] ) -route53_backend = route53_backends["global"] diff --git a/moto/route53/responses.py b/moto/route53/responses.py index 01ca96431..d0f5a4549 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -6,7 +6,7 @@ import xmltodict from moto.core.responses import BaseResponse from moto.route53.exceptions import InvalidChangeBatch -from moto.route53.models import route53_backend +from moto.route53.models import route53_backends XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/" @@ -14,6 +14,9 @@ XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/" class Route53(BaseResponse): """Handler for Route53 requests and responses.""" + def __init__(self): + super().__init__(service_name="route53") + @staticmethod def _convert_to_bool(bool_str): if isinstance(bool_str, bool): @@ -24,6 +27,10 @@ class Route53(BaseResponse): return False + @property + def backend(self): + return route53_backends[self.current_account]["global"] + def list_or_create_hostzone_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -60,7 +67,7 @@ class Route53(BaseResponse): name += "." delegation_set_id = zone_request.get("DelegationSetId") - new_zone = route53_backend.create_hosted_zone( + new_zone = self.backend.create_hosted_zone( name, comment=comment, private_zone=private_zone, @@ -72,7 +79,7 @@ class Route53(BaseResponse): return 201, headers, template.render(zone=new_zone) elif request.method == "GET": - all_zones = route53_backend.list_hosted_zones() + all_zones = self.backend.list_hosted_zones() template = Template(LIST_HOSTED_ZONES_RESPONSE) return 200, headers, template.render(zones=all_zones) @@ -82,7 +89,7 @@ class Route53(BaseResponse): query_params = parse_qs(parsed_url.query) dnsname = query_params.get("dnsname") - dnsname, zones = route53_backend.list_hosted_zones_by_name(dnsname) + dnsname, zones = self.backend.list_hosted_zones_by_name(dnsname) template = Template(LIST_HOSTED_ZONES_BY_NAME_RESPONSE) return 200, headers, template.render(zones=zones, dnsname=dnsname, xmlns=XMLNS) @@ -92,13 +99,13 @@ class Route53(BaseResponse): parsed_url = urlparse(full_url) query_params = parse_qs(parsed_url.query) vpc_id = query_params.get("vpcid")[0] - zones = route53_backend.list_hosted_zones_by_vpc(vpc_id) + zones = self.backend.list_hosted_zones_by_vpc(vpc_id) template = Template(LIST_HOSTED_ZONES_BY_VPC_RESPONSE) return 200, headers, template.render(zones=zones, xmlns=XMLNS) def get_hosted_zone_count_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) - num_zones = route53_backend.get_hosted_zone_count() + num_zones = self.backend.get_hosted_zone_count() template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE) return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS) @@ -108,18 +115,18 @@ class Route53(BaseResponse): zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1] if request.method == "GET": - the_zone = route53_backend.get_hosted_zone(zoneid) + the_zone = self.backend.get_hosted_zone(zoneid) template = Template(GET_HOSTED_ZONE_RESPONSE) return 200, headers, template.render(zone=the_zone) elif request.method == "DELETE": - route53_backend.delete_hosted_zone(zoneid) + self.backend.delete_hosted_zone(zoneid) return 200, headers, DELETE_HOSTED_ZONE_RESPONSE elif request.method == "POST": elements = xmltodict.parse(self.body) comment = elements.get("UpdateHostedZoneCommentRequest", {}).get( "Comment", None ) - zone = route53_backend.update_hosted_zone_comment(zoneid, comment) + zone = self.backend.update_hosted_zone_comment(zoneid, comment) template = Template(UPDATE_HOSTED_ZONE_COMMENT_RESPONSE) return 200, headers, template.render(zone=zone) @@ -134,7 +141,7 @@ class Route53(BaseResponse): zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1] if method == "GET": - route53_backend.get_dnssec(zoneid) + self.backend.get_dnssec(zoneid) return 200, headers, GET_DNSSEC def associate_vpc_response(self, request, full_url, headers): @@ -151,7 +158,7 @@ class Route53(BaseResponse): vpcid = vpc.get("VPCId", None) vpcregion = vpc.get("VPCRegion", None) - route53_backend.associate_vpc_with_hosted_zone(zoneid, vpcid, vpcregion) + self.backend.associate_vpc_with_hosted_zone(zoneid, vpcid, vpcregion) template = Template(ASSOCIATE_VPC_RESPONSE) return 200, headers, template.render(comment=comment) @@ -169,7 +176,7 @@ class Route53(BaseResponse): vpc = elements.get("DisassociateVPCFromHostedZoneRequest", {}).get("VPC", {}) vpcid = vpc.get("VPCId", None) - route53_backend.disassociate_vpc_from_hosted_zone(zoneid, vpcid) + self.backend.disassociate_vpc_from_hosted_zone(zoneid, vpcid) template = Template(DISASSOCIATE_VPC_RESPONSE) return 200, headers, template.render(comment=comment) @@ -212,7 +219,7 @@ class Route53(BaseResponse): if effective_rr_count > 1000: raise InvalidChangeBatch - error_msg = route53_backend.change_resource_record_sets(zoneid, change_list) + error_msg = self.backend.change_resource_record_sets(zoneid, change_list) if error_msg: return 400, headers, error_msg @@ -233,7 +240,7 @@ class Route53(BaseResponse): next_name, next_type, is_truncated, - ) = route53_backend.list_resource_record_sets( + ) = self.backend.list_resource_record_sets( zoneid, start_type=start_type, start_name=start_name, @@ -275,19 +282,19 @@ class Route53(BaseResponse): "children": config.get("ChildHealthChecks", {}).get("ChildHealthCheck"), "regions": config.get("Regions", {}).get("Region"), } - health_check = route53_backend.create_health_check( + health_check = self.backend.create_health_check( caller_reference, health_check_args ) template = Template(CREATE_HEALTH_CHECK_RESPONSE) return 201, headers, template.render(health_check=health_check, xmlns=XMLNS) elif method == "DELETE": health_check_id = parsed_url.path.split("/")[-1] - route53_backend.delete_health_check(health_check_id) + self.backend.delete_health_check(health_check_id) template = Template(DELETE_HEALTH_CHECK_RESPONSE) return 200, headers, template.render(xmlns=XMLNS) elif method == "GET": template = Template(LIST_HEALTH_CHECKS_RESPONSE) - health_checks = route53_backend.list_health_checks() + health_checks = self.backend.list_health_checks() return ( 200, headers, @@ -302,11 +309,11 @@ class Route53(BaseResponse): health_check_id = parsed_url.path.split("/")[-1] if method == "GET": - health_check = route53_backend.get_health_check(health_check_id) + health_check = self.backend.get_health_check(health_check_id) template = Template(GET_HEALTH_CHECK_RESPONSE) return 200, headers, template.render(health_check=health_check) elif method == "DELETE": - route53_backend.delete_health_check(health_check_id) + self.backend.delete_health_check(health_check_id) template = Template(DELETE_HEALTH_CHECK_RESPONSE) return 200, headers, template.render(xmlns=XMLNS) elif method == "POST": @@ -325,7 +332,7 @@ class Route53(BaseResponse): "children": config.get("ChildHealthChecks", {}).get("ChildHealthCheck"), "regions": config.get("Regions", {}).get("Region"), } - health_check = route53_backend.update_health_check( + health_check = self.backend.update_health_check( health_check_id, health_check_args ) template = Template(UPDATE_HEALTH_CHECK_RESPONSE) @@ -351,7 +358,7 @@ class Route53(BaseResponse): type_ = parsed_url.path.split("/")[-2] if request.method == "GET": - tags = route53_backend.list_tags_for_resource(id_) + tags = self.backend.list_tags_for_resource(id_) template = Template(LIST_TAGS_FOR_RESOURCE_RESPONSE) return ( 200, @@ -367,7 +374,7 @@ class Route53(BaseResponse): elif "RemoveTagKeys" in tags: tags = tags["RemoveTagKeys"] - route53_backend.change_tags_for_resource(id_, tags) + self.backend.change_tags_for_resource(id_, tags) template = Template(CHANGE_TAGS_FOR_RESOURCE_RESPONSE) return 200, headers, template.render() @@ -388,7 +395,7 @@ class Route53(BaseResponse): hosted_zone_id = json_body["HostedZoneId"] log_group_arn = json_body["CloudWatchLogsLogGroupArn"] - query_logging_config = route53_backend.create_query_logging_config( + query_logging_config = self.backend.create_query_logging_config( self.region, hosted_zone_id, log_group_arn ) @@ -407,7 +414,7 @@ class Route53(BaseResponse): # The paginator picks up named arguments, returns tuple. # pylint: disable=unbalanced-tuple-unpacking - (all_configs, next_token,) = route53_backend.list_query_logging_configs( + (all_configs, next_token,) = self.backend.list_query_logging_configs( hosted_zone_id=hosted_zone_id, next_token=next_token, max_results=max_results, @@ -430,7 +437,7 @@ class Route53(BaseResponse): query_logging_config_id = parsed_url.path.rstrip("/").rsplit("/", 1)[1] if request.method == "GET": - query_logging_config = route53_backend.get_query_logging_config( + query_logging_config = self.backend.get_query_logging_config( query_logging_config_id ) template = Template(GET_QUERY_LOGGING_CONFIG_RESPONSE) @@ -441,13 +448,13 @@ class Route53(BaseResponse): ) elif request.method == "DELETE": - route53_backend.delete_query_logging_config(query_logging_config_id) + self.backend.delete_query_logging_config(query_logging_config_id) return 200, headers, "" def reusable_delegation_sets(self, request, full_url, headers): self.setup_class(request, full_url, headers) if request.method == "GET": - delegation_sets = route53_backend.list_reusable_delegation_sets() + delegation_sets = self.backend.list_reusable_delegation_sets() template = self.response_template(LIST_REUSABLE_DELEGATION_SETS_TEMPLATE) return ( 200, @@ -464,7 +471,7 @@ class Route53(BaseResponse): root_elem = elements["CreateReusableDelegationSetRequest"] caller_reference = root_elem.get("CallerReference") hosted_zone_id = root_elem.get("HostedZoneId") - delegation_set = route53_backend.create_reusable_delegation_set( + delegation_set = self.backend.create_reusable_delegation_set( caller_reference=caller_reference, hosted_zone_id=hosted_zone_id ) template = self.response_template(CREATE_REUSABLE_DELEGATION_SET_TEMPLATE) @@ -479,13 +486,13 @@ class Route53(BaseResponse): parsed_url = urlparse(full_url) ds_id = parsed_url.path.rstrip("/").rsplit("/")[-1] if request.method == "GET": - delegation_set = route53_backend.get_reusable_delegation_set( + delegation_set = self.backend.get_reusable_delegation_set( delegation_set_id=ds_id ) template = self.response_template(GET_REUSABLE_DELEGATION_SET_TEMPLATE) return 200, {}, template.render(delegation_set=delegation_set) if request.method == "DELETE": - route53_backend.delete_reusable_delegation_set(delegation_set_id=ds_id) + self.backend.delete_reusable_delegation_set(delegation_set_id=ds_id) template = self.response_template(DELETE_REUSABLE_DELEGATION_SET_TEMPLATE) return 200, {}, template.render() diff --git a/moto/route53resolver/models.py b/moto/route53resolver/models.py index 90eb6e66b..7ae1158a5 100644 --- a/moto/route53resolver/models.py +++ b/moto/route53resolver/models.py @@ -4,7 +4,6 @@ from datetime import datetime, timezone from ipaddress import ip_address, ip_network, IPv4Address import re -from moto.core import get_account_id from moto.core import BaseBackend, BaseModel from moto.core.utils import get_random_hex, BackendDict from moto.ec2 import ec2_backends @@ -87,6 +86,7 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes def __init__( self, + account_id, region, rule_id, creator_request_id, @@ -96,6 +96,7 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes resolver_endpoint_id=None, name=None, ): # pylint: disable=too-many-arguments + self.account_id = account_id self.region = region self.creator_request_id = creator_request_id self.name = name @@ -123,7 +124,7 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes @property def arn(self): """Return ARN for this resolver rule.""" - return f"arn:aws:route53resolver:{self.region}:{get_account_id()}:resolver-rule/{self.id}" + return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-rule/{self.id}" def description(self): """Return a dictionary of relevant info for this resolver rule.""" @@ -138,7 +139,7 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes "Name": self.name, "TargetIps": self.target_ips, "ResolverEndpointId": self.resolver_endpoint_id, - "OwnerId": get_account_id(), + "OwnerId": self.account_id, "ShareStatus": self.share_status, "CreationTime": self.creation_time, "ModificationTime": self.modification_time, @@ -165,6 +166,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut def __init__( self, + account_id, region, endpoint_id, creator_request_id, @@ -173,12 +175,14 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut ip_addresses, name=None, ): # pylint: disable=too-many-arguments + self.account_id = account_id self.region = region self.creator_request_id = creator_request_id self.name = name self.security_group_ids = security_group_ids self.direction = direction self.ip_addresses = ip_addresses + self.ec2_backend = ec2_backends[self.account_id][self.region] # Constructed members. self.id = endpoint_id # pylint: disable=invalid-name @@ -204,7 +208,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut @property def arn(self): """Return ARN for this resolver endpoint.""" - return f"arn:aws:route53resolver:{self.region}:{get_account_id()}:resolver-endpoint/{self.id}" + return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-endpoint/{self.id}" def _vpc_id_from_subnet(self): """Return VPC Id associated with the subnet. @@ -214,9 +218,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut of the subnets has already been checked. """ first_subnet_id = self.ip_addresses[0]["SubnetId"] - subnet_info = ec2_backends[self.region].get_all_subnets( - subnet_ids=[first_subnet_id] - )[0] + subnet_info = self.ec2_backend.get_all_subnets(subnet_ids=[first_subnet_id])[0] return subnet_info.vpc_id def _build_subnet_info(self): @@ -234,7 +236,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut eni_ids = [] for subnet, ip_info in self.subnets.items(): for ip_addr, eni_id in ip_info.items(): - eni_info = ec2_backends[self.region].create_network_interface( + eni_info = self.ec2_backend.create_network_interface( description=f"Route 53 Resolver: {self.id}:{eni_id}", group_ids=self.security_group_ids, interface_type="interface", @@ -250,7 +252,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut def delete_eni(self): """Delete the VPC ENI created for the subnet and IP combos.""" for eni_id in self.eni_ids: - ec2_backends[self.region].delete_network_interface(eni_id) + self.ec2_backend.delete_network_interface(eni_id) def description(self): """Return a dictionary of relevant info for this resolver endpoint.""" @@ -299,7 +301,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut eni_id = f"rni-{get_random_hex(17)}" self.subnets[ip_address["SubnetId"]][ip_address["Ip"]] = eni_id - eni_info = ec2_backends[self.region].create_network_interface( + eni_info = self.ec2_backend.create_network_interface( description=f"Route 53 Resolver: {self.id}:{eni_id}", group_ids=self.security_group_ids, interface_type="interface", @@ -326,9 +328,9 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut else: self.subnets[ip_address["SubnetId"]].pop(ip_address["Ip"]) for eni_id in self.eni_ids: - eni_info = ec2_backends[self.region].get_network_interface(eni_id) + eni_info = self.ec2_backend.get_network_interface(eni_id) if eni_info.private_ip_address == ip_address.get("Ip"): - ec2_backends[self.region].delete_network_interface(eni_id) + self.ec2_backend.delete_network_interface(eni_id) self.eni_ids.remove(eni_id) self.ip_address_count = len(self.ip_addresses) @@ -343,6 +345,8 @@ class Route53ResolverBackend(BaseBackend): self.resolver_rule_associations = {} # Key is resolver_rule_association_id) self.tagger = TaggingService() + self.ec2_backend = ec2_backends[self.account_id][self.region_name] + @staticmethod def default_vpc_endpoint_service(service_region, zones): """List of dicts representing default VPC endpoints for this service.""" @@ -350,18 +354,20 @@ class Route53ResolverBackend(BaseBackend): service_region, zones, "route53resolver" ) - def associate_resolver_rule(self, region, resolver_rule_id, name, vpc_id): + def associate_resolver_rule(self, resolver_rule_id, name, vpc_id): validate_args( [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)] ) associations = [ - x for x in self.resolver_rule_associations.values() if x.region == region + x + for x in self.resolver_rule_associations.values() + if x.region == self.region_name ] if len(associations) > ResolverRuleAssociation.MAX_RULE_ASSOCIATIONS_PER_REGION: # This error message was not verified to be the same for AWS. raise LimitExceededException( - f"Account '{get_account_id()}' has exceeded 'max-rule-association'" + f"Account '{self.account_id}' has exceeded 'max-rule-association'" ) if resolver_rule_id not in self.resolver_rules: @@ -369,7 +375,7 @@ class Route53ResolverBackend(BaseBackend): f"Resolver rule with ID '{resolver_rule_id}' does not exist." ) - vpcs = ec2_backends[region].describe_vpcs() + vpcs = self.ec2_backend.describe_vpcs() if vpc_id not in [x.id for x in vpcs]: raise InvalidParameterException(f"The vpc ID '{vpc_id}' does not exist") @@ -386,14 +392,14 @@ class Route53ResolverBackend(BaseBackend): rule_association_id = f"rslvr-rrassoc-{get_random_hex(17)}" rule_association = ResolverRuleAssociation( - region, rule_association_id, resolver_rule_id, vpc_id, name + self.region_name, rule_association_id, resolver_rule_id, vpc_id, name ) self.resolver_rule_associations[rule_association_id] = rule_association return rule_association - @staticmethod - def _verify_subnet_ips(region, ip_addresses, initial=True): - """Perform additional checks on the IPAddresses. + def _verify_subnet_ips(self, ip_addresses, initial=True): + """ + Perform additional checks on the IPAddresses. NOTE: This does not include IPv6 addresses. """ @@ -407,9 +413,9 @@ class Route53ResolverBackend(BaseBackend): subnets = defaultdict(set) for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]: try: - subnet_info = ec2_backends[region].get_all_subnets( - subnet_ids=[subnet_id] - )[0] + subnet_info = self.ec2_backend.get_all_subnets(subnet_ids=[subnet_id])[ + 0 + ] except InvalidSubnetIdError as exc: raise InvalidParameterException( f"The subnet ID '{subnet_id}' does not exist" @@ -430,8 +436,7 @@ class Route53ResolverBackend(BaseBackend): ) subnets[subnet_id].add(ip_addr) - @staticmethod - def _verify_security_group_ids(region, security_group_ids): + def _verify_security_group_ids(self, security_group_ids): """Perform additional checks on the security groups.""" if len(security_group_ids) > 10: raise InvalidParameterException("Maximum of 10 security groups are allowed") @@ -443,7 +448,7 @@ class Route53ResolverBackend(BaseBackend): f"(expecting 'sg-...')" ) try: - ec2_backends[region].describe_security_groups(group_ids=[group_id]) + self.ec2_backend.describe_security_groups(group_ids=[group_id]) except InvalidSecurityGroupNotFoundError as exc: raise ResourceNotFoundException( f"The security group '{group_id}' does not exist" @@ -484,18 +489,18 @@ class Route53ResolverBackend(BaseBackend): endpoints = [x for x in self.resolver_endpoints.values() if x.region == region] if len(endpoints) > ResolverEndpoint.MAX_ENDPOINTS_PER_REGION: raise LimitExceededException( - f"Account '{get_account_id()}' has exceeded 'max-endpoints'" + f"Account '{self.account_id}' has exceeded 'max-endpoints'" ) for x in ip_addresses: if not x.get("Ip"): - subnet_info = ec2_backends[region].get_all_subnets( + subnet_info = self.ec2_backend.get_all_subnets( subnet_ids=[x["SubnetId"]] )[0] x["Ip"] = subnet_info.get_available_subnet_ip(self) - self._verify_subnet_ips(region, ip_addresses) - self._verify_security_group_ids(region, security_group_ids) + self._verify_subnet_ips(ip_addresses) + self._verify_security_group_ids(security_group_ids) if creator_request_id in [ x.creator_request_id for x in self.resolver_endpoints.values() ]: @@ -508,6 +513,7 @@ class Route53ResolverBackend(BaseBackend): f"rslvr-{'in' if direction == 'INBOUND' else 'out'}-{get_random_hex(17)}" ) resolver_endpoint = ResolverEndpoint( + self.account_id, region, endpoint_id, creator_request_id, @@ -553,7 +559,7 @@ class Route53ResolverBackend(BaseBackend): if len(rules) > ResolverRule.MAX_RULES_PER_REGION: # Did not verify that this is the actual error message. raise LimitExceededException( - f"Account '{get_account_id()}' has exceeded 'max-rules'" + f"Account '{self.account_id}' has exceeded 'max-rules'" ) # Per the AWS documentation and as seen with the AWS console, target @@ -601,6 +607,7 @@ class Route53ResolverBackend(BaseBackend): rule_id = f"rslvr-rr-{get_random_hex(17)}" resolver_rule = ResolverRule( + self.account_id, region, rule_id, creator_request_id, @@ -865,18 +872,16 @@ class Route53ResolverBackend(BaseBackend): resolver_endpoint.update_name(name) return resolver_endpoint - def associate_resolver_endpoint_ip_address( - self, region, resolver_endpoint_id, ip_address - ): + def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, ip_address): self._validate_resolver_endpoint_id(resolver_endpoint_id) resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] if not ip_address.get("Ip"): - subnet_info = ec2_backends[region].get_all_subnets( + subnet_info = self.ec2_backend.get_all_subnets( subnet_ids=[ip_address.get("SubnetId")] )[0] ip_address["Ip"] = subnet_info.get_available_subnet_ip(self) - self._verify_subnet_ips(region, [ip_address], False) + self._verify_subnet_ips([ip_address], False) resolver_endpoint.associate_ip_address(ip_address) return resolver_endpoint diff --git a/moto/route53resolver/responses.py b/moto/route53resolver/responses.py index f2ac8e95c..316fbaf01 100644 --- a/moto/route53resolver/responses.py +++ b/moto/route53resolver/responses.py @@ -11,10 +11,13 @@ from moto.route53resolver.validations import validate_args class Route53ResolverResponse(BaseResponse): """Handler for Route53Resolver requests and responses.""" + def __init__(self): + super().__init__(service_name="route53-resolver") + @property def route53resolver_backend(self): """Return backend instance specific for this region.""" - return route53resolver_backends[self.region] + return route53resolver_backends[self.current_account][self.region] def associate_resolver_rule(self): """Associate a Resolver rule with a VPC.""" @@ -23,7 +26,6 @@ class Route53ResolverResponse(BaseResponse): vpc_id = self._get_param("VPCId") resolver_rule_association = ( self.route53resolver_backend.associate_resolver_rule( - region=self.region, resolver_rule_id=resolver_rule_id, name=name, vpc_id=vpc_id, @@ -263,7 +265,6 @@ class Route53ResolverResponse(BaseResponse): resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint = ( self.route53resolver_backend.associate_resolver_endpoint_ip_address( - region=self.region, resolver_endpoint_id=resolver_endpoint_id, ip_address=ip_address, ) diff --git a/moto/s3/config.py b/moto/s3/config.py index 98bdfe2dd..acb7d57e3 100644 --- a/moto/s3/config.py +++ b/moto/s3/config.py @@ -8,6 +8,7 @@ from moto.s3 import s3_backends class S3ConfigQuery(ConfigQueryModel): def list_config_service_resources( self, + account_id, resource_ids, resource_name, limit, @@ -28,14 +29,14 @@ class S3ConfigQuery(ConfigQueryModel): # If no filter was passed in for resource names/ids then return them all: if not resource_ids and not resource_name: - bucket_list = list(self.backends["global"].buckets.keys()) + bucket_list = list(self.backends[account_id]["global"].buckets.keys()) else: # Match the resource name / ID: bucket_list = [] filter_buckets = [resource_name] if resource_name else resource_ids - for bucket in self.backends["global"].buckets.keys(): + for bucket in self.backends[account_id]["global"].buckets.keys(): if bucket in filter_buckets: bucket_list.append(bucket) @@ -45,7 +46,10 @@ class S3ConfigQuery(ConfigQueryModel): region_buckets = [] for bucket in bucket_list: - if self.backends["global"].buckets[bucket].region_name == region_filter: + if ( + self.backends[account_id]["global"].buckets[bucket].region_name + == region_filter + ): region_buckets.append(bucket) bucket_list = region_buckets @@ -80,7 +84,9 @@ class S3ConfigQuery(ConfigQueryModel): "type": "AWS::S3::Bucket", "id": bucket, "name": bucket, - "region": self.backends["global"].buckets[bucket].region_name, + "region": self.backends[account_id]["global"] + .buckets[bucket] + .region_name, } for bucket in bucket_list ], @@ -88,10 +94,15 @@ class S3ConfigQuery(ConfigQueryModel): ) def get_config_resource( - self, resource_id, resource_name=None, backend_region=None, resource_region=None + self, + account_id, + resource_id, + resource_name=None, + backend_region=None, + resource_region=None, ): # Get the bucket: - bucket = self.backends["global"].buckets.get(resource_id, {}) + bucket = self.backends[account_id]["global"].buckets.get(resource_id, {}) if not bucket: return diff --git a/moto/s3/models.py b/moto/s3/models.py index 84459b54a..eb599e0bf 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -17,13 +17,8 @@ import urllib.parse from bisect import insort from importlib import reload -from moto.core import ( - get_account_id, - BaseBackend, - BaseModel, - CloudFormationModel, - CloudWatchMetricProvider, -) +from moto.core import BaseBackend, BaseModel, CloudFormationModel +from moto.core import CloudWatchMetricProvider from moto.core.utils import ( iso_8601_datetime_without_milliseconds_s3, @@ -80,13 +75,6 @@ DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" -def get_moto_s3_account_id(): - """This makes it easy for mocking AWS Account IDs when using AWS Config - -- Simply mock.patch get_account_id() here, and Config gets it for free. - """ - return get_account_id() - - class FakeDeleteMarker(BaseModel): def __init__(self, key): self.key = key @@ -108,6 +96,7 @@ class FakeKey(BaseModel, ManagedState): self, name, value, + account_id=None, storage="STANDARD", etag=None, is_versioned=False, @@ -131,6 +120,7 @@ class FakeKey(BaseModel, ManagedState): ], ) self.name = name + self.account_id = account_id self.last_modified = datetime.datetime.utcnow() self.acl = get_canned_acl("private") self.website_redirect_location = None @@ -289,7 +279,9 @@ class FakeKey(BaseModel, ManagedState): res["x-amz-object-lock-retain-until-date"] = self.lock_until if self.lock_mode: res["x-amz-object-lock-mode"] = self.lock_mode - tags = s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn) + tags = s3_backends[self.account_id]["global"].tagger.get_tag_dict_for_resource( + self.arn + ) if tags: res["x-amz-tagging-count"] = str(len(tags.keys())) @@ -865,8 +857,9 @@ class PublicAccessBlock(BaseModel): class FakeBucket(CloudFormationModel): - def __init__(self, name, region_name): + def __init__(self, name, account_id, region_name): self.name = name + self.account_id = account_id self.region_name = region_name self.keys = _VersionedKeyStore() self.multiparts = {} @@ -1150,7 +1143,7 @@ class FakeBucket(CloudFormationModel): raise InvalidNotificationDestination() # Send test events so the user can verify these notifications were set correctly - notifications.send_test_event(bucket=self) + notifications.send_test_event(account_id=self.account_id, bucket=self) def set_accelerate_configuration(self, accelerate_config): if self.accelerate_configuration is None and accelerate_config == "Suspended": @@ -1224,15 +1217,17 @@ class FakeBucket(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - bucket = s3_backends["global"].create_bucket(resource_name, region_name) + bucket = s3_backends[account_id]["global"].create_bucket( + resource_name, region_name + ) properties = cloudformation_json.get("Properties", {}) if "BucketEncryption" in properties: bucket_encryption = cfn_to_api_encryption(properties["BucketEncryption"]) - s3_backends["global"].put_bucket_encryption( + s3_backends[account_id]["global"].put_bucket_encryption( bucket_name=resource_name, encryption=bucket_encryption ) @@ -1240,7 +1235,12 @@ class FakeBucket(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] @@ -1249,11 +1249,14 @@ class FakeBucket(CloudFormationModel): if resource_name_property not in properties: properties[resource_name_property] = new_resource_name new_resource = cls.create_from_cloudformation_json( - properties[resource_name_property], cloudformation_json, region_name + properties[resource_name_property], + cloudformation_json, + account_id, + region_name, ) properties[resource_name_property] = original_resource.name cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return new_resource @@ -1262,16 +1265,16 @@ class FakeBucket(CloudFormationModel): bucket_encryption = cfn_to_api_encryption( properties["BucketEncryption"] ) - s3_backends["global"].put_bucket_encryption( + s3_backends[account_id]["global"].put_bucket_encryption( bucket_name=original_resource.name, encryption=bucket_encryption ) return original_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - s3_backends["global"].delete_bucket(resource_name) + s3_backends[account_id]["global"].delete_bucket(resource_name) def to_config_dict(self): """Return the AWS Config JSON format of this S3 bucket. @@ -1296,7 +1299,9 @@ class FakeBucket(CloudFormationModel): "resourceCreationTime": str(self.creation_date), "relatedEvents": [], "relationships": [], - "tags": s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn), + "tags": s3_backends[self.account_id][ + "global" + ].tagger.get_tag_dict_for_resource(self.arn), "configuration": { "name": self.name, "owner": {"id": OWNER}, @@ -1445,9 +1450,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): # metric_providers["S3"] = self @classmethod - def get_cloudwatch_metrics(cls): + def get_cloudwatch_metrics(cls, account_id): metrics = [] - for name, bucket in s3_backends["global"].buckets.items(): + for name, bucket in s3_backends[account_id]["global"].buckets.items(): metrics.append( MetricDatum( namespace="AWS/S3", @@ -1485,7 +1490,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): raise BucketAlreadyExists(bucket=bucket_name) if not MIN_BUCKET_NAME_LENGTH <= len(bucket_name) <= MAX_BUCKET_NAME_LENGTH: raise InvalidBucketName() - new_bucket = FakeBucket(name=bucket_name, region_name=region_name) + new_bucket = FakeBucket( + name=bucket_name, account_id=self.account_id, region_name=region_name + ) self.buckets[bucket_name] = new_bucket @@ -1493,7 +1500,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): "version": "0", "bucket": {"name": bucket_name}, "request-id": "N4N7GDK58NMKJ12R", - "requester": get_account_id(), + "requester": self.account_id, "source-ip-address": "1.2.3.4", "reason": "PutObject", } @@ -1687,6 +1694,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): name=key_name, bucket_name=bucket_name, value=value, + account_id=self.account_id, storage=storage, etag=etag, is_versioned=bucket.is_versioned, @@ -1707,7 +1715,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): ] + [new_key] bucket.keys.setlist(key_name, keys) - notifications.send_event(notifications.S3_OBJECT_CREATE_PUT, bucket, new_key) + notifications.send_event( + self.account_id, notifications.S3_OBJECT_CREATE_PUT, bucket, new_key + ) return new_key @@ -2130,7 +2140,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): # Send notifications that an object was copied bucket = self.get_bucket(dest_bucket_name) - notifications.send_event(notifications.S3_OBJECT_CREATE_COPY, bucket, new_key) + notifications.send_event( + self.account_id, notifications.S3_OBJECT_CREATE_COPY, bucket, new_key + ) def put_bucket_acl(self, bucket_name, acl): bucket = self.get_bucket(bucket_name) diff --git a/moto/s3/notifications.py b/moto/s3/notifications.py index 0de29b852..12be09203 100644 --- a/moto/s3/notifications.py +++ b/moto/s3/notifications.py @@ -38,7 +38,7 @@ def _get_region_from_arn(arn): return arn.split(":")[3] -def send_event(event_name, bucket, key): +def send_event(account_id, event_name, bucket, key): if bucket.notification_configuration is None: return @@ -47,7 +47,7 @@ def send_event(event_name, bucket, key): event_body = _get_s3_event(event_name, bucket, key, notification.id) region_name = _get_region_from_arn(notification.arn) - _invoke_awslambda(event_body, notification.arn, region_name) + _invoke_awslambda(account_id, event_body, notification.arn, region_name) for notification in bucket.notification_configuration.queue: if notification.matches(event_name, key.name): @@ -55,14 +55,14 @@ def send_event(event_name, bucket, key): region_name = _get_region_from_arn(notification.arn) queue_name = notification.arn.split(":")[-1] - _send_sqs_message(event_body, queue_name, region_name) + _send_sqs_message(account_id, event_body, queue_name, region_name) -def _send_sqs_message(event_body, queue_name, region_name): +def _send_sqs_message(account_id, event_body, queue_name, region_name): try: from moto.sqs.models import sqs_backends - sqs_backend = sqs_backends[region_name] + sqs_backend = sqs_backends[account_id][region_name] sqs_backend.send_message( queue_name=queue_name, message_body=json.dumps(event_body) ) @@ -74,11 +74,11 @@ def _send_sqs_message(event_body, queue_name, region_name): pass -def _invoke_awslambda(event_body, fn_arn, region_name): +def _invoke_awslambda(account_id, event_body, fn_arn, region_name): try: from moto.awslambda.models import lambda_backends - lambda_backend = lambda_backends[region_name] + lambda_backend = lambda_backends[account_id][region_name] func = lambda_backend.get_function(fn_arn) func.invoke(json.dumps(event_body), dict(), dict()) except: # noqa @@ -99,10 +99,10 @@ def _get_test_event(bucket_name): } -def send_test_event(bucket): +def send_test_event(account_id, bucket): arns = [n.arn for n in bucket.notification_configuration.queue] for arn in set(arns): region_name = _get_region_from_arn(arn) queue_name = arn.split(":")[-1] message_body = _get_test_event(bucket.name) - _send_sqs_message(message_body, queue_name, region_name) + _send_sqs_message(account_id, message_body, queue_name, region_name) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 6d88bae44..5a59f1097 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -17,7 +17,6 @@ import xmltodict from moto.core.responses import BaseResponse from moto.core.utils import path_url -from moto.core import get_account_id from moto.s3bucket_path.utils import ( bucket_name_from_url as bucketpath_bucket_name_from_url, @@ -155,9 +154,12 @@ def is_delete_keys(request, path): class S3Response(BaseResponse): + def __init__(self): + super().__init__(service_name="s3") + @property def backend(self): - return s3_backends["global"] + return s3_backends[self.current_account]["global"] @property def should_autoescape(self): @@ -429,7 +431,11 @@ class S3Response(BaseResponse): if upload.key_name.startswith(prefix) ] template = self.response_template(S3_ALL_MULTIPARTS) - return template.render(bucket_name=bucket_name, uploads=multiparts) + return template.render( + bucket_name=bucket_name, + uploads=multiparts, + account_id=self.current_account, + ) elif "location" in querystring: location = self.backend.get_bucket_location(bucket_name) template = self.response_template(S3_BUCKET_LOCATION) @@ -2429,8 +2435,7 @@ S3_MULTIPART_COMPLETE_RESPONSE = """ """ -S3_ALL_MULTIPARTS = ( - """ +S3_ALL_MULTIPARTS = """ {{ bucket_name }} @@ -2442,9 +2447,7 @@ S3_ALL_MULTIPARTS = ( {{ upload.key_name }} {{ upload.id }} - arn:aws:iam::""" - + get_account_id() - + """:user/user1-11111a31-17b5-4fb7-9df5-b111111f13de + arn:aws:iam::{{ account_id }}:user/user1-11111a31-17b5-4fb7-9df5-b111111f13de user1-11111a31-17b5-4fb7-9df5-b111111f13de @@ -2457,7 +2460,6 @@ S3_ALL_MULTIPARTS = ( {% endfor %} """ -) S3_NO_POLICY = """ diff --git a/moto/s3control/config.py b/moto/s3control/config.py index 10914a21f..5d579adf8 100644 --- a/moto/s3control/config.py +++ b/moto/s3control/config.py @@ -7,12 +7,12 @@ from boto3 import Session from moto.core.exceptions import InvalidNextTokenException from moto.core.common_models import ConfigQueryModel from moto.s3control import s3control_backends -from moto.s3.models import get_moto_s3_account_id class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): def list_config_service_resources( self, + account_id, resource_ids, resource_name, limit, @@ -29,19 +29,18 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): return [], None pab = None - account_id = get_moto_s3_account_id() regions = [region for region in Session().get_available_regions("config")] # If a resource ID was passed in, then filter accordingly: if resource_ids: for resource_id in resource_ids: if account_id == resource_id: - pab = self.backends["global"].public_access_block + pab = self.backends[account_id]["global"].public_access_block break # Otherwise, just grab the one from the backend: if not resource_ids: - pab = self.backends["global"].public_access_block + pab = self.backends[account_id]["global"].public_access_block # If it's not present, then return nothing if not pab: @@ -95,18 +94,23 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): ) def get_config_resource( - self, resource_id, resource_name=None, backend_region=None, resource_region=None + self, + account_id, + resource_id, + resource_name=None, + backend_region=None, + resource_region=None, ): + # Do we even have this defined? - if not self.backends["global"].public_access_block: + backend = self.backends[account_id]["global"] + if not backend.public_access_block: return None # Resource name can only ever be "" if it's supplied: if resource_name is not None and resource_name != "": return None - # Are we filtering based on region? - account_id = get_moto_s3_account_id() regions = [region for region in Session().get_available_regions("config")] # Is the resource ID correct?: @@ -138,9 +142,7 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): "resourceId": account_id, "awsRegion": pab_region, "availabilityZone": "Not Applicable", - "configuration": self.backends[ - "global" - ].public_access_block.to_config_dict(), + "configuration": backend.public_access_block.to_config_dict(), "supplementaryConfiguration": {}, } diff --git a/moto/s3control/models.py b/moto/s3control/models.py index 2b3719d5d..f8b0951d8 100644 --- a/moto/s3control/models.py +++ b/moto/s3control/models.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import get_random_hex, BackendDict from moto.s3.exceptions import ( WrongPublicAccessBlockAccountIdError, @@ -14,13 +14,18 @@ from .exceptions import AccessPointNotFound, AccessPointPolicyNotFound class AccessPoint(BaseModel): def __init__( - self, name, bucket, vpc_configuration, public_access_block_configuration + self, + account_id, + name, + bucket, + vpc_configuration, + public_access_block_configuration, ): self.name = name self.alias = f"{name}-{get_random_hex(34)}-s3alias" self.bucket = bucket self.created = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") - self.arn = f"arn:aws:s3:us-east-1:{get_account_id()}:accesspoint/{name}" + self.arn = f"arn:aws:s3:us-east-1:{account_id}:accesspoint/{name}" self.policy = None self.network_origin = "VPC" if vpc_configuration else "Internet" self.vpc_id = (vpc_configuration or {}).get("VpcId") @@ -50,7 +55,7 @@ class S3ControlBackend(BaseBackend): def get_public_access_block(self, account_id): # The account ID should equal the account id that is set for Moto: - if account_id != get_account_id(): + if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() if not self.public_access_block: @@ -60,14 +65,14 @@ class S3ControlBackend(BaseBackend): def delete_public_access_block(self, account_id): # The account ID should equal the account id that is set for Moto: - if account_id != get_account_id(): + if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() self.public_access_block = None def put_public_access_block(self, account_id, pub_block_config): # The account ID should equal the account id that is set for Moto: - if account_id != get_account_id(): + if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() if not pub_block_config: @@ -89,7 +94,11 @@ class S3ControlBackend(BaseBackend): public_access_block_configuration, ): access_point = AccessPoint( - name, bucket, vpc_configuration, public_access_block_configuration + account_id, + name, + bucket, + vpc_configuration, + public_access_block_configuration, ) self.access_points[account_id][name] = access_point return access_point diff --git a/moto/s3control/responses.py b/moto/s3control/responses.py index 07cc33ee5..1c166d478 100644 --- a/moto/s3control/responses.py +++ b/moto/s3control/responses.py @@ -9,14 +9,16 @@ from .models import s3control_backends class S3ControlResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="s3control") + @property def backend(self): - return s3control_backends["global"] + return s3control_backends[self.current_account]["global"] @amzn_request_id - def public_access_block( - self, request, full_url, headers - ): # pylint: disable=unused-argument + def public_access_block(self, request, full_url, headers): + self.setup_class(request, full_url, headers) try: if request.method == "GET": return self.get_public_access_block(request) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index f5490d10a..e4c9c2523 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1,7 +1,7 @@ import json import os from datetime import datetime -from moto.core import get_account_id, BaseBackend, BaseModel, CloudFormationModel +from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.utils import BackendDict from moto.sagemaker import validators from moto.utilities.paginator import paginate @@ -45,6 +45,10 @@ PAGINATION_MODEL = { } +def arn_formatter(_type, _id, account_id, region_name): + return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}" + + class BaseObject(BaseModel): def camelCase(self, key): words = [] @@ -80,14 +84,15 @@ class FakeProcessingJob(BaseObject): processing_inputs, processing_job_name, processing_output_config, + account_id, region_name, role_arn, tags, stopping_condition, ): self.processing_job_name = processing_job_name - self.processing_job_arn = FakeProcessingJob.arn_formatter( - processing_job_name, region_name + self.processing_job_arn = arn_formatter( + "processing-job", processing_job_name, account_id, region_name ) now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -115,21 +120,11 @@ class FakeProcessingJob(BaseObject): def response_create(self): return {"ProcessingJobArn": self.processing_job_arn} - @staticmethod - def arn_formatter(endpoint_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":processing-job/" - + endpoint_name - ) - class FakeTrainingJob(BaseObject): def __init__( self, + account_id, region_name, training_job_name, hyper_parameters, @@ -170,8 +165,8 @@ class FakeTrainingJob(BaseObject): self.debug_rule_configurations = debug_rule_configurations self.tensor_board_output_config = tensor_board_output_config self.experiment_config = experiment_config - self.training_job_arn = FakeTrainingJob.arn_formatter( - training_job_name, region_name + self.training_job_arn = arn_formatter( + "training-job", training_job_name, account_id, region_name ) self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" @@ -224,21 +219,11 @@ class FakeTrainingJob(BaseObject): def response_create(self): return {"TrainingJobArn": self.training_job_arn} - @staticmethod - def arn_formatter(endpoint_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":training-job/" - + endpoint_name - ) - class FakeEndpoint(BaseObject, CloudFormationModel): def __init__( self, + account_id, region_name, endpoint_name, endpoint_config_name, @@ -247,7 +232,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel): tags, ): self.endpoint_name = endpoint_name - self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name) + self.endpoint_arn = FakeEndpoint.arn_formatter( + endpoint_name, account_id, region_name + ) self.endpoint_config_name = endpoint_config_name self.production_variants = self._process_production_variants( production_variants @@ -308,15 +295,8 @@ class FakeEndpoint(BaseObject, CloudFormationModel): return {"EndpointArn": self.endpoint_arn} @staticmethod - def arn_formatter(endpoint_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":endpoint/" - + endpoint_name - ) + def arn_formatter(endpoint_name, account_id, region_name): + return arn_formatter("endpoint", endpoint_name, account_id, region_name) @property def physical_resource_id(self): @@ -345,9 +325,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - sagemaker_backend = sagemaker_backends[region_name] + sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] @@ -362,32 +342,41 @@ class FakeEndpoint(BaseObject, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # Changes to the Endpoint will not change resource name cls.delete_from_cloudformation_json( - original_resource.endpoint_arn, cloudformation_json, region_name + original_resource.endpoint_arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( - original_resource.endpoint_name, cloudformation_json, region_name + original_resource.endpoint_name, + cloudformation_json, + account_id, + region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. endpoint_name = resource_name.split("/")[-1] - sagemaker_backends[region_name].delete_endpoint(endpoint_name) + sagemaker_backends[account_id][region_name].delete_endpoint(endpoint_name) class FakeEndpointConfig(BaseObject, CloudFormationModel): def __init__( self, + account_id, region_name, endpoint_config_name, production_variants, @@ -399,7 +388,7 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): self.endpoint_config_name = endpoint_config_name self.endpoint_config_arn = FakeEndpointConfig.arn_formatter( - endpoint_config_name, region_name + endpoint_config_name, account_id, region_name ) self.production_variants = production_variants or [] self.data_capture_config = data_capture_config or {} @@ -498,14 +487,9 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): return {"EndpointConfigArn": self.endpoint_config_arn} @staticmethod - def arn_formatter(model_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":endpoint-config/" - + model_name + def arn_formatter(endpoint_config_name, account_id, region_name): + return arn_formatter( + "endpoint-config", endpoint_config_name, account_id, region_name ) @property @@ -535,9 +519,9 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - sagemaker_backend = sagemaker_backends[region_name] + sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] @@ -554,32 +538,43 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # Most changes to the endpoint config will change resource name for EndpointConfigs cls.delete_from_cloudformation_json( - original_resource.endpoint_config_arn, cloudformation_json, region_name + original_resource.endpoint_config_arn, + cloudformation_json, + account_id, + region_name, ) new_resource = cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) return new_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. endpoint_config_name = resource_name.split("/")[-1] - sagemaker_backends[region_name].delete_endpoint_config(endpoint_config_name) + sagemaker_backends[account_id][region_name].delete_endpoint_config( + endpoint_config_name + ) class Model(BaseObject, CloudFormationModel): def __init__( self, + account_id, region_name, model_name, execution_role_arn, @@ -596,7 +591,9 @@ class Model(BaseObject, CloudFormationModel): self.vpc_config = vpc_config self.primary_container = primary_container self.execution_role_arn = execution_role_arn or "arn:test" - self.model_arn = self.arn_for_model_name(self.model_name, region_name) + self.model_arn = arn_formatter( + "model", self.model_name, account_id, region_name + ) @property def response_object(self): @@ -609,17 +606,6 @@ class Model(BaseObject, CloudFormationModel): def response_create(self): return {"ModelArn": self.model_arn} - @staticmethod - def arn_for_model_name(model_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":model/" - + model_name - ) - @property def physical_resource_id(self): return self.model_arn @@ -647,9 +633,9 @@ class Model(BaseObject, CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - sagemaker_backend = sagemaker_backends[region_name] + sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] @@ -668,27 +654,32 @@ class Model(BaseObject, CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # Most changes to the model will change resource name for Models cls.delete_from_cloudformation_json( - original_resource.model_arn, cloudformation_json, region_name + original_resource.model_arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) return new_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. model_name = resource_name.split("/")[-1] - sagemaker_backends[region_name].delete_model(model_name) + sagemaker_backends[account_id][region_name].delete_model(model_name) class VpcConfig(BaseObject): @@ -723,6 +714,7 @@ class Container(BaseObject): class FakeSagemakerNotebookInstance(CloudFormationModel): def __init__( self, + account_id, region_name, notebook_instance_name, instance_type, @@ -759,6 +751,9 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): self.root_access = root_access self.status = None self.creation_time = self.last_modified_time = datetime.now() + self.arn = arn_formatter( + "notebook-instance", notebook_instance_name, account_id, region_name + ) self.start() def validate_volume_size_in_gb(self, volume_size_in_gb): @@ -813,17 +808,6 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): ) raise ValidationError(message=message) - @property - def arn(self): - return ( - "arn:aws:sagemaker:" - + self.region_name - + ":" - + str(get_account_id()) - + ":notebook-instance/" - + self.notebook_instance_name - ) - @property def url(self): return "{}.notebook.{}.sagemaker.aws".format( @@ -867,14 +851,14 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] instance_type = properties["InstanceType"] role_arn = properties["RoleArn"] - notebook = sagemaker_backends[region_name].create_notebook_instance( + notebook = sagemaker_backends[account_id][region_name].create_notebook_instance( notebook_instance_name=resource_name, instance_type=instance_type, role_arn=role_arn, @@ -883,34 +867,47 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( - original_resource.arn, cloudformation_json, region_name + original_resource.arn, cloudformation_json, account_id, region_name ) new_resource = cls.create_from_cloudformation_json( - original_resource.notebook_instance_name, cloudformation_json, region_name + original_resource.notebook_instance_name, + cloudformation_json, + account_id, + region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. notebook_instance_name = resource_name.split("/")[-1] - backend = sagemaker_backends[region_name] + backend = sagemaker_backends[account_id][region_name] backend.stop_notebook_instance(notebook_instance_name) backend.delete_notebook_instance(notebook_instance_name) class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( - self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start + self, + account_id, + region_name, + notebook_instance_lifecycle_config_name, + on_create, + on_start, ): self.region_name = region_name self.notebook_instance_lifecycle_config_name = ( @@ -923,19 +920,14 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod ) self.notebook_instance_lifecycle_config_arn = ( FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( - self.notebook_instance_lifecycle_config_name, self.region_name + self.notebook_instance_lifecycle_config_name, account_id, region_name ) ) @staticmethod - def arn_formatter(notebook_instance_lifecycle_config_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":notebook-instance-lifecycle-configuration/" - + notebook_instance_lifecycle_config_name + def arn_formatter(name, account_id, region_name): + return arn_formatter( + "notebook-instance-lifecycle-configuration", name, account_id, region_name ) @property @@ -976,11 +968,11 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] - config = sagemaker_backends[ + config = sagemaker_backends[account_id][ region_name ].create_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=resource_name, @@ -991,31 +983,38 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( original_resource.notebook_instance_lifecycle_config_arn, cloudformation_json, + account_id, region_name, ) new_resource = cls.create_from_cloudformation_json( original_resource.notebook_instance_lifecycle_config_name, cloudformation_json, + account_id, region_name, ) return new_resource @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. config_name = resource_name.split("/")[-1] - backend = sagemaker_backends[region_name] + backend = sagemaker_backends[account_id][region_name] backend.delete_notebook_instance_lifecycle_config(config_name) @@ -1087,6 +1086,7 @@ class SageMakerModelBackend(BaseBackend): def create_model(self, **kwargs): model_obj = Model( + account_id=self.account_id, region_name=self.region_name, model_name=kwargs.get("ModelName"), execution_role_arn=kwargs.get("ExecutionRoleArn"), @@ -1103,10 +1103,8 @@ class SageMakerModelBackend(BaseBackend): model = self._models.get(model_name) if model: return model - message = "Could not find model '{}'.".format( - Model.arn_for_model_name(model_name, self.region_name) - ) - raise ValidationError(message=message) + arn = arn_formatter("model", model_name, self.account_id, self.region_name) + raise ValidationError(message=f"Could not find model '{arn}'.") def list_models(self): return self._models.values() @@ -1121,7 +1119,10 @@ class SageMakerModelBackend(BaseBackend): def create_experiment(self, experiment_name): experiment = FakeExperiment( - region_name=self.region_name, experiment_name=experiment_name, tags=[] + account_id=self.account_id, + region_name=self.region_name, + experiment_name=experiment_name, + tags=[], ) self.experiments[experiment_name] = experiment return experiment.response_create @@ -1310,6 +1311,7 @@ class SageMakerModelBackend(BaseBackend): def create_trial(self, trial_name, experiment_name): trial = FakeTrial( + account_id=self.account_id, region_name=self.region_name, trial_name=trial_name, experiment_name=experiment_name, @@ -1360,6 +1362,7 @@ class SageMakerModelBackend(BaseBackend): def create_trial_component(self, trial_component_name, trial_name): trial_component = FakeTrialComponent( + account_id=self.account_id, region_name=self.region_name, trial_component_name=trial_component_name, trial_name=trial_name, @@ -1382,7 +1385,9 @@ class SageMakerModelBackend(BaseBackend): return self.trial_components[trial_component_name].response_object except KeyError: message = "Could not find trial component '{}'.".format( - FakeTrialComponent.arn_formatter(trial_component_name, self.region_name) + FakeTrialComponent.arn_formatter( + trial_component_name, self.account_id, self.region_name + ) ) raise ValidationError(message=message) @@ -1407,7 +1412,7 @@ class SageMakerModelBackend(BaseBackend): self.trials[trial_name].trial_components.extend([trial_component_name]) else: raise ResourceNotFound( - message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{get_account_id()}:experiment-trial/{trial_name}' does not exist." + message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}' does not exist." ) if trial_component_name in self.trial_components.keys(): @@ -1436,8 +1441,8 @@ class SageMakerModelBackend(BaseBackend): ) return { - "TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{get_account_id()}:experiment-trial-component/{trial_component_name}", - "TrialArn": f"arn:aws:sagemaker:{self.region_name}:{get_account_id()}:experiment-trial/{trial_name}", + "TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial-component/{trial_component_name}", + "TrialArn": f"arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}", } def create_notebook_instance( @@ -1460,6 +1465,7 @@ class SageMakerModelBackend(BaseBackend): self._validate_unique_notebook_instance_name(notebook_instance_name) notebook_instance = FakeSagemakerNotebookInstance( + account_id=self.account_id, region_name=self.region_name, notebook_instance_name=notebook_instance_name, instance_type=instance_type, @@ -1521,11 +1527,14 @@ class SageMakerModelBackend(BaseBackend): ): message = "Unable to create Notebook Instance Lifecycle Config {}. (Details: Notebook Instance Lifecycle Config already exists.)".format( FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( - notebook_instance_lifecycle_config_name, self.region_name + notebook_instance_lifecycle_config_name, + self.account_id, + self.region_name, ) ) raise ValidationError(message=message) lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig( + account_id=self.account_id, region_name=self.region_name, notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name, on_create=on_create, @@ -1546,7 +1555,9 @@ class SageMakerModelBackend(BaseBackend): except KeyError: message = "Unable to describe Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format( FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( - notebook_instance_lifecycle_config_name, self.region_name + notebook_instance_lifecycle_config_name, + self.account_id, + self.region_name, ) ) raise ValidationError(message=message) @@ -1561,7 +1572,9 @@ class SageMakerModelBackend(BaseBackend): except KeyError: message = "Unable to delete Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format( FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( - notebook_instance_lifecycle_config_name, self.region_name + notebook_instance_lifecycle_config_name, + self.account_id, + self.region_name, ) ) raise ValidationError(message=message) @@ -1575,6 +1588,7 @@ class SageMakerModelBackend(BaseBackend): kms_key_id, ): endpoint_config = FakeEndpointConfig( + account_id=self.account_id, region_name=self.region_name, endpoint_config_name=endpoint_config_name, production_variants=production_variants, @@ -1590,41 +1604,47 @@ class SageMakerModelBackend(BaseBackend): def validate_production_variants(self, production_variants): for production_variant in production_variants: if production_variant["ModelName"] not in self._models: - message = "Could not find model '{}'.".format( - Model.arn_for_model_name( - production_variant["ModelName"], self.region_name - ) + arn = arn_formatter( + "model", + production_variant["ModelName"], + self.account_id, + self.region_name, ) - raise ValidationError(message=message) + raise ValidationError(message=f"Could not find model '{arn}'.") def describe_endpoint_config(self, endpoint_config_name): try: return self.endpoint_configs[endpoint_config_name].response_object except KeyError: - message = "Could not find endpoint configuration '{}'.".format( - FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + arn = FakeEndpointConfig.arn_formatter( + endpoint_config_name, self.account_id, self.region_name + ) + raise ValidationError( + message=f"Could not find endpoint configuration '{arn}'." ) - raise ValidationError(message=message) def delete_endpoint_config(self, endpoint_config_name): try: del self.endpoint_configs[endpoint_config_name] except KeyError: - message = "Could not find endpoint configuration '{}'.".format( - FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + arn = FakeEndpointConfig.arn_formatter( + endpoint_config_name, self.account_id, self.region_name + ) + raise ValidationError( + message=f"Could not find endpoint configuration '{arn}'." ) - raise ValidationError(message=message) def create_endpoint(self, endpoint_name, endpoint_config_name, tags): try: endpoint_config = self.describe_endpoint_config(endpoint_config_name) except KeyError: - message = "Could not find endpoint_config '{}'.".format( - FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + arn = FakeEndpointConfig.arn_formatter( + endpoint_config_name, self.account_id, self.region_name ) - raise ValidationError(message=message) + raise ValidationError(message=f"Could not find endpoint_config '{arn}'.") endpoint = FakeEndpoint( + account_id=self.account_id, region_name=self.region_name, endpoint_name=endpoint_name, endpoint_config_name=endpoint_config_name, @@ -1640,19 +1660,19 @@ class SageMakerModelBackend(BaseBackend): try: return self.endpoints[endpoint_name].response_object except KeyError: - message = "Could not find endpoint '{}'.".format( - FakeEndpoint.arn_formatter(endpoint_name, self.region_name) + arn = FakeEndpoint.arn_formatter( + endpoint_name, self.account_id, self.region_name ) - raise ValidationError(message=message) + raise ValidationError(message=f"Could not find endpoint '{arn}'.") def delete_endpoint(self, endpoint_name): try: del self.endpoints[endpoint_name] except KeyError: - message = "Could not find endpoint '{}'.".format( - FakeEndpoint.arn_formatter(endpoint_name, self.region_name) + arn = FakeEndpoint.arn_formatter( + endpoint_name, self.account_id, self.region_name ) - raise ValidationError(message=message) + raise ValidationError(message=f"Could not find endpoint '{arn}'.") def create_processing_job( self, @@ -1673,6 +1693,7 @@ class SageMakerModelBackend(BaseBackend): processing_inputs=processing_inputs, processing_job_name=processing_job_name, processing_output_config=processing_output_config, + account_id=self.account_id, region_name=self.region_name, role_arn=role_arn, stopping_condition=stopping_condition, @@ -1685,10 +1706,10 @@ class SageMakerModelBackend(BaseBackend): try: return self.processing_jobs[processing_job_name].response_object except KeyError: - message = "Could not find processing job '{}'.".format( - FakeProcessingJob.arn_formatter(processing_job_name, self.region_name) + arn = FakeProcessingJob.arn_formatter( + processing_job_name, self.account_id, self.region_name ) - raise ValidationError(message=message) + raise ValidationError(message=f"Could not find processing job '{arn}'.") def list_processing_jobs( self, @@ -1797,6 +1818,7 @@ class SageMakerModelBackend(BaseBackend): experiment_config, ): training_job = FakeTrainingJob( + account_id=self.account_id, region_name=self.region_name, training_job_name=training_job_name, hyper_parameters=hyper_parameters, @@ -1929,9 +1951,10 @@ class SageMakerModelBackend(BaseBackend): # Validate inputs endpoint = self.endpoints.get(endpoint_name, None) if not endpoint: - raise AWSValidationException( - f'Could not find endpoint "{FakeEndpoint.arn_formatter(endpoint_name, self.region_name)}".' + arn = FakeEndpoint.arn_formatter( + endpoint_name, self.account_id, self.region_name ) + raise AWSValidationException(f'Could not find endpoint "{arn}".') names_checked = [] for variant_config in desired_weights_and_capacities: @@ -1973,9 +1996,11 @@ class SageMakerModelBackend(BaseBackend): class FakeExperiment(BaseObject): - def __init__(self, region_name, experiment_name, tags): + def __init__(self, account_id, region_name, experiment_name, tags): self.experiment_name = experiment_name - self.experiment_arn = FakeExperiment.arn_formatter(experiment_name, region_name) + self.experiment_arn = arn_formatter( + "experiment", experiment_name, account_id, region_name + ) self.tags = tags self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" @@ -1992,24 +2017,21 @@ class FakeExperiment(BaseObject): def response_create(self): return {"ExperimentArn": self.experiment_arn} - @staticmethod - def arn_formatter(experiment_arn, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":experiment/" - + experiment_arn - ) - class FakeTrial(BaseObject): def __init__( - self, region_name, trial_name, experiment_name, tags, trial_components + self, + account_id, + region_name, + trial_name, + experiment_name, + tags, + trial_components, ): self.trial_name = trial_name - self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name) + self.trial_arn = arn_formatter( + "experiment-trial", trial_name, account_id, region_name + ) self.tags = tags self.trial_components = trial_components self.experiment_name = experiment_name @@ -2028,23 +2050,12 @@ class FakeTrial(BaseObject): def response_create(self): return {"TrialArn": self.trial_arn} - @staticmethod - def arn_formatter(trial_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":experiment-trial/" - + trial_name - ) - class FakeTrialComponent(BaseObject): - def __init__(self, region_name, trial_component_name, trial_name, tags): + def __init__(self, account_id, region_name, trial_component_name, trial_name, tags): self.trial_component_name = trial_component_name self.trial_component_arn = FakeTrialComponent.arn_formatter( - trial_component_name, region_name + trial_component_name, account_id, region_name ) self.tags = tags self.trial_name = trial_name @@ -2063,14 +2074,9 @@ class FakeTrialComponent(BaseObject): return {"TrialComponentArn": self.trial_component_arn} @staticmethod - def arn_formatter(trial_component_name, region_name): - return ( - "arn:aws:sagemaker:" - + region_name - + ":" - + str(get_account_id()) - + ":experiment-trial-component/" - + trial_component_name + def arn_formatter(trial_component_name, account_id, region_name): + return arn_formatter( + "experiment-trial-component", trial_component_name, account_id, region_name ) diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index b4be4cd0e..195678502 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -12,9 +12,12 @@ def format_enum_error(value, attribute, allowed): class SageMakerResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="sagemaker") + @property def sagemaker_backend(self): - return sagemaker_backends[self.region] + return sagemaker_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/sdb/responses.py b/moto/sdb/responses.py index 9b8a5e044..1e03a7d61 100644 --- a/moto/sdb/responses.py +++ b/moto/sdb/responses.py @@ -3,9 +3,12 @@ from .models import sdb_backends class SimpleDBResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="sdb") + @property def sdb_backend(self): - return sdb_backends[self.region] + return sdb_backends[self.current_account][self.region] def create_domain(self): domain_name = self._get_param("DomainName") diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index 0e2e96622..91ff04760 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -53,6 +53,7 @@ class SecretsManager(BaseModel): class FakeSecret: def __init__( self, + account_id, region_name, secret_id, secret_string=None, @@ -67,7 +68,7 @@ class FakeSecret: ): self.secret_id = secret_id self.name = secret_id - self.arn = secret_arn(region_name, secret_id) + self.arn = secret_arn(account_id, region_name, secret_id) self.secret_string = secret_string self.secret_binary = secret_binary self.description = description @@ -391,6 +392,7 @@ class SecretsManagerBackend(BaseBackend): secret.versions[version_id] = secret_version else: secret = FakeSecret( + account_id=self.account_id, region_name=self.region_name, secret_id=secret_id, secret_string=secret_string, @@ -533,7 +535,7 @@ class SecretsManagerBackend(BaseBackend): if secret.rotation_lambda_arn: from moto.awslambda.models import lambda_backends - lambda_backend = lambda_backends[self.region_name] + lambda_backend = lambda_backends[self.account_id][self.region_name] request_headers = {} response_headers = {} @@ -673,7 +675,7 @@ class SecretsManagerBackend(BaseBackend): if not force_delete_without_recovery: raise SecretNotFoundException() else: - secret = FakeSecret(self.region_name, secret_id) + secret = FakeSecret(self.account_id, self.region_name, secret_id) arn = secret.arn name = secret.name deletion_date = datetime.datetime.utcnow() diff --git a/moto/secretsmanager/responses.py b/moto/secretsmanager/responses.py index 959f3339f..b662e019e 100644 --- a/moto/secretsmanager/responses.py +++ b/moto/secretsmanager/responses.py @@ -30,9 +30,12 @@ def _validate_filters(filters): class SecretsManagerResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="secretsmanager") + @property def backend(self): - return secretsmanager_backends[self.region] + return secretsmanager_backends[self.current_account][self.region] def get_secret_value(self): secret_id = self._get_param("SecretId") diff --git a/moto/secretsmanager/utils.py b/moto/secretsmanager/utils.py index b9a7671b3..e46b36717 100644 --- a/moto/secretsmanager/utils.py +++ b/moto/secretsmanager/utils.py @@ -2,8 +2,6 @@ import random import string import re -from moto.core import get_account_id - def random_password( password_length, @@ -63,10 +61,10 @@ def random_password( return password -def secret_arn(region, secret_id): +def secret_arn(account_id, region, secret_id): id_string = "".join(random.choice(string.ascii_letters) for _ in range(6)) - return "arn:aws:secretsmanager:{0}:{1}:secret:{2}-{3}".format( - region, get_account_id(), secret_id, id_string + return ( + f"arn:aws:secretsmanager:{region}:{account_id}:secret:{secret_id}-{id_string}" ) diff --git a/moto/servicediscovery/models.py b/moto/servicediscovery/models.py index d1bcd8655..2753d79e0 100644 --- a/moto/servicediscovery/models.py +++ b/moto/servicediscovery/models.py @@ -1,7 +1,7 @@ import random import string -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, unix_time from moto.utilities.tagging_service import TaggingService @@ -22,6 +22,7 @@ def random_id(size): class Namespace(BaseModel): def __init__( self, + account_id, region, name, ns_type, @@ -33,9 +34,7 @@ class Namespace(BaseModel): ): super().__init__() self.id = f"ns-{random_id(20)}" - self.arn = ( - f"arn:aws:servicediscovery:{region}:{get_account_id()}:namespace/{self.id}" - ) + self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:namespace/{self.id}" self.name = name self.type = ns_type self.creator_request_id = creator_request_id @@ -66,6 +65,7 @@ class Namespace(BaseModel): class Service(BaseModel): def __init__( self, + account_id, region, name, namespace_id, @@ -78,9 +78,7 @@ class Service(BaseModel): ): super().__init__() self.id = f"srv-{random_id(8)}" - self.arn = ( - f"arn:aws:servicediscovery:{region}:{get_account_id()}:service/{self.id}" - ) + self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:service/{self.id}" self.name = name self.namespace_id = namespace_id self.description = description @@ -164,6 +162,7 @@ class ServiceDiscoveryBackend(BaseBackend): def create_http_namespace(self, name, creator_request_id, description, tags): namespace = Namespace( + account_id=self.account_id, region=self.region_name, name=name, ns_type="HTTP", @@ -235,6 +234,7 @@ class ServiceDiscoveryBackend(BaseBackend): dns_properties = (properties or {}).get("DnsProperties", {}) dns_properties["HostedZoneId"] = "hzi" namespace = Namespace( + account_id=self.account_id, region=self.region_name, name=name, ns_type="DNS_PRIVATE", @@ -258,6 +258,7 @@ class ServiceDiscoveryBackend(BaseBackend): dns_properties = (properties or {}).get("DnsProperties", {}) dns_properties["HostedZoneId"] = "hzi" namespace = Namespace( + account_id=self.account_id, region=self.region_name, name=name, ns_type="DNS_PUBLIC", @@ -287,6 +288,7 @@ class ServiceDiscoveryBackend(BaseBackend): service_type, ): service = Service( + account_id=self.account_id, region=self.region_name, name=name, namespace_id=namespace_id, diff --git a/moto/servicediscovery/responses.py b/moto/servicediscovery/responses.py index 1633d401d..55d43fe7a 100644 --- a/moto/servicediscovery/responses.py +++ b/moto/servicediscovery/responses.py @@ -6,10 +6,13 @@ from .models import servicediscovery_backends class ServiceDiscoveryResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="servicediscovery") + @property def servicediscovery_backend(self): """Return backend instance specific for this region.""" - return servicediscovery_backends[self.region] + return servicediscovery_backends[self.current_account][self.region] def list_namespaces(self): namespaces = self.servicediscovery_backend.list_namespaces() diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py index b643d51ab..2d55e1f33 100644 --- a/moto/ses/feedback.py +++ b/moto/ses/feedback.py @@ -1,5 +1,3 @@ -from moto.core import get_account_id - """ SES Feedback messages Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html @@ -12,7 +10,7 @@ COMMON_MAIL = { "source": "sender@example.com", "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com", "sourceIp": "127.0.3.0", - "sendingAccountId": get_account_id(), + "sendingAccountId": None, "destination": ["recipient@example.com"], "headersTruncated": False, "headers": [ diff --git a/moto/ses/models.py b/moto/ses/models.py index ceb66a1c5..bd3375eff 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -50,8 +50,9 @@ class SESFeedback(BaseModel): FORWARDING_ENABLED = "feedback_forwarding_enabled" @staticmethod - def generate_message(msg_type): + def generate_message(account_id, msg_type): msg = dict(COMMON_MAIL) + msg["mail"]["sendingAccountId"] = account_id if msg_type == SESFeedback.BOUNCE: msg["bounce"] = BOUNCE elif msg_type == SESFeedback.COMPLAINT: @@ -277,7 +278,7 @@ class SESBackend(BaseBackend): def __generate_feedback__(self, msg_type): """Generates the SNS message for the feedback""" - return SESFeedback.generate_message(msg_type) + return SESFeedback.generate_message(self.account_id, msg_type) def __process_sns_feedback__(self, source, destinations, region): domain = str(source) @@ -290,7 +291,9 @@ class SESBackend(BaseBackend): if sns_topic is not None: message = self.__generate_feedback__(msg_type) if message: - sns_backends[region].publish(message, arn=sns_topic) + sns_backends[self.account_id][region].publish( + message, arn=sns_topic + ) def send_raw_email(self, source, destinations, raw_data, region): if source is not None: diff --git a/moto/ses/responses.py b/moto/ses/responses.py index 245d6c26a..95116446c 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -6,9 +6,12 @@ from datetime import datetime class EmailResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="ses") + @property def backend(self): - return ses_backends["global"] + return ses_backends[self.current_account]["global"] def verify_email_identity(self): address = self.querystring.get("EmailAddress")[0] diff --git a/moto/sns/models.py b/moto/sns/models.py index f2f417a29..e024afce8 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -30,7 +30,6 @@ from .exceptions import ( ) from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164 -from moto.core import get_account_id DEFAULT_PAGE_SIZE = 100 MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB @@ -41,7 +40,7 @@ class Topic(CloudFormationModel): def __init__(self, name, sns_backend): self.name = name self.sns_backend = sns_backend - self.account_id = get_account_id() + self.account_id = sns_backend.account_id self.display_name = "" self.delivery_policy = "" self.kms_master_key_id = "" @@ -110,9 +109,9 @@ class Topic(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): - sns_backend = sns_backends[region_name] + sns_backend = sns_backends[account_id][region_name] properties = cloudformation_json["Properties"] topic = sns_backend.create_topic(resource_name) @@ -124,26 +123,29 @@ class Topic(CloudFormationModel): @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + original_resource.name, cloudformation_json, account_id, region_name ) return cls.create_from_cloudformation_json( - new_resource_name, cloudformation_json, region_name + new_resource_name, cloudformation_json, account_id, region_name ) @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): - sns_backend = sns_backends[region_name] + sns_backend = sns_backends[account_id][region_name] properties = cloudformation_json["Properties"] topic_name = properties.get(cls.cloudformation_name_type()) or resource_name - topic_arn = make_arn_for_topic( - get_account_id(), topic_name, sns_backend.region_name - ) + topic_arn = make_arn_for_topic(account_id, topic_name, sns_backend.region_name) subscriptions, _ = sns_backend.list_subscriptions(topic_arn) for subscription in subscriptions: sns_backend.unsubscribe(subscription.arn) @@ -177,7 +179,8 @@ class Topic(CloudFormationModel): class Subscription(BaseModel): - def __init__(self, topic, endpoint, protocol): + def __init__(self, account_id, topic, endpoint, protocol): + self.account_id = account_id self.topic = topic self.endpoint = endpoint self.protocol = protocol @@ -196,7 +199,7 @@ class Subscription(BaseModel): queue_name = self.endpoint.split(":")[-1] region = self.endpoint.split(":")[3] if self.attributes.get("RawMessageDelivery") != "true": - sqs_backends[region].send_message( + sqs_backends[self.account_id][region].send_message( queue_name, json.dumps( self.get_post_data( @@ -226,7 +229,7 @@ class Subscription(BaseModel): attr_type: type_value, } - sqs_backends[region].send_message( + sqs_backends[self.account_id][region].send_message( queue_name, message, message_attributes=raw_message_attributes, @@ -257,7 +260,7 @@ class Subscription(BaseModel): from moto.awslambda import lambda_backends - lambda_backends[region].send_sns_message( + lambda_backends[self.account_id][region].send_sns_message( function_name, message, subject=subject, qualifier=qualifier ) @@ -333,9 +336,7 @@ class Subscription(BaseModel): "SignatureVersion": "1", "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=", "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem", - "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:{}:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55".format( - get_account_id() - ), + "UnsubscribeURL": f"https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:{self.account_id}:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55", } if subject: post_data["Subject"] = subject @@ -345,30 +346,25 @@ class Subscription(BaseModel): class PlatformApplication(BaseModel): - def __init__(self, region, name, platform, attributes): + def __init__(self, account_id, region, name, platform, attributes): self.region = region self.name = name self.platform = platform self.attributes = attributes - - @property - def arn(self): - return "arn:aws:sns:{region}:{AccountId}:app/{platform}/{name}".format( - region=self.region, - platform=self.platform, - name=self.name, - AccountId=get_account_id(), - ) + self.arn = f"arn:aws:sns:{region}:{account_id}:app/{platform}/{name}" class PlatformEndpoint(BaseModel): - def __init__(self, region, application, custom_user_data, token, attributes): + def __init__( + self, account_id, region, application, custom_user_data, token, attributes + ): self.region = region self.application = application self.custom_user_data = custom_user_data self.token = token self.attributes = attributes self.id = uuid.uuid4() + self.arn = f"arn:aws:sns:{region}:{account_id}:endpoint/{self.application.platform}/{self.application.name}/{self.id}" self.messages = OrderedDict() self.__fixup_attributes() @@ -387,18 +383,6 @@ class PlatformEndpoint(BaseModel): def enabled(self): return json.loads(self.attributes.get("Enabled", "true").lower()) - @property - def arn(self): - return ( - "arn:aws:sns:{region}:{AccountId}:endpoint/{platform}/{name}/{id}".format( - region=self.region, - AccountId=get_account_id(), - platform=self.application.platform, - name=self.application.name, - id=self.id, - ) - ) - def publish(self, message): if not self.enabled: raise SnsEndpointDisabled("Endpoint %s disabled" % self.id) @@ -545,7 +529,7 @@ class SNSBackend(BaseBackend): if old_subscription: return old_subscription topic = self.get_topic(topic_arn) - subscription = Subscription(topic, endpoint, protocol) + subscription = Subscription(self.account_id, topic, endpoint, protocol) attributes = { "PendingConfirmation": "false", "ConfirmationWasAuthenticated": "true", @@ -553,7 +537,7 @@ class SNSBackend(BaseBackend): "TopicArn": topic_arn, "Protocol": protocol, "SubscriptionArn": subscription.arn, - "Owner": get_account_id(), + "Owner": self.account_id, "RawMessageDelivery": "false", } @@ -641,8 +625,10 @@ class SNSBackend(BaseBackend): message_id = endpoint.publish(message) return message_id - def create_platform_application(self, region, name, platform, attributes): - application = PlatformApplication(region, name, platform, attributes) + def create_platform_application(self, name, platform, attributes): + application = PlatformApplication( + self.account_id, self.region_name, name, platform, attributes + ) self.applications[application.arn] = application return application @@ -667,7 +653,7 @@ class SNSBackend(BaseBackend): self.platform_endpoints.pop(endpoint.arn) def create_platform_endpoint( - self, region, application, custom_user_data, token, attributes + self, application, custom_user_data, token, attributes ): for endpoint in self.platform_endpoints.values(): if token == endpoint.token: @@ -680,7 +666,12 @@ class SNSBackend(BaseBackend): "Duplicate endpoint token with different attributes: %s" % token ) platform_endpoint = PlatformEndpoint( - region, application, custom_user_data, token, attributes + self.account_id, + self.region_name, + application, + custom_user_data, + token, + attributes, ) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 4abc8713a..d138a4ff6 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -15,9 +15,12 @@ class SNSResponse(BaseResponse): ) OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$") + def __init__(self): + super().__init__(service_name="sns") + @property def backend(self): - return sns_backends[self.region] + return sns_backends[self.current_account][self.region] def _error(self, code, message, sender="Sender"): template = self.response_template(ERROR_RESPONSE) @@ -410,7 +413,7 @@ class SNSResponse(BaseResponse): platform = self._get_param("Platform") attributes = self._get_attributes() platform_application = self.backend.create_platform_application( - self.region, name, platform, attributes + name, platform, attributes ) if self.request_json: @@ -525,7 +528,7 @@ class SNSResponse(BaseResponse): attributes = self._get_attributes() platform_endpoint = self.backend.create_platform_endpoint( - self.region, application, custom_user_data, token, attributes + application, custom_user_data, token, attributes ) if self.request_json: diff --git a/moto/sqs/models.py b/moto/sqs/models.py index d92a0400f..32abbb1df 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -38,8 +38,6 @@ from .exceptions import ( InvalidAttributeValue, ) -from moto.core import get_account_id - DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB @@ -250,9 +248,10 @@ class Queue(CloudFormationModel): "SendMessage", ) - def __init__(self, name, region, **kwargs): + def __init__(self, name, region, account_id, **kwargs): self.name = name self.region = region + self.account_id = account_id self.tags = {} self.permissions = {} @@ -262,9 +261,7 @@ class Queue(CloudFormationModel): now = unix_time() self.created_timestamp = now - self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format( - self.region, get_account_id(), self.name - ) + self.queue_arn = f"arn:aws:sqs:{region}:{account_id}:{name}" self.dead_letter_queue = None self.lambda_event_source_mappings = {} @@ -389,7 +386,8 @@ class Queue(CloudFormationModel): self.redrive_policy["maxReceiveCount"] ) - for queue in sqs_backends[self.region].queues.values(): + sqs_backend = sqs_backends[self.account_id][self.region] + for queue in sqs_backend.queues.values(): if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]: self.dead_letter_queue = queue @@ -418,7 +416,7 @@ class Queue(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = deepcopy(cloudformation_json["Properties"]) # remove Tags from properties and convert tags list to dict @@ -428,19 +426,24 @@ class Queue(CloudFormationModel): # Could be passed as an integer - just treat it as a string resource_name = str(resource_name) - sqs_backend = sqs_backends[region_name] + sqs_backend = sqs_backends[account_id][region_name] return sqs_backend.create_queue( name=resource_name, tags=tags_dict, region=region_name, **properties ) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json["Properties"] queue_name = original_resource.name - sqs_backend = sqs_backends[region_name] + sqs_backend = sqs_backends[account_id][region_name] queue = sqs_backend.get_queue(queue_name) if "VisibilityTimeout" in properties: queue.visibility_timeout = int(properties["VisibilityTimeout"]) @@ -453,12 +456,12 @@ class Queue(CloudFormationModel): @classmethod def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name + cls, resource_name, cloudformation_json, account_id, region_name ): # ResourceName will be the full queue URL - we only need the name # https://sqs.us-west-1.amazonaws.com/123456789012/queue_name queue_name = resource_name.split("/")[-1] - sqs_backend = sqs_backends[region_name] + sqs_backend = sqs_backends[account_id][region_name] sqs_backend.delete_queue(queue_name) @property @@ -475,7 +478,7 @@ class Queue(CloudFormationModel): @property def physical_resource_id(self): - return f"https://sqs.{self.region}.amazonaws.com/{get_account_id()}/{self.name}" + return f"https://sqs.{self.region}.amazonaws.com/{self.account_id}/{self.name}" @property def attributes(self): @@ -509,7 +512,7 @@ class Queue(CloudFormationModel): def url(self, request_url): return "{0}://{1}/{2}/{3}".format( - request_url.scheme, request_url.netloc, get_account_id(), self.name + request_url.scheme, request_url.netloc, self.account_id, self.name ) @property @@ -537,7 +540,7 @@ class Queue(CloudFormationModel): self._messages.append(message) for arn, esm in self.lambda_event_source_mappings.items(): - backend = sqs_backends[self.region] + backend = sqs_backends[self.account_id][self.region] """ Lambda polls the queue and invokes your function synchronously with an event @@ -554,7 +557,7 @@ class Queue(CloudFormationModel): from moto.awslambda import lambda_backends - result = lambda_backends[self.region].send_sqs_batch( + result = lambda_backends[self.account_id][self.region].send_sqs_batch( arn, messages, self.queue_arn ) @@ -650,7 +653,9 @@ class SQSBackend(BaseBackend): except KeyError: pass - new_queue = Queue(name, region=self.region_name, **kwargs) + new_queue = Queue( + name, region=self.region_name, account_id=self.account_id, **kwargs + ) queue_attributes = queue.attributes new_queue_attributes = new_queue.attributes @@ -665,7 +670,9 @@ class SQSBackend(BaseBackend): kwargs.pop("region") except KeyError: pass - queue = Queue(name, region=self.region_name, **kwargs) + queue = Queue( + name, region=self.region_name, account_id=self.account_id, **kwargs + ) self.queues[name] = queue if tags: diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index c6dea2601..28fd57e85 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -28,9 +28,12 @@ class SQSResponse(BaseResponse): region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") + def __init__(self): + super().__init__(service_name="sqs") + @property def sqs_backend(self): - return sqs_backends[self.region] + return sqs_backends[self.current_account][self.region] @property def attribute(self): diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 19c2aabb0..77846f1cc 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -4,7 +4,7 @@ from typing import Dict from collections import defaultdict -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.core.utils import BackendDict from moto.ec2 import ec2_backends @@ -46,11 +46,12 @@ from .exceptions import ( class ParameterDict(defaultdict): - def __init__(self, region_name): + def __init__(self, account_id, region_name): # each value is a list of all of the versions for a parameter # to get the current value, grab the last item of the list super().__init__(list) self.parameters_loaded = False + self.account_id = account_id self.region_name = region_name def _check_loading_status(self, key): @@ -73,6 +74,7 @@ class ParameterDict(defaultdict): version = 1 super().__getitem__(name).append( Parameter( + account_id=self.account_id, name=name, value=value, parameter_type=parameter_type, @@ -87,14 +89,15 @@ class ParameterDict(defaultdict): self.parameters_loaded = True def _get_secretsmanager_parameter(self, secret_name): - secret = secretsmanager_backends[self.region_name].describe_secret(secret_name) + secrets_backend = secretsmanager_backends[self.account_id][self.region_name] + secret = secrets_backend.describe_secret(secret_name) version_id_to_stage = secret["VersionIdsToStages"] # Sort version ID's so that AWSCURRENT is last sorted_version_ids = [ k for k in version_id_to_stage if "AWSCURRENT" not in version_id_to_stage[k] ] + [k for k in version_id_to_stage if "AWSCURRENT" in version_id_to_stage[k]] values = [ - secretsmanager_backends[self.region_name].get_secret_value( + secrets_backend.get_secret_value( secret_name, version_id=version_id, version_stage=None, @@ -103,6 +106,7 @@ class ParameterDict(defaultdict): ] return [ Parameter( + account_id=self.account_id, name=secret["Name"], value=val.get("SecretString"), parameter_type="SecureString", @@ -153,6 +157,7 @@ PARAMETER_HISTORY_MAX_RESULTS = 50 class Parameter(BaseModel): def __init__( self, + account_id, name, value, parameter_type, @@ -166,6 +171,7 @@ class Parameter(BaseModel): labels=None, source_result=None, ): + self.account_id = account_id self.name = name self.type = parameter_type self.description = description @@ -210,7 +216,7 @@ class Parameter(BaseModel): r["SourceResult"] = self.source_result if region: - r["ARN"] = parameter_arn(region, self.name) + r["ARN"] = parameter_arn(self.account_id, region, self.name) return r @@ -426,6 +432,7 @@ class Documents(BaseModel): class Document(BaseModel): def __init__( self, + account_id, name, version_name, content, @@ -447,7 +454,7 @@ class Document(BaseModel): self.status = "Active" self.document_version = document_version - self.owner = get_account_id() + self.owner = account_id self.created_date = datetime.datetime.utcnow() if document_format == "JSON": @@ -523,6 +530,7 @@ class Document(BaseModel): class Command(BaseModel): def __init__( self, + account_id, comment="", document_name="", timeout_seconds=MAX_TIMEOUT_SECONDS, @@ -554,6 +562,7 @@ class Command(BaseModel): self.command_id = str(uuid.uuid4()) self.status = "Success" self.status_details = "Details placeholder" + self.account_id = account_id self.requested_date_time = datetime.datetime.now() self.requested_date_time_iso = self.requested_date_time.isoformat() @@ -598,7 +607,7 @@ class Command(BaseModel): def _get_instance_ids_from_targets(self): target_instance_ids = [] - ec2_backend = ec2_backends[self.backend_region] + ec2_backend = ec2_backends[self.account_id][self.backend_region] ec2_filters = {target["Key"]: target["Values"] for target in self.targets} reservations = ec2_backend.all_reservations(filters=ec2_filters) for reservation in reservations: @@ -735,7 +744,7 @@ def _document_filter_list_includes_comparator(keyed_value_list, _filter): return False -def _document_filter_match(filters, ssm_doc): +def _document_filter_match(account_id, filters, ssm_doc): for _filter in filters: if _filter["Key"] == "Name" and not _document_filter_equal_comparator( ssm_doc.name, _filter @@ -747,7 +756,7 @@ def _document_filter_match(filters, ssm_doc): raise ValidationException("Owner filter can only have one value.") if _filter["Values"][0] == "Self": # Update to running account ID - _filter["Values"][0] = get_account_id() + _filter["Values"][0] = account_id if not _document_filter_equal_comparator(ssm_doc.owner, _filter): return False @@ -840,7 +849,7 @@ class SimpleSystemManagerBackend(BaseBackend): def __init__(self, region_name, account_id): super().__init__(region_name, account_id) - self._parameters = ParameterDict(region_name) + self._parameters = ParameterDict(account_id, region_name) self._resource_tags = defaultdict(lambda: defaultdict(dict)) self._commands = [] @@ -918,6 +927,7 @@ class SimpleSystemManagerBackend(BaseBackend): tags, ): ssm_document = Document( + account_id=self.account_id, name=name, version_name=version_name, content=content, @@ -1065,6 +1075,7 @@ class SimpleSystemManagerBackend(BaseBackend): new_version = str(int(documents.latest_version) + 1) new_ssm_document = Document( + account_id=self.account_id, name=name, version_name=version_name, content=content, @@ -1115,7 +1126,9 @@ class SimpleSystemManagerBackend(BaseBackend): continue ssm_doc = documents.get_default_version() - if filters and not _document_filter_match(filters, ssm_doc): + if filters and not _document_filter_match( + self.account_id, filters, ssm_doc + ): # If we have filters enabled, and we don't match them, continue else: @@ -1761,6 +1774,7 @@ class SimpleSystemManagerBackend(BaseBackend): last_modified_date = time.time() self._parameters[name].append( Parameter( + account_id=self.account_id, name=name, value=value, parameter_type=parameter_type, @@ -1821,6 +1835,7 @@ class SimpleSystemManagerBackend(BaseBackend): def send_command(self, **kwargs): command = Command( + account_id=self.account_id, comment=kwargs.get("Comment", ""), document_name=kwargs.get("DocumentName"), timeout_seconds=kwargs.get("TimeoutSeconds", 3600), diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index 6b7cf7025..956cb4f49 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -6,9 +6,12 @@ from .models import ssm_backends class SimpleSystemManagerResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="ssm") + @property def ssm_backend(self): - return ssm_backends[self.region] + return ssm_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/ssm/utils.py b/moto/ssm/utils.py index 6641fef8d..121ce019c 100644 --- a/moto/ssm/utils.py +++ b/moto/ssm/utils.py @@ -1,12 +1,7 @@ -from moto.core import get_account_id - - -def parameter_arn(region, parameter_name): +def parameter_arn(account_id, region, parameter_name): if parameter_name[0] == "/": parameter_name = parameter_name[1:] - return "arn:aws:ssm:{0}:{1}:parameter/{2}".format( - region, get_account_id(), parameter_name - ) + return f"arn:aws:ssm:{region}:{account_id}:parameter/{parameter_name}" def convert_to_tree(parameters): diff --git a/moto/ssoadmin/responses.py b/moto/ssoadmin/responses.py index 8988b9a7f..e5a70a9be 100644 --- a/moto/ssoadmin/responses.py +++ b/moto/ssoadmin/responses.py @@ -9,10 +9,13 @@ from .models import ssoadmin_backends class SSOAdminResponse(BaseResponse): """Handler for SSOAdmin requests and responses.""" + def __init__(self): + super().__init__(service_name="sso-admin") + @property def ssoadmin_backend(self): """Return backend instance specific for this region.""" - return ssoadmin_backends[self.region] + return ssoadmin_backends[self.current_account][self.region] def create_account_assignment(self): params = json.loads(self.body) diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index e4180d892..aebf5c18a 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -3,7 +3,7 @@ import re from datetime import datetime from dateutil.tz import tzlocal -from moto.core import get_account_id, BaseBackend, CloudFormationModel +from moto.core import BaseBackend, CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds, BackendDict from uuid import uuid4 from .exceptions import ( @@ -159,24 +159,29 @@ class StateMachine(CloudFormationModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name, **kwargs + cls, resource_name, cloudformation_json, account_id, region_name, **kwargs ): properties = cloudformation_json["Properties"] name = properties.get("StateMachineName", resource_name) definition = properties.get("DefinitionString", "") role_arn = properties.get("RoleArn", "") tags = cfn_to_api_tags(properties.get("Tags", [])) - sf_backend = stepfunction_backends[region_name] + sf_backend = stepfunction_backends[account_id][region_name] return sf_backend.create_state_machine(name, definition, role_arn, tags=tags) @classmethod - def delete_from_cloudformation_json(cls, resource_name, _, region_name): - sf_backend = stepfunction_backends[region_name] + def delete_from_cloudformation_json(cls, resource_name, _, account_id, region_name): + sf_backend = stepfunction_backends[account_id][region_name] sf_backend.delete_state_machine(resource_name) @classmethod def update_from_cloudformation_json( - cls, original_resource, new_resource_name, cloudformation_json, region_name + cls, + original_resource, + new_resource_name, + cloudformation_json, + account_id, + region_name, ): properties = cloudformation_json.get("Properties", {}) name = properties.get("StateMachineName", original_resource.name) @@ -186,10 +191,10 @@ class StateMachine(CloudFormationModel): new_properties = original_resource.get_cfn_properties(properties) cloudformation_json["Properties"] = new_properties new_resource = cls.create_from_cloudformation_json( - name, cloudformation_json, region_name + name, cloudformation_json, account_id, region_name ) cls.delete_from_cloudformation_json( - original_resource.arn, cloudformation_json, region_name + original_resource.arn, cloudformation_json, account_id, region_name ) return new_resource @@ -198,7 +203,7 @@ class StateMachine(CloudFormationModel): definition = properties.get("DefinitionString") role_arn = properties.get("RoleArn") tags = cfn_to_api_tags(properties.get("Tags", [])) - sf_backend = stepfunction_backends[region_name] + sf_backend = stepfunction_backends[account_id][region_name] state_machine = sf_backend.update_state_machine( original_resource.arn, definition=definition, role_arn=role_arn ) @@ -451,14 +456,7 @@ class StepFunctionBackend(BaseBackend): def create_state_machine(self, name, definition, roleArn, tags=None): self._validate_name(name) self._validate_role_arn(roleArn) - arn = ( - "arn:aws:states:" - + self.region_name - + ":" - + str(self._get_account_id()) - + ":stateMachine:" - + name - ) + arn = f"arn:aws:states:{self.region_name}:{self.account_id}:stateMachine:{name}" try: return self.describe_state_machine(arn) except StateMachineDoesNotExist: @@ -499,7 +497,7 @@ class StepFunctionBackend(BaseBackend): state_machine = self.describe_state_machine(state_machine_arn) execution = state_machine.start_execution( region_name=self.region_name, - account_id=self._get_account_id(), + account_id=self.account_id, execution_name=name or str(uuid4()), execution_input=execution_input, ) @@ -613,8 +611,5 @@ class StepFunctionBackend(BaseBackend): ) return self.describe_state_machine(state_machine_arn) - def _get_account_id(self): - return get_account_id() - stepfunction_backends = BackendDict(StepFunctionBackend, "stepfunctions") diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index e7bc2c571..540a38d06 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -6,9 +6,12 @@ from .models import stepfunction_backends class StepFunctionResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="stepfunctions") + @property def stepfunction_backend(self): - return stepfunction_backends[self.region] + return stepfunction_backends[self.current_account][self.region] @amzn_request_id def create_state_machine(self): diff --git a/moto/sts/models.py b/moto/sts/models.py index 10109897c..36409b475 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -1,12 +1,11 @@ from base64 import b64decode import datetime +import re import xmltodict from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds, BackendDict -from moto.core import get_account_id +from moto.iam import iam_backends from moto.sts.utils import ( - random_access_key_id, - random_secret_access_key, random_session_token, random_assumed_role_id, DEFAULT_STS_SESSION_DURATION, @@ -27,15 +26,26 @@ class Token(BaseModel): class AssumedRole(BaseModel): - def __init__(self, role_session_name, role_arn, policy, duration, external_id): + def __init__( + self, + account_id, + access_key, + role_session_name, + role_arn, + policy, + duration, + external_id, + ): + self.account_id = account_id self.session_name = role_session_name self.role_arn = role_arn self.policy = policy now = datetime.datetime.utcnow() self.expiration = now + datetime.timedelta(seconds=duration) self.external_id = external_id - self.access_key_id = "ASIA" + random_access_key_id() - self.secret_access_key = random_secret_access_key() + self.access_key = access_key + self.access_key_id = access_key.access_key_id + self.secret_access_key = access_key.secret_access_key self.session_token = random_session_token() self.assumed_role_id = "AROA" + random_assumed_role_id() @@ -51,7 +61,7 @@ class AssumedRole(BaseModel): def arn(self): return ( "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=get_account_id(), + account_id=self.account_id, role_name=self.role_arn.split("/")[-1], session_name=self.session_name, ) @@ -78,8 +88,20 @@ class STSBackend(BaseBackend): token = Token(duration=duration, name=name) return token - def assume_role(self, **kwargs): - role = AssumedRole(**kwargs) + def assume_role(self, role_session_name, role_arn, policy, duration, external_id): + """ + Assume an IAM Role. Note that the role does not need to exist. The ARN can point to another account, providing an opportunity to switch accounts. + """ + account_id, access_key = self._create_access_key(role=role_arn) + role = AssumedRole( + account_id, + access_key, + role_session_name, + role_arn, + policy, + duration, + external_id, + ) self.assumed_roles.append(role) return role @@ -109,6 +131,7 @@ class STSBackend(BaseBackend): namespace_separator="|", ) + target_role = None saml_assertion_attributes = saml_assertion["samlp|Response"]["saml|Assertion"][ "saml|AttributeStatement" ]["saml|Attribute"] @@ -123,20 +146,45 @@ class STSBackend(BaseBackend): == "https://aws.amazon.com/SAML/Attributes/SessionDuration" ): kwargs["duration"] = int(attribute["saml|AttributeValue"]["#text"]) + if attribute["@Name"] == "https://aws.amazon.com/SAML/Attributes/Role": + target_role = attribute["saml|AttributeValue"]["#text"].split(",")[0] if "duration" not in kwargs: kwargs["duration"] = DEFAULT_STS_SESSION_DURATION + account_id, access_key = self._create_access_key(role=target_role) + kwargs["account_id"] = account_id + kwargs["access_key"] = access_key + kwargs["external_id"] = None kwargs["policy"] = None role = AssumedRole(**kwargs) self.assumed_roles.append(role) return role - def get_caller_identity(self): - # Logic resides in responses.py - # Fake method here to make implementation coverage script aware that this method is implemented - pass + def get_caller_identity(self, access_key_id): + assumed_role = self.get_assumed_role_from_access_key(access_key_id) + if assumed_role: + return assumed_role.user_id, assumed_role.arn, assumed_role.account_id + + iam_backend = iam_backends[self.account_id]["global"] + user = iam_backend.get_user_from_access_key_id(access_key_id) + if user: + return user.id, user.arn, user.account_id + + # Default values in case the request does not use valid credentials generated by moto + user_id = "AKIAIOSFODNN7EXAMPLE" + arn = f"arn:aws:sts::{self.account_id}:user/moto" + return user_id, arn, self.account_id + + def _create_access_key(self, role): + account_id_match = re.search(r"arn:aws:iam::([0-9]+).+", role) + if account_id_match: + account_id = account_id_match.group(1) + else: + account_id = self.account_id + iam_backend = iam_backends[account_id]["global"] + return account_id, iam_backend.create_temp_access_key() sts_backends: Mapping[str, STSBackend] = BackendDict( diff --git a/moto/sts/responses.py b/moto/sts/responses.py index 0622ac82c..f16043941 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -1,6 +1,4 @@ from moto.core.responses import BaseResponse -from moto.core import get_account_id -from moto.iam import iam_backends from .exceptions import STSValidationError from .models import sts_backends @@ -8,9 +6,12 @@ MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 class TokenResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="sts") + @property def backend(self): - return sts_backends["global"] + return sts_backends[self.current_account]["global"] def get_session_token(self): duration = int(self.querystring.get("DurationSeconds", [43200])[0]) @@ -33,7 +34,7 @@ class TokenResponse(BaseResponse): name = self.querystring.get("Name")[0] token = self.backend.get_federation_token(duration=duration, name=name) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) - return template.render(token=token, account_id=get_account_id()) + return template.render(token=token, account_id=self.current_account) def assume_role(self): role_session_name = self.querystring.get("RoleSessionName")[0] @@ -87,22 +88,10 @@ class TokenResponse(BaseResponse): def get_caller_identity(self): template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) - # Default values in case the request does not use valid credentials generated by moto - user_id = "AKIAIOSFODNN7EXAMPLE" - arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=get_account_id()) + access_key_id = self.get_access_key() + user_id, arn, account_id = self.backend.get_caller_identity(access_key_id) - access_key_id = self.get_current_user() - assumed_role = self.backend.get_assumed_role_from_access_key(access_key_id) - if assumed_role: - user_id = assumed_role.user_id - arn = assumed_role.arn - - user = iam_backends["global"].get_user_from_access_key_id(access_key_id) - if user: - user_id = user.id - arn = user.arn - - return template.render(account_id=get_account_id(), user_id=user_id, arn=arn) + return template.render(account_id=account_id, user_id=user_id, arn=arn) GET_SESSION_TOKEN_RESPONSE = """ diff --git a/moto/sts/utils.py b/moto/sts/utils.py index 668c5d0e1..fd3f2ac87 100644 --- a/moto/sts/utils.py +++ b/moto/sts/utils.py @@ -9,14 +9,6 @@ SESSION_TOKEN_PREFIX = "FQoGZXIvYXdzEBYaD" DEFAULT_STS_SESSION_DURATION = 3600 -def random_access_key_id(): - return ACCOUNT_SPECIFIC_ACCESS_KEY_PREFIX + _random_uppercase_or_digit_sequence(8) - - -def random_secret_access_key(): - return base64.b64encode(os.urandom(30)).decode() - - def random_session_token(): return ( SESSION_TOKEN_PREFIX diff --git a/moto/support/responses.py b/moto/support/responses.py index 3246b9164..a400dc36c 100644 --- a/moto/support/responses.py +++ b/moto/support/responses.py @@ -4,11 +4,12 @@ import json class SupportResponse(BaseResponse): - SERVICE_NAME = "support" + def __init__(self): + super().__init__(service_name="support") @property def support_backend(self): - return support_backends[self.region] + return support_backends[self.current_account][self.region] def describe_trusted_advisor_checks(self): checks = self.support_backend.describe_trusted_advisor_checks() diff --git a/moto/swf/models/__init__.py b/moto/swf/models/__init__.py index 2d7a98e50..a70a5403f 100644 --- a/moto/swf/models/__init__.py +++ b/moto/swf/models/__init__.py @@ -105,8 +105,9 @@ class SWFBackend(BaseBackend): domain = Domain( name, workflow_execution_retention_period_in_days, - self.region_name, - description, + account_id=self.account_id, + region_name=self.region_name, + description=description, ) self.domains.append(domain) diff --git a/moto/swf/models/domain.py b/moto/swf/models/domain.py index 5230b5de7..262dda1db 100644 --- a/moto/swf/models/domain.py +++ b/moto/swf/models/domain.py @@ -1,6 +1,6 @@ from collections import defaultdict -from moto.core import get_account_id, BaseModel +from moto.core import BaseModel from ..exceptions import ( SWFUnknownResourceFault, SWFWorkflowExecutionAlreadyStartedFault, @@ -8,9 +8,10 @@ from ..exceptions import ( class Domain(BaseModel): - def __init__(self, name, retention, region_name, description=None): + def __init__(self, name, retention, account_id, region_name, description=None): self.name = name self.retention = retention + self.account_id = account_id self.region_name = region_name self.description = description self.status = "REGISTERED" @@ -31,9 +32,9 @@ class Domain(BaseModel): hsh = {"name": self.name, "status": self.status} if self.description: hsh["description"] = self.description - hsh["arn"] = "arn:aws:swf:{0}:{1}:/domain/{2}".format( - self.region_name, get_account_id(), self.name - ) + hsh[ + "arn" + ] = f"arn:aws:swf:{self.region_name}:{self.account_id}:/domain/{self.name}" return hsh def to_full_dict(self): diff --git a/moto/swf/responses.py b/moto/swf/responses.py index 9de26facd..ee546f9ec 100644 --- a/moto/swf/responses.py +++ b/moto/swf/responses.py @@ -7,9 +7,12 @@ from .models import swf_backends class SWFResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="swf") + @property def swf_backend(self): - return swf_backends[self.region] + return swf_backends[self.current_account][self.region] # SWF parameters are passed through a JSON body, so let's ease retrieval @property diff --git a/moto/textract/responses.py b/moto/textract/responses.py index 332cb24c5..b10949531 100644 --- a/moto/textract/responses.py +++ b/moto/textract/responses.py @@ -8,10 +8,13 @@ from .models import textract_backends class TextractResponse(BaseResponse): """Handler for Textract requests and responses.""" + def __init__(self): + super().__init__(service_name="textract") + @property def textract_backend(self): """Return backend instance specific for this region.""" - return textract_backends[self.region] + return textract_backends[self.current_account][self.region] def get_document_text_detection(self): params = json.loads(self.body) diff --git a/moto/timestreamwrite/models.py b/moto/timestreamwrite/models.py index f2fe02985..c35c454ae 100644 --- a/moto/timestreamwrite/models.py +++ b/moto/timestreamwrite/models.py @@ -1,4 +1,4 @@ -from moto.core import get_account_id, BaseBackend, BaseModel +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.utilities.tagging_service import TaggingService from .exceptions import ResourceNotFound @@ -7,6 +7,7 @@ from .exceptions import ResourceNotFound class TimestreamTable(BaseModel): def __init__( self, + account_id, region_name, table_name, db_name, @@ -22,6 +23,7 @@ class TimestreamTable(BaseModel): } self.magnetic_store_write_properties = magnetic_store_write_properties or {} self.records = [] + self.arn = f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.db_name}/table/{self.name}" def update(self, retention_properties, magnetic_store_write_properties): self.retention_properties = retention_properties @@ -31,10 +33,6 @@ class TimestreamTable(BaseModel): def write_records(self, records): self.records.extend(records) - @property - def arn(self): - return f"arn:aws:timestream:{self.region_name}:{get_account_id()}:database/{self.db_name}/table/{self.name}" - def description(self): return { "Arn": self.arn, @@ -47,12 +45,15 @@ class TimestreamTable(BaseModel): class TimestreamDatabase(BaseModel): - def __init__(self, region_name, database_name, kms_key_id): + def __init__(self, account_id, region_name, database_name, kms_key_id): + self.account_id = account_id self.region_name = region_name self.name = database_name self.kms_key_id = ( - kms_key_id - or f"arn:aws:kms:{region_name}:{get_account_id()}:key/default_key" + kms_key_id or f"arn:aws:kms:{region_name}:{account_id}:key/default_key" + ) + self.arn = ( + f"arn:aws:timestream:{self.region_name}:{account_id}:database/{self.name}" ) self.tables = dict() @@ -63,6 +64,7 @@ class TimestreamDatabase(BaseModel): self, table_name, retention_properties, magnetic_store_write_properties ): table = TimestreamTable( + account_id=self.account_id, region_name=self.region_name, table_name=table_name, db_name=self.name, @@ -93,10 +95,6 @@ class TimestreamDatabase(BaseModel): def list_tables(self): return self.tables.values() - @property - def arn(self): - return f"arn:aws:timestream:{self.region_name}:{get_account_id()}:database/{self.name}" - def description(self): return { "Arn": self.arn, @@ -113,7 +111,9 @@ class TimestreamWriteBackend(BaseBackend): self.tagging_service = TaggingService() def create_database(self, database_name, kms_key_id, tags): - database = TimestreamDatabase(self.region_name, database_name, kms_key_id) + database = TimestreamDatabase( + self.account_id, self.region_name, database_name, kms_key_id + ) self.databases[database_name] = database self.tagging_service.tag_resource(database.arn, tags) return database diff --git a/moto/timestreamwrite/responses.py b/moto/timestreamwrite/responses.py index 232202029..2580b3a3b 100644 --- a/moto/timestreamwrite/responses.py +++ b/moto/timestreamwrite/responses.py @@ -6,12 +6,12 @@ from .models import timestreamwrite_backends class TimestreamWriteResponse(BaseResponse): def __init__(self): - super().__init__() + super().__init__(service_name="timestream-write") @property def timestreamwrite_backend(self): """Return backend instance specific for this region.""" - return timestreamwrite_backends[self.region] + return timestreamwrite_backends[self.current_account][self.region] def create_database(self): database_name = self._get_param("DatabaseName") diff --git a/moto/transcribe/models.py b/moto/transcribe/models.py index b58cefdad..79fd4b1d9 100644 --- a/moto/transcribe/models.py +++ b/moto/transcribe/models.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timedelta -from moto.core import BaseBackend, BaseModel, get_account_id +from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.moto_api import state_manager from moto.moto_api._internal.managed_state_model import ManagedState @@ -31,6 +31,7 @@ class BaseObject(BaseModel): class FakeTranscriptionJob(BaseObject, ManagedState): def __init__( self, + account_id, region_name, transcription_job_name, language_code, @@ -56,6 +57,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): ("IN_PROGRESS", "COMPLETED"), ], ) + self._account_id = account_id self._region_name = region_name self.transcription_job_name = transcription_job_name self.language_code = language_code @@ -200,7 +202,7 @@ class FakeTranscriptionJob(BaseObject, ManagedState): else: transcript_file_uri = "https://s3.{0}.amazonaws.com/aws-transcribe-{0}-prod/{1}/{2}/{3}/asrOutput.json".format( # noqa: E501 self._region_name, - get_account_id(), + self._account_id, self.transcription_job_name, uuid.uuid4(), ) @@ -210,7 +212,13 @@ class FakeTranscriptionJob(BaseObject, ManagedState): class FakeVocabulary(BaseObject, ManagedState): def __init__( - self, region_name, vocabulary_name, language_code, phrases, vocabulary_file_uri + self, + account_id, + region_name, + vocabulary_name, + language_code, + phrases, + vocabulary_file_uri, ): # Configured ManagedState super().__init__( @@ -226,7 +234,7 @@ class FakeVocabulary(BaseObject, ManagedState): self.last_modified_time = None self.failure_reason = None self.download_uri = "https://s3.{0}.amazonaws.com/aws-transcribe-dictionary-model-{0}-prod/{1}/{2}/{3}/input.txt".format( # noqa: E501 - region_name, get_account_id(), vocabulary_name, uuid + region_name, account_id, vocabulary_name, uuid ) def response_object(self, response_type): @@ -406,9 +414,15 @@ class FakeMedicalTranscriptionJob(BaseObject, ManagedState): class FakeMedicalVocabulary(FakeVocabulary): def __init__( - self, region_name, vocabulary_name, language_code, vocabulary_file_uri + self, + account_id, + region_name, + vocabulary_name, + language_code, + vocabulary_file_uri, ): super().__init__( + account_id, region_name, vocabulary_name, language_code=language_code, @@ -423,7 +437,7 @@ class FakeMedicalVocabulary(FakeVocabulary): self.last_modified_time = None self.failure_reason = None self.download_uri = "https://s3.us-east-1.amazonaws.com/aws-transcribe-dictionary-model-{}-prod/{}/medical/{}/{}/input.txt".format( # noqa: E501 - region_name, get_account_id(), self.vocabulary_name, uuid.uuid4() + region_name, account_id, self.vocabulary_name, uuid.uuid4() ) @@ -477,6 +491,7 @@ class TranscribeBackend(BaseBackend): ) transcription_job_object = FakeTranscriptionJob( + account_id=self.account_id, region_name=self.region_name, transcription_job_name=name, language_code=kwargs.get("language_code"), @@ -662,6 +677,7 @@ class TranscribeBackend(BaseBackend): ) vocabulary_object = FakeVocabulary( + account_id=self.account_id, region_name=self.region_name, vocabulary_name=vocabulary_name, language_code=language_code, @@ -686,6 +702,7 @@ class TranscribeBackend(BaseBackend): ) medical_vocabulary_object = FakeMedicalVocabulary( + account_id=self.account_id, region_name=self.region_name, vocabulary_name=vocabulary_name, language_code=language_code, diff --git a/moto/transcribe/responses.py b/moto/transcribe/responses.py index af26f41a7..05e06bab7 100644 --- a/moto/transcribe/responses.py +++ b/moto/transcribe/responses.py @@ -6,9 +6,12 @@ from .models import transcribe_backends class TranscribeResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="transcribe") + @property def transcribe_backend(self): - return transcribe_backends[self.region] + return transcribe_backends[self.current_account][self.region] @property def request_params(self): diff --git a/moto/wafv2/models.py b/moto/wafv2/models.py index cfe891bd8..37f77325a 100644 --- a/moto/wafv2/models.py +++ b/moto/wafv2/models.py @@ -80,7 +80,11 @@ class WAFV2Backend(BaseBackend): def create_web_acl(self, name, visibility_config, default_action, scope): wacl_id = str(uuid4()) arn = make_arn_for_wacl( - name=name, region_name=self.region_name, wacl_id=wacl_id, scope=scope + name=name, + account_id=self.account_id, + region_name=self.region_name, + wacl_id=wacl_id, + scope=scope, ) if arn in self.wacls or self._is_duplicate_name(name): raise WAFV2DuplicateItemException() diff --git a/moto/wafv2/responses.py b/moto/wafv2/responses.py index 9792ff6b9..37d0af200 100644 --- a/moto/wafv2/responses.py +++ b/moto/wafv2/responses.py @@ -6,9 +6,12 @@ from .models import GLOBAL_REGION, wafv2_backends class WAFV2Response(BaseResponse): + def __init__(self): + super().__init__(service_name="wafv2") + @property def wafv2_backend(self): - return wafv2_backends[self.region] # default region is "us-east-1" + return wafv2_backends[self.current_account][self.region] @amzn_request_id def create_web_acl(self): diff --git a/moto/wafv2/utils.py b/moto/wafv2/utils.py index f5e437d2c..e0af81346 100644 --- a/moto/wafv2/utils.py +++ b/moto/wafv2/utils.py @@ -1,17 +1,14 @@ -from moto.core import get_account_id from moto.core.utils import pascal_to_camelcase, camelcase_to_underscores -def make_arn_for_wacl(name, region_name, wacl_id, scope): +def make_arn_for_wacl(name, account_id, region_name, wacl_id, scope): """https://docs.aws.amazon.com/waf/latest/developerguide/how-aws-waf-works.html - explains --scope (cloudfront vs regional)""" if scope == "REGIONAL": scope = "regional" elif scope == "CLOUDFRONT": scope = "global" - return "arn:aws:wafv2:{}:{}:{}/webacl/{}/{}".format( - region_name, get_account_id(), scope, name, wacl_id - ) + return f"arn:aws:wafv2:{region_name}:{account_id}:{scope}/webacl/{name}/{wacl_id}" def pascal_to_underscores_dict(original_dict): diff --git a/moto/xray/responses.py b/moto/xray/responses.py index c173c4471..4b1d1c548 100644 --- a/moto/xray/responses.py +++ b/moto/xray/responses.py @@ -10,12 +10,15 @@ from .exceptions import BadSegmentException class XRayResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="xray") + def _error(self, code, message): return json.dumps({"__type": code, "message": message}), dict(status=400) @property def xray_backend(self): - return xray_backends[self.region] + return xray_backends[self.current_account][self.region] @property def request_params(self): diff --git a/scripts/template/lib/responses.py.j2 b/scripts/template/lib/responses.py.j2 index def142151..5d3994163 100644 --- a/scripts/template/lib/responses.py.j2 +++ b/scripts/template/lib/responses.py.j2 @@ -8,10 +8,13 @@ from .models import {{ escaped_service }}_backends class {{ service_class }}Response(BaseResponse): """Handler for {{ service_class }} requests and responses.""" + def __init__(self): + super().__init__(service_name="{{ escaped_service }}") + @property def {{ escaped_service }}_backend(self): """Return backend instance specific for this region.""" - return {{ escaped_service }}_backends[self.region] + return {{ escaped_service }}_backends[self.current_account][self.region] # add methods from here diff --git a/tests/__init__.py b/tests/__init__.py index fb4a8ea14..e938a957f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,3 +13,5 @@ EXAMPLE_AMI_ID = "ami-12c6146b" EXAMPLE_AMI_ID2 = "ami-03cf127a" EXAMPLE_AMI_PARAVIRTUAL = "ami-fa7cdd89" EXAMPLE_AMI_WINDOWS = "ami-f4cf1d8d" + +DEFAULT_ACCOUNT_ID = "123456789012" diff --git a/tests/test_acm/test_acm.py b/tests/test_acm/test_acm.py index 6594787a5..cdbcea379 100644 --- a/tests/test_acm/test_acm.py +++ b/tests/test_acm/test_acm.py @@ -11,7 +11,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from freezegun import freeze_time from moto import mock_acm, mock_elb, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest, mock diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 647a7e1ef..507524fd1 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -6,7 +6,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_apigateway, mock_cognitoidp -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import pytest diff --git a/tests/test_apigateway/test_apigateway_deployments.py b/tests/test_apigateway/test_apigateway_deployments.py index 346f4f8cc..641e7bc25 100644 --- a/tests/test_apigateway/test_apigateway_deployments.py +++ b/tests/test_apigateway/test_apigateway_deployments.py @@ -185,7 +185,6 @@ def test_delete_deployment__requires_stage_to_be_deleted(): # Deployment still exists deployments = client.get_deployments(restApiId=api_id)["items"] - print(deployments) deployments.should.have.length_of(1) # Now delete deployment diff --git a/tests/test_applicationautoscaling/test_applicationautoscaling.py b/tests/test_applicationautoscaling/test_applicationautoscaling.py index 2807535fb..7782ae98d 100644 --- a/tests/test_applicationautoscaling/test_applicationautoscaling.py +++ b/tests/test_applicationautoscaling/test_applicationautoscaling.py @@ -2,7 +2,7 @@ import boto3 import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_applicationautoscaling, mock_ecs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID DEFAULT_REGION = "us-east-1" DEFAULT_ECS_CLUSTER = "default" diff --git a/tests/test_appsync/test_appsync.py b/tests/test_appsync/test_appsync.py index 14c275643..8f05817d1 100644 --- a/tests/test_appsync/test_appsync.py +++ b/tests/test_appsync/test_appsync.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_appsync -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_autoscaling/test_autoscaling.py b/tests/test_autoscaling/test_autoscaling.py index 9060b243e..1a9197ded 100644 --- a/tests/test_autoscaling/test_autoscaling.py +++ b/tests/test_autoscaling/test_autoscaling.py @@ -4,7 +4,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_autoscaling, mock_elb, mock_ec2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from .utils import setup_networking, setup_instance_with_networking from tests import EXAMPLE_AMI_ID diff --git a/tests/test_autoscaling/test_launch_configurations.py b/tests/test_autoscaling/test_launch_configurations.py index c31f1f7b3..39bd85e6e 100644 --- a/tests/test_autoscaling/test_launch_configurations.py +++ b/tests/test_autoscaling/test_launch_configurations.py @@ -6,7 +6,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_autoscaling, mock_ec2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID diff --git a/tests/test_awslambda/test_lambda.py b/tests/test_awslambda/test_lambda.py index e409e06d3..7ca021cb3 100644 --- a/tests/test_awslambda/test_lambda.py +++ b/tests/test_awslambda/test_lambda.py @@ -8,7 +8,7 @@ import pytest from botocore.exceptions import ClientError from freezegun import freeze_time from moto import mock_lambda, mock_s3 -from moto.core.models import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 from .utilities import ( get_role_name, diff --git a/tests/test_awslambda/test_lambda_alias.py b/tests/test_awslambda/test_lambda_alias.py index 3b731c64d..f324d4442 100644 --- a/tests/test_awslambda/test_lambda_alias.py +++ b/tests/test_awslambda/test_lambda_alias.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_lambda -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 from .utilities import ( get_role_name, diff --git a/tests/test_awslambda/test_lambda_layers.py b/tests/test_awslambda/test_lambda_layers.py index 0de7635ca..7404b9259 100644 --- a/tests/test_awslambda/test_lambda_layers.py +++ b/tests/test_awslambda/test_lambda_layers.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from freezegun import freeze_time from moto import mock_lambda, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 from .utilities import get_role_name, get_test_zip_file1 diff --git a/tests/test_awslambda/test_lambda_policy.py b/tests/test_awslambda/test_lambda_policy.py index 25087eb6c..9d6dd9d7c 100644 --- a/tests/test_awslambda/test_lambda_policy.py +++ b/tests/test_awslambda/test_lambda_policy.py @@ -5,7 +5,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_lambda, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 from .utilities import get_role_name, get_test_zip_file1 diff --git a/tests/test_awslambda/test_lambda_tags.py b/tests/test_awslambda/test_lambda_tags.py index a789eebf9..916f6ad7d 100644 --- a/tests/test_awslambda/test_lambda_tags.py +++ b/tests/test_awslambda/test_lambda_tags.py @@ -3,7 +3,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_lambda, mock_s3 -from moto.core.models import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 from .utilities import get_role_name, get_test_zip_file2 diff --git a/tests/test_budgets/test_budgets.py b/tests/test_budgets/test_budgets.py index b3de31214..90e2f8715 100644 --- a/tests/test_budgets/test_budgets.py +++ b/tests/test_budgets/test_budgets.py @@ -5,7 +5,7 @@ from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from datetime import datetime from moto import mock_budgets -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_budgets diff --git a/tests/test_budgets/test_notifications.py b/tests/test_budgets/test_notifications.py index e836e5f48..7e1de0747 100644 --- a/tests/test_budgets/test_notifications.py +++ b/tests/test_budgets/test_notifications.py @@ -4,7 +4,7 @@ import pytest from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_budgets -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_budgets diff --git a/tests/test_ce/test_ce.py b/tests/test_ce/test_ce.py index 08cb15c74..e49d0ceb0 100644 --- a/tests/test_ce/test_ce.py +++ b/tests/test_ce/test_ce.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_ce -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_cloudformation/test_cloudformation_depends_on.py b/tests/test_cloudformation/test_cloudformation_depends_on.py index 3d59b9767..3e6328d07 100644 --- a/tests/test_cloudformation/test_cloudformation_depends_on.py +++ b/tests/test_cloudformation/test_cloudformation_depends_on.py @@ -2,7 +2,7 @@ import boto3 from moto import mock_cloudformation, mock_ecs, mock_autoscaling, mock_s3 import json -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID depends_on_template_list = { diff --git a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py index 904023c40..657446b4c 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py +++ b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py @@ -20,7 +20,7 @@ from moto import ( settings, ) from moto.cloudformation import cloudformation_backends -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID @@ -1911,7 +1911,9 @@ def test_update_stack_when_rolled_back(): stack = cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) stack_id = stack["StackId"] - cloudformation_backends["us-east-1"].stacks[stack_id].status = "ROLLBACK_COMPLETE" + cloudformation_backends[ACCOUNT_ID]["us-east-1"].stacks[ + stack_id + ].status = "ROLLBACK_COMPLETE" with pytest.raises(ClientError) as ex: cf.update_stack(StackName="test_stack", TemplateBody=dummy_template_json) diff --git a/tests/test_cloudformation/test_cloudformation_stack_integration.py b/tests/test_cloudformation/test_cloudformation_stack_integration.py index 9ae341085..f10267641 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_integration.py +++ b/tests/test_cloudformation/test_cloudformation_stack_integration.py @@ -23,7 +23,7 @@ from moto import ( mock_sqs, mock_elbv2, ) -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID, EXAMPLE_AMI_ID2 from tests.test_cloudformation.fixtures import fn_join, single_instance_with_ebs_volume diff --git a/tests/test_cloudformation/test_stack_parsing.py b/tests/test_cloudformation/test_stack_parsing.py index 4481e060b..87229a4be 100644 --- a/tests/test_cloudformation/test_stack_parsing.py +++ b/tests/test_cloudformation/test_stack_parsing.py @@ -15,6 +15,7 @@ from moto.cloudformation.parsing import ( Output, ) from moto import mock_cloudformation, mock_sqs, mock_ssm, settings +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.sqs.models import Queue from moto.s3.models import FakeBucket from moto.cloudformation.utils import yaml_tag_constructor @@ -183,6 +184,7 @@ def test_parse_stack_resources(): name="test_stack", template=dummy_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -209,6 +211,7 @@ def test_parse_stack_with_name_type_resource(): name="test_stack", template=name_type_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -224,6 +227,7 @@ def test_parse_stack_with_tabbed_json_template(): name="test_stack", template=name_type_template_with_tabs_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -239,6 +243,7 @@ def test_parse_stack_with_yaml_template(): name="test_stack", template=yaml.dump(name_type_template), parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -254,6 +259,7 @@ def test_parse_stack_with_outputs(): name="test_stack", template=output_type_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -270,6 +276,7 @@ def test_parse_stack_with_get_attribute_outputs(): name="test_stack", template=get_attribute_outputs_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -289,6 +296,7 @@ def test_parse_stack_with_get_attribute_kms(): name="test_stack", template=template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -304,6 +312,7 @@ def test_parse_stack_with_get_availability_zones(): name="test_stack", template=get_availability_zones_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-east-1", ) @@ -328,11 +337,17 @@ def test_parse_stack_with_bad_get_attribute_outputs_using_boto3(): def test_parse_stack_with_null_outputs_section(): - FakeStack.when.called_with( - "test_id", "test_stack", null_output_template_json, {}, "us-west-1" - ).should.throw( - ValidationError, "[/Outputs] 'null' values are not allowed in templates" - ) + with pytest.raises(ValidationError) as exc: + FakeStack( + "test_id", + "test_stack", + null_output_template_json, + {}, + account_id=ACCOUNT_ID, + region_name="us-west-1", + ) + err = str(exc.value) + err.should.contain("[/Outputs] 'null' values are not allowed in templates") def test_parse_stack_with_parameters(): @@ -346,6 +361,7 @@ def test_parse_stack_with_parameters(): "NumberListParam": "42,3.14159", "NoEchoParam": "hidden value", }, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -449,6 +465,7 @@ def test_parse_split_and_select(): name="test_stack", template=split_select_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -463,6 +480,7 @@ def test_sub(): name="test_stack", template=sub_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -477,6 +495,7 @@ def test_import(): name="test_stack", template=export_value_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) import_stack = FakeStack( @@ -484,6 +503,7 @@ def test_import(): name="test_stack", template=import_value_template_json, parameters={}, + account_id=ACCOUNT_ID, region_name="us-west-1", cross_stack_resources={export_stack.exports[0].value: export_stack.exports[0]}, ) @@ -552,6 +572,7 @@ def test_ssm_parameter_parsing(): "SingleParamCfn": "/path/to/single/param", "ListParamCfn": "/path/to/list/param", }, + account_id=ACCOUNT_ID, region_name="us-west-1", ) @@ -567,6 +588,7 @@ def test_ssm_parameter_parsing(): name="test_stack", template=ssm_parameter_template_json, parameters={"SingleParamCfn": "/path/to/single/param"}, + account_id=ACCOUNT_ID, region_name="us-west-1", ) diff --git a/tests/test_cloudfront/test_cloudfront.py b/tests/test_cloudfront/test_cloudfront.py index 612c775f7..509d4e076 100644 --- a/tests/test_cloudfront/test_cloudfront.py +++ b/tests/test_cloudfront/test_cloudfront.py @@ -4,7 +4,7 @@ import boto3 from botocore.exceptions import ClientError, ParamValidationError from moto import mock_cloudfront import sure # noqa # pylint: disable=unused-import -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from . import cloudfront_test_scaffolding as scaffold # See our Development Tips on writing tests for hints on how to write good tests: diff --git a/tests/test_cloudfront/test_cloudfront_distributions.py b/tests/test_cloudfront/test_cloudfront_distributions.py index e5b6cc587..daf71aa21 100644 --- a/tests/test_cloudfront/test_cloudfront_distributions.py +++ b/tests/test_cloudfront/test_cloudfront_distributions.py @@ -1,7 +1,7 @@ import boto3 from botocore.exceptions import ClientError from moto import mock_cloudfront -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from . import cloudfront_test_scaffolding as scaffold import pytest import sure # noqa # pylint: disable=unused-import diff --git a/tests/test_cloudtrail/test_cloudtrail.py b/tests/test_cloudtrail/test_cloudtrail.py index 23e1b9fe6..48ac54cae 100644 --- a/tests/test_cloudtrail/test_cloudtrail.py +++ b/tests/test_cloudtrail/test_cloudtrail.py @@ -6,7 +6,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from datetime import datetime from moto import mock_cloudtrail, mock_s3, mock_sns -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 diff --git a/tests/test_cloudtrail/test_cloudtrail_eventselectors.py b/tests/test_cloudtrail/test_cloudtrail_eventselectors.py index e7f1b3c3f..0e9cfe088 100644 --- a/tests/test_cloudtrail/test_cloudtrail_eventselectors.py +++ b/tests/test_cloudtrail/test_cloudtrail_eventselectors.py @@ -2,7 +2,7 @@ import boto3 import pytest from moto import mock_cloudtrail, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from .test_cloudtrail import create_trail_simple diff --git a/tests/test_cloudwatch/test_cloudwatch_alarms.py b/tests/test_cloudwatch/test_cloudwatch_alarms.py index e5337ad8e..46e454486 100644 --- a/tests/test_cloudwatch/test_cloudwatch_alarms.py +++ b/tests/test_cloudwatch/test_cloudwatch_alarms.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_cloudwatch -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_cloudwatch diff --git a/tests/test_cloudwatch/test_cloudwatch_boto3.py b/tests/test_cloudwatch/test_cloudwatch_boto3.py index a9c725de3..2e4a85ba2 100644 --- a/tests/test_cloudwatch/test_cloudwatch_boto3.py +++ b/tests/test_cloudwatch/test_cloudwatch_boto3.py @@ -12,7 +12,7 @@ from operator import itemgetter from uuid import uuid4 from moto import mock_cloudwatch, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_cloudwatch diff --git a/tests/test_cloudwatch/test_cloudwatch_tags.py b/tests/test_cloudwatch/test_cloudwatch_tags.py index 8b467fa80..dddab423c 100644 --- a/tests/test_cloudwatch/test_cloudwatch_tags.py +++ b/tests/test_cloudwatch/test_cloudwatch_tags.py @@ -7,7 +7,7 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_cloudwatch from moto.cloudwatch.utils import make_arn_for_alarm -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_cloudwatch diff --git a/tests/test_codebuild/test_codebuild.py b/tests/test_codebuild/test_codebuild.py index c6ec1320f..60905952c 100644 --- a/tests/test_codebuild/test_codebuild.py +++ b/tests/test_codebuild/test_codebuild.py @@ -1,7 +1,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_codebuild -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from botocore.exceptions import ClientError, ParamValidationError from uuid import uuid1 import pytest @@ -40,10 +40,9 @@ def test_codebuild_create_project_s3_artifacts(): serviceRole=service_role, ) - response.should_not.be.none - response["project"].should_not.be.none - response["project"]["serviceRole"].should_not.be.none - response["project"]["name"].should_not.be.none + response.should.have.key("project") + response["project"].should.have.key("serviceRole") + response["project"].should.have.key("name").equals(name) response["project"]["environment"].should.equal( { @@ -92,10 +91,9 @@ def test_codebuild_create_project_no_artifacts(): serviceRole=service_role, ) - response.should_not.be.none - response["project"].should_not.be.none - response["project"]["serviceRole"].should_not.be.none - response["project"]["name"].should_not.be.none + response.should.have.key("project") + response["project"].should.have.key("serviceRole") + response["project"].should.have.key("name").equals(name) response["project"]["environment"].should.equal( { @@ -258,7 +256,6 @@ def test_codebuild_list_projects(): projects = client.list_projects() - projects["projects"].should_not.be.none projects["projects"].should.equal(["project1", "project2"]) @@ -293,7 +290,7 @@ def test_codebuild_list_builds_for_project_no_history(): history = client.list_builds_for_project(projectName=name) # no build history if it's never started - history["ids"].should.be.empty + history["ids"].should.equal([]) @mock_codebuild @@ -327,7 +324,7 @@ def test_codebuild_list_builds_for_project_with_history(): client.start_build(projectName=name) response = client.list_builds_for_project(projectName=name) - response["ids"].should_not.be.empty + response["ids"].should.have.length_of(1) # project never started @@ -347,9 +344,7 @@ def test_codebuild_get_batch_builds_for_project_no_history(): environment["image"] = "contents_not_validated" environment["computeType"] = "BUILD_GENERAL1_SMALL" service_role = ( - "arn:aws:iam::{0}:role/service-role/my-codebuild-service-role".format( - ACCOUNT_ID - ) + f"arn:aws:iam::{ACCOUNT_ID}:role/service-role/my-codebuild-service-role" ) client.create_project( @@ -361,8 +356,7 @@ def test_codebuild_get_batch_builds_for_project_no_history(): ) response = client.list_builds_for_project(projectName=name) - response.should_not.be.none - response["ids"].should.be.empty + response["ids"].should.equal([]) with pytest.raises(ParamValidationError) as err: client.batch_get_builds(ids=response["ids"]) @@ -412,8 +406,7 @@ def test_codebuild_start_build_no_overrides(): ) response = client.start_build(projectName=name) - response.should_not.be.none - response["build"].should_not.be.none + response.should.have.key("build") response["build"]["sourceVersion"].should.equal("refs/heads/main") @@ -491,8 +484,7 @@ def test_codebuild_start_build_with_overrides(): artifactsOverride=artifacts_override, ) - response.should_not.be.none - response["build"].should_not.be.none + response.should.have.key("build") response["build"]["sourceVersion"].should.equal("fix/testing") @@ -529,11 +521,10 @@ def test_codebuild_batch_get_builds_1_project(): history = client.list_builds_for_project(projectName=name) response = client.batch_get_builds(ids=history["ids"]) - response.should_not.be.none - response["builds"].should_not.be.none + response.should.have.key("builds").length_of(1) response["builds"][0]["currentPhase"].should.equal("COMPLETED") response["builds"][0]["buildNumber"].should.be.a(int) - response["builds"][0]["phases"].should_not.be.none + response["builds"][0].should.have.key("phases") len(response["builds"][0]["phases"]).should.equal(11) @@ -576,13 +567,14 @@ def test_codebuild_batch_get_builds_2_projects(): client.start_build(projectName="project-2") response = client.list_builds() - response["ids"].should_not.be.empty + response["ids"].should.have.length_of(2) "project-1".should.be.within(response["ids"][0]) "project-2".should.be.within(response["ids"][1]) metadata = client.batch_get_builds(ids=response["ids"])["builds"] - metadata.should_not.be.none + metadata.should.have.length_of(2) + "project-1".should.be.within(metadata[0]["id"]) "project-2".should.be.within(metadata[1]["id"]) @@ -636,7 +628,7 @@ def test_codebuild_delete_project(): client.start_build(projectName=name) response = client.list_builds_for_project(projectName=name) - response["ids"].should_not.be.empty + response["ids"].should.have.length_of(1) client.delete_project(name=name) diff --git a/tests/test_codecommit/test_codecommit.py b/tests/test_codecommit/test_codecommit.py index 750a6239c..cf918b34c 100644 --- a/tests/test_codecommit/test_codecommit.py +++ b/tests/test_codecommit/test_codecommit.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_codecommit -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from botocore.exceptions import ClientError import pytest diff --git a/tests/test_cognitoidentity/test_cognitoidentity.py b/tests/test_cognitoidentity/test_cognitoidentity.py index b30cad7c2..c61e4e01e 100644 --- a/tests/test_cognitoidentity/test_cognitoidentity.py +++ b/tests/test_cognitoidentity/test_cognitoidentity.py @@ -5,7 +5,7 @@ import pytest from moto import mock_cognitoidentity from moto.cognitoidentity.utils import get_random_identity_id -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import UUID diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index ef3677995..abb54e997 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -22,7 +22,7 @@ import pytest from moto import mock_cognitoidp, settings from moto.cognitoidp.utils import create_id -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_cognitoidp @@ -2408,7 +2408,7 @@ def test_get_user_unconfirmed(): conn = boto3.client("cognito-idp", "us-west-2") outputs = authentication_flow(conn, "ADMIN_NO_SRP_AUTH") - backend = moto.cognitoidp.models.cognitoidp_backends["us-west-2"] + backend = moto.cognitoidp.models.cognitoidp_backends[ACCOUNT_ID]["us-west-2"] user_pool = backend.user_pools[outputs["user_pool_id"]] user_pool.users[outputs["username"]].status = "UNCONFIRMED" diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index a7196a1da..86a576815 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -10,7 +10,7 @@ import pytest from moto import mock_s3 from moto.config import mock_config -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import sure # noqa # pylint: disable=unused-import diff --git a/tests/test_config/test_config_rules_integration.py b/tests/test_config/test_config_rules_integration.py index 20931cef2..bf4ece64d 100644 --- a/tests/test_config/test_config_rules_integration.py +++ b/tests/test_config/test_config_rules_integration.py @@ -1,7 +1,7 @@ from .test_config_rules import managed_config_rule, TEST_REGION from botocore.exceptions import ClientError from moto import mock_config, mock_iam, mock_lambda -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from io import BytesIO from uuid import uuid4 from zipfile import ZipFile, ZIP_DEFLATED diff --git a/tests/test_config/test_config_tags.py b/tests/test_config/test_config_tags.py index a8166d5d0..b7257ec0e 100644 --- a/tests/test_config/test_config_tags.py +++ b/tests/test_config/test_config_tags.py @@ -14,7 +14,7 @@ import pytest from moto.config import mock_config from moto.config.models import MAX_TAGS_IN_ARG from moto.config.models import random_string -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID TEST_REGION = "us-east-1" diff --git a/tests/test_core/test_account_id_resolution.py b/tests/test_core/test_account_id_resolution.py new file mode 100644 index 000000000..b181f1401 --- /dev/null +++ b/tests/test_core/test_account_id_resolution.py @@ -0,0 +1,67 @@ +import os + +import requests +import xmltodict + +from moto import settings +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from moto.server import ThreadedMotoServer +from unittest import SkipTest + + +SERVER_PORT = 5001 +BASE_URL = f"http://localhost:{SERVER_PORT}/" + + +class TestAccountIdResolution: + def setup(self): + if settings.TEST_SERVER_MODE: + raise SkipTest( + "No point in testing this in ServerMode, as we already start our own server" + ) + self.server = ThreadedMotoServer(port=SERVER_PORT, verbose=False) + self.server.start() + + def teardown(self): + self.server.stop() + + def test_environment_variable_takes_precedence(self): + # Verify ACCOUNT ID is standard + resp = self._get_caller_identity() + self._get_account_id(resp).should.equal(ACCOUNT_ID) + + # Specify environment variable, and verify this becomes the new ACCOUNT ID + os.environ["MOTO_ACCOUNT_ID"] = "111122223333" + resp = self._get_caller_identity() + self._get_account_id(resp).should.equal("111122223333") + + # Specify special request header - the environment variable should still take precedence + resp = self._get_caller_identity( + extra_headers={"x-moto-account-id": "333344445555"} + ) + self._get_account_id(resp).should.equal("111122223333") + + # Remove the environment variable - the Request Header should now take precedence + del os.environ["MOTO_ACCOUNT_ID"] + resp = self._get_caller_identity( + extra_headers={"x-moto-account-id": "333344445555"} + ) + self._get_account_id(resp).should.equal("333344445555") + + # Without Header, we're back to the regular account ID + resp = self._get_caller_identity() + self._get_account_id(resp).should.equal(ACCOUNT_ID) + + def _get_caller_identity(self, extra_headers=None): + data = "Action=GetCallerIdentity&Version=2011-06-15" + headers = { + "Authorization": "AWS4-HMAC-SHA256 Credential=abcd/20010101/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=...", + "Content-Length": f"{len(data)}", + "Content-Type": "application/x-www-form-urlencoded", + } + headers.update(extra_headers or {}) + return requests.post(f"{BASE_URL}", headers=headers, data=data) + + def _get_account_id(self, resp): + data = xmltodict.parse(resp.content) + return data["GetCallerIdentityResponse"]["GetCallerIdentityResult"]["Account"] diff --git a/tests/test_core/test_auth.py b/tests/test_core/test_auth.py index 00460c7dd..6966a8f6d 100644 --- a/tests/test_core/test_auth.py +++ b/tests/test_core/test_auth.py @@ -8,7 +8,7 @@ import pytest from moto import mock_iam, mock_ec2, mock_s3, mock_sts, mock_elbv2, mock_rds from moto.core import set_initial_no_auth_action_count -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 diff --git a/tests/test_core/test_backenddict.py b/tests/test_core/test_backenddict.py index e468cf367..1d7f557fb 100644 --- a/tests/test_core/test_backenddict.py +++ b/tests/test_core/test_backenddict.py @@ -2,8 +2,13 @@ import random import time import pytest -from moto.core import BaseBackend +from moto.core import BaseBackend, DEFAULT_ACCOUNT_ID from moto.core.utils import AccountSpecificBackend, BackendDict + +from moto.autoscaling.models import AutoScalingBackend +from moto.ec2.models import EC2Backend +from moto.elbv2.models import ELBv2Backend + from threading import Thread @@ -14,43 +19,41 @@ class ExampleBackend(BaseBackend): def test_backend_dict_returns_nothing_by_default(): backend_dict = BackendDict(ExampleBackend, "ebs") - backend_dict.should.equal({}) + list(backend_dict.items()).should.equal([]) -def test_backend_dict_contains_known_regions(): +def test_account_specific_dict_contains_known_regions(): backend_dict = BackendDict(ExampleBackend, "ec2") - backend_dict.should.have.key("eu-north-1") - backend_dict["eu-north-1"].should.be.a(ExampleBackend) + backend_dict["account"].should.have.key("eu-north-1") + backend_dict["account"]["eu-north-1"].should.be.a(ExampleBackend) def test_backend_dict_known_regions_can_be_retrieved_directly(): backend_dict = BackendDict(ExampleBackend, "ec2") - backend_dict["eu-west-1"].should.be.a(ExampleBackend) + backend_dict["account"]["eu-west-1"].should.be.a(ExampleBackend) def test_backend_dict_can_get_known_regions(): - backend_dict = BackendDict(ExampleBackend, "ec2") - backend_dict.get("us-east-1").should.be.a(ExampleBackend) + backend_dict = BackendDict(ExampleBackend, "ec2")["12345"] + backend_dict["us-east-1"].should.be.a(ExampleBackend) def test_backend_dict_does_not_contain_unknown_regions(): backend_dict = BackendDict(ExampleBackend, "ec2") - backend_dict.shouldnt.have.key("mars-south-1") + backend_dict["account"].shouldnt.have.key("mars-south-1") def test_backend_dict_fails_when_retrieving_unknown_regions(): backend_dict = BackendDict(ExampleBackend, "ec2") with pytest.raises(KeyError): - backend_dict["mars-south-1"] # pylint: disable=pointless-statement + backend_dict["account"]["mars-south-1"] # pylint: disable=pointless-statement def test_backend_dict_can_retrieve_for_specific_account(): backend_dict = BackendDict(ExampleBackend, "ec2") - # Retrieve AccountSpecificBackend after checking it exists - backend_dict.should.have.key("000000") - backend = backend_dict.get("000000") - backend.should.be.a(AccountSpecificBackend) + # Random account does not exist + backend_dict.shouldnt.have.key("000000") # Retrieve AccountSpecificBackend by assuming it exists backend = backend_dict["012345"] @@ -61,21 +64,21 @@ def test_backend_dict_can_retrieve_for_specific_account(): regional_backend.should.be.a(ExampleBackend) regional_backend.region_name.should.equal("eu-north-1") # We always return a fixed account_id for now, until we have proper multi-account support - regional_backend.account_id.should.equal("123456789012") + regional_backend.account_id.should.equal("012345") def test_backend_dict_can_ignore_boto3_regions(): backend_dict = BackendDict(ExampleBackend, "ec2", use_boto3_regions=False) - backend_dict.get("us-east-1").should.equal(None) + backend_dict["account"].get("us-east-1").should.equal(None) def test_backend_dict_can_specify_additional_regions(): backend_dict = BackendDict( ExampleBackend, "ec2", additional_regions=["region1", "global"] - ) - backend_dict.get("us-east-1").should.be.a(ExampleBackend) - backend_dict.get("region1").should.be.a(ExampleBackend) - backend_dict.get("global").should.be.a(ExampleBackend) + )["123456"] + backend_dict["us-east-1"].should.be.a(ExampleBackend) + backend_dict["region1"].should.be.a(ExampleBackend) + backend_dict["global"].should.be.a(ExampleBackend) # Unknown regions still do not exist backend_dict.get("us-east-3").should.equal(None) @@ -92,7 +95,7 @@ class TestMultiThreadedAccess: self.backend = BackendDict(TestMultiThreadedAccess.SlowExampleBackend, "ec2") def test_access_a_slow_backend_concurrently(self): - """ + """None Usecase that we want to avoid: Thread 1 comes in, and sees the backend does not exist for this region @@ -126,3 +129,140 @@ class TestMultiThreadedAccess: x.join() self.backend["123456789012"]["us-east-1"].data.should.have.length_of(15) + + +def test_backend_dict_can_be_hashed(): + hashes = [] + for backend in [ExampleBackend, set, list, BaseBackend]: + hashes.append(BackendDict(backend, "n/a").__hash__()) + # Hash is different for different backends + set(hashes).should.have.length_of(4) + + +def test_account_specific_dict_can_be_hashed(): + hashes = [] + ids = ["01234567912", "01234567911", "01234567913", "000000000000", "0"] + for accnt_id in ids: + asb = _create_asb(accnt_id) + hashes.append(asb.__hash__()) + # Hash is different for different accounts + set(hashes).should.have.length_of(5) + + +def _create_asb(account_id, backend=None, use_boto3_regions=False, regions=None): + return AccountSpecificBackend( + service_name="ec2", + account_id=account_id, + backend=backend or ExampleBackend, + use_boto3_regions=use_boto3_regions, + additional_regions=regions, + ) + + +def test_multiple_backends_cache_behaviour(): + + ec2 = BackendDict(EC2Backend, "ec2") + ec2_useast1 = ec2[DEFAULT_ACCOUNT_ID]["us-east-1"] + assert type(ec2_useast1) == EC2Backend + + autoscaling = BackendDict(AutoScalingBackend, "autoscaling") + as_1 = autoscaling[DEFAULT_ACCOUNT_ID]["us-east-1"] + assert type(as_1) == AutoScalingBackend + + from moto.elbv2 import elbv2_backends + + elbv2_useast = elbv2_backends["00000000"]["us-east-1"] + assert type(elbv2_useast) == ELBv2Backend + elbv2_useast2 = elbv2_backends[DEFAULT_ACCOUNT_ID]["us-east-2"] + assert type(elbv2_useast2) == ELBv2Backend + + ec2_useast1 = ec2[DEFAULT_ACCOUNT_ID]["us-east-1"] + assert type(ec2_useast1) == EC2Backend + ec2_useast2 = ec2[DEFAULT_ACCOUNT_ID]["us-east-2"] + assert type(ec2_useast2) == EC2Backend + + as_1 = autoscaling[DEFAULT_ACCOUNT_ID]["us-east-1"] + assert type(as_1) == AutoScalingBackend + + +def test_backenddict_cache_hits_and_misses(): + backend = BackendDict(ExampleBackend, "ebs") + backend.__getitem__.cache_clear() + + assert backend.__getitem__.cache_info().hits == 0 + assert backend.__getitem__.cache_info().misses == 0 + assert backend.__getitem__.cache_info().currsize == 0 + + # Create + Retrieve an account - verify it is stored in cache + accnt_1 = backend["accnt1"] + assert accnt_1.account_id == "accnt1" + + assert backend.__getitem__.cache_info().hits == 0 + assert backend.__getitem__.cache_info().misses == 1 + assert backend.__getitem__.cache_info().currsize == 1 + + # Creating + Retrieving a second account + accnt_2 = backend["accnt2"] + assert accnt_2.account_id == "accnt2" + + assert backend.__getitem__.cache_info().hits == 0 + assert backend.__getitem__.cache_info().misses == 2 + assert backend.__getitem__.cache_info().currsize == 2 + + # Retrieving the first account from cache + accnt_1_again = backend["accnt1"] + assert accnt_1_again.account_id == "accnt1" + + assert backend.__getitem__.cache_info().hits == 1 + assert backend.__getitem__.cache_info().misses == 2 + assert backend.__getitem__.cache_info().currsize == 2 + + # Retrieving the second account from cache + accnt_2_again = backend["accnt2"] + assert accnt_2_again.account_id == "accnt2" + + assert backend.__getitem__.cache_info().hits == 2 + assert backend.__getitem__.cache_info().misses == 2 + assert backend.__getitem__.cache_info().currsize == 2 + + +def test_asb_cache_hits_and_misses(): + backend = BackendDict(ExampleBackend, "ebs") + acb = backend["accnt_id"] + acb.__getitem__.cache_clear() + + assert acb.__getitem__.cache_info().hits == 0 + assert acb.__getitem__.cache_info().misses == 0 + assert acb.__getitem__.cache_info().currsize == 0 + + # Create + Retrieve an account - verify it is stored in cache + region_1 = acb["us-east-1"] + assert region_1.region_name == "us-east-1" + + assert acb.__getitem__.cache_info().hits == 0 + assert acb.__getitem__.cache_info().misses == 1 + assert acb.__getitem__.cache_info().currsize == 1 + + # Creating + Retrieving a second account + region_2 = acb["us-east-2"] + assert region_2.region_name == "us-east-2" + + assert acb.__getitem__.cache_info().hits == 0 + assert acb.__getitem__.cache_info().misses == 2 + assert acb.__getitem__.cache_info().currsize == 2 + + # Retrieving the first account from cache + region_1_again = acb["us-east-1"] + assert region_1_again.region_name == "us-east-1" + + assert acb.__getitem__.cache_info().hits == 1 + assert acb.__getitem__.cache_info().misses == 2 + assert acb.__getitem__.cache_info().currsize == 2 + + # Retrieving the second account from cache + region_2_again = acb["us-east-2"] + assert region_2_again.region_name == "us-east-2" + + assert acb.__getitem__.cache_info().hits == 2 + assert acb.__getitem__.cache_info().misses == 2 + assert acb.__getitem__.cache_info().currsize == 2 diff --git a/tests/test_core/test_context_manager.py b/tests/test_core/test_context_manager.py index 077e6c44f..d7d2d6d1c 100644 --- a/tests/test_core/test_context_manager.py +++ b/tests/test_core/test_context_manager.py @@ -1,6 +1,7 @@ import sure # noqa # pylint: disable=unused-import import boto3 from moto import mock_sqs, settings +from tests import DEFAULT_ACCOUNT_ID def test_context_manager_returns_mock(): @@ -9,4 +10,5 @@ def test_context_manager_returns_mock(): conn.create_queue(QueueName="queue1") if not settings.TEST_SERVER_MODE: - list(sqs_mock.backends["us-west-1"].queues.keys()).should.equal(["queue1"]) + backend = sqs_mock.backends[DEFAULT_ACCOUNT_ID]["us-west-1"] + list(backend.queues.keys()).should.equal(["queue1"]) diff --git a/tests/test_databrew/test_databrew_datasets.py b/tests/test_databrew/test_databrew_datasets.py index 81722324f..27f16c154 100644 --- a/tests/test_databrew/test_databrew_datasets.py +++ b/tests/test_databrew/test_databrew_datasets.py @@ -5,7 +5,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_databrew -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def _create_databrew_client(): diff --git a/tests/test_databrew/test_databrew_jobs.py b/tests/test_databrew/test_databrew_jobs.py index 06a5acfed..d559fe23c 100644 --- a/tests/test_databrew/test_databrew_jobs.py +++ b/tests/test_databrew/test_databrew_jobs.py @@ -5,7 +5,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_databrew -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def _create_databrew_client(): diff --git a/tests/test_dax/test_dax.py b/tests/test_dax/test_dax.py index 4091d2f0b..4722f8ae3 100644 --- a/tests/test_dax/test_dax.py +++ b/tests/test_dax/test_dax.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_dax -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_dynamodb/conftest.py b/tests/test_dynamodb/conftest.py index 5918f9891..63b1c950e 100644 --- a/tests/test_dynamodb/conftest.py +++ b/tests/test_dynamodb/conftest.py @@ -1,4 +1,5 @@ import pytest +from moto.core import DEFAULT_ACCOUNT_ID from moto.dynamodb.models import Table @@ -6,6 +7,7 @@ from moto.dynamodb.models import Table def table(): return Table( "Forums", + account_id=DEFAULT_ACCOUNT_ID, region="us-east-1", schema=[ {"KeyType": "HASH", "AttributeName": "forum_name"}, diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index d359e49f6..97213af75 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -7,6 +7,7 @@ from boto3.dynamodb.conditions import Attr, Key import re import sure # noqa # pylint: disable=unused-import from moto import mock_dynamodb, settings +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.dynamodb import dynamodb_backends from botocore.exceptions import ClientError @@ -401,7 +402,7 @@ def test_put_item_with_streams(): ) if not settings.TEST_SERVER_MODE: - table = dynamodb_backends["us-west-2"].get_table(name) + table = dynamodb_backends[ACCOUNT_ID]["us-west-2"].get_table(name) len(table.stream_shard.items).should.be.equal(1) stream_record = table.stream_shard.items[0].record stream_record["eventName"].should.be.equal("INSERT") diff --git a/tests/test_dynamodb/test_dynamodb_create_table.py b/tests/test_dynamodb/test_dynamodb_create_table.py index b7985ae7b..e640970a6 100644 --- a/tests/test_dynamodb/test_dynamodb_create_table.py +++ b/tests/test_dynamodb/test_dynamodb_create_table.py @@ -5,7 +5,7 @@ from datetime import datetime import pytest from moto import mock_dynamodb -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_dynamodb diff --git a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py index 5b7372fe8..caf5bdda2 100644 --- a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py @@ -5,7 +5,7 @@ import pytest from datetime import datetime from botocore.exceptions import ClientError from moto import mock_dynamodb -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import botocore diff --git a/tests/test_dynamodb_v20111205/test_server.py b/tests/test_dynamodb_v20111205/test_server.py index 0d88079f6..5b7fec296 100644 --- a/tests/test_dynamodb_v20111205/test_server.py +++ b/tests/test_dynamodb_v20111205/test_server.py @@ -315,7 +315,6 @@ def test_put_return_none_without_range_key(test_client): } res = test_client.post("/", headers=headers, json=request_body) res = json.loads(res.data) - print(res) # This seems wrong - it should return nothing, considering return_values is set to none res["Attributes"].should.equal({"hkey": "customer", "name": "myname"}) diff --git a/tests/test_ebs/test_ebs.py b/tests/test_ebs/test_ebs.py index 8e288834e..f4a19f291 100644 --- a/tests/test_ebs/test_ebs.py +++ b/tests/test_ebs/test_ebs.py @@ -3,7 +3,7 @@ import boto3 import hashlib import sure # noqa # pylint: disable=unused-import from moto import mock_ebs, mock_ec2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_ec2/test_amis.py b/tests/test_ec2/test_amis.py index d93c8ce32..40943a42b 100644 --- a/tests/test_ec2/test_amis.py +++ b/tests/test_ec2/test_amis.py @@ -6,9 +6,8 @@ import sure # noqa # pylint: disable=unused-import import random from moto import mock_ec2 -from moto.ec2.models import OWNER_ID from moto.ec2.models.amis import AMIS -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID, EXAMPLE_AMI_PARAVIRTUAL from uuid import uuid4 @@ -281,7 +280,7 @@ def test_copy_image_changes_owner_id(): # confirm the source ami owner id is different from the default owner id. # if they're ever the same it means this test is invalid. check_resp = conn.describe_images(ImageIds=[source_ami_id]) - check_resp["Images"][0]["OwnerId"].should_not.equal(OWNER_ID) + check_resp["Images"][0]["OwnerId"].should_not.equal(ACCOUNT_ID) new_image_name = str(uuid4())[0:6] @@ -296,7 +295,7 @@ def test_copy_image_changes_owner_id(): Owners=["self"], Filters=[{"Name": "name", "Values": [new_image_name]}] )["Images"] describe_resp.should.have.length_of(1) - describe_resp[0]["OwnerId"].should.equal(OWNER_ID) + describe_resp[0]["OwnerId"].should.equal(ACCOUNT_ID) describe_resp[0]["ImageId"].should.equal(copy_resp["ImageId"]) diff --git a/tests/test_ec2/test_carrier_gateways.py b/tests/test_ec2/test_carrier_gateways.py index aadda048f..fc09125e4 100644 --- a/tests/test_ec2/test_carrier_gateways.py +++ b/tests/test_ec2/test_carrier_gateways.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from botocore.exceptions import ClientError from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest diff --git a/tests/test_ec2/test_elastic_block_store.py b/tests/test_ec2/test_elastic_block_store.py index 071a61de7..9ce9d226b 100644 --- a/tests/test_ec2/test_elastic_block_store.py +++ b/tests/test_ec2/test_elastic_block_store.py @@ -4,7 +4,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_ec2 -from moto.ec2.models import OWNER_ID +from moto.core import DEFAULT_ACCOUNT_ID as OWNER_ID from moto.kms import mock_kms from tests import EXAMPLE_AMI_ID from uuid import uuid4 diff --git a/tests/test_ec2/test_elastic_network_interfaces.py b/tests/test_ec2/test_elastic_network_interfaces.py index c64872e41..c9282b053 100644 --- a/tests/test_ec2/test_elastic_network_interfaces.py +++ b/tests/test_ec2/test_elastic_network_interfaces.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.ec2.utils import random_private_ip from tests import EXAMPLE_AMI_ID from uuid import uuid4 diff --git a/tests/test_ec2/test_flow_logs.py b/tests/test_ec2/test_flow_logs.py index 9bd3eb562..fdba998ef 100644 --- a/tests/test_ec2/test_flow_logs.py +++ b/tests/test_ec2/test_flow_logs.py @@ -7,7 +7,7 @@ from botocore.parsers import ResponseParserError import sure # noqa # pylint: disable=unused-import from moto import settings, mock_ec2, mock_s3, mock_logs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.ec2.exceptions import FilterNotImplementedError from uuid import uuid4 diff --git a/tests/test_ec2/test_flow_logs_cloudformation.py b/tests/test_ec2/test_flow_logs_cloudformation.py index ebe8a095c..91db6c446 100644 --- a/tests/test_ec2/test_flow_logs_cloudformation.py +++ b/tests/test_ec2/test_flow_logs_cloudformation.py @@ -4,7 +4,7 @@ import json import sure # noqa # pylint: disable=unused-import from moto import mock_cloudformation, mock_ec2, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID from uuid import uuid4 diff --git a/tests/test_ec2/test_instances.py b/tests/test_ec2/test_instances.py index eee9283ea..8b4f3c6a4 100644 --- a/tests/test_ec2/test_instances.py +++ b/tests/test_ec2/test_instances.py @@ -9,7 +9,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError, ParamValidationError from freezegun import freeze_time from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID decode_method = base64.decodebytes diff --git a/tests/test_ec2/test_network_acls.py b/tests/test_ec2/test_network_acls.py index 6f777bef3..9904f32e2 100644 --- a/tests/test_ec2/test_network_acls.py +++ b/tests/test_ec2/test_network_acls.py @@ -4,7 +4,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_ec2, settings -from moto.ec2.models import OWNER_ID +from moto.core import DEFAULT_ACCOUNT_ID as OWNER_ID from random import randint from unittest import SkipTest diff --git a/tests/test_ec2/test_prefix_lists.py b/tests/test_ec2/test_prefix_lists.py index 8f08602e9..a31a5106a 100644 --- a/tests/test_ec2/test_prefix_lists.py +++ b/tests/test_ec2/test_prefix_lists.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_ec2 diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 331139d5a..3f9969545 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -9,7 +9,8 @@ from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, settings -from moto.ec2 import ec2_backend +from moto.core import DEFAULT_ACCOUNT_ID +from moto.ec2 import ec2_backends from random import randint from uuid import uuid4 from unittest import SkipTest @@ -1202,6 +1203,7 @@ def test_non_existent_security_group_raises_error_on_authorize(): def test_security_group_rules_added_via_the_backend_can_be_revoked_via_the_api(): if settings.TEST_SERVER_MODE: raise unittest.SkipTest("Can't test backend directly in server mode.") + ec2_backend = ec2_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] ec2_resource = boto3.resource("ec2", region_name="us-east-1") ec2_client = boto3.client("ec2", region_name="us-east-1") vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") diff --git a/tests/test_ec2/test_spot_fleet.py b/tests/test_ec2/test_spot_fleet.py index 689b56cf9..b6483972f 100644 --- a/tests/test_ec2/test_spot_fleet.py +++ b/tests/test_ec2/test_spot_fleet.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_ec2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID from uuid import uuid4 diff --git a/tests/test_ec2/test_transit_gateway.py b/tests/test_ec2/test_transit_gateway.py index a2d638bef..dcfc39d36 100644 --- a/tests/test_ec2/test_transit_gateway.py +++ b/tests/test_ec2/test_transit_gateway.py @@ -1,7 +1,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest diff --git a/tests/test_ec2/test_transit_gateway_peering_attachments.py b/tests/test_ec2/test_transit_gateway_peering_attachments.py index 8fe86e397..a61cd5f91 100644 --- a/tests/test_ec2/test_transit_gateway_peering_attachments.py +++ b/tests/test_ec2/test_transit_gateway_peering_attachments.py @@ -1,7 +1,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest diff --git a/tests/test_ec2/test_vpc_endpoint_services_integration.py b/tests/test_ec2/test_vpc_endpoint_services_integration.py index 83f56f3ca..3eb3b016c 100644 --- a/tests/test_ec2/test_vpc_endpoint_services_integration.py +++ b/tests/test_ec2/test_vpc_endpoint_services_integration.py @@ -5,6 +5,7 @@ import boto3 from botocore.exceptions import ClientError from moto import mock_ec2, settings +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest @@ -147,7 +148,7 @@ def test_describe_vpc_endpoint_services_filters(): """Verify that different type of filters return the expected results.""" from moto.ec2.models import ec2_backends # pylint: disable=import-outside-toplevel - ec2_backend = ec2_backends["us-west-1"] + ec2_backend = ec2_backends[ACCOUNT_ID]["us-west-1"] test_data = fake_endpoint_services() # Allow access to _filter_endpoint_services as it provides the best diff --git a/tests/test_ecr/test_ecr_boto3.py b/tests/test_ecr/test_ecr_boto3.py index f498b79a4..e7c69b8de 100644 --- a/tests/test_ecr/test_ecr_boto3.py +++ b/tests/test_ecr/test_ecr_boto3.py @@ -17,7 +17,7 @@ from dateutil.tz import tzlocal from moto import mock_ecr from unittest import SkipTest -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def _create_image_digest(contents=None): diff --git a/tests/test_ecr/test_ecr_cloudformation.py b/tests/test_ecr/test_ecr_cloudformation.py index 232648830..43f19fc99 100644 --- a/tests/test_ecr/test_ecr_cloudformation.py +++ b/tests/test_ecr/test_ecr_cloudformation.py @@ -6,7 +6,7 @@ import json from moto import mock_cloudformation, mock_ecr import sure # noqa # pylint: disable=unused-import -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID repo_template = Template( json.dumps( diff --git a/tests/test_ecs/test_ecs_account_settings.py b/tests/test_ecs/test_ecs_account_settings.py index 537a200e7..0f2dbe406 100644 --- a/tests/test_ecs/test_ecs_account_settings.py +++ b/tests/test_ecs/test_ecs_account_settings.py @@ -3,7 +3,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import import json -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.ec2 import utils as ec2_utils from moto import mock_ecs, mock_ec2 diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index 0f459e2ff..2cd30dbde 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -7,7 +7,7 @@ import sure # noqa # pylint: disable=unused-import import json import os -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.ec2 import utils as ec2_utils from uuid import UUID diff --git a/tests/test_ecs/test_ecs_capacity_provider.py b/tests/test_ecs/test_ecs_capacity_provider.py index 2311594a6..0cfa98a74 100644 --- a/tests/test_ecs/test_ecs_capacity_provider.py +++ b/tests/test_ecs/test_ecs_capacity_provider.py @@ -1,7 +1,7 @@ import boto3 from moto import mock_ecs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_ecs diff --git a/tests/test_efs/test_access_points.py b/tests/test_efs/test_access_points.py index b3e45fc33..2526a58f1 100644 --- a/tests/test_efs/test_access_points.py +++ b/tests/test_efs/test_access_points.py @@ -3,7 +3,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_efs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @pytest.fixture(scope="function") diff --git a/tests/test_efs/test_mount_target.py b/tests/test_efs/test_mount_target.py index 7b4a1d2e5..4503c1d2b 100644 --- a/tests/test_efs/test_mount_target.py +++ b/tests/test_efs/test_mount_target.py @@ -7,7 +7,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_ec2, mock_efs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests.test_efs.junk_drawer import has_status_code diff --git a/tests/test_efs/test_mount_target_security_groups.py b/tests/test_efs/test_mount_target_security_groups.py index 05fc2ee2f..bbfa7fea6 100644 --- a/tests/test_efs/test_mount_target_security_groups.py +++ b/tests/test_efs/test_mount_target_security_groups.py @@ -84,7 +84,6 @@ def test_modify_mount_target_security_groups(efs, ec2, file_system, subnet): file_system_id = file_system["FileSystemId"] desc_sg_resp = ec2.describe_security_groups()["SecurityGroups"] - print(desc_sg_resp) security_group_id = desc_sg_resp[0]["GroupId"] # Create Mount Target diff --git a/tests/test_eks/test_eks.py b/tests/test_eks/test_eks.py index 21349e50e..1cedb9e3a 100644 --- a/tests/test_eks/test_eks.py +++ b/tests/test_eks/test_eks.py @@ -9,7 +9,7 @@ from botocore.exceptions import ClientError from freezegun import freeze_time from moto import mock_eks, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.eks.exceptions import ( InvalidParameterException, diff --git a/tests/test_eks/test_eks_constants.py b/tests/test_eks/test_eks_constants.py index 2719aece7..d7ef9e7f7 100644 --- a/tests/test_eks/test_eks_constants.py +++ b/tests/test_eks/test_eks_constants.py @@ -6,7 +6,7 @@ from enum import Enum from boto3 import Session -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.eks import REGION as DEFAULT_REGION DEFAULT_ENCODING = "utf-8" diff --git a/tests/test_eks/test_server.py b/tests/test_eks/test_server.py index 1c65ff4a8..ec1f62a20 100644 --- a/tests/test_eks/test_server.py +++ b/tests/test_eks/test_server.py @@ -6,7 +6,7 @@ import sure # noqa # pylint: disable=unused-import import moto.server as server from moto import mock_eks -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.eks.exceptions import ResourceInUseException, ResourceNotFoundException from moto.eks.models import ( CLUSTER_EXISTS_MSG, @@ -81,7 +81,7 @@ class TestNodegroup: @pytest.fixture(autouse=True) def test_client(): - backend = server.create_backend_app(SERVICE) + backend = server.create_backend_app(service=SERVICE) yield backend.test_client() diff --git a/tests/test_elasticache/test_elasticache.py b/tests/test_elasticache/test_elasticache.py index adee2ff0a..4d2a68492 100644 --- a/tests/test_elasticache/test_elasticache.py +++ b/tests/test_elasticache/test_elasticache.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_elasticache -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_elastictranscoder/test_elastictranscoder.py b/tests/test_elastictranscoder/test_elastictranscoder.py index 948a27e3f..36b0f631c 100644 --- a/tests/test_elastictranscoder/test_elastictranscoder.py +++ b/tests/test_elastictranscoder/test_elastictranscoder.py @@ -3,7 +3,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_elastictranscoder -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_elastictranscoder diff --git a/tests/test_elb/test_elb.py b/tests/test_elb/test_elb.py index 06a259d80..0e666d020 100644 --- a/tests/test_elb/test_elb.py +++ b/tests/test_elb/test_elb.py @@ -5,7 +5,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_acm, mock_elb, mock_ec2, mock_iam -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID from tests import EXAMPLE_AMI_ID from uuid import uuid4 @@ -1129,7 +1129,7 @@ def test_subnets(): lb.should.have.key("VPCId").which.should.equal(vpc.id) lb.should.have.key("SourceSecurityGroup").equals( - {"OwnerAlias": f"{ACCOUNT_ID}", "GroupName": "default"} + {"OwnerAlias": f"{DEFAULT_ACCOUNT_ID}", "GroupName": "default"} ) diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index ca83cefe3..1c169f526 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -7,7 +7,7 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_elbv2, mock_ec2, mock_acm from moto.elbv2 import elbv2_backends -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests import EXAMPLE_AMI_ID @@ -1377,7 +1377,7 @@ def test_modify_listener_http_to_https(): # Check default cert, can't do this in server mode if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": listener = ( - elbv2_backends["eu-central-1"] + elbv2_backends[ACCOUNT_ID]["eu-central-1"] .load_balancers[load_balancer_arn] .listeners[listener_arn] ) diff --git a/tests/test_elbv2/test_elbv2_cloudformation.py b/tests/test_elbv2/test_elbv2_cloudformation.py index 11907d5b3..a661d0561 100644 --- a/tests/test_elbv2/test_elbv2_cloudformation.py +++ b/tests/test_elbv2/test_elbv2_cloudformation.py @@ -3,7 +3,7 @@ import json import sure # noqa # pylint: disable=unused-import from moto import mock_elbv2, mock_ec2, mock_cloudformation -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_elbv2 diff --git a/tests/test_elbv2/test_elbv2_target_groups.py b/tests/test_elbv2/test_elbv2_target_groups.py index bc42876b4..a67350466 100644 --- a/tests/test_elbv2/test_elbv2_target_groups.py +++ b/tests/test_elbv2/test_elbv2_target_groups.py @@ -4,7 +4,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_elbv2, mock_ec2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from .test_elbv2 import create_load_balancer diff --git a/tests/test_emr/test_emr_boto3.py b/tests/test_emr/test_emr_boto3.py index 06f14b08e..2a4aed97b 100644 --- a/tests/test_emr/test_emr_boto3.py +++ b/tests/test_emr/test_emr_boto3.py @@ -11,7 +11,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_emr -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID run_job_flow_args = dict( diff --git a/tests/test_emr/test_emr_integration.py b/tests/test_emr/test_emr_integration.py index 963dfe18d..51dbefc2e 100644 --- a/tests/test_emr/test_emr_integration.py +++ b/tests/test_emr/test_emr_integration.py @@ -3,11 +3,15 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import settings -from moto.ec2 import mock_ec2, ec2_backend +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from moto.ec2 import mock_ec2, ec2_backends from moto.emr import mock_emr from moto.emr.utils import EmrSecurityGroupManager +ec2_backend = ec2_backends[ACCOUNT_ID]["us-east-1"] + + @mock_emr @mock_ec2 def test_default_emr_security_groups_get_created_on_first_job_flow(): diff --git a/tests/test_emrcontainers/test_emrcontainers.py b/tests/test_emrcontainers/test_emrcontainers.py index eec9caa8c..5f34a42fb 100644 --- a/tests/test_emrcontainers/test_emrcontainers.py +++ b/tests/test_emrcontainers/test_emrcontainers.py @@ -9,7 +9,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError, ParamValidationError from moto import mock_emrcontainers, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest.mock import patch from moto.emrcontainers import REGION as DEFAULT_REGION diff --git a/tests/test_emrserverless/test_emrserverless.py b/tests/test_emrserverless/test_emrserverless.py index 443960465..d4e1cb6a6 100644 --- a/tests/test_emrserverless/test_emrserverless.py +++ b/tests/test_emrserverless/test_emrserverless.py @@ -8,7 +8,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_emrserverless, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.emrserverless import REGION as DEFAULT_REGION from moto.emrserverless import RELEASE_LABEL as DEFAULT_RELEASE_LABEL from unittest.mock import patch diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index 375dcae6e..b222abf1c 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -10,7 +10,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_logs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.events import mock_events @@ -286,7 +286,6 @@ def test_list_rule_names_by_target_using_limit(): client = generate_environment() response = client.list_rule_names_by_target(TargetArn=test_1_target["Arn"], Limit=1) - print(response) response.should.have.key("NextToken") response["RuleNames"].should.have.length_of(1) # diff --git a/tests/test_events/test_events_cloudformation.py b/tests/test_events/test_events_cloudformation.py index 46fcfa76b..d33c6b995 100644 --- a/tests/test_events/test_events_cloudformation.py +++ b/tests/test_events/test_events_cloudformation.py @@ -8,7 +8,7 @@ from botocore.exceptions import ClientError from moto import mock_cloudformation, mock_events import sure # noqa # pylint: disable=unused-import -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID archive_template = Template( json.dumps( diff --git a/tests/test_events/test_events_integration.py b/tests/test_events/test_events_integration.py index 7de45cceb..231760133 100644 --- a/tests/test_events/test_events_integration.py +++ b/tests/test_events/test_events_integration.py @@ -5,7 +5,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_events, mock_sqs, mock_logs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import iso_8601_datetime_without_milliseconds diff --git a/tests/test_events/test_events_lambdatriggers_integration.py b/tests/test_events/test_events_lambdatriggers_integration.py index 387b869d2..dd6624086 100644 --- a/tests/test_events/test_events_lambdatriggers_integration.py +++ b/tests/test_events/test_events_lambdatriggers_integration.py @@ -2,7 +2,7 @@ import boto3 import json from moto import mock_events, mock_iam, mock_lambda, mock_logs, mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from ..test_awslambda.utilities import get_test_zip_file1, wait_for_log_msg diff --git a/tests/test_firehose/test_firehose.py b/tests/test_firehose/test_firehose.py index 720d1264c..ffcfd15d2 100644 --- a/tests/test_firehose/test_firehose.py +++ b/tests/test_firehose/test_firehose.py @@ -8,7 +8,7 @@ import pytest from moto import mock_firehose from moto import settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex from moto.firehose.models import DeliveryStream @@ -527,6 +527,8 @@ def test_lookup_name_from_arn(): firehose_backends, ) - delivery_stream = firehose_backends[TEST_REGION].lookup_name_from_arn(arn) + delivery_stream = firehose_backends[ACCOUNT_ID][TEST_REGION].lookup_name_from_arn( + arn + ) assert delivery_stream.delivery_stream_arn == arn assert delivery_stream.delivery_stream_name == stream_name diff --git a/tests/test_firehose/test_firehose_destination_types.py b/tests/test_firehose/test_firehose_destination_types.py index 96c587768..05f83a748 100644 --- a/tests/test_firehose/test_firehose_destination_types.py +++ b/tests/test_firehose/test_firehose_destination_types.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_firehose from moto import settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex TEST_REGION = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2" diff --git a/tests/test_firehose/test_firehose_put.py b/tests/test_firehose/test_firehose_put.py index 4eeab28d5..5704f0a06 100644 --- a/tests/test_firehose/test_firehose_put.py +++ b/tests/test_firehose/test_firehose_put.py @@ -4,7 +4,7 @@ import sure # noqa pylint: disable=unused-import from moto import mock_firehose from moto import mock_s3 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex from tests.test_firehose.test_firehose import TEST_REGION from tests.test_firehose.test_firehose import sample_s3_dest_config diff --git a/tests/test_firehose/test_firehose_tags.py b/tests/test_firehose/test_firehose_tags.py index 3c804df63..010ae3f71 100644 --- a/tests/test_firehose/test_firehose_tags.py +++ b/tests/test_firehose/test_firehose_tags.py @@ -4,7 +4,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_firehose -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex from moto.firehose.models import MAX_TAGS_PER_DELIVERY_STREAM from tests.test_firehose.test_firehose import TEST_REGION diff --git a/tests/test_forecast/test_forecast.py b/tests/test_forecast/test_forecast.py index 94f278ab5..f19371ff4 100644 --- a/tests/test_forecast/test_forecast.py +++ b/tests/test_forecast/test_forecast.py @@ -3,7 +3,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_forecast -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID region = "us-east-1" account_id = None diff --git a/tests/test_glacier/test_glacier_jobs.py b/tests/test_glacier/test_glacier_jobs.py index 6ac49749a..b5035cf82 100644 --- a/tests/test_glacier/test_glacier_jobs.py +++ b/tests/test_glacier/test_glacier_jobs.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import import time from moto import mock_glacier -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_glacier diff --git a/tests/test_glacier/test_glacier_vaults.py b/tests/test_glacier/test_glacier_vaults.py index 550ddfc28..e01f80bd6 100644 --- a/tests/test_glacier/test_glacier_vaults.py +++ b/tests/test_glacier/test_glacier_vaults.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_glacier -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from uuid import uuid4 diff --git a/tests/test_glue/fixtures/schema_registry.py b/tests/test_glue/fixtures/schema_registry.py index 899c41106..d248b24ef 100644 --- a/tests/test_glue/fixtures/schema_registry.py +++ b/tests/test_glue/fixtures/schema_registry.py @@ -1,4 +1,4 @@ -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID TEST_DESCRIPTION = "test_description" diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index d0ce39741..0be66d903 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -11,7 +11,7 @@ import pytz from freezegun import freeze_time from moto import mock_glue, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from . import helpers diff --git a/tests/test_glue/test_schema_registry.py b/tests/test_glue/test_schema_registry.py index 5541174ff..ae3b6192a 100644 --- a/tests/test_glue/test_schema_registry.py +++ b/tests/test_glue/test_schema_registry.py @@ -3,7 +3,7 @@ import boto3 import pytest import sure # noqa # pylint: disable=unused-import from botocore.client import ClientError -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto import mock_glue diff --git a/tests/test_greengrass/test_greengrass_core.py b/tests/test_greengrass/test_greengrass_core.py index ccd89b747..684745807 100644 --- a/tests/test_greengrass/test_greengrass_core.py +++ b/tests/test_greengrass/test_greengrass_core.py @@ -4,11 +4,9 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @freezegun.freeze_time("2022-06-01 12:00:00") @mock_greengrass diff --git a/tests/test_greengrass/test_greengrass_deployment.py b/tests/test_greengrass/test_greengrass_deployment.py index 95dbeb264..791cae16d 100644 --- a/tests/test_greengrass/test_greengrass_deployment.py +++ b/tests/test_greengrass/test_greengrass_deployment.py @@ -6,11 +6,9 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @mock_greengrass def test_create_deployment(): diff --git a/tests/test_greengrass/test_greengrass_device.py b/tests/test_greengrass/test_greengrass_device.py index 668e38406..5390c55af 100644 --- a/tests/test_greengrass/test_greengrass_device.py +++ b/tests/test_greengrass/test_greengrass_device.py @@ -4,11 +4,9 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @freezegun.freeze_time("2022-06-01 12:00:00") @mock_greengrass diff --git a/tests/test_greengrass/test_greengrass_functions.py b/tests/test_greengrass/test_greengrass_functions.py index 39a614de7..506920b01 100644 --- a/tests/test_greengrass/test_greengrass_functions.py +++ b/tests/test_greengrass/test_greengrass_functions.py @@ -5,11 +5,8 @@ import pytest from moto import mock_greengrass -from moto.core import get_account_id from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @freezegun.freeze_time("2022-06-01 12:00:00") @mock_greengrass diff --git a/tests/test_greengrass/test_greengrass_groups.py b/tests/test_greengrass/test_greengrass_groups.py index 4d345ea08..082a0c5f3 100644 --- a/tests/test_greengrass/test_greengrass_groups.py +++ b/tests/test_greengrass/test_greengrass_groups.py @@ -4,11 +4,9 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @freezegun.freeze_time("2022-06-01 12:00:00") @mock_greengrass diff --git a/tests/test_greengrass/test_greengrass_resource.py b/tests/test_greengrass/test_greengrass_resource.py index df4932063..ace36e27e 100644 --- a/tests/test_greengrass/test_greengrass_resource.py +++ b/tests/test_greengrass/test_greengrass_resource.py @@ -4,11 +4,8 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @freezegun.freeze_time("2022-06-01 12:00:00") @mock_greengrass diff --git a/tests/test_greengrass/test_greengrass_subscriptions.py b/tests/test_greengrass/test_greengrass_subscriptions.py index ec75ba46b..a684a9e50 100644 --- a/tests/test_greengrass/test_greengrass_subscriptions.py +++ b/tests/test_greengrass/test_greengrass_subscriptions.py @@ -4,11 +4,8 @@ import freezegun import pytest from moto import mock_greengrass -from moto.core import get_account_id from moto.settings import TEST_SERVER_MODE -ACCOUNT_ID = get_account_id() - @pytest.mark.parametrize( "target", diff --git a/tests/test_iam/test_iam.py b/tests/test_iam/test_iam.py index 67232fef4..54e533c98 100644 --- a/tests/test_iam/test_iam.py +++ b/tests/test_iam/test_iam.py @@ -6,9 +6,10 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_config, mock_iam, settings -from moto.core import ACCOUNT_ID -from moto.iam.models import aws_managed_policies +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from moto.iam import iam_backends from moto.backends import get_backend +from tests import DEFAULT_ACCOUNT_ID import pytest from datetime import datetime @@ -89,7 +90,7 @@ def test_get_role__should_contain_last_used(): role["RoleLastUsed"].should.equal({}) if not settings.TEST_SERVER_MODE: - iam_backend = get_backend("iam")["global"] + iam_backend = get_backend("iam")[ACCOUNT_ID]["global"] last_used = datetime.strptime( "2022-07-18T10:30:00+00:00", "%Y-%m-%dT%H:%M:%S+00:00" ) @@ -1889,7 +1890,7 @@ def test_get_credential_report_content(): key1 = conn.create_access_key(UserName=username)["AccessKey"] timestamp = datetime.utcnow() if not settings.TEST_SERVER_MODE: - iam_backend = get_backend("iam")["global"] + iam_backend = get_backend("iam")[ACCOUNT_ID]["global"] iam_backend.users[username].access_keys[1].last_used = timestamp iam_backend.users[username].password_last_used = timestamp with pytest.raises(ClientError): @@ -1927,20 +1928,21 @@ def test_get_access_key_last_used_when_used(): with pytest.raises(ClientError): client.get_access_key_last_used(AccessKeyId="non-existent-key-id") create_key_response = client.create_access_key(UserName=username)["AccessKey"] - # Set last used date using the IAM backend. Moto currently does not have a mechanism for tracking usage of access keys - if not settings.TEST_SERVER_MODE: - timestamp = datetime.utcnow() - iam_backend = get_backend("iam")["global"] - iam_backend.users[username].access_keys[0].last_used = timestamp + + access_key_client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=create_key_response["AccessKeyId"], + aws_secret_access_key=create_key_response["SecretAccessKey"], + ) + access_key_client.list_users() + resp = client.get_access_key_last_used( AccessKeyId=create_key_response["AccessKeyId"] ) - if not settings.TEST_SERVER_MODE: - datetime.strftime( - resp["AccessKeyLastUsed"]["LastUsedDate"], "%Y-%m-%d" - ).should.equal(timestamp.strftime("%Y-%m-%d")) - else: - resp["AccessKeyLastUsed"].should_not.contain("LastUsedDate") + resp["AccessKeyLastUsed"].should.have.key("LastUsedDate") + resp["AccessKeyLastUsed"].should.have.key("ServiceName").equals("iam") + resp["AccessKeyLastUsed"].should.have.key("Region").equals("us-east-1") @mock_iam @@ -1961,6 +1963,7 @@ def test_managed_policy(): for policy in response["Policies"]: aws_policies.append(policy) marker = response.get("Marker") + aws_managed_policies = iam_backends[ACCOUNT_ID]["global"].aws_managed_policies set(p.name for p in aws_managed_policies).should.equal( set(p["PolicyName"] for p in aws_policies) ) @@ -3373,7 +3376,9 @@ def test_role_list_config_discovered_resources(): from moto.iam.config import role_config_query # Without any roles - assert role_config_query.list_config_service_resources(None, None, 100, None) == ( + assert role_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + ) == ( [], None, ) @@ -3382,7 +3387,9 @@ def test_role_list_config_discovered_resources(): roles = [] num_roles = 3 for ix in range(1, num_roles + 1): - this_role = role_config_query.backends["global"].create_role( + this_role = role_config_query.backends[DEFAULT_ACCOUNT_ID][ + "global" + ].create_role( role_name="role{}".format(ix), assume_role_policy_document=None, path="/", @@ -3395,7 +3402,9 @@ def test_role_list_config_discovered_resources(): assert len(roles) == num_roles - result = role_config_query.list_config_service_resources(None, None, 100, None)[0] + result = role_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + )[0] assert len(result) == num_roles # The roles gets a random ID, so we can't directly test it @@ -3407,13 +3416,13 @@ def test_role_list_config_discovered_resources(): # test passing list of resource ids resource_ids = role_config_query.list_config_service_resources( - [roles[0]["id"], roles[1]["id"]], None, 100, None + DEFAULT_ACCOUNT_ID, [roles[0]["id"], roles[1]["id"]], None, 100, None )[0] assert len(resource_ids) == 2 # test passing a single resource name resource_name = role_config_query.list_config_service_resources( - None, roles[0]["name"], 100, None + DEFAULT_ACCOUNT_ID, None, roles[0]["name"], 100, None )[0] assert len(resource_name) == 1 assert resource_name[0]["id"] == roles[0]["id"] @@ -3421,14 +3430,22 @@ def test_role_list_config_discovered_resources(): # test passing a single resource name AND some resource id's both_filter_good = role_config_query.list_config_service_resources( - [roles[0]["id"], roles[1]["id"]], roles[0]["name"], 100, None + DEFAULT_ACCOUNT_ID, + [roles[0]["id"], roles[1]["id"]], + roles[0]["name"], + 100, + None, )[0] assert len(both_filter_good) == 1 assert both_filter_good[0]["id"] == roles[0]["id"] assert both_filter_good[0]["name"] == roles[0]["name"] both_filter_bad = role_config_query.list_config_service_resources( - [roles[0]["id"], roles[1]["id"]], roles[2]["name"], 100, None + DEFAULT_ACCOUNT_ID, + [roles[0]["id"], roles[1]["id"]], + roles[2]["name"], + 100, + None, )[0] assert len(both_filter_bad) == 0 @@ -3439,8 +3456,10 @@ def test_role_config_dict(): from moto.iam.utils import random_resource_id, random_policy_id # Without any roles - assert not role_config_query.get_config_resource("something") - assert role_config_query.list_config_service_resources(None, None, 100, None) == ( + assert not role_config_query.get_config_resource(DEFAULT_ACCOUNT_ID, "something") + assert role_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + ) == ( [], None, ) @@ -3459,7 +3478,7 @@ def test_role_config_dict(): # Create a policy for use in role permissions boundary policy_arn = ( - policy_config_query.backends["global"] + policy_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .create_policy( description="basic_policy", path="/", @@ -3471,12 +3490,12 @@ def test_role_config_dict(): ) policy_id = policy_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0][0]["id"] assert len(policy_id) == len(random_policy_id()) # Create some roles (and grab them repeatedly since they create with random names) - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="plain_role", assume_role_policy_document=None, path="/", @@ -3486,13 +3505,13 @@ def test_role_config_dict(): max_session_duration=3600, ) - plain_role = role_config_query.list_config_service_resources(None, None, 100, None)[ - 0 - ][0] + plain_role = role_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + )[0][0] assert plain_role is not None assert len(plain_role["id"]) == len(random_resource_id()) - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="assume_role", assume_role_policy_document=json.dumps(basic_assume_role), path="/", @@ -3505,7 +3524,7 @@ def test_role_config_dict(): assume_role = next( role for role in role_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0] if role["id"] not in [plain_role["id"]] ) @@ -3513,7 +3532,7 @@ def test_role_config_dict(): assert len(assume_role["id"]) == len(random_resource_id()) assert assume_role["id"] is not plain_role["id"] - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="assume_and_permission_boundary_role", assume_role_policy_document=json.dumps(basic_assume_role), path="/", @@ -3526,7 +3545,7 @@ def test_role_config_dict(): assume_and_permission_boundary_role = next( role for role in role_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0] if role["id"] not in [plain_role["id"], assume_role["id"]] ) @@ -3535,7 +3554,7 @@ def test_role_config_dict(): assert assume_and_permission_boundary_role["id"] is not plain_role["id"] assert assume_and_permission_boundary_role["id"] is not assume_role["id"] - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="role_with_attached_policy", assume_role_policy_document=json.dumps(basic_assume_role), path="/", @@ -3544,13 +3563,13 @@ def test_role_config_dict(): tags=[], max_session_duration=3600, ) - role_config_query.backends["global"].attach_role_policy( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].attach_role_policy( policy_arn, "role_with_attached_policy" ) role_with_attached_policy = next( role for role in role_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0] if role["id"] not in [ @@ -3567,7 +3586,7 @@ def test_role_config_dict(): role_with_attached_policy["id"] is not assume_and_permission_boundary_role["id"] ) - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="role_with_inline_policy", assume_role_policy_document=json.dumps(basic_assume_role), path="/", @@ -3576,14 +3595,14 @@ def test_role_config_dict(): tags=[], max_session_duration=3600, ) - role_config_query.backends["global"].put_role_policy( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].put_role_policy( "role_with_inline_policy", "inline_policy", json.dumps(basic_policy) ) role_with_inline_policy = next( role for role in role_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0] if role["id"] not in [ @@ -3604,7 +3623,9 @@ def test_role_config_dict(): # plain role plain_role_config = ( - role_config_query.backends["global"].roles[plain_role["id"]].to_config_dict() + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] + .roles[plain_role["id"]] + .to_config_dict() ) assert plain_role_config["version"] == "1.3" assert plain_role_config["configurationItemStatus"] == "ResourceDiscovered" @@ -3633,7 +3654,9 @@ def test_role_config_dict(): # assume_role assume_role_config = ( - role_config_query.backends["global"].roles[assume_role["id"]].to_config_dict() + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] + .roles[assume_role["id"]] + .to_config_dict() ) assert assume_role_config["arn"] == "arn:aws:iam::123456789012:role/assume_role" assert assume_role_config["resourceId"] == "assume_role" @@ -3644,7 +3667,7 @@ def test_role_config_dict(): # assume_and_permission_boundary_role assume_and_permission_boundary_role_config = ( - role_config_query.backends["global"] + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .roles[assume_and_permission_boundary_role["id"]] .to_config_dict() ) @@ -3672,7 +3695,7 @@ def test_role_config_dict(): # role_with_attached_policy role_with_attached_policy_config = ( - role_config_query.backends["global"] + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .roles[role_with_attached_policy["id"]] .to_config_dict() ) @@ -3686,7 +3709,7 @@ def test_role_config_dict(): # role_with_inline_policy role_with_inline_policy_config = ( - role_config_query.backends["global"] + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .roles[role_with_inline_policy["id"]] .to_config_dict() ) @@ -3946,7 +3969,9 @@ def test_policy_list_config_discovered_resources(): from moto.iam.config import policy_config_query # Without any policies - assert policy_config_query.list_config_service_resources(None, None, 100, None) == ( + assert policy_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + ) == ( [], None, ) @@ -3962,7 +3987,9 @@ def test_policy_list_config_discovered_resources(): policies = [] num_policies = 3 for ix in range(1, num_policies + 1): - this_policy = policy_config_query.backends["global"].create_policy( + this_policy = policy_config_query.backends[DEFAULT_ACCOUNT_ID][ + "global" + ].create_policy( description="policy{}".format(ix), path="", policy_document=json.dumps(basic_policy), @@ -3975,11 +4002,15 @@ def test_policy_list_config_discovered_resources(): # We expect the backend to have arns as their keys for backend_key in list( - policy_config_query.backends["global"].managed_policies.keys() + policy_config_query.backends[DEFAULT_ACCOUNT_ID][ + "global" + ].managed_policies.keys() ): assert backend_key.startswith("arn:aws:iam::") - result = policy_config_query.list_config_service_resources(None, None, 100, None)[0] + result = policy_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + )[0] assert len(result) == num_policies policy = result[0] @@ -3990,13 +4021,13 @@ def test_policy_list_config_discovered_resources(): # test passing list of resource ids resource_ids = policy_config_query.list_config_service_resources( - [policies[0]["id"], policies[1]["id"]], None, 100, None + DEFAULT_ACCOUNT_ID, [policies[0]["id"], policies[1]["id"]], None, 100, None )[0] assert len(resource_ids) == 2 # test passing a single resource name resource_name = policy_config_query.list_config_service_resources( - None, policies[0]["name"], 100, None + DEFAULT_ACCOUNT_ID, None, policies[0]["name"], 100, None )[0] assert len(resource_name) == 1 assert resource_name[0]["id"] == policies[0]["id"] @@ -4004,14 +4035,22 @@ def test_policy_list_config_discovered_resources(): # test passing a single resource name AND some resource id's both_filter_good = policy_config_query.list_config_service_resources( - [policies[0]["id"], policies[1]["id"]], policies[0]["name"], 100, None + DEFAULT_ACCOUNT_ID, + [policies[0]["id"], policies[1]["id"]], + policies[0]["name"], + 100, + None, )[0] assert len(both_filter_good) == 1 assert both_filter_good[0]["id"] == policies[0]["id"] assert both_filter_good[0]["name"] == policies[0]["name"] both_filter_bad = policy_config_query.list_config_service_resources( - [policies[0]["id"], policies[1]["id"]], policies[2]["name"], 100, None + DEFAULT_ACCOUNT_ID, + [policies[0]["id"], policies[1]["id"]], + policies[2]["name"], + 100, + None, )[0] assert len(both_filter_bad) == 0 @@ -4023,9 +4062,11 @@ def test_policy_config_dict(): # Without any roles assert not policy_config_query.get_config_resource( - "arn:aws:iam::123456789012:policy/basic_policy" + DEFAULT_ACCOUNT_ID, "arn:aws:iam::123456789012:policy/basic_policy" ) - assert policy_config_query.list_config_service_resources(None, None, 100, None) == ( + assert policy_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 100, None + ) == ( [], None, ) @@ -4043,7 +4084,7 @@ def test_policy_config_dict(): } policy_arn = ( - policy_config_query.backends["global"] + policy_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .create_policy( description="basic_policy", path="/", @@ -4055,20 +4096,23 @@ def test_policy_config_dict(): ) policy_id = policy_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None )[0][0]["id"] assert len(policy_id) == len(random_policy_id()) assert policy_arn == "arn:aws:iam::123456789012:policy/basic_policy" - assert policy_config_query.get_config_resource(policy_id) is not None + assert ( + policy_config_query.get_config_resource(DEFAULT_ACCOUNT_ID, policy_id) + is not None + ) # Create a new version - policy_config_query.backends["global"].create_policy_version( + policy_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_policy_version( policy_arn, json.dumps(basic_policy_v2), "true" ) # Create role to trigger attachment - role_config_query.backends["global"].create_role( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].create_role( role_name="role_with_attached_policy", assume_role_policy_document=None, path="/", @@ -4077,12 +4121,12 @@ def test_policy_config_dict(): tags=[], max_session_duration=3600, ) - role_config_query.backends["global"].attach_role_policy( + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"].attach_role_policy( policy_arn, "role_with_attached_policy" ) policy = ( - role_config_query.backends["global"] + role_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] .managed_policies["arn:aws:iam::123456789012:policy/basic_policy"] .to_config_dict() ) diff --git a/tests/test_iam/test_iam_access_integration.py b/tests/test_iam/test_iam_access_integration.py new file mode 100644 index 000000000..1c4451e9f --- /dev/null +++ b/tests/test_iam/test_iam_access_integration.py @@ -0,0 +1,25 @@ +import boto3 +from moto import mock_ec2, mock_iam + + +@mock_ec2 +@mock_iam +def test_invoking_ec2_mark_access_key_as_used(): + c_iam = boto3.client("iam", region_name="us-east-1") + c_iam.create_user(Path="my/path", UserName="fakeUser") + key = c_iam.create_access_key(UserName="fakeUser") + + c_ec2 = boto3.client( + "ec2", + region_name="us-east-2", + aws_access_key_id=key["AccessKey"]["AccessKeyId"], + aws_secret_access_key=key["AccessKey"]["SecretAccessKey"], + ) + c_ec2.describe_instances() + + last_used = c_iam.get_access_key_last_used( + AccessKeyId=key["AccessKey"]["AccessKeyId"] + )["AccessKeyLastUsed"] + last_used.should.have.key("LastUsedDate") + last_used.should.have.key("ServiceName").equals("ec2") + last_used.should.have.key("Region").equals("us-east-2") diff --git a/tests/test_iam/test_iam_cloudformation.py b/tests/test_iam/test_iam_cloudformation.py index c80ed0e6b..36dcd8561 100644 --- a/tests/test_iam/test_iam_cloudformation.py +++ b/tests/test_iam/test_iam_cloudformation.py @@ -6,7 +6,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from botocore.exceptions import ClientError -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto import mock_autoscaling, mock_iam, mock_cloudformation, mock_s3, mock_sts from tests import EXAMPLE_AMI_ID diff --git a/tests/test_iam/test_iam_groups.py b/tests/test_iam/test_iam_groups.py index 60bce11ff..67c73f22b 100644 --- a/tests/test_iam/test_iam_groups.py +++ b/tests/test_iam/test_iam_groups.py @@ -7,7 +7,7 @@ import json import pytest from botocore.exceptions import ClientError from moto import mock_iam -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID MOCK_POLICY = """ { diff --git a/tests/test_iam/test_iam_oidc.py b/tests/test_iam/test_iam_oidc.py index 0522564a7..69ab0b05a 100644 --- a/tests/test_iam/test_iam_oidc.py +++ b/tests/test_iam/test_iam_oidc.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_iam -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import pytest from datetime import datetime diff --git a/tests/test_iam/test_iam_server_certificates.py b/tests/test_iam/test_iam_server_certificates.py index 065f31982..a93344002 100644 --- a/tests/test_iam/test_iam_server_certificates.py +++ b/tests/test_iam/test_iam_server_certificates.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from datetime import datetime from moto import mock_iam -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_iam diff --git a/tests/test_iot/test_iot.py b/tests/test_iot/test_iot.py index c868423ce..9ba031e83 100644 --- a/tests/test_iot/test_iot.py +++ b/tests/test_iot/test_iot.py @@ -2,7 +2,7 @@ import sure # noqa # pylint: disable=unused-import import boto3 from moto import mock_iot -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from botocore.exceptions import ClientError import pytest diff --git a/tests/test_iot/test_iot_things.py b/tests/test_iot/test_iot_things.py index 5cdcdc069..ee1ca8b56 100644 --- a/tests/test_iot/test_iot_things.py +++ b/tests/test_iot/test_iot_things.py @@ -1,7 +1,7 @@ import boto3 from moto import mock_iot -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_iot diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index 768af2d14..0a0b7a350 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -6,6 +6,7 @@ from botocore.exceptions import ClientError import moto.iotdata.models from moto import mock_iotdata, mock_iot, settings +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_iot @@ -112,7 +113,7 @@ def test_publish(): client.publish(topic="test/topic", qos=1, payload=b"pl") if not settings.TEST_SERVER_MODE: - mock_backend = moto.iotdata.models.iotdata_backends[region_name] + mock_backend = moto.iotdata.models.iotdata_backends[ACCOUNT_ID][region_name] mock_backend.published_payloads.should.have.length_of(1) mock_backend.published_payloads.should.contain(("test/topic", "pl")) diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index 60c1ab7bb..a5acfa1b7 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -8,7 +8,7 @@ from botocore.exceptions import ClientError from dateutil.tz import tzlocal from moto import mock_kinesis -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import sure # noqa # pylint: disable=unused-import diff --git a/tests/test_kinesis/test_kinesis_boto3.py b/tests/test_kinesis/test_kinesis_boto3.py index e8ea13753..cc8ab6d17 100644 --- a/tests/test_kinesis/test_kinesis_boto3.py +++ b/tests/test_kinesis/test_kinesis_boto3.py @@ -3,7 +3,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_kinesis -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import sure # noqa # pylint: disable=unused-import diff --git a/tests/test_kinesis/test_kinesis_stream_consumers.py b/tests/test_kinesis/test_kinesis_stream_consumers.py index 9b8ca2962..7be8d612d 100644 --- a/tests/test_kinesis/test_kinesis_stream_consumers.py +++ b/tests/test_kinesis/test_kinesis_stream_consumers.py @@ -3,7 +3,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_kinesis -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def create_stream(client): diff --git a/tests/test_kms/test_kms_boto3.py b/tests/test_kms/test_kms_boto3.py index 7be496506..3d251b546 100644 --- a/tests/test_kms/test_kms_boto3.py +++ b/tests/test_kms/test_kms_boto3.py @@ -13,7 +13,7 @@ from freezegun import freeze_time import pytest from moto import mock_kms -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID PLAINTEXT_VECTORS = [ diff --git a/tests/test_kms/test_kms_grants.py b/tests/test_kms/test_kms_grants.py index d8258ae2a..964f8bc5a 100644 --- a/tests/test_kms/test_kms_grants.py +++ b/tests/test_kms/test_kms_grants.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_kms -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID grantee_principal = ( diff --git a/tests/test_kms/test_model.py b/tests/test_kms/test_model.py index 5d0ffc097..991d9a6e3 100644 --- a/tests/test_kms/test_model.py +++ b/tests/test_kms/test_model.py @@ -14,7 +14,7 @@ def backend(): @pytest.fixture def key(backend): return backend.create_key( - None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None, REGION + None, "ENCRYPT_DECRYPT", "SYMMETRIC_DEFAULT", "Test key", None ) diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index 0de0d9457..e8a0e4391 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -101,7 +101,7 @@ def test_deserialize_ciphertext_blob(raw, serialized): ) def test_encrypt_decrypt_cycle(encryption_context): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", "nop", "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( @@ -132,7 +132,7 @@ def test_encrypt_unknown_key_id(): def test_decrypt_invalid_ciphertext_format(): - master_key = Key("nop", "nop", "nop", "nop", "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} with pytest.raises(InvalidCiphertextException): @@ -152,7 +152,7 @@ def test_decrypt_unknwown_key_id(): def test_decrypt_invalid_ciphertext(): - master_key = Key("nop", "nop", "nop", "nop", "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = ( master_key.id.encode("utf-8") + b"123456789012" @@ -170,7 +170,7 @@ def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_encryption_context(): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", "nop", "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( diff --git a/tests/test_logs/test_models.py b/tests/test_logs/test_models.py index 59c61fba0..e64471b4f 100644 --- a/tests/test_logs/test_models.py +++ b/tests/test_logs/test_models.py @@ -1,6 +1,7 @@ import sure # noqa # pylint: disable=unused-import from moto.logs.models import LogGroup +from tests import DEFAULT_ACCOUNT_ID def test_log_group_to_describe_dict(): @@ -14,7 +15,7 @@ def test_log_group_to_describe_dict(): kwargs = dict(kmsKeyId=kms_key_id) # When - log_group = LogGroup(region, name, tags, **kwargs) + log_group = LogGroup(DEFAULT_ACCOUNT_ID, region, name, tags, **kwargs) describe_dict = log_group.to_describe_dict() # Then diff --git a/tests/test_mediaconnect/test_mediaconnect.py b/tests/test_mediaconnect/test_mediaconnect.py index b17244a53..047dc829b 100644 --- a/tests/test_mediaconnect/test_mediaconnect.py +++ b/tests/test_mediaconnect/test_mediaconnect.py @@ -6,7 +6,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_mediaconnect -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID region = "eu-west-1" diff --git a/tests/test_medialive/test_medialive.py b/tests/test_medialive/test_medialive.py index b38d7cb08..1cf9ef49c 100644 --- a/tests/test_medialive/test_medialive.py +++ b/tests/test_medialive/test_medialive.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_medialive from uuid import uuid4 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID region = "eu-west-1" diff --git a/tests/test_mq/test_mq_configuration.py b/tests/test_mq/test_mq_configuration.py index cefbe5fac..fe8f6a854 100644 --- a/tests/test_mq/test_mq_configuration.py +++ b/tests/test_mq/test_mq_configuration.py @@ -7,7 +7,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_mq -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_organizations/organizations_test_utils.py b/tests/test_organizations/organizations_test_utils.py index 15bee591b..26bc27e15 100644 --- a/tests/test_organizations/organizations_test_utils.py +++ b/tests/test_organizations/organizations_test_utils.py @@ -1,4 +1,5 @@ import datetime +from moto.core import DEFAULT_ACCOUNT_ID from moto.organizations import utils @@ -47,7 +48,7 @@ def validate_organization(response): ] ) org["Id"].should.match(utils.ORG_ID_REGEX) - org["MasterAccountId"].should.equal(utils.MASTER_ACCOUNT_ID) + org["MasterAccountId"].should.equal(DEFAULT_ACCOUNT_ID) org["MasterAccountArn"].should.equal( utils.MASTER_ACCOUNT_ARN_FORMAT.format(org["MasterAccountId"], org["Id"]) ) diff --git a/tests/test_organizations/test_organizations_boto3.py b/tests/test_organizations/test_organizations_boto3.py index 2e335cfd0..8c783b0e4 100644 --- a/tests/test_organizations/test_organizations_boto3.py +++ b/tests/test_organizations/test_organizations_boto3.py @@ -17,7 +17,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_organizations -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.organizations import utils from .organizations_test_utils import ( validate_organization, @@ -40,7 +40,7 @@ def test_create_organization(): response = client.list_accounts() len(response["Accounts"]).should.equal(1) response["Accounts"][0]["Name"].should.equal("master") - response["Accounts"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response["Accounts"][0]["Id"].should.equal(ACCOUNT_ID) response["Accounts"][0]["Email"].should.equal(utils.MASTER_ACCOUNT_EMAIL) response = client.list_policies(Filter="SERVICE_CONTROL_POLICY") @@ -433,7 +433,7 @@ def test_list_children(): response02 = client.list_children(ParentId=root_id, ChildType="ORGANIZATIONAL_UNIT") response03 = client.list_children(ParentId=ou01_id, ChildType="ACCOUNT") response04 = client.list_children(ParentId=ou01_id, ChildType="ORGANIZATIONAL_UNIT") - response01["Children"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response01["Children"][0]["Id"].should.equal(ACCOUNT_ID) response01["Children"][0]["Type"].should.equal("ACCOUNT") response01["Children"][1]["Id"].should.equal(account01_id) response01["Children"][1]["Type"].should.equal("ACCOUNT") @@ -1240,7 +1240,7 @@ def test_tag_resource_errors(): def test__get_resource_for_tagging_existing_root(): - org = FakeOrganization("ALL") + org = FakeOrganization(ACCOUNT_ID, "ALL") root = FakeRoot(org) org_backend = OrganizationsBackend(region_name="N/A", account_id="N/A") @@ -1260,7 +1260,7 @@ def test__get_resource_for_tagging_existing_non_root(): def test__get_resource_for_tagging_existing_ou(): - org = FakeOrganization("ALL") + org = FakeOrganization(ACCOUNT_ID, "ALL") ou = FakeOrganizationalUnit(org) org_backend = OrganizationsBackend(region_name="N/A", account_id="N/A") @@ -1280,7 +1280,7 @@ def test__get_resource_for_tagging_non_existing_ou(): def test__get_resource_for_tagging_existing_account(): - org = FakeOrganization("ALL") + org = FakeOrganization(ACCOUNT_ID, "ALL") org_backend = OrganizationsBackend(region_name="N/A", account_id="N/A") account = FakeAccount(org, AccountName="test", Email="test@test.test") @@ -1300,7 +1300,7 @@ def test__get_resource_for_tagging_non_existing_account(): def test__get_resource_for_tagging_existing_policy(): - org = FakeOrganization("ALL") + org = FakeOrganization(ACCOUNT_ID, "ALL") org_backend = OrganizationsBackend(region_name="N/A", account_id="N/A") policy = FakePolicy(org, Type="SERVICE_CONTROL_POLICY") diff --git a/tests/test_quicksight/test_quicksight_datasets.py b/tests/test_quicksight/test_quicksight_datasets.py index 659f303c8..4fc878276 100644 --- a/tests/test_quicksight/test_quicksight_datasets.py +++ b/tests/test_quicksight/test_quicksight_datasets.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_quicksight -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_quicksight/test_quicksight_groups.py b/tests/test_quicksight/test_quicksight_groups.py index 354d19947..a0734a0b3 100644 --- a/tests/test_quicksight/test_quicksight_groups.py +++ b/tests/test_quicksight/test_quicksight_groups.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_quicksight -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_quicksight/test_quicksight_users.py b/tests/test_quicksight/test_quicksight_users.py index a9cee283b..81a678981 100644 --- a/tests/test_quicksight/test_quicksight_users.py +++ b/tests/test_quicksight/test_quicksight_users.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_quicksight -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_ram/test_ram.py b/tests/test_ram/test_ram.py index 772058441..cf3fc915d 100644 --- a/tests/test_ram/test_ram.py +++ b/tests/test_ram/test_ram.py @@ -7,7 +7,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_ram, mock_organizations -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_ram diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index 1dcaff6bc..345f30df2 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -3,7 +3,7 @@ import boto3 import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_ec2, mock_kms, mock_rds -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_rds diff --git a/tests/test_rds/test_rds_clusters.py b/tests/test_rds/test_rds_clusters.py index 0b018ff71..010998916 100644 --- a/tests/test_rds/test_rds_clusters.py +++ b/tests/test_rds/test_rds_clusters.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_rds -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_rds diff --git a/tests/test_rds/test_rds_event_subscriptions.py b/tests/test_rds/test_rds_event_subscriptions.py index 4bb2114af..88bbb804e 100644 --- a/tests/test_rds/test_rds_event_subscriptions.py +++ b/tests/test_rds/test_rds_event_subscriptions.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_rds -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID DB_INSTANCE_IDENTIFIER = "db-primary-1" diff --git a/tests/test_rds/test_rds_export_tasks.py b/tests/test_rds/test_rds_export_tasks.py index 185bf6f73..2f614ef26 100644 --- a/tests/test_rds/test_rds_export_tasks.py +++ b/tests/test_rds/test_rds_export_tasks.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_rds -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def _prepare_db_snapshot(client, snapshot_name="snapshot-1"): diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index d33615884..41aa6a58e 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -3,7 +3,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_rds2 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def test_deprecation_warning(): diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index 8490ede56..47f7adc18 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -8,7 +8,7 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_ec2 from moto import mock_redshift -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_redshift diff --git a/tests/test_route53/test_route53_query_logging_config.py b/tests/test_route53/test_route53_query_logging_config.py index 5bf389267..641bda0f6 100644 --- a/tests/test_route53/test_route53_query_logging_config.py +++ b/tests/test_route53/test_route53_query_logging_config.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError from moto import mock_logs from moto import mock_route53 -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex # The log group must be in the us-east-1 region. diff --git a/tests/test_route53resolver/test_route53resolver_endpoint.py b/tests/test_route53resolver/test_route53resolver_endpoint.py index b85969be6..172c293d3 100644 --- a/tests/test_route53resolver/test_route53resolver_endpoint.py +++ b/tests/test_route53resolver/test_route53resolver_endpoint.py @@ -8,7 +8,7 @@ import pytest from moto import mock_route53resolver from moto import settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex from moto.ec2 import mock_ec2 diff --git a/tests/test_route53resolver/test_route53resolver_rule.py b/tests/test_route53resolver/test_route53resolver_rule.py index d2f59f967..acdc100d9 100644 --- a/tests/test_route53resolver/test_route53resolver_rule.py +++ b/tests/test_route53resolver/test_route53resolver_rule.py @@ -7,7 +7,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_route53resolver -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.utils import get_random_hex from moto.ec2 import mock_ec2 diff --git a/tests/test_s3/test_multiple_accounts_server.py b/tests/test_s3/test_multiple_accounts_server.py new file mode 100644 index 000000000..323526f86 --- /dev/null +++ b/tests/test_s3/test_multiple_accounts_server.py @@ -0,0 +1,49 @@ +import requests + +from moto import settings +from moto.server import ThreadedMotoServer +from unittest import SkipTest + + +SERVER_PORT = 5001 +BASE_URL = f"http://localhost:{SERVER_PORT}/" + + +class TestAccountIdResolution: + def setup(self): + if settings.TEST_SERVER_MODE: + raise SkipTest( + "No point in testing this in ServerMode, as we already start our own server" + ) + self.server = ThreadedMotoServer(port=SERVER_PORT, verbose=False) + self.server.start() + + def teardown(self): + self.server.stop() + + def test_with_custom_request_header(self): + buckets_for_account_1 = ["foo", "bar"] + for name in buckets_for_account_1: + requests.put(f"http://{name}.localhost:{SERVER_PORT}/") + + res = requests.get(BASE_URL) + res.content.should.contain(b"foo") + res.content.should.contain(b"bar") + + # Create two more buckets in another account + headers = {"x-moto-account-id": "333344445555"} + buckets_for_account_2 = ["baz", "bla"] + for name in buckets_for_account_2: + requests.put(f"http://{name}.localhost:{SERVER_PORT}/", headers=headers) + + # Verify only these buckets exist in this account + res = requests.get(BASE_URL, headers=headers) + res.content.should.contain(b"baz") + res.content.should.contain(b"bla") + res.content.shouldnt.contain(b"foo") + res.content.shouldnt.contain(b"bar") + + # Verify these buckets do not exist in the original account + res = requests.get(BASE_URL) + res.content.shouldnt.contain(b"baz") + res.content.shouldnt.contain(b"bla") diff --git a/tests/test_s3/test_s3_auth.py b/tests/test_s3/test_s3_auth.py index feb22f035..c915ee26e 100644 --- a/tests/test_s3/test_s3_auth.py +++ b/tests/test_s3/test_s3_auth.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_iam, mock_s3, mock_sts, settings -from moto.core import ACCOUNT_ID, set_initial_no_auth_action_count +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID, set_initial_no_auth_action_count from unittest import SkipTest diff --git a/tests/test_s3/test_s3_config.py b/tests/test_s3/test_s3_config.py index 8927e7eaf..2dde559cc 100644 --- a/tests/test_s3/test_s3_config.py +++ b/tests/test_s3/test_s3_config.py @@ -4,14 +4,17 @@ import sure # noqa # pylint: disable=unused-import from moto import mock_s3 from moto.core.exceptions import InvalidNextTokenException +from moto.s3.config import s3_config_query +from tests import DEFAULT_ACCOUNT_ID + +s3_config_query_backend = s3_config_query.backends[DEFAULT_ACCOUNT_ID]["global"] @mock_s3 def test_s3_public_access_block_to_config_dict(): - from moto.s3.config import s3_config_query # With 1 bucket in us-west-2: - s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + s3_config_query_backend.create_bucket("bucket1", "us-west-2") public_access_block = { "BlockPublicAcls": "True", @@ -21,22 +24,20 @@ def test_s3_public_access_block_to_config_dict(): } # Add a public access block: - s3_config_query.backends["global"].put_bucket_public_access_block( + s3_config_query_backend.put_bucket_public_access_block( "bucket1", public_access_block ) - result = ( - s3_config_query.backends["global"] - .buckets["bucket1"] - .public_access_block.to_config_dict() - ) + result = s3_config_query_backend.buckets[ + "bucket1" + ].public_access_block.to_config_dict() for key, value in public_access_block.items(): k = "{lowercase}{rest}".format(lowercase=key[0].lower(), rest=key[1:]) assert result[k] is (value == "True") # Verify that this resides in the full bucket's to_config_dict: - full_result = s3_config_query.backends["global"].buckets["bucket1"].to_config_dict() + full_result = s3_config_query_backend.buckets["bucket1"].to_config_dict() assert ( json.loads( full_result["supplementaryConfiguration"]["PublicAccessBlockConfiguration"] @@ -47,7 +48,6 @@ def test_s3_public_access_block_to_config_dict(): @mock_s3 def test_list_config_discovered_resources(): - from moto.s3.config import s3_config_query # Without any buckets: assert s3_config_query.list_config_service_resources( @@ -56,18 +56,14 @@ def test_list_config_discovered_resources(): # With 10 buckets in us-west-2: for x in range(0, 10): - s3_config_query.backends["global"].create_bucket( - "bucket{}".format(x), "us-west-2" - ) + s3_config_query_backend.create_bucket(f"bucket{x}", "us-west-2") # With 2 buckets in eu-west-1: for x in range(10, 12): - s3_config_query.backends["global"].create_bucket( - "eu-bucket{}".format(x), "eu-west-1" - ) + s3_config_query_backend.create_bucket(f"eu-bucket{x}", "eu-west-1") result, next_token = s3_config_query.list_config_service_resources( - None, None, 100, None + DEFAULT_ACCOUNT_ID, None, None, 100, None ) assert not next_token assert len(result) == 12 @@ -88,19 +84,19 @@ def test_list_config_discovered_resources(): # With a name: result, next_token = s3_config_query.list_config_service_resources( - None, "bucket0", 100, None + DEFAULT_ACCOUNT_ID, None, "bucket0", 100, None ) assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token # With a region: result, next_token = s3_config_query.list_config_service_resources( - None, None, 100, None, resource_region="eu-west-1" + DEFAULT_ACCOUNT_ID, None, None, 100, None, resource_region="eu-west-1" ) assert len(result) == 2 and not next_token and result[1]["name"] == "eu-bucket11" # With resource ids: result, next_token = s3_config_query.list_config_service_resources( - ["bucket0", "bucket1"], None, 100, None + DEFAULT_ACCOUNT_ID, ["bucket0", "bucket1"], None, 100, None ) assert ( len(result) == 2 @@ -111,13 +107,13 @@ def test_list_config_discovered_resources(): # With duplicated resource ids: result, next_token = s3_config_query.list_config_service_resources( - ["bucket0", "bucket0"], None, 100, None + DEFAULT_ACCOUNT_ID, ["bucket0", "bucket0"], None, 100, None ) assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token # Pagination: result, next_token = s3_config_query.list_config_service_resources( - None, None, 1, None + DEFAULT_ACCOUNT_ID, None, None, 1, None ) assert ( len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" @@ -125,13 +121,13 @@ def test_list_config_discovered_resources(): # Last Page: result, next_token = s3_config_query.list_config_service_resources( - None, None, 1, "eu-bucket11", resource_region="eu-west-1" + DEFAULT_ACCOUNT_ID, None, None, 1, "eu-bucket11", resource_region="eu-west-1" ) assert len(result) == 1 and result[0]["name"] == "eu-bucket11" and not next_token # With a list of buckets: result, next_token = s3_config_query.list_config_service_resources( - ["bucket0", "bucket1"], None, 1, None + DEFAULT_ACCOUNT_ID, ["bucket0", "bucket1"], None, 1, None ) assert ( len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" @@ -139,17 +135,17 @@ def test_list_config_discovered_resources(): # With an invalid page: with pytest.raises(InvalidNextTokenException) as inte: - s3_config_query.list_config_service_resources(None, None, 1, "notabucket") + s3_config_query.list_config_service_resources( + DEFAULT_ACCOUNT_ID, None, None, 1, "notabucket" + ) assert "The nextToken provided is invalid" in inte.value.message @mock_s3 def test_s3_lifecycle_config_dict(): - from moto.s3.config import s3_config_query - # With 1 bucket in us-west-2: - s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + s3_config_query_backend.create_bucket("bucket1", "us-west-2") # And a lifecycle policy lifecycle = [ @@ -178,12 +174,12 @@ def test_s3_lifecycle_config_dict(): "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 1}, }, ] - s3_config_query.backends["global"].put_bucket_lifecycle("bucket1", lifecycle) + s3_config_query_backend.put_bucket_lifecycle("bucket1", lifecycle) # Get the rules for this: lifecycles = [ rule.to_config_dict() - for rule in s3_config_query.backends["global"].buckets["bucket1"].rules + for rule in s3_config_query_backend.buckets["bucket1"].rules ] # Verify the first: @@ -260,10 +256,8 @@ def test_s3_lifecycle_config_dict(): @mock_s3 def test_s3_notification_config_dict(): - from moto.s3.config import s3_config_query - # With 1 bucket in us-west-2: - s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + s3_config_query_backend.create_bucket("bucket1", "us-west-2") # And some notifications: notifications = { @@ -305,16 +299,14 @@ def test_s3_notification_config_dict(): ], } - s3_config_query.backends["global"].put_bucket_notification_configuration( - "bucket1", notifications - ) + s3_config_query.backends[DEFAULT_ACCOUNT_ID][ + "global" + ].put_bucket_notification_configuration("bucket1", notifications) # Get the notifications for this: - notifications = ( - s3_config_query.backends["global"] - .buckets["bucket1"] - .notification_configuration.to_config_dict() - ) + notifications = s3_config_query_backend.buckets[ + "bucket1" + ].notification_configuration.to_config_dict() # Verify it all: assert notifications == { @@ -361,14 +353,13 @@ def test_s3_notification_config_dict(): @mock_s3 def test_s3_acl_to_config_dict(): - from moto.s3.config import s3_config_query from moto.s3.models import FakeAcl, FakeGrant, FakeGrantee, OWNER # With 1 bucket in us-west-2: - s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") + s3_config_query_backend.create_bucket("logbucket", "us-west-2") # Get the config dict with nothing other than the owner details: - acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + acls = s3_config_query_backend.buckets["logbucket"].acl.to_config_dict() owner_acl = { "grantee": {"id": OWNER, "displayName": None}, "permission": "FullControl", @@ -393,9 +384,9 @@ def test_s3_acl_to_config_dict(): FakeGrant([FakeGrantee(grantee_id=OWNER)], "FULL_CONTROL"), ] ) - s3_config_query.backends["global"].put_bucket_acl("logbucket", log_acls) + s3_config_query_backend.put_bucket_acl("logbucket", log_acls) - acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + acls = s3_config_query_backend.buckets["logbucket"].acl.to_config_dict() assert acls == { "grantSet": None, "grantList": [ @@ -419,8 +410,8 @@ def test_s3_acl_to_config_dict(): FakeGrant([FakeGrantee(grantee_id=OWNER)], "WRITE_ACP"), ] ) - s3_config_query.backends["global"].put_bucket_acl("logbucket", log_acls) - acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + s3_config_query_backend.put_bucket_acl("logbucket", log_acls) + acls = s3_config_query_backend.buckets["logbucket"].acl.to_config_dict() assert acls == { "grantSet": None, "grantList": [ @@ -433,20 +424,19 @@ def test_s3_acl_to_config_dict(): @mock_s3 def test_s3_config_dict(): - from moto.s3.config import s3_config_query from moto.s3.models import FakeAcl, FakeGrant, FakeGrantee, OWNER # Without any buckets: - assert not s3_config_query.get_config_resource("some_bucket") + assert not s3_config_query.get_config_resource(DEFAULT_ACCOUNT_ID, "some_bucket") tags = {"someTag": "someValue", "someOtherTag": "someOtherValue"} # With 1 bucket in us-west-2: - s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") - s3_config_query.backends["global"].put_bucket_tagging("bucket1", tags) + s3_config_query_backend.create_bucket("bucket1", "us-west-2") + s3_config_query_backend.put_bucket_tagging("bucket1", tags) # With a log bucket: - s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") + s3_config_query_backend.create_bucket("logbucket", "us-west-2") log_acls = FakeAcl( [ FakeGrant( @@ -461,8 +451,8 @@ def test_s3_config_dict(): ] ) - s3_config_query.backends["global"].put_bucket_acl("logbucket", log_acls) - s3_config_query.backends["global"].put_bucket_logging( + s3_config_query_backend.put_bucket_acl("logbucket", log_acls) + s3_config_query_backend.put_bucket_logging( "bucket1", {"TargetBucket": "logbucket", "TargetPrefix": ""} ) @@ -481,10 +471,10 @@ def test_s3_config_dict(): # The policy is a byte array -- need to encode in Python 3 pass_policy = bytes(policy, "utf-8") - s3_config_query.backends["global"].put_bucket_policy("bucket1", pass_policy) + s3_config_query_backend.put_bucket_policy("bucket1", pass_policy) # Get the us-west-2 bucket and verify that it works properly: - bucket1_result = s3_config_query.get_config_resource("bucket1") + bucket1_result = s3_config_query.get_config_resource(DEFAULT_ACCOUNT_ID, "bucket1") # Just verify a few things: assert bucket1_result["arn"] == "arn:aws:s3:::bucket1" @@ -541,26 +531,33 @@ def test_s3_config_dict(): # Filter by correct region: assert bucket1_result == s3_config_query.get_config_resource( - "bucket1", resource_region="us-west-2" + DEFAULT_ACCOUNT_ID, "bucket1", resource_region="us-west-2" ) # By incorrect region: assert not s3_config_query.get_config_resource( - "bucket1", resource_region="eu-west-1" + DEFAULT_ACCOUNT_ID, "bucket1", resource_region="eu-west-1" ) # With correct resource ID and name: assert bucket1_result == s3_config_query.get_config_resource( - "bucket1", resource_name="bucket1" + DEFAULT_ACCOUNT_ID, "bucket1", resource_name="bucket1" ) # With an incorrect resource name: assert not s3_config_query.get_config_resource( - "bucket1", resource_name="eu-bucket-1" + DEFAULT_ACCOUNT_ID, "bucket1", resource_name="eu-bucket-1" + ) + + # With an incorrect account: + assert not s3_config_query.get_config_resource( + "unknown-accountid", "bucket1", resource_name="bucket-1" ) # Verify that no bucket policy returns the proper value: - logging_bucket = s3_config_query.get_config_resource("logbucket") + logging_bucket = s3_config_query.get_config_resource( + DEFAULT_ACCOUNT_ID, "logbucket" + ) assert json.loads(logging_bucket["supplementaryConfiguration"]["BucketPolicy"]) == { "policyText": None } diff --git a/tests/test_s3/test_s3_lambda_integration.py b/tests/test_s3/test_s3_lambda_integration.py index 943c6198e..a804b0574 100644 --- a/tests/test_s3/test_s3_lambda_integration.py +++ b/tests/test_s3/test_s3_lambda_integration.py @@ -2,7 +2,7 @@ import boto3 import json import pytest from moto import mock_lambda, mock_logs, mock_s3, mock_sqs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from tests.test_awslambda.utilities import ( get_test_zip_file_print_event, get_role_name, diff --git a/tests/test_s3control/test_s3control.py b/tests/test_s3control/test_s3control.py index b56f3c980..06f231635 100644 --- a/tests/test_s3control/test_s3control.py +++ b/tests/test_s3control/test_s3control.py @@ -9,7 +9,7 @@ from moto import mock_s3control @mock_s3control def test_get_public_access_block_for_account(): - from moto.core import ACCOUNT_ID + from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID client = boto3.client("s3control", region_name="us-west-2") diff --git a/tests/test_s3control/test_s3control_access_points.py b/tests/test_s3control/test_s3control_access_points.py index 8b09b4bbb..e92d991f2 100644 --- a/tests/test_s3control/test_s3control_access_points.py +++ b/tests/test_s3control/test_s3control_access_points.py @@ -4,7 +4,6 @@ import sure # noqa # pylint: disable=unused-import from botocore.client import ClientError from moto import mock_s3control -from moto.core import ACCOUNT_ID @mock_s3control @@ -53,7 +52,7 @@ def test_get_access_point_minimal(): resp.should.have.key("CreationDate") resp.should.have.key("Alias").match("ap_name-[a-z0-9]+-s3alias") resp.should.have.key("AccessPointArn").equals( - f"arn:aws:s3:us-east-1:{ACCOUNT_ID}:accesspoint/ap_name" + "arn:aws:s3:us-east-1:111111111111:accesspoint/ap_name" ) resp.should.have.key("Endpoints") diff --git a/tests/test_s3control/test_s3control_config_integration.py b/tests/test_s3control/test_s3control_config_integration.py index 0a01d9716..87d45b526 100644 --- a/tests/test_s3control/test_s3control_config_integration.py +++ b/tests/test_s3control/test_s3control_config_integration.py @@ -18,7 +18,7 @@ if not settings.TEST_SERVER_MODE: @mock_s3control @mock_config def test_config_list_account_pab(): - from moto.core import ACCOUNT_ID + from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID client = boto3.client("s3control", region_name="us-west-2") config_client = boto3.client("config", region_name="us-west-2") @@ -191,7 +191,7 @@ if not settings.TEST_SERVER_MODE: @mock_s3control @mock_config def test_config_get_account_pab(): - from moto.core import ACCOUNT_ID + from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID client = boto3.client("s3control", region_name="us-west-2") config_client = boto3.client("config", region_name="us-west-2") diff --git a/tests/test_s3control/test_s3control_s3.py b/tests/test_s3control/test_s3control_s3.py index 0921862ac..7547bfef5 100644 --- a/tests/test_s3control/test_s3control_s3.py +++ b/tests/test_s3control/test_s3control_s3.py @@ -3,7 +3,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_s3, mock_s3control, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID if not settings.TEST_SERVER_MODE: diff --git a/tests/test_sagemaker/cloudformation_test_configs.py b/tests/test_sagemaker/cloudformation_test_configs.py index e92e69263..3571c8c80 100644 --- a/tests/test_sagemaker/cloudformation_test_configs.py +++ b/tests/test_sagemaker/cloudformation_test_configs.py @@ -1,7 +1,7 @@ import json from abc import ABCMeta, abstractmethod -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID class TestConfig(metaclass=ABCMeta): diff --git a/tests/test_sagemaker/test_sagemaker_cloudformation.py b/tests/test_sagemaker/test_sagemaker_cloudformation.py index 7ac908fe1..742126a13 100644 --- a/tests/test_sagemaker/test_sagemaker_cloudformation.py +++ b/tests/test_sagemaker/test_sagemaker_cloudformation.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_cloudformation, mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from .cloudformation_test_configs import ( NotebookInstanceTestConfig, diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py index b036084f1..5f73ab819 100644 --- a/tests/test_sagemaker/test_sagemaker_endpoint.py +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -6,7 +6,7 @@ from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import pytest TEST_REGION_NAME = "us-east-1" diff --git a/tests/test_sagemaker/test_sagemaker_experiment.py b/tests/test_sagemaker/test_sagemaker_experiment.py index 8bb9f442e..24e4719fe 100644 --- a/tests/test_sagemaker/test_sagemaker_experiment.py +++ b/tests/test_sagemaker/test_sagemaker_experiment.py @@ -2,7 +2,7 @@ import boto3 import pytest from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID TEST_REGION_NAME = "us-east-1" TEST_EXPERIMENT_NAME = "MyExperimentName" diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py index 45c452fb9..d36abbb70 100644 --- a/tests/test_sagemaker/test_sagemaker_notebooks.py +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -4,7 +4,7 @@ from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import pytest TEST_REGION_NAME = "us-east-1" diff --git a/tests/test_sagemaker/test_sagemaker_processing.py b/tests/test_sagemaker/test_sagemaker_processing.py index dce4848f7..cc3a9d68c 100644 --- a/tests/test_sagemaker/test_sagemaker_processing.py +++ b/tests/test_sagemaker/test_sagemaker_processing.py @@ -4,7 +4,7 @@ import datetime import pytest from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" FAKE_PROCESSING_JOB_NAME = "MyProcessingJob" diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py index 19dc7d338..93bc8a708 100644 --- a/tests/test_sagemaker/test_sagemaker_training.py +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) TEST_REGION_NAME = "us-east-1" diff --git a/tests/test_sagemaker/test_sagemaker_trial.py b/tests/test_sagemaker/test_sagemaker_trial.py index 1eb28d882..a727aabb9 100644 --- a/tests/test_sagemaker/test_sagemaker_trial.py +++ b/tests/test_sagemaker/test_sagemaker_trial.py @@ -3,7 +3,7 @@ import uuid import boto3 from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID TEST_REGION_NAME = "us-east-1" diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py index aeaf5302d..0affddc3c 100644 --- a/tests/test_sagemaker/test_sagemaker_trial_component.py +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -6,7 +6,7 @@ import pytest from botocore.exceptions import ClientError from moto import mock_sagemaker -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID TEST_REGION_NAME = "us-east-1" diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index 67e0cfa07..3fc8e7f73 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -5,7 +5,7 @@ from dateutil.tz import tzlocal import re from moto import mock_secretsmanager, mock_lambda, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from botocore.exceptions import ClientError, ParamValidationError import string import pytz diff --git a/tests/test_servicediscovery/test_servicediscovery_httpnamespaces.py b/tests/test_servicediscovery/test_servicediscovery_httpnamespaces.py index 97998c2ea..e7b4374f2 100644 --- a/tests/test_servicediscovery/test_servicediscovery_httpnamespaces.py +++ b/tests/test_servicediscovery/test_servicediscovery_httpnamespaces.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_servicediscovery -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py index 5d57d19ff..6f6e94361 100644 --- a/tests/test_ses/test_ses_sns_boto3.py +++ b/tests/test_ses/test_ses_sns_boto3.py @@ -4,7 +4,7 @@ import json import sure # noqa # pylint: disable=unused-import from moto import mock_ses, mock_sns, mock_sqs from moto.ses.models import SESFeedback -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_ses @@ -88,7 +88,7 @@ def __test_sns_feedback__(addr, expected_msg, raw_email=False): if expected_msg is not None: msg = messages[0].body msg = json.loads(msg) - assert msg["Message"] == SESFeedback.generate_message(expected_msg) + assert msg["Message"] == SESFeedback.generate_message(ACCOUNT_ID, expected_msg) else: assert len(messages) == 0 diff --git a/tests/test_sns/test_application_boto3.py b/tests/test_sns/test_application_boto3.py index 4489ccfed..e9799532d 100644 --- a/tests/test_sns/test_application_boto3.py +++ b/tests/test_sns/test_application_boto3.py @@ -2,7 +2,7 @@ import boto3 from botocore.exceptions import ClientError from moto import mock_sns import sure # noqa # pylint: disable=unused-import -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import pytest diff --git a/tests/test_sns/test_publish_batch.py b/tests/test_sns/test_publish_batch.py index 5d441e30e..8cb6abaf6 100644 --- a/tests/test_sns/test_publish_batch.py +++ b/tests/test_sns/test_publish_batch.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError import pytest from moto import mock_sns, mock_sqs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_sns diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 962f0e7ef..3f71cf600 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -10,7 +10,7 @@ from botocore.exceptions import ClientError from unittest import SkipTest import pytest from moto import mock_sns, mock_sqs, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core.models import responses_mock from moto.sns import sns_backends @@ -288,7 +288,7 @@ def test_publish_sms(): result.should.contain("MessageId") if not settings.TEST_SERVER_MODE: - sns_backend = sns_backends["us-east-1"] + sns_backend = sns_backends[ACCOUNT_ID]["us-east-1"] sns_backend.sms_messages.should.have.key(result["MessageId"]).being.equal( ("+15551234567", "my message") ) @@ -422,7 +422,7 @@ def test_publish_to_http(): conn.publish(TopicArn=topic_arn, Message="my message", Subject="my subject") - sns_backend = sns_backends["us-east-1"] + sns_backend = sns_backends[ACCOUNT_ID]["us-east-1"] sns_backend.topics[topic_arn].sent_notifications.should.have.length_of(1) notification = sns_backend.topics[topic_arn].sent_notifications[0] _, msg, subject, _, _ = notification diff --git a/tests/test_sns/test_server.py b/tests/test_sns/test_server.py index 40ca1c2e0..338cc2789 100644 --- a/tests/test_sns/test_server.py +++ b/tests/test_sns/test_server.py @@ -1,4 +1,4 @@ -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID import sure # noqa # pylint: disable=unused-import diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index e5844bd4d..25d9136be 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -7,7 +7,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_sns, mock_sqs -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.sns.models import ( DEFAULT_PAGE_SIZE, DEFAULT_EFFECTIVE_DELIVERY_POLICY, diff --git a/tests/test_sns/test_topics_boto3.py b/tests/test_sns/test_topics_boto3.py index 9e15020dd..8aa82aa9f 100644 --- a/tests/test_sns/test_topics_boto3.py +++ b/tests/test_sns/test_topics_boto3.py @@ -6,7 +6,7 @@ import json from botocore.exceptions import ClientError from moto import mock_sns from moto.sns.models import DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_sns diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 2c5f82b42..a7e8b68aa 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -13,7 +13,7 @@ from moto import mock_sqs, settings from unittest import SkipTest, mock import pytest -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.sqs.models import ( Queue, MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND, diff --git a/tests/test_sqs/test_sqs_cloudformation.py b/tests/test_sqs/test_sqs_cloudformation.py index 975de690e..c96d63fef 100644 --- a/tests/test_sqs/test_sqs_cloudformation.py +++ b/tests/test_sqs/test_sqs_cloudformation.py @@ -3,7 +3,7 @@ import boto3 import json import sure # noqa # pylint: disable=unused-import from moto import mock_sqs, mock_cloudformation -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from string import Template from random import randint from uuid import uuid4 diff --git a/tests/test_sqs/test_sqs_multiaccount.py b/tests/test_sqs/test_sqs_multiaccount.py new file mode 100644 index 000000000..4ae03eee6 --- /dev/null +++ b/tests/test_sqs/test_sqs_multiaccount.py @@ -0,0 +1,36 @@ +import unittest +import boto3 +from moto import mock_sts, mock_sqs +from uuid import uuid4 + + +class TestStsAssumeRole(unittest.TestCase): + @mock_sqs + @mock_sts + def test_list_queues_in_different_account(self): + + sqs = boto3.client("sqs", region_name="us-east-1") + queue_url = sqs.create_queue(QueueName=str(uuid4()))["QueueUrl"] + + # verify function exists + all_urls = sqs.list_queues()["QueueUrls"] + all_urls.should.contain(queue_url) + + # assume role to another aws account + account_b = "111111111111" + sts = boto3.client("sts", region_name="us-east-1") + response = sts.assume_role( + RoleArn=f"arn:aws:iam::{account_b}:role/my-role", + RoleSessionName="test-session-name", + ExternalId="test-external-id", + ) + client2 = boto3.client( + "sqs", + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + + # client2 belongs to another account, where there are no queues + client2.list_queues().shouldnt.have.key("QueueUrls") diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index 2384233dd..d5fa3c57c 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -10,7 +10,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_ec2, mock_ssm -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.ssm.models import PARAMETER_VERSION_LIMIT, PARAMETER_HISTORY_MAX_RESULTS from tests import EXAMPLE_AMI_ID diff --git a/tests/test_ssm/test_ssm_defaults.py b/tests/test_ssm/test_ssm_defaults.py index b3cefc791..10c591b9e 100644 --- a/tests/test_ssm/test_ssm_defaults.py +++ b/tests/test_ssm/test_ssm_defaults.py @@ -2,7 +2,7 @@ import boto3 import sure # noqa # pylint: disable=unused-import from moto import mock_ssm -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_ssm diff --git a/tests/test_ssm/test_ssm_docs.py b/tests/test_ssm/test_ssm_docs.py index 2ac44da28..ef3058395 100644 --- a/tests/test_ssm/test_ssm_docs.py +++ b/tests/test_ssm/test_ssm_docs.py @@ -12,7 +12,7 @@ import pytest from botocore.exceptions import ClientError -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto import mock_ssm diff --git a/tests/test_ssm/test_ssm_parameterstore.py b/tests/test_ssm/test_ssm_parameterstore.py index edac9f261..9d51faff7 100644 --- a/tests/test_ssm/test_ssm_parameterstore.py +++ b/tests/test_ssm/test_ssm_parameterstore.py @@ -4,20 +4,20 @@ from moto.ssm.models import ParameterDict def test_simple_setget(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") store["/a/b/c"] = "some object" store.get("/a/b/c").should.equal("some object") def test_get_none(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") store.get(None).should.equal(None) def test_get_aws_param(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") p = store["/aws/service/global-infrastructure/regions/us-west-1/longName"] p.should.have.length_of(1) @@ -25,7 +25,7 @@ def test_get_aws_param(): def test_iter(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") store["/a/b/c"] = "some object" "/a/b/c".should.be.within(store) @@ -33,12 +33,12 @@ def test_iter(): def test_iter_none(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") None.shouldnt.be.within(store) def test_iter_aws(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") "/aws/service/global-infrastructure/regions/us-west-1/longName".should.be.within( store @@ -46,7 +46,7 @@ def test_iter_aws(): def test_get_key_beginning_with(): - store = ParameterDict(list) + store = ParameterDict("accnt", "region") store["/a/b/c"] = "some object" store["/b/c/d"] = "some other object" store["/a/c/d"] = "some third object" @@ -66,7 +66,7 @@ def test_get_key_beginning_with_aws(): ParameterDict should load the default parameters if we request a key starting with '/aws' :return: """ - store = ParameterDict(list) + store = ParameterDict("accnt", "region") uswest_params = set( store.get_keys_beginning_with( diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index ce3017cc3..0134da75a 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -8,7 +8,7 @@ from botocore.exceptions import ClientError import pytest from moto import mock_sts, mock_stepfunctions -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from unittest import SkipTest, mock @@ -295,11 +295,7 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): # with pytest.raises(ClientError): unknown_state_machine = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":stateMachine:unknown" + f"arn:aws:states:{region}:{ACCOUNT_ID}:stateMachine:unknown" ) client.describe_state_machine(stateMachineArn=unknown_state_machine) @@ -466,7 +462,7 @@ def test_state_machine_list_tags_for_nonexisting_machine(): client = boto3.client("stepfunctions", region_name=region) # non_existing_state_machine = ( - "arn:aws:states:" + region + ":" + _get_account_id() + ":stateMachine:unknown" + f"arn:aws:states:{region}:{ACCOUNT_ID}:stateMachine:unknown" ) response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) tags = response["tags"] @@ -486,12 +482,7 @@ def test_state_machine_start_execution(): execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) uuid_regex = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" expected_exec_name = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":execution:name:" - + uuid_regex + f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:name:{uuid_regex}" ) execution["executionArn"].should.match(expected_exec_name) execution["startDate"].should.be.a(datetime) @@ -520,11 +511,7 @@ def test_state_machine_start_execution_with_custom_name(): # execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) expected_exec_name = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":execution:name:execution_name" + f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:name:execution_name" ) execution["executionArn"].should.equal(expected_exec_name) execution["startDate"].should.be.a(datetime) @@ -567,12 +554,7 @@ def test_state_machine_start_execution_with_custom_input(): execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) uuid_regex = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" expected_exec_name = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":execution:name:" - + uuid_regex + f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:name:{uuid_regex}" ) execution["executionArn"].should.match(expected_exec_name) execution["startDate"].should.be.a(datetime) @@ -743,9 +725,7 @@ def test_execution_throws_error_when_describing_unknown_execution(): client = boto3.client("stepfunctions", region_name=region) # with pytest.raises(ClientError): - unknown_execution = ( - "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" - ) + unknown_execution = f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:unknown" client.describe_execution(executionArn=unknown_execution) @@ -774,9 +754,7 @@ def test_state_machine_throws_error_when_describing_unknown_execution(): client = boto3.client("stepfunctions", region_name=region) # with pytest.raises(ClientError): - unknown_execution = ( - "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" - ) + unknown_execution = f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:unknown" client.describe_state_machine_for_execution(executionArn=unknown_execution) @@ -806,11 +784,7 @@ def test_state_machine_stop_raises_error_when_unknown_execution(): ) with pytest.raises(ClientError) as ex: unknown_execution = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":execution:test-state-machine:unknown" + f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:test-state-machine:unknown" ) client.stop_execution(executionArn=unknown_execution) ex.value.response["Error"]["Code"].should.equal("ExecutionDoesNotExist") @@ -844,11 +818,7 @@ def test_state_machine_get_execution_history_throws_error_with_unknown_execution ) with pytest.raises(ClientError) as ex: unknown_execution = ( - "arn:aws:states:" - + region - + ":" - + _get_account_id() - + ":execution:test-state-machine:unknown" + f"arn:aws:states:{region}:{ACCOUNT_ID}:execution:test-state-machine:unknown" ) client.get_execution_history(executionArn=unknown_execution) ex.value.response["Error"]["Code"].should.equal("ExecutionDoesNotExist") @@ -981,15 +951,5 @@ def test_state_machine_get_execution_history_contains_expected_failure_events_wh execution_history["events"].should.equal(expected_events) -def _get_account_id(): - global account_id - if account_id: - return account_id - sts = boto3.client("sts", region_name=region) - identity = sts.get_caller_identity() - account_id = identity["Account"] - return account_id - - def _get_default_role(): - return "arn:aws:iam::" + _get_account_id() + ":role/unknown_sf_role" + return "arn:aws:iam::" + ACCOUNT_ID + ":role/unknown_sf_role" diff --git a/tests/test_sts/test_sts.py b/tests/test_sts/test_sts.py index b516121de..025fa0be6 100644 --- a/tests/test_sts/test_sts.py +++ b/tests/test_sts/test_sts.py @@ -9,7 +9,7 @@ import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_sts, mock_iam, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH diff --git a/tests/test_sts/test_sts_integration.py b/tests/test_sts/test_sts_integration.py new file mode 100644 index 000000000..08f78e9b8 --- /dev/null +++ b/tests/test_sts/test_sts_integration.py @@ -0,0 +1,185 @@ +import boto3 +import unittest + +from base64 import b64encode +from moto import mock_dynamodb, mock_sts, mock_iam +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID + + +@mock_sts +@mock_iam +@mock_dynamodb +class TestStsAssumeRole(unittest.TestCase): + def setUp(self) -> None: + self.account_b = "111111111111" + self.sts = boto3.client("sts", region_name="us-east-1") + + def test_assume_role_in_different_account(self): + # assume role to another aws account + role_name = f"arn:aws:iam::{self.account_b}:role/my-role" + response = self.sts.assume_role( + RoleArn=role_name, + RoleSessionName="test-session-name", + ExternalId="test-external-id", + ) + + # Assume the new role + iam_account_b = boto3.client( + "iam", + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + + # Verify new users belong to the different account + user = iam_account_b.create_user(UserName="user-in-new-account")["User"] + user["Arn"].should.equal( + f"arn:aws:iam::{self.account_b}:user/user-in-new-account" + ) + + def test_assume_role_with_saml_in_different_account(self): + role_name = "test-role" + provider_name = "TestProvFed" + fed_identifier = "7ca82df9-1bad-4dd3-9b2b-adb68b554282" + fed_name = "testuser" + role_input = "arn:aws:iam::{account_id}:role/{role_name}".format( + account_id=self.account_b, role_name=role_name + ) + principal_role = ( + "arn:aws:iam:{account_id}:saml-provider/{provider_name}".format( + account_id=ACCOUNT_ID, provider_name=provider_name + ) + ) + saml_assertion = """ + + http://localhost/ + + + + + http://localhost:3000/ + + + + + + + + + + + NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo= + + + NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo= + + + NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo= + + + + + {fed_identifier} + + + + + + + urn:amazon:webservices + + + + + {fed_name} + + + arn:aws:iam::{account_id}:role/{role_name},arn:aws:iam::{account_id}:saml-provider/{provider_name} + + + 900 + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + """.format( + account_id=self.account_b, + role_name=role_name, + provider_name=provider_name, + fed_identifier=fed_identifier, + fed_name=fed_name, + ).replace( + "\n", "" + ) + + assume_role_response = self.sts.assume_role_with_saml( + RoleArn=role_input, + PrincipalArn=principal_role, + SAMLAssertion=b64encode(saml_assertion.encode("utf-8")).decode("utf-8"), + ) + + # Assume the new role + iam_account_b = boto3.client( + "iam", + aws_access_key_id=assume_role_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=assume_role_response["Credentials"][ + "SecretAccessKey" + ], + aws_session_token=assume_role_response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + + # Verify new users belong to the different account + user = iam_account_b.create_user(UserName="user-in-new-account")["User"] + user["Arn"].should.equal( + f"arn:aws:iam::{self.account_b}:user/user-in-new-account" + ) + + def test_dynamodb_supports_multiple_accounts(self): + ddb_client = boto3.client("dynamodb", region_name="us-east-1") + + ddb_client.create_table( + TableName="table-in-default-account", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 5}, + ) + # assume role to another aws account + role_name = f"arn:aws:iam::{self.account_b}:role/my-role" + response = self.sts.assume_role( + RoleArn=role_name, + RoleSessionName="test-session-name", + ExternalId="test-external-id", + ) + + # Assume the new role + ddb_account_b = boto3.client( + "dynamodb", + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + region_name="us-east-1", + ) + + # Verify new dynamodb belong to the different account + ddb_account_b.create_table( + TableName="table-in-new-account", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 5}, + ) + + table = ddb_client.describe_table(TableName="table-in-default-account")["Table"] + table["TableArn"].should.equal( + "arn:aws:dynamodb:us-east-1:123456789012:table/table-in-default-account" + ) + + table = ddb_account_b.describe_table(TableName="table-in-new-account")["Table"] + table["TableArn"].should.equal( + f"arn:aws:dynamodb:us-east-1:{self.account_b}:table/table-in-new-account" + ) diff --git a/tests/test_swf/models/test_domain.py b/tests/test_swf/models/test_domain.py index 31a840217..896118b0f 100644 --- a/tests/test_swf/models/test_domain.py +++ b/tests/test_swf/models/test_domain.py @@ -1,7 +1,7 @@ from collections import namedtuple import sure # noqa # pylint: disable=unused-import -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.swf.exceptions import SWFUnknownResourceFault from moto.swf.models import Domain @@ -13,7 +13,7 @@ WorkflowExecution = namedtuple( def test_domain_short_dict_representation(): - domain = Domain("foo", "52", TEST_REGION) + domain = Domain("foo", "52", ACCOUNT_ID, TEST_REGION) domain.to_short_dict().should.equal( { "name": "foo", @@ -27,7 +27,7 @@ def test_domain_short_dict_representation(): def test_domain_full_dict_representation(): - domain = Domain("foo", "52", TEST_REGION) + domain = Domain("foo", "52", ACCOUNT_ID, TEST_REGION) domain.to_full_dict()["domainInfo"].should.equal(domain.to_short_dict()) _config = domain.to_full_dict()["configuration"] @@ -35,38 +35,38 @@ def test_domain_full_dict_representation(): def test_domain_string_representation(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) str(domain).should.equal("Domain(name: my-domain, status: REGISTERED)") def test_domain_add_to_activity_task_list(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) domain.add_to_activity_task_list("foo", "bar") domain.activity_task_lists.should.equal({"foo": ["bar"]}) def test_domain_activity_tasks(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) domain.add_to_activity_task_list("foo", "bar") domain.add_to_activity_task_list("other", "baz") sorted(domain.activity_tasks).should.equal(["bar", "baz"]) def test_domain_add_to_decision_task_list(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) domain.add_to_decision_task_list("foo", "bar") domain.decision_task_lists.should.equal({"foo": ["bar"]}) def test_domain_decision_tasks(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) domain.add_to_decision_task_list("foo", "bar") domain.add_to_decision_task_list("other", "baz") sorted(domain.decision_tasks).should.equal(["bar", "baz"]) def test_domain_get_workflow_execution(): - domain = Domain("my-domain", "60", TEST_REGION) + domain = Domain("my-domain", "60", ACCOUNT_ID, TEST_REGION) wfe1 = WorkflowExecution( workflow_id="wf-id-1", run_id="run-id-1", execution_status="OPEN", open=True diff --git a/tests/test_swf/responses/test_domains.py b/tests/test_swf/responses/test_domains.py index 9d00cbfca..e5b4cd88d 100644 --- a/tests/test_swf/responses/test_domains.py +++ b/tests/test_swf/responses/test_domains.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_swf -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID # RegisterDomain endpoint diff --git a/tests/test_swf/utils.py b/tests/test_swf/utils.py index 0ba7177ac..f5282acf2 100644 --- a/tests/test_swf/utils.py +++ b/tests/test_swf/utils.py @@ -1,5 +1,6 @@ import boto3 +from moto.core import DEFAULT_ACCOUNT_ID from moto.swf.models import ActivityType, Domain, WorkflowType, WorkflowExecution @@ -31,7 +32,7 @@ for key, value in ACTIVITY_TASK_TIMEOUTS.items(): # A test Domain def get_basic_domain(): - return Domain("test-domain", "90", "us-east-1") + return Domain("test-domain", "90", DEFAULT_ACCOUNT_ID, "us-east-1") # A test WorkflowType diff --git a/tests/test_timestreamwrite/test_timestreamwrite_database.py b/tests/test_timestreamwrite/test_timestreamwrite_database.py index cf003b52c..55c8ca6bf 100644 --- a/tests/test_timestreamwrite/test_timestreamwrite_database.py +++ b/tests/test_timestreamwrite/test_timestreamwrite_database.py @@ -4,7 +4,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_timestreamwrite -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_timestreamwrite diff --git a/tests/test_timestreamwrite/test_timestreamwrite_table.py b/tests/test_timestreamwrite/test_timestreamwrite_table.py index 33ad966da..83738b658 100644 --- a/tests/test_timestreamwrite/test_timestreamwrite_table.py +++ b/tests/test_timestreamwrite/test_timestreamwrite_table.py @@ -5,7 +5,7 @@ import sure # noqa # pylint: disable=unused-import from botocore.exceptions import ClientError from moto import mock_timestreamwrite, settings -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_timestreamwrite @@ -299,7 +299,7 @@ def test_write_records(): if not settings.TEST_SERVER_MODE: from moto.timestreamwrite.models import timestreamwrite_backends - backend = timestreamwrite_backends["us-east-1"] + backend = timestreamwrite_backends[ACCOUNT_ID]["us-east-1"] records = backend.databases["mydatabase"].tables["mytable"].records records.should.equal(sample_records) diff --git a/tests/test_wafv2/test_server.py b/tests/test_wafv2/test_server.py index c9bac1907..28b313949 100644 --- a/tests/test_wafv2/test_server.py +++ b/tests/test_wafv2/test_server.py @@ -3,7 +3,7 @@ import sure # noqa # pylint: disable=unused-import import moto.server as server from moto import mock_wafv2 from .test_helper_functions import CREATE_WEB_ACL_BODY, LIST_WEB_ACL_BODY -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID CREATE_WEB_ACL_HEADERS = { "X-Amz-Target": "AWSWAF_20190729.CreateWebACL", diff --git a/tests/test_wafv2/test_utils.py b/tests/test_wafv2/test_utils.py index 5482c751d..c970cdc60 100644 --- a/tests/test_wafv2/test_utils.py +++ b/tests/test_wafv2/test_utils.py @@ -1,7 +1,7 @@ import uuid from moto.wafv2.utils import make_arn_for_wacl -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID def test_make_arn_for_wacl(): @@ -9,13 +9,13 @@ def test_make_arn_for_wacl(): region = "us-east-1" name = "testName" scope = "REGIONAL" - arn = make_arn_for_wacl(name, region, uniqueID, scope) + arn = make_arn_for_wacl(name, ACCOUNT_ID, region, uniqueID, scope) assert arn == "arn:aws:wafv2:{}:{}:regional/webacl/{}/{}".format( region, ACCOUNT_ID, name, uniqueID ) scope = "CLOUDFRONT" - arn = make_arn_for_wacl(name, region, uniqueID, scope) + arn = make_arn_for_wacl(name, ACCOUNT_ID, region, uniqueID, scope) assert arn == "arn:aws:wafv2:{}:{}:global/webacl/{}/{}".format( region, ACCOUNT_ID, name, uniqueID ) diff --git a/tests/test_wafv2/test_wafv2.py b/tests/test_wafv2/test_wafv2.py index 6bca57713..570f4ee55 100644 --- a/tests/test_wafv2/test_wafv2.py +++ b/tests/test_wafv2/test_wafv2.py @@ -5,7 +5,7 @@ import boto3 from botocore.exceptions import ClientError from moto import mock_wafv2 from .test_helper_functions import CREATE_WEB_ACL_BODY, LIST_WEB_ACL_BODY -from moto.core import ACCOUNT_ID +from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @mock_wafv2