Preparation for MultiAccount support (#5157)

This commit is contained in:
Bert Blommers 2022-06-04 11:30:16 +00:00 committed by GitHub
parent 620f15a562
commit 79a2a9d423
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
155 changed files with 724 additions and 943 deletions

View File

@ -423,17 +423,11 @@ class CertBundle(BaseModel):
class AWSCertificateManagerBackend(BaseBackend): class AWSCertificateManagerBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region = region
self._certificates = {} self._certificates = {}
self._idempotency_tokens = {} self._idempotency_tokens = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
@ -491,12 +485,16 @@ class AWSCertificateManagerBackend(BaseBackend):
else: else:
# Will reuse provided ARN # Will reuse provided ARN
bundle = CertBundle( bundle = CertBundle(
certificate, private_key, chain=chain, region=self.region, arn=arn certificate,
private_key,
chain=chain,
region=self.region_name,
arn=arn,
) )
else: else:
# Will generate a random ARN # Will generate a random ARN
bundle = CertBundle( bundle = CertBundle(
certificate, private_key, chain=chain, region=self.region certificate, private_key, chain=chain, region=self.region_name
) )
self._certificates[bundle.arn] = bundle self._certificates[bundle.arn] = bundle
@ -548,7 +546,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return arn return arn
cert = CertBundle.generate_cert( cert = CertBundle.generate_cert(
domain_name, region=self.region, sans=subject_alt_names domain_name, region=self.region_name, sans=subject_alt_names
) )
if idempotency_token is not None: if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn) self._set_idempotency_token_arn(idempotency_token, cert.arn)

View File

@ -1239,23 +1239,17 @@ class APIGatewayBackend(BaseBackend):
- This only works when using the decorators, not in ServerMode - This only works when using the decorators, not in ServerMode
""" """
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.apis = {} self.apis = {}
self.keys = {} self.keys = {}
self.usage_plans = {} self.usage_plans = {}
self.usage_plan_keys = {} self.usage_plan_keys = {}
self.domain_names = {} self.domain_names = {}
self.models = {} self.models = {}
self.region_name = region_name
self.base_path_mappings = {} self.base_path_mappings = {}
self.vpc_links = {} self.vpc_links = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_rest_api( def create_rest_api(
self, self,
name, name,

View File

@ -980,18 +980,12 @@ class VpcLink(BaseModel):
class ApiGatewayV2Backend(BaseBackend): class ApiGatewayV2Backend(BaseBackend):
"""Implementation of ApiGatewayV2 APIs.""" """Implementation of ApiGatewayV2 APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.apis = dict() self.apis = dict()
self.vpc_links = dict() self.vpc_links = dict()
self.tagger = TaggingService() self.tagger = TaggingService()
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_api( def create_api(
self, self,
api_key_selection_expression, api_key_selection_expression,

View File

@ -63,19 +63,13 @@ class ScalableDimensionValueSet(Enum):
class ApplicationAutoscalingBackend(BaseBackend): class ApplicationAutoscalingBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region = region self.ecs_backend = ecs_backends[region_name]
self.ecs_backend = ecs_backends[region]
self.targets = OrderedDict() self.targets = OrderedDict()
self.policies = {} self.policies = {}
self.scheduled_actions = list() self.scheduled_actions = list()
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
@ -85,7 +79,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
@property @property
def applicationautoscaling_backend(self): def applicationautoscaling_backend(self):
return applicationautoscaling_backends[self.region] return applicationautoscaling_backends[self.region_name]
def describe_scalable_targets(self, namespace, r_ids=None, dimension=None): def describe_scalable_targets(self, namespace, r_ids=None, dimension=None):
"""Describe scalable targets.""" """Describe scalable targets."""
@ -166,7 +160,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
if policy_key in self.policies: if policy_key in self.policies:
old_policy = self.policies[policy_key] old_policy = self.policies[policy_key]
policy = FakeApplicationAutoscalingPolicy( policy = FakeApplicationAutoscalingPolicy(
region_name=self.region, region_name=self.region_name,
policy_name=policy_name, policy_name=policy_name,
service_namespace=service_namespace, service_namespace=service_namespace,
resource_id=resource_id, resource_id=resource_id,
@ -176,7 +170,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
) )
else: else:
policy = FakeApplicationAutoscalingPolicy( policy = FakeApplicationAutoscalingPolicy(
region_name=self.region, region_name=self.region_name,
policy_name=policy_name, policy_name=policy_name,
service_namespace=service_namespace, service_namespace=service_namespace,
resource_id=resource_id, resource_id=resource_id,
@ -311,7 +305,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
start_time, start_time,
end_time, end_time,
scalable_target_action, scalable_target_action,
self.region, self.region_name,
) )
self.scheduled_actions.append(action) self.scheduled_actions.append(action)

View File

@ -187,17 +187,11 @@ class GraphqlAPIKey(BaseModel):
class AppSyncBackend(BaseBackend): class AppSyncBackend(BaseBackend):
"""Implementation of AppSync APIs.""" """Implementation of AppSync APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.graphql_apis = dict() self.graphql_apis = dict()
self.tagger = TaggingService() self.tagger = TaggingService()
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_graphql_api( def create_graphql_api(
self, self,
name, name,

View File

@ -85,9 +85,8 @@ class NamedQuery(BaseModel):
class AthenaBackend(BaseBackend): class AthenaBackend(BaseBackend):
region_name = None region_name = None
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
if region_name is not None: super().__init__(region_name, account_id)
self.region_name = region_name
self.work_groups = {} self.work_groups = {}
self.executions = {} self.executions = {}
self.named_queries = {} self.named_queries = {}

View File

@ -97,7 +97,7 @@ class FakeScalingPolicy(BaseModel):
@property @property
def arn(self): def arn(self):
return f"arn:aws:autoscaling:{self.autoscaling_backend.region}:{get_account_id()}:scalingPolicy:c322761b-3172-4d56-9a21-0ed9d6161d67:autoScalingGroupName/{self.as_name}:policyName/{self.name}" return f"arn:aws:autoscaling:{self.autoscaling_backend.region_name}:{get_account_id()}:scalingPolicy:c322761b-3172-4d56-9a21-0ed9d6161d67:autoScalingGroupName/{self.as_name}:policyName/{self.name}"
def execute(self): def execute(self):
if self.adjustment_type == "ExactCapacity": if self.adjustment_type == "ExactCapacity":
@ -303,7 +303,7 @@ class FakeAutoScalingGroup(CloudFormationModel):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.name = name self.name = name
self._id = str(uuid4()) self._id = str(uuid4())
self.region = self.autoscaling_backend.region self.region = self.autoscaling_backend.region_name
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier) self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier)
@ -650,7 +650,8 @@ class FakeAutoScalingGroup(CloudFormationModel):
class AutoScalingBackend(BaseBackend): class AutoScalingBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.autoscaling_groups = OrderedDict() self.autoscaling_groups = OrderedDict()
self.launch_configurations = OrderedDict() self.launch_configurations = OrderedDict()
self.policies = {} self.policies = {}
@ -658,12 +659,6 @@ class AutoScalingBackend(BaseBackend):
self.ec2_backend = ec2_backends[region_name] self.ec2_backend = ec2_backends[region_name]
self.elb_backend = elb_backends[region_name] self.elb_backend = elb_backends[region_name]
self.elbv2_backend = elbv2_backends[region_name] self.elbv2_backend = elbv2_backends[region_name]
self.region = region_name
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -721,7 +716,7 @@ class AutoScalingBackend(BaseBackend):
ebs_optimized=ebs_optimized, ebs_optimized=ebs_optimized,
associate_public_ip_address=associate_public_ip_address, associate_public_ip_address=associate_public_ip_address,
block_device_mapping_dict=block_device_mappings, block_device_mapping_dict=block_device_mappings,
region_name=self.region, region_name=self.region_name,
metadata_options=metadata_options, metadata_options=metadata_options,
classic_link_vpc_id=classic_link_vpc_id, classic_link_vpc_id=classic_link_vpc_id,
classic_link_vpc_security_groups=classic_link_vpc_security_groups, classic_link_vpc_security_groups=classic_link_vpc_security_groups,
@ -1273,4 +1268,4 @@ class AutoScalingBackend(BaseBackend):
return tags return tags
autoscaling_backends = BackendDict(AutoScalingBackend, "ec2") autoscaling_backends = BackendDict(AutoScalingBackend, "autoscaling")

View File

@ -1308,16 +1308,11 @@ class LambdaBackend(BaseBackend):
.. note:: When using the decorators, a Docker container cannot reach Moto, as it does not run as a server. Any boto3-invocations used within your Lambda will try to connect to AWS. .. note:: When using the decorators, a Docker container cannot reach Moto, as it does not run as a server. Any boto3-invocations used within your Lambda will try to connect to AWS.
""" """
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self._lambdas = LambdaStorage(region_name=region_name) self._lambdas = LambdaStorage(region_name=region_name)
self._event_source_mappings = {} self._event_source_mappings = {}
self._layers = LayerStorage() self._layers = LayerStorage()
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):

View File

@ -825,9 +825,8 @@ class BatchBackend(BaseBackend):
With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts. With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts.
""" """
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.tagger = TaggingService() self.tagger = TaggingService()
self._compute_environments = {} self._compute_environments = {}
@ -872,16 +871,13 @@ class BatchBackend(BaseBackend):
return logs_backends[self.region_name] return logs_backends[self.region_name]
def reset(self): def reset(self):
region_name = self.region_name
for job in self._jobs.values(): for job in self._jobs.values():
if job.status not in ("FAILED", "SUCCEEDED"): if job.status not in ("FAILED", "SUCCEEDED"):
job.stop = True job.stop = True
# Try to join # Try to join
job.join(0.2) job.join(0.2)
self.__dict__ = {} super().reset()
self.__init__(region_name)
def get_compute_environment_by_arn(self, arn): def get_compute_environment_by_arn(self, arn):
return self._compute_environments.get(arn) return self._compute_environments.get(arn)

View File

@ -7,12 +7,9 @@ import datetime
class BatchSimpleBackend(BaseBackend): class BatchSimpleBackend(BaseBackend):
""" """
Implements a Batch-Backend that does not use Docker containers. Submitted Jobs are simply marked as Success Implements a Batch-Backend that does not use Docker containers. Submitted Jobs are simply marked as Success
Use the `@mock_batch_simple`-decorator to use this class. Annotate your tests with `@mock_batch_simple`-decorator to use this Batch-implementation.
""" """
def __init__(self, region_name=None):
self.region_name = region_name
@property @property
def backend(self): def backend(self):
return batch_backends[self.region_name] return batch_backends[self.region_name]

View File

@ -1,5 +1,4 @@
from .models import budgets_backend from .models import budgets_backends
from ..core.models import base_decorator from ..core.models import base_decorator
budgets_backends = {"global": budgets_backend}
mock_budgets = base_decorator(budgets_backends) mock_budgets = base_decorator(budgets_backends)

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time, BackendDict
from .exceptions import BudgetMissingLimit, DuplicateRecordException, NotFoundException from .exceptions import BudgetMissingLimit, DuplicateRecordException, NotFoundException
@ -69,7 +69,8 @@ class Budget(BaseModel):
class BudgetsBackend(BaseBackend): class BudgetsBackend(BaseBackend):
"""Implementation of Budgets APIs.""" """Implementation of Budgets APIs."""
def __init__(self): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
# {"account_id": {"budget_name": Budget}} # {"account_id": {"budget_name": Budget}}
self.budgets = defaultdict(dict) self.budgets = defaultdict(dict)
@ -123,4 +124,6 @@ class BudgetsBackend(BaseBackend):
return self.budgets[account_id][budget_name].get_notifications() return self.budgets[account_id][budget_name].get_notifications()
budgets_backend = BudgetsBackend() budgets_backends = BackendDict(
BudgetsBackend, "budgets", use_boto3_regions=False, additional_regions=["global"]
)

View File

@ -1,15 +1,19 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import budgets_backend from .models import budgets_backends
class BudgetsResponse(BaseResponse): class BudgetsResponse(BaseResponse):
@property
def backend(self):
return budgets_backends["global"]
def create_budget(self): def create_budget(self):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budget = self._get_param("Budget") budget = self._get_param("Budget")
notifications = self._get_param("NotificationsWithSubscribers", []) notifications = self._get_param("NotificationsWithSubscribers", [])
budgets_backend.create_budget( self.backend.create_budget(
account_id=account_id, budget=budget, notifications=notifications account_id=account_id, budget=budget, notifications=notifications
) )
return json.dumps(dict()) return json.dumps(dict())
@ -17,20 +21,20 @@ class BudgetsResponse(BaseResponse):
def describe_budget(self): def describe_budget(self):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budget_name = self._get_param("BudgetName") budget_name = self._get_param("BudgetName")
budget = budgets_backend.describe_budget( budget = self.backend.describe_budget(
account_id=account_id, budget_name=budget_name account_id=account_id, budget_name=budget_name
) )
return json.dumps(dict(Budget=budget)) return json.dumps(dict(Budget=budget))
def describe_budgets(self): def describe_budgets(self):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budgets = budgets_backend.describe_budgets(account_id=account_id) budgets = self.backend.describe_budgets(account_id=account_id)
return json.dumps(dict(Budgets=budgets, nextToken=None)) return json.dumps(dict(Budgets=budgets, nextToken=None))
def delete_budget(self): def delete_budget(self):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budget_name = self._get_param("BudgetName") budget_name = self._get_param("BudgetName")
budgets_backend.delete_budget(account_id=account_id, budget_name=budget_name) self.backend.delete_budget(account_id=account_id, budget_name=budget_name)
return json.dumps(dict()) return json.dumps(dict())
def create_notification(self): def create_notification(self):
@ -38,7 +42,7 @@ class BudgetsResponse(BaseResponse):
budget_name = self._get_param("BudgetName") budget_name = self._get_param("BudgetName")
notification = self._get_param("Notification") notification = self._get_param("Notification")
subscribers = self._get_param("Subscribers") subscribers = self._get_param("Subscribers")
budgets_backend.create_notification( self.backend.create_notification(
account_id=account_id, account_id=account_id,
budget_name=budget_name, budget_name=budget_name,
notification=notification, notification=notification,
@ -50,7 +54,7 @@ class BudgetsResponse(BaseResponse):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budget_name = self._get_param("BudgetName") budget_name = self._get_param("BudgetName")
notification = self._get_param("Notification") notification = self._get_param("Notification")
budgets_backend.delete_notification( self.backend.delete_notification(
account_id=account_id, budget_name=budget_name, notification=notification account_id=account_id, budget_name=budget_name, notification=notification
) )
return json.dumps(dict()) return json.dumps(dict())
@ -58,7 +62,7 @@ class BudgetsResponse(BaseResponse):
def describe_notifications_for_budget(self): def describe_notifications_for_budget(self):
account_id = self._get_param("AccountId") account_id = self._get_param("AccountId")
budget_name = self._get_param("BudgetName") budget_name = self._get_param("BudgetName")
notifications = budgets_backend.describe_notifications_for_budget( notifications = self.backend.describe_notifications_for_budget(
account_id=account_id, budget_name=budget_name account_id=account_id, budget_name=budget_name
) )
return json.dumps(dict(Notifications=notifications, NextToken=None)) return json.dumps(dict(Notifications=notifications, NextToken=None))

View File

@ -534,18 +534,13 @@ class CloudFormationBackend(BaseBackend):
This means it has to run inside a Docker-container, or be started using `moto_server -h 0.0.0.0`. This means it has to run inside a Docker-container, or be started using `moto_server -h 0.0.0.0`.
""" """
def __init__(self, region=None): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.stacks = OrderedDict() self.stacks = OrderedDict()
self.stacksets = OrderedDict() self.stacksets = OrderedDict()
self.deleted_stacks = {} self.deleted_stacks = {}
self.exports = OrderedDict() self.exports = OrderedDict()
self.change_sets = OrderedDict() self.change_sets = OrderedDict()
self.region = region
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -676,13 +671,13 @@ class CloudFormationBackend(BaseBackend):
tags=None, tags=None,
role_arn=None, role_arn=None,
): ):
stack_id = generate_stack_id(name, self.region) stack_id = generate_stack_id(name, self.region_name)
new_stack = FakeStack( new_stack = FakeStack(
stack_id=stack_id, stack_id=stack_id,
name=name, name=name,
template=template, template=template,
parameters=parameters, parameters=parameters,
region_name=self.region, region_name=self.region_name,
notification_arns=notification_arns, notification_arns=notification_arns,
tags=tags, tags=tags,
role_arn=role_arn, role_arn=role_arn,
@ -717,13 +712,13 @@ class CloudFormationBackend(BaseBackend):
else: else:
raise ValidationError(stack_name) raise ValidationError(stack_name)
else: else:
stack_id = generate_stack_id(stack_name, self.region) stack_id = generate_stack_id(stack_name, self.region_name)
stack = FakeStack( stack = FakeStack(
stack_id=stack_id, stack_id=stack_id,
name=stack_name, name=stack_name,
template={}, template={},
parameters=parameters, parameters=parameters,
region_name=self.region, region_name=self.region_name,
notification_arns=notification_arns, notification_arns=notification_arns,
tags=tags, tags=tags,
role_arn=role_arn, role_arn=role_arn,
@ -734,7 +729,7 @@ class CloudFormationBackend(BaseBackend):
"REVIEW_IN_PROGRESS", resource_status_reason="User Initiated" "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated"
) )
change_set_id = generate_changeset_id(change_set_name, self.region) change_set_id = generate_changeset_id(change_set_name, self.region_name)
new_change_set = FakeChangeSet( new_change_set = FakeChangeSet(
change_set_type=change_set_type, change_set_type=change_set_type,

View File

@ -47,7 +47,7 @@ from moto.ssm import models # noqa # pylint: disable=all
# End ugly list of imports # End ugly list of imports
from moto.core import get_account_id, CloudFormationModel from moto.core import get_account_id, CloudFormationModel
from moto.s3 import s3_backend from moto.s3.models import s3_backend
from moto.s3.utils import bucket_and_name_from_url from moto.s3.utils import bucket_and_name_from_url
from moto.ssm import ssm_backends from moto.ssm import ssm_backends
from .utils import random_suffix from .utils import random_suffix

View File

@ -6,7 +6,7 @@ from yaml.scanner import ScannerError # pylint:disable=c-extension-no-member
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from moto.s3 import s3_backend from moto.s3.models import s3_backend
from moto.s3.exceptions import S3ClientError from moto.s3.exceptions import S3ClientError
from moto.core import get_account_id from moto.core import get_account_id
from .models import cloudformation_backends from .models import cloudformation_backends

View File

@ -1,5 +1,4 @@
from .models import cloudfront_backend from .models import cloudfront_backends
from ..core.models import base_decorator from ..core.models import base_decorator
cloudfront_backends = {"global": cloudfront_backend}
mock_cloudfront = base_decorator(cloudfront_backends) mock_cloudfront = base_decorator(cloudfront_backends)

View File

@ -2,6 +2,7 @@ import random
import string import string
from moto.core import get_account_id, BaseBackend, BaseModel from moto.core import get_account_id, BaseBackend, BaseModel
from moto.core.utils import BackendDict
from moto.moto_api import state_manager from moto.moto_api import state_manager
from moto.moto_api._internal.managed_state_model import ManagedState from moto.moto_api._internal.managed_state_model import ManagedState
from uuid import uuid4 from uuid import uuid4
@ -171,7 +172,8 @@ class Distribution(BaseModel, ManagedState):
class CloudFrontBackend(BaseBackend): class CloudFrontBackend(BaseBackend):
def __init__(self): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.distributions = dict() self.distributions = dict()
state_manager.register_default_transition( state_manager.register_default_transition(
@ -247,4 +249,10 @@ class CloudFrontBackend(BaseBackend):
return dist, dist.location, dist.etag return dist, dist.location, dist.etag
cloudfront_backend = CloudFrontBackend() cloudfront_backends = BackendDict(
CloudFrontBackend,
"cloudfront",
use_boto3_regions=False,
additional_regions=["global"],
)
cloudfront_backend = cloudfront_backends["global"]

View File

@ -130,7 +130,7 @@ class Trail(BaseModel):
raise TrailNameInvalidChars() raise TrailNameInvalidChars()
def check_bucket_exists(self): def check_bucket_exists(self):
from moto.s3 import s3_backend from moto.s3.models import s3_backend
try: try:
s3_backend.get_bucket(self.bucket_name) s3_backend.get_bucket(self.bucket_name)
@ -242,8 +242,8 @@ class Trail(BaseModel):
class CloudTrailBackend(BaseBackend): class CloudTrailBackend(BaseBackend):
"""Implementation of CloudTrail APIs.""" """Implementation of CloudTrail APIs."""
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.trails = dict() self.trails = dict()
self.tagging_service = TaggingService(tag_name="TagsList") self.tagging_service = TaggingService(tag_name="TagsList")
@ -313,7 +313,8 @@ class CloudTrailBackend(BaseBackend):
def describe_trails(self, include_shadow_trails): def describe_trails(self, include_shadow_trails):
all_trails = [] all_trails = []
if include_shadow_trails: if include_shadow_trails:
for backend in cloudtrail_backends.values(): current_account = cloudtrail_backends[self.account_id]
for backend in current_account.values():
all_trails.extend(backend.trails.values()) all_trails.extend(backend.trails.values())
else: else:
all_trails.extend(self.trails.values()) all_trails.extend(self.trails.values())
@ -363,12 +364,6 @@ class CloudTrailBackend(BaseBackend):
) )
return trail return trail
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def put_event_selectors( def put_event_selectors(
self, trail_name, event_selectors, advanced_event_selectors self, trail_name, event_selectors, advanced_event_selectors
): ):

View File

@ -305,19 +305,14 @@ class Statistics:
class CloudWatchBackend(BaseBackend): class CloudWatchBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.alarms = {} self.alarms = {}
self.dashboards = {} self.dashboards = {}
self.metric_data = [] self.metric_data = []
self.paged_metric_data = {} self.paged_metric_data = {}
self.tagger = TaggingService() self.tagger = TaggingService()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -32,14 +32,9 @@ class CodeCommit(BaseModel):
class CodeCommitBackend(BaseBackend): class CodeCommitBackend(BaseBackend):
def __init__(self, region=None): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.repositories = {} self.repositories = {}
self.region = region
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -54,7 +49,7 @@ class CodeCommitBackend(BaseBackend):
raise RepositoryNameExistsException(repository_name) raise RepositoryNameExistsException(repository_name)
self.repositories[repository_name] = CodeCommit( self.repositories[repository_name] = CodeCommit(
self.region, repository_description, repository_name self.region_name, repository_description, repository_name
) )
return self.repositories[repository_name].repository_metadata return self.repositories[repository_name].repository_metadata

View File

@ -67,14 +67,9 @@ class CodePipeline(BaseModel):
class CodePipelineBackend(BaseBackend): class CodePipelineBackend(BaseBackend):
def __init__(self, region=None): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.pipelines = {} self.pipelines = {}
self.region = region
def reset(self):
region_name = self.region
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -114,7 +109,7 @@ class CodePipelineBackend(BaseBackend):
"Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline" "Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline"
) )
self.pipelines[pipeline["name"]] = CodePipeline(self.region, pipeline) self.pipelines[pipeline["name"]] = CodePipeline(self.region_name, pipeline)
if tags: if tags:
self.pipelines[pipeline["name"]].validate_tags(tags) self.pipelines[pipeline["name"]].validate_tags(tags)

View File

@ -48,17 +48,11 @@ class CognitoIdentity(BaseModel):
class CognitoIdentityBackend(BaseBackend): class CognitoIdentityBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region = region
self.identity_pools = OrderedDict() self.identity_pools = OrderedDict()
self.pools_identities = {} self.pools_identities = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
def describe_identity_pool(self, identity_pool_id): def describe_identity_pool(self, identity_pool_id):
identity_pool = self.identity_pools.get(identity_pool_id, None) identity_pool = self.identity_pools.get(identity_pool_id, None)
@ -93,7 +87,7 @@ class CognitoIdentityBackend(BaseBackend):
tags=None, tags=None,
): ):
new_identity = CognitoIdentity( new_identity = CognitoIdentity(
self.region, self.region_name,
identity_pool_name, identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers, supported_login_providers=supported_login_providers,
@ -151,7 +145,7 @@ class CognitoIdentityBackend(BaseBackend):
return response return response
def get_id(self, identity_pool_id: str): def get_id(self, identity_pool_id: str):
identity_id = {"IdentityId": get_random_identity_id(self.region)} identity_id = {"IdentityId": get_random_identity_id(self.region_name)}
self.pools_identities[identity_pool_id]["Identities"].append(identity_id) self.pools_identities[identity_pool_id]["Identities"].append(identity_id)
return json.dumps(identity_id) return json.dumps(identity_id)
@ -175,13 +169,19 @@ class CognitoIdentityBackend(BaseBackend):
def get_open_id_token_for_developer_identity(self, identity_id): def get_open_id_token_for_developer_identity(self, identity_id):
response = json.dumps( response = json.dumps(
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)} {
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region_name),
}
) )
return response return response
def get_open_id_token(self, identity_id): def get_open_id_token(self, identity_id):
response = json.dumps( response = json.dumps(
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)} {
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region_name),
}
) )
return response return response

View File

@ -821,21 +821,15 @@ class CognitoResourceServer(BaseModel):
class CognitoIdpBackend(BaseBackend): class CognitoIdpBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region = region
self.user_pools = OrderedDict() self.user_pools = OrderedDict()
self.user_pool_domains = OrderedDict() self.user_pool_domains = OrderedDict()
self.sessions = {} self.sessions = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
# User pool # User pool
def create_user_pool(self, name, extended_config): def create_user_pool(self, name, extended_config):
user_pool = CognitoIdpUserPool(self.region, name, extended_config) user_pool = CognitoIdpUserPool(self.region_name, name, extended_config)
self.user_pools[user_pool.id] = user_pool self.user_pools[user_pool.id] = user_pool
return user_pool return user_pool
@ -1797,28 +1791,30 @@ class CognitoIdpBackend(BaseBackend):
raise NotAuthorizedError(access_token) raise NotAuthorizedError(access_token)
class GlobalCognitoIdpBackend(CognitoIdpBackend): class RegionAgnosticBackend:
# Some operations are unauthenticated # Some operations are unauthenticated
# Without authentication-header, we lose the context of which region the request was send to # Without authentication-header, we lose the context of which region the request was send to
# This backend will cycle through all backends as a workaround # This backend will cycle through all backends as a workaround
def _find_backend_by_access_token(self, access_token): def _find_backend_by_access_token(self, access_token):
for region, backend in cognitoidp_backends.items(): account_specific_backends = cognitoidp_backends[get_account_id()]
for region, backend in account_specific_backends.items():
if region == "global": if region == "global":
continue continue
for p in backend.user_pools.values(): for p in backend.user_pools.values():
if access_token in p.access_tokens: if access_token in p.access_tokens:
return backend return backend
return cognitoidp_backends["us-east-1"] return account_specific_backends["us-east-1"]
def _find_backend_for_clientid(self, client_id): def _find_backend_for_clientid(self, client_id):
for region, backend in cognitoidp_backends.items(): account_specific_backends = cognitoidp_backends[get_account_id()]
for region, backend in account_specific_backends.items():
if region == "global": if region == "global":
continue continue
for p in backend.user_pools.values(): for p in backend.user_pools.values():
if client_id in p.clients: if client_id in p.clients:
return backend return backend
return cognitoidp_backends["us-east-1"] return account_specific_backends["us-east-1"]
def sign_up(self, client_id, username, password, attributes): def sign_up(self, client_id, username, password, attributes):
backend = self._find_backend_for_clientid(client_id) backend = self._find_backend_for_clientid(client_id)
@ -1846,14 +1842,14 @@ class GlobalCognitoIdpBackend(CognitoIdpBackend):
cognitoidp_backends = BackendDict(CognitoIdpBackend, "cognito-idp") cognitoidp_backends = BackendDict(CognitoIdpBackend, "cognito-idp")
cognitoidp_backends["global"] = GlobalCognitoIdpBackend("global")
# Hack to help moto-server process requests on localhost, where the region isn't # Hack to help moto-server process requests on localhost, where the region isn't
# specified in the host header. Some endpoints (change password, confirm forgot # specified in the host header. Some endpoints (change password, confirm forgot
# password) have no authorization header from which to extract the region. # password) have no authorization header from which to extract the region.
def find_region_by_value(key, value): def find_region_by_value(key, value):
for region in cognitoidp_backends: account_specific_backends = cognitoidp_backends[get_account_id()]
for region in account_specific_backends:
backend = cognitoidp_backends[region] backend = cognitoidp_backends[region]
for user_pool in backend.user_pools.values(): for user_pool in backend.user_pools.values():
if key == "client_id" and value in user_pool.clients: if key == "client_id" and value in user_pool.clients:
@ -1864,4 +1860,4 @@ def find_region_by_value(key, value):
# If we can't find the `client_id` or `access_token`, we just pass # If we can't find the `client_id` or `access_token`, we just pass
# back a default backend region, which will raise the appropriate # back a default backend region, which will raise the appropriate
# error message (e.g. NotAuthorized or NotFound). # error message (e.g. NotAuthorized or NotFound).
return list(cognitoidp_backends)[0] return list(account_specific_backends)[0]

View File

@ -3,10 +3,18 @@ import os
import re import re
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cognitoidp_backends, find_region_by_value, UserStatus from .models import (
cognitoidp_backends,
find_region_by_value,
RegionAgnosticBackend,
UserStatus,
)
from .exceptions import InvalidParameterException from .exceptions import InvalidParameterException
region_agnostic_backend = RegionAgnosticBackend()
class CognitoIdpResponse(BaseResponse): class CognitoIdpResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
@ -346,7 +354,7 @@ class CognitoIdpResponse(BaseResponse):
def get_user(self): def get_user(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
user = cognitoidp_backends["global"].get_user(access_token=access_token) user = region_agnostic_backend.get_user(access_token=access_token)
return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes")) return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
def list_users(self): def list_users(self):
@ -444,7 +452,7 @@ class CognitoIdpResponse(BaseResponse):
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
challenge_name = self._get_param("ChallengeName") challenge_name = self._get_param("ChallengeName")
challenge_responses = self._get_param("ChallengeResponses") challenge_responses = self._get_param("ChallengeResponses")
auth_result = cognitoidp_backends["global"].respond_to_auth_challenge( auth_result = region_agnostic_backend.respond_to_auth_challenge(
session, client_id, challenge_name, challenge_responses session, client_id, challenge_name, challenge_responses
) )
@ -454,6 +462,7 @@ class CognitoIdpResponse(BaseResponse):
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
username = self._get_param("Username") username = self._get_param("Username")
region = find_region_by_value("client_id", client_id) region = find_region_by_value("client_id", client_id)
print(f"Region: {region}")
confirmation_code, response = cognitoidp_backends[region].forgot_password( confirmation_code, response = cognitoidp_backends[region].forgot_password(
client_id, username client_id, username
) )
@ -534,7 +543,7 @@ class CognitoIdpResponse(BaseResponse):
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
username = self._get_param("Username") username = self._get_param("Username")
password = self._get_param("Password") password = self._get_param("Password")
user = cognitoidp_backends["global"].sign_up( user = region_agnostic_backend.sign_up(
client_id=client_id, client_id=client_id,
username=username, username=username,
password=password, password=password,
@ -550,9 +559,7 @@ class CognitoIdpResponse(BaseResponse):
def confirm_sign_up(self): def confirm_sign_up(self):
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends["global"].confirm_sign_up( region_agnostic_backend.confirm_sign_up(client_id=client_id, username=username)
client_id=client_id, username=username
)
return "" return ""
def initiate_auth(self): def initiate_auth(self):
@ -560,7 +567,7 @@ class CognitoIdpResponse(BaseResponse):
auth_flow = self._get_param("AuthFlow") auth_flow = self._get_param("AuthFlow")
auth_parameters = self._get_param("AuthParameters") auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends["global"].initiate_auth( auth_result = region_agnostic_backend.initiate_auth(
client_id, auth_flow, auth_parameters client_id, auth_flow, auth_parameters
) )

View File

@ -848,7 +848,8 @@ class ConfigRule(ConfigEmptyDictable):
class ConfigBackend(BaseBackend): class ConfigBackend(BaseBackend):
def __init__(self, region=None): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.recorders = {} self.recorders = {}
self.delivery_channels = {} self.delivery_channels = {}
self.config_aggregators = {} self.config_aggregators = {}
@ -856,12 +857,6 @@ class ConfigBackend(BaseBackend):
self.organization_conformance_packs = {} self.organization_conformance_packs = {}
self.config_rules = {} self.config_rules = {}
self.config_schema = None self.config_schema = None
self.region = region
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -974,7 +969,7 @@ class ConfigBackend(BaseBackend):
): ):
aggregator = ConfigAggregator( aggregator = ConfigAggregator(
config_aggregator["ConfigurationAggregatorName"], config_aggregator["ConfigurationAggregatorName"],
self.region, self.region_name,
account_sources=account_sources, account_sources=account_sources,
org_source=org_source, org_source=org_source,
tags=tags, tags=tags,
@ -1054,7 +1049,7 @@ class ConfigBackend(BaseBackend):
agg_auth = self.aggregation_authorizations.get(key) agg_auth = self.aggregation_authorizations.get(key)
if not agg_auth: if not agg_auth:
agg_auth = ConfigAggregationAuthorization( agg_auth = ConfigAggregationAuthorization(
self.region, authorized_account, authorized_region, tags=tags self.region_name, authorized_account, authorized_region, tags=tags
) )
self.aggregation_authorizations[ self.aggregation_authorizations[
"{}/{}".format(authorized_account, authorized_region) "{}/{}".format(authorized_account, authorized_region)
@ -1473,11 +1468,14 @@ class ConfigBackend(BaseBackend):
backend_query_region = ( backend_query_region = (
backend_region # Always provide the backend this request arrived from. backend_region # Always provide the backend this request arrived from.
) )
print(RESOURCE_MAP[resource_type].backends)
if RESOURCE_MAP[resource_type].backends.get("global"): if RESOURCE_MAP[resource_type].backends.get("global"):
print("yes, its global")
backend_region = "global" backend_region = "global"
# If the backend region isn't implemented then we won't find the item: # If the backend region isn't implemented then we won't find the item:
if not RESOURCE_MAP[resource_type].backends.get(backend_region): if not RESOURCE_MAP[resource_type].backends.get(backend_region):
print(f"cant find {backend_region} for {resource_type}")
raise ResourceNotDiscoveredException(resource_type, resource_id) raise ResourceNotDiscoveredException(resource_type, resource_id)
# Get the item: # Get the item:
@ -1485,6 +1483,7 @@ class ConfigBackend(BaseBackend):
resource_id, backend_region=backend_query_region resource_id, backend_region=backend_query_region
) )
if not item: if not item:
print("item not found")
raise ResourceNotDiscoveredException(resource_type, resource_id) raise ResourceNotDiscoveredException(resource_type, resource_id)
item["accountId"] = get_account_id() item["accountId"] = get_account_id()
@ -1655,7 +1654,7 @@ class ConfigBackend(BaseBackend):
) )
else: else:
pack = OrganizationConformancePack( pack = OrganizationConformancePack(
region=self.region, region=self.region_name,
name=name, name=name,
delivery_s3_bucket=delivery_s3_bucket, delivery_s3_bucket=delivery_s3_bucket,
delivery_s3_key_prefix=delivery_s3_key_prefix, delivery_s3_key_prefix=delivery_s3_key_prefix,
@ -1875,14 +1874,14 @@ class ConfigBackend(BaseBackend):
) )
# Update the current rule. # Update the current rule.
rule.modify_fields(self.region, config_rule, tags) rule.modify_fields(self.region_name, config_rule, tags)
else: else:
# Create a new ConfigRule if the limit hasn't been reached. # Create a new ConfigRule if the limit hasn't been reached.
if len(self.config_rules) == ConfigRule.MAX_RULES: if len(self.config_rules) == ConfigRule.MAX_RULES:
raise MaxNumberOfConfigRulesExceededException( raise MaxNumberOfConfigRulesExceededException(
rule_name, ConfigRule.MAX_RULES rule_name, ConfigRule.MAX_RULES
) )
rule = ConfigRule(self.region, config_rule, tags) rule = ConfigRule(self.region_name, config_rule, tags)
self.config_rules[rule_name] = rule self.config_rules[rule_name] = rule
return "" return ""

View File

@ -22,6 +22,10 @@ class InstanceTrackerMeta(type):
class BaseBackend: class BaseBackend:
def __init__(self, region_name, account_id=None):
self.region_name = region_name
self.account_id = account_id
def _reset_model_refs(self): def _reset_model_refs(self):
# Remove all references to the models stored # Remove all references to the models stored
for models in model_data.values(): for models in model_data.values():
@ -29,9 +33,11 @@ class BaseBackend:
model.instances = [] model.instances = []
def reset(self): def reset(self):
region_name = self.region_name
account_id = self.account_id
self._reset_model_refs() self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__() self.__init__(region_name, account_id)
@property @property
def _url_module(self): def _url_module(self):

View File

@ -53,13 +53,11 @@ class BaseMockAWS:
"moto_api": moto_api_backend, "moto_api": moto_api_backend,
} }
if "us-east-1" in self.backends: if "us-east-1" in self.backends:
# We only need to know the URL for a single region # We only need to know the URL for a single region - they will be the same everywhere
# They will be the same everywhere
self.backends_for_urls["us-east-1"] = self.backends["us-east-1"] self.backends_for_urls["us-east-1"] = self.backends["us-east-1"]
else: elif "global" in self.backends:
# If us-east-1 is not available, it's probably a global service # If us-east-1 is not available, it's probably a global service
# Global services will only have a single region anyway self.backends_for_urls["global"] = self.backends["global"]
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends) self.backends_for_urls.update(default_backends)
self.FAKE_KEYS = { self.FAKE_KEYS = {

View File

@ -9,6 +9,7 @@ import string
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from boto3 import Session from boto3 import Session
from moto.settings import allow_unknown_region from moto.settings import allow_unknown_region
from threading import RLock
from urllib.parse import urlparse from urllib.parse import urlparse
@ -408,28 +409,126 @@ def extract_region_from_aws_authorization(string):
return region return region
backend_lock = RLock()
class BackendDict(dict): class BackendDict(dict):
def __init__(self, fn, service_name): """
self.fn = fn Data Structure to store everything related to a specific service.
Format:
[account_id: str]: AccountSpecificBackend
[account_id: str][region: str] = BaseBackend
Full multi-account support is not yet available. We will always return account_id 123456789012, regardless of the input.
To not break existing usage patterns, the following data access pattern is also supported:
[region: str] = BaseBackend
This will automatically resolve to:
[default_account_id][region: str] = BaseBackend
"""
def __init__(
self, backend, service_name, use_boto3_regions=True, additional_regions=None
):
self.backend = backend
self.service_name = service_name self.service_name = service_name
sess = Session() self._use_boto3_regions = use_boto3_regions
self.regions = list(sess.get_available_regions(service_name)) self._additional_regions = additional_regions
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
def __contains__(self, item): def __contains__(self, account_id_or_region):
return item in self.regions or item in self.keys() """
Possible data access patterns:
backend_dict[account_id][region_name]
backend_dict[region_name]
backend_dict[unknown_region]
def __getitem__(self, item): The latter two will be phased out in the future, and we can remove this method.
if item in self.keys(): """
return super().__getitem__(item) if re.match(r"[0-9]+", account_id_or_region):
self._create_account_specific_backend("123456789012")
return True
else:
region = account_id_or_region
self._create_account_specific_backend("123456789012")
return region in self["123456789012"]
def get(self, account_id_or_region, if_none=None):
if self.__contains__(account_id_or_region):
return self.__getitem__(account_id_or_region)
return if_none
def __getitem__(self, account_id_or_region):
"""
Possible data access patterns:
backend_dict[account_id][region_name]
backend_dict[region_name]
backend_dict[unknown_region]
The latter two will be phased out in the future.
"""
if re.match(r"[0-9]+", account_id_or_region):
self._create_account_specific_backend("123456789012")
return super().__getitem__("123456789012")
else:
region_name = account_id_or_region
return self["123456789012"][region_name]
def _create_account_specific_backend(self, account_id):
with backend_lock:
if account_id not in self.keys():
self[account_id] = AccountSpecificBackend(
service_name=self.service_name,
account_id=account_id,
backend=self.backend,
use_boto3_regions=self._use_boto3_regions,
additional_regions=self._additional_regions,
)
class AccountSpecificBackend(dict):
"""
Dictionary storing the data for a service in a specific account.
Data access pattern:
account_specific_backend[region: str] = backend: BaseBackend
"""
def __init__(
self, service_name, account_id, backend, use_boto3_regions, additional_regions
):
self.service_name = service_name
self.account_id = account_id
self.backend = backend
self.regions = []
if use_boto3_regions:
sess = Session()
self.regions.extend(sess.get_available_regions(service_name))
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
self.regions.extend(additional_regions or [])
def reset(self):
for region_specific_backend in self.values():
region_specific_backend.reset()
def __contains__(self, region):
return region in self.regions or region in self.keys()
def __getitem__(self, region_name):
if region_name in self.keys():
return super().__getitem__(region_name)
# Create the backend for a specific region # Create the backend for a specific region
if item in self.regions and item not in self.keys(): with backend_lock:
super().__setitem__(item, self.fn(item)) if region_name in self.regions and region_name not in self.keys():
if item not in self.regions and allow_unknown_region(): super().__setitem__(
super().__setitem__(item, self.fn(item)) region_name, self.backend(region_name, account_id=self.account_id)
return super().__getitem__(item) )
if region_name not in self.regions and allow_unknown_region():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
return super().__getitem__(region_name)

View File

@ -45,17 +45,12 @@ class DataBrewBackend(BaseBackend):
}, },
} }
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.recipes = OrderedDict() self.recipes = OrderedDict()
self.rulesets = OrderedDict() self.rulesets = OrderedDict()
self.datasets = OrderedDict() self.datasets = OrderedDict()
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__init__(region_name)
@staticmethod @staticmethod
def validate_length(param, param_name, max_length): def validate_length(param, param_name, max_length):
if len(param) > max_length: if len(param) > max_length:

View File

@ -102,9 +102,9 @@ class Pipeline(CloudFormationModel):
class DataPipelineBackend(BaseBackend): class DataPipelineBackend(BaseBackend):
def __init__(self, region=None): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.pipelines = OrderedDict() self.pipelines = OrderedDict()
self.region = region
def create_pipeline(self, name, unique_id, **kwargs): def create_pipeline(self, name, unique_id, **kwargs):
pipeline = Pipeline(name, unique_id, **kwargs) pipeline = Pipeline(name, unique_id, **kwargs)

View File

@ -96,8 +96,8 @@ class TaskExecution(BaseModel):
class DataSyncBackend(BaseBackend): class DataSyncBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
# Always increase when new things are created # Always increase when new things are created
# This ensures uniqueness # This ensures uniqueness
self.arn_counter = 0 self.arn_counter = 0
@ -105,12 +105,6 @@ class DataSyncBackend(BaseBackend):
self.tasks = OrderedDict() self.tasks = OrderedDict()
self.task_executions = OrderedDict() self.task_executions = OrderedDict()
def reset(self):
region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -153,8 +153,8 @@ class DaxCluster(BaseModel, ManagedState):
class DAXBackend(BaseBackend): class DAXBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self._clusters = dict() self._clusters = dict()
self._tagger = TaggingService() self._tagger = TaggingService()
@ -171,11 +171,6 @@ class DAXBackend(BaseBackend):
} }
return self._clusters return self._clusters
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_cluster( def create_cluster(
self, self,
cluster_name, cluster_name,

View File

@ -13,16 +13,10 @@ from .utils import filter_tasks
class DatabaseMigrationServiceBackend(BaseBackend): class DatabaseMigrationServiceBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.replication_tasks = {} self.replication_tasks = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -180,17 +180,11 @@ class Directory(BaseModel): # pylint: disable=too-many-instance-attributes
class DirectoryServiceBackend(BaseBackend): class DirectoryServiceBackend(BaseBackend):
"""Implementation of DirectoryService APIs.""" """Implementation of DirectoryService APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.directories = {} self.directories = {}
self.tagger = TaggingService() self.tagger = TaggingService()
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""List of dicts representing default VPC endpoints for this service.""" """List of dicts representing default VPC endpoints for this service."""
@ -508,4 +502,4 @@ class DirectoryServiceBackend(BaseBackend):
return self.tagger.list_tags_for_resource(resource_id).get("Tags") return self.tagger.list_tags_for_resource(resource_id).get("Tags")
ds_backends = BackendDict(fn=DirectoryServiceBackend, service_name="ds") ds_backends = BackendDict(DirectoryServiceBackend, service_name="ds")

View File

@ -1176,16 +1176,11 @@ class Backup(object):
class DynamoDBBackend(BaseBackend): class DynamoDBBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.tables = OrderedDict() self.tables = OrderedDict()
self.backups = OrderedDict() self.backups = OrderedDict()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -1,4 +1,4 @@
from .models import dynamodb_backend from .models import dynamodb_backends
from ..core.models import base_decorator from ..core.models import base_decorator
""" """
@ -6,5 +6,4 @@ An older API version of DynamoDB.
Please see the corresponding tests (tests/test_dynamodb_v20111205) on how to invoke this API. Please see the corresponding tests (tests/test_dynamodb_v20111205) on how to invoke this API.
""" """
dynamodb_backends = {"global": dynamodb_backend} mock_dynamodb = base_decorator(dynamodb_backends)
mock_dynamodb = base_decorator(dynamodb_backend)

View File

@ -4,7 +4,7 @@ import json
from collections import OrderedDict from collections import OrderedDict
from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.utils import unix_time from moto.core.utils import unix_time, BackendDict
from moto.core import get_account_id from moto.core import get_account_id
from .comparisons import get_comparison_func from .comparisons import get_comparison_func
@ -313,7 +313,8 @@ class Table(CloudFormationModel):
class DynamoDBBackend(BaseBackend): class DynamoDBBackend(BaseBackend):
def __init__(self): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.tables = OrderedDict() self.tables = OrderedDict()
def create_table(self, name, **params): def create_table(self, name, **params):
@ -390,4 +391,10 @@ class DynamoDBBackend(BaseBackend):
return table.update_item(hash_key, range_key, attr_updates) return table.update_item(hash_key, range_key, attr_updates)
dynamodb_backend = DynamoDBBackend() dynamodb_backends = BackendDict(
DynamoDBBackend,
"dynamodb_v20111205",
use_boto3_regions=False,
additional_regions=["global"],
)
dynamodb_backend = dynamodb_backends["global"]

View File

@ -64,18 +64,13 @@ class ShardIterator(BaseModel):
class DynamoDBStreamsBackend(BaseBackend): class DynamoDBStreamsBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
self.region = region super().__init__(region_name, account_id)
self.shard_iterators = {} self.shard_iterators = {}
def reset(self):
region = self.region
self.__dict__ = {}
self.__init__(region)
@property @property
def dynamodb(self): def dynamodb(self):
return dynamodb_backends[self.region] return dynamodb_backends[self.region_name]
def _get_table_from_arn(self, arn): def _get_table_from_arn(self, arn):
table_name = arn.split(":", 6)[5].split("/")[1] table_name = arn.split(":", 6)[5].split("/")[1]

View File

@ -52,20 +52,14 @@ class EBSSnapshot(BaseModel):
class EBSBackend(BaseBackend): class EBSBackend(BaseBackend):
"""Implementation of EBS APIs.""" """Implementation of EBS APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.snapshots = dict() self.snapshots = dict()
@property @property
def ec2_backend(self): def ec2_backend(self):
return ec2_backends[self.region_name] return ec2_backends[self.region_name]
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def start_snapshot(self, volume_size, tags, description): def start_snapshot(self, volume_size, tags, description):
zone_name = f"{self.region_name}a" zone_name = f"{self.region_name}a"
vol = self.ec2_backend.create_volume(size=volume_size, zone_name=zone_name) vol = self.ec2_backend.create_volume(size=volume_size, zone_name=zone_name)

View File

@ -66,10 +66,9 @@ def validate_resource_ids(resource_ids):
return True return True
class SettingsBackend(object): class SettingsBackend:
def __init__(self): def __init__(self):
self.ebs_encryption_by_default = False self.ebs_encryption_by_default = False
super().__init__()
def disable_ebs_encryption_by_default(self): def disable_ebs_encryption_by_default(self):
ec2_backend = ec2_backends[self.region_name] ec2_backend = ec2_backends[self.region_name]
@ -140,9 +139,11 @@ class EC2Backend(
""" """
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name BaseBackend.__init__(self, region_name, account_id)
super().__init__() for backend in EC2Backend.__mro__:
if backend not in [EC2Backend, BaseBackend, object]:
backend.__init__(self)
# Default VPC exists by default, which is the current behavior # Default VPC exists by default, which is the current behavior
# of EC2-VPC. See for detail: # of EC2-VPC. See for detail:
@ -169,11 +170,6 @@ class EC2Backend(
self.create_subnet(vpc.id, cidr_block, availability_zone=az_name) self.create_subnet(vpc.id, cidr_block, availability_zone=az_name)
ip[2] += 16 ip[2] += 16
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -142,14 +142,13 @@ class Ami(TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeImages") return super().get_filter_value(filter_name, "DescribeImages")
class AmiBackend(object): class AmiBackend:
AMI_REGEX = re.compile("ami-[a-z0-9]+") AMI_REGEX = re.compile("ami-[a-z0-9]+")
def __init__(self): def __init__(self):
self.amis = {} self.amis = {}
self.deleted_amis = list() self.deleted_amis = list()
self._load_amis() self._load_amis()
super().__init__()
def _load_amis(self): def _load_amis(self):
for ami in AMIS: for ami in AMIS:

View File

@ -15,7 +15,7 @@ class Zone(object):
self.zone_id = zone_id self.zone_id = zone_id
class RegionsAndZonesBackend(object): class RegionsAndZonesBackend:
regions_opt_in_not_required = [ regions_opt_in_not_required = [
"af-south-1", "af-south-1",
"ap-northeast-1", "ap-northeast-1",

View File

@ -23,10 +23,9 @@ class CarrierGateway(TaggedEC2Resource):
return get_account_id() return get_account_id()
class CarrierGatewayBackend(object): class CarrierGatewayBackend:
def __init__(self): def __init__(self):
self.carrier_gateways = {} self.carrier_gateways = {}
super().__init__()
def create_carrier_gateway(self, vpc_id, tags=None): def create_carrier_gateway(self, vpc_id, tags=None):
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)

View File

@ -28,10 +28,9 @@ class CustomerGateway(TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeCustomerGateways") return super().get_filter_value(filter_name, "DescribeCustomerGateways")
class CustomerGatewayBackend(object): class CustomerGatewayBackend:
def __init__(self): def __init__(self):
self.customer_gateways = {} self.customer_gateways = {}
super().__init__()
def create_customer_gateway( def create_customer_gateway(
self, gateway_type="ipsec.1", ip_address=None, bgp_asn=None, tags=None self, gateway_type="ipsec.1", ip_address=None, bgp_asn=None, tags=None

View File

@ -58,10 +58,9 @@ class DHCPOptionsSet(TaggedEC2Resource):
return self._options return self._options
class DHCPOptionsSetBackend(object): class DHCPOptionsSetBackend:
def __init__(self): def __init__(self):
self.dhcp_options_sets = {} self.dhcp_options_sets = {}
super().__init__()
def associate_dhcp_options(self, dhcp_options, vpc): def associate_dhcp_options(self, dhcp_options, vpc):
dhcp_options.vpc = vpc dhcp_options.vpc = vpc

View File

@ -187,12 +187,11 @@ class Snapshot(TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeSnapshots") return super().get_filter_value(filter_name, "DescribeSnapshots")
class EBSBackend(object): class EBSBackend:
def __init__(self): def __init__(self):
self.volumes = {} self.volumes = {}
self.attachments = {} self.attachments = {}
self.snapshots = {} self.snapshots = {}
super().__init__()
def create_volume( def create_volume(
self, self,

View File

@ -107,10 +107,9 @@ class ElasticAddress(TaggedEC2Resource, CloudFormationModel):
return super().get_filter_value(filter_name, "DescribeAddresses") return super().get_filter_value(filter_name, "DescribeAddresses")
class ElasticAddressBackend(object): class ElasticAddressBackend:
def __init__(self): def __init__(self):
self.addresses = [] self.addresses = []
super().__init__()
def allocate_address(self, domain, address=None, tags=None): def allocate_address(self, domain, address=None, tags=None):
if domain not in ["standard", "vpc"]: if domain not in ["standard", "vpc"]:

View File

@ -249,10 +249,9 @@ class NetworkInterface(TaggedEC2Resource, CloudFormationModel):
return super().get_filter_value(filter_name, "DescribeNetworkInterfaces") return super().get_filter_value(filter_name, "DescribeNetworkInterfaces")
class NetworkInterfaceBackend(object): class NetworkInterfaceBackend:
def __init__(self): def __init__(self):
self.enis = {} self.enis = {}
super().__init__()
def create_network_interface( def create_network_interface(
self, self,

View File

@ -128,10 +128,9 @@ class FlowLogs(TaggedEC2Resource, CloudFormationModel):
return super().get_filter_value(filter_name, "DescribeFlowLogs") return super().get_filter_value(filter_name, "DescribeFlowLogs")
class FlowLogsBackend(object): class FlowLogsBackend:
def __init__(self): def __init__(self):
self.flow_logs = defaultdict(dict) self.flow_logs = defaultdict(dict)
super().__init__()
def _validate_request( def _validate_request(
self, self,

View File

@ -22,10 +22,9 @@ class IamInstanceProfileAssociation(CloudFormationModel):
self.state = "associated" self.state = "associated"
class IamInstanceProfileAssociationBackend(object): class IamInstanceProfileAssociationBackend:
def __init__(self): def __init__(self):
self.iam_instance_profile_associations = {} self.iam_instance_profile_associations = {}
super().__init__()
def associate_iam_instance_profile( def associate_iam_instance_profile(
self, instance_id, iam_instance_profile_name=None, iam_instance_profile_arn=None self, instance_id, iam_instance_profile_name=None, iam_instance_profile_arn=None

View File

@ -21,10 +21,7 @@ for location_type in listdir(root / offerings_path):
INSTANCE_TYPE_OFFERINGS[location_type][_region.replace(".json", "")] = res INSTANCE_TYPE_OFFERINGS[location_type][_region.replace(".json", "")] = res
class InstanceTypeBackend(object): class InstanceTypeBackend:
def __init__(self):
super().__init__()
def describe_instance_types(self, instance_types=None): def describe_instance_types(self, instance_types=None):
matches = INSTANCE_TYPES.values() matches = INSTANCE_TYPES.values()
if instance_types: if instance_types:
@ -37,10 +34,7 @@ class InstanceTypeBackend(object):
return matches return matches
class InstanceTypeOfferingBackend(object): class InstanceTypeOfferingBackend:
def __init__(self):
super().__init__()
def describe_instance_type_offerings(self, location_type=None, filters=None): def describe_instance_type_offerings(self, location_type=None, filters=None):
location_type = location_type or "region" location_type = location_type or "region"
matches = INSTANCE_TYPE_OFFERINGS[location_type] matches = INSTANCE_TYPE_OFFERINGS[location_type]

View File

@ -539,10 +539,9 @@ class Instance(TaggedEC2Resource, BotoInstance, CloudFormationModel):
return True return True
class InstanceBackend(object): class InstanceBackend:
def __init__(self): def __init__(self):
self.reservations = OrderedDict() self.reservations = OrderedDict()
super().__init__()
def get_instance(self, instance_id): def get_instance(self, instance_id):
for instance in self.all_instances(): for instance in self.all_instances():

View File

@ -30,10 +30,9 @@ class EgressOnlyInternetGateway(TaggedEC2Resource):
return self.id return self.id
class EgressOnlyInternetGatewayBackend(object): class EgressOnlyInternetGatewayBackend:
def __init__(self): def __init__(self):
self.egress_only_internet_gateway_backend = {} self.egress_only_internet_gateway_backend = {}
super().__init__()
def create_egress_only_internet_gateway(self, vpc_id, tags=None): def create_egress_only_internet_gateway(self, vpc_id, tags=None):
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
@ -113,10 +112,9 @@ class InternetGateway(TaggedEC2Resource, CloudFormationModel):
return "detached" return "detached"
class InternetGatewayBackend(object): class InternetGatewayBackend:
def __init__(self): def __init__(self):
self.internet_gateways = {} self.internet_gateways = {}
super().__init__()
def create_internet_gateway(self, tags=None): def create_internet_gateway(self, tags=None):
igw = InternetGateway(self) igw = InternetGateway(self)

View File

@ -28,10 +28,9 @@ class KeyPair(BaseModel):
raise FilterNotImplementedError(filter_name, "DescribeKeyPairs") raise FilterNotImplementedError(filter_name, "DescribeKeyPairs")
class KeyPairBackend(object): class KeyPairBackend:
def __init__(self): def __init__(self):
self.keypairs = {} self.keypairs = {}
super().__init__()
def create_key_pair(self, name): def create_key_pair(self, name):
if name in self.keypairs: if name in self.keypairs:

View File

@ -73,12 +73,11 @@ class LaunchTemplate(TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeLaunchTemplates") return super().get_filter_value(filter_name, "DescribeLaunchTemplates")
class LaunchTemplateBackend(object): class LaunchTemplateBackend:
def __init__(self): def __init__(self):
self.launch_template_name_to_ids = {} self.launch_template_name_to_ids = {}
self.launch_templates = OrderedDict() self.launch_templates = OrderedDict()
self.launch_template_insert_order = [] self.launch_template_insert_order = []
super().__init__()
def create_launch_template(self, name, description, template_data): def create_launch_template(self, name, description, template_data):
if name in self.launch_template_name_to_ids: if name in self.launch_template_name_to_ids:

View File

@ -42,11 +42,10 @@ class ManagedPrefixList(TaggedEC2Resource):
) )
class ManagedPrefixListBackend(object): class ManagedPrefixListBackend:
def __init__(self): def __init__(self):
self.managed_prefix_lists = {} self.managed_prefix_lists = {}
self.create_default_pls() self.create_default_pls()
super().__init__()
def create_managed_prefix_list( def create_managed_prefix_list(
self, self,

View File

@ -81,10 +81,9 @@ class NatGateway(CloudFormationModel, TaggedEC2Resource):
return nat_gateway return nat_gateway
class NatGatewayBackend(object): class NatGatewayBackend:
def __init__(self): def __init__(self):
self.nat_gateways = {} self.nat_gateways = {}
super().__init__()
def describe_nat_gateways(self, filters, nat_gateway_ids): def describe_nat_gateways(self, filters, nat_gateway_ids):
nat_gateways = list(self.nat_gateways.values()) nat_gateways = list(self.nat_gateways.values())

View File

@ -15,10 +15,9 @@ from ..utils import (
OWNER_ID = get_account_id() OWNER_ID = get_account_id()
class NetworkAclBackend(object): class NetworkAclBackend:
def __init__(self): def __init__(self):
self.network_acls = {} self.network_acls = {}
super().__init__()
def get_network_acl(self, network_acl_id): def get_network_acl(self, network_acl_id):
network_acl = self.network_acls.get(network_acl_id, None) network_acl = self.network_acls.get(network_acl_id, None)

View File

@ -86,10 +86,9 @@ class RouteTable(TaggedEC2Resource, CloudFormationModel):
return super().get_filter_value(filter_name, "DescribeRouteTables") return super().get_filter_value(filter_name, "DescribeRouteTables")
class RouteTableBackend(object): class RouteTableBackend:
def __init__(self): def __init__(self):
self.route_tables = {} self.route_tables = {}
super().__init__()
def create_route_table(self, vpc_id, tags=None, main=False): def create_route_table(self, vpc_id, tags=None, main=False):
route_table_id = random_route_table_id() route_table_id = random_route_table_id()
@ -284,10 +283,7 @@ class Route(CloudFormationModel):
return route_table return route_table
class RouteBackend(object): class RouteBackend:
def __init__(self):
super().__init__()
def create_route( def create_route(
self, self,
route_table_id, route_table_id,

View File

@ -454,7 +454,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
) )
class SecurityGroupBackend(object): class SecurityGroupBackend:
def __init__(self): def __init__(self):
# the key in the dict group is the vpc_id or None (non-vpc) # the key in the dict group is the vpc_id or None (non-vpc)
self.groups = defaultdict(dict) self.groups = defaultdict(dict)
@ -462,8 +462,6 @@ class SecurityGroupBackend(object):
self.sg_old_ingress_ruls = {} self.sg_old_ingress_ruls = {}
self.sg_old_egress_ruls = {} self.sg_old_egress_ruls = {}
super().__init__()
def create_security_group( def create_security_group(
self, name, description, vpc_id=None, tags=None, force=False, is_default=None self, name, description, vpc_id=None, tags=None, force=False, is_default=None
): ):

View File

@ -117,10 +117,9 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
return instance return instance
class SpotRequestBackend(object): class SpotRequestBackend:
def __init__(self): def __init__(self):
self.spot_instance_requests = {} self.spot_instance_requests = {}
super().__init__()
def request_spot_instances( def request_spot_instances(
self, self,
@ -411,10 +410,9 @@ class SpotFleetRequest(TaggedEC2Resource, CloudFormationModel):
self.ec2_backend.terminate_instances(instance_ids) self.ec2_backend.terminate_instances(instance_ids)
class SpotFleetBackend(object): class SpotFleetBackend:
def __init__(self): def __init__(self):
self.spot_fleet_requests = {} self.spot_fleet_requests = {}
super().__init__()
def request_spot_fleet( def request_spot_fleet(
self, self,
@ -485,7 +483,7 @@ class SpotFleetBackend(object):
return True return True
class SpotPriceBackend(object): class SpotPriceBackend:
def describe_spot_price_history(self, instance_types=None, filters=None): def describe_spot_price_history(self, instance_types=None, filters=None):
matches = INSTANCE_TYPE_OFFERINGS["availability-zone"] matches = INSTANCE_TYPE_OFFERINGS["availability-zone"]
matches = matches.get(self.region_name, []) matches = matches.get(self.region_name, [])

View File

@ -224,11 +224,10 @@ class Subnet(TaggedEC2Resource, CloudFormationModel):
return association return association
class SubnetBackend(object): class SubnetBackend:
def __init__(self): def __init__(self):
# maps availability zone to dict of (subnet_id, subnet) # maps availability zone to dict of (subnet_id, subnet)
self.subnets = defaultdict(dict) self.subnets = defaultdict(dict)
super().__init__()
def get_subnet(self, subnet_id): def get_subnet(self, subnet_id):
for subnets in self.subnets.values(): for subnets in self.subnets.values():
@ -432,10 +431,9 @@ class SubnetRouteTableAssociation(CloudFormationModel):
return subnet_association return subnet_association
class SubnetRouteTableAssociationBackend(object): class SubnetRouteTableAssociationBackend:
def __init__(self): def __init__(self):
self.subnet_associations = {} self.subnet_associations = {}
super().__init__()
def create_subnet_association(self, route_table_id, subnet_id): def create_subnet_association(self, route_table_id, subnet_id):
subnet_association = SubnetRouteTableAssociation(route_table_id, subnet_id) subnet_association = SubnetRouteTableAssociation(route_table_id, subnet_id)

View File

@ -12,12 +12,11 @@ from ..utils import (
) )
class TagBackend(object): class TagBackend:
VALID_TAG_FILTERS = ["key", "resource-id", "resource-type", "value"] VALID_TAG_FILTERS = ["key", "resource-id", "resource-type", "value"]
def __init__(self): def __init__(self):
self.tags = defaultdict(dict) self.tags = defaultdict(dict)
super().__init__()
def create_tags(self, resource_ids, tags): def create_tags(self, resource_ids, tags):
if None in set([tags[tag] for tag in tags]): if None in set([tags[tag] for tag in tags]):

View File

@ -72,10 +72,9 @@ class TransitGateway(TaggedEC2Resource, CloudFormationModel):
return transit_gateway return transit_gateway
class TransitGatewayBackend(object): class TransitGatewayBackend:
def __init__(self): def __init__(self):
self.transit_gateways = {} self.transit_gateways = {}
super().__init__()
def create_transit_gateway(self, description=None, options=None, tags=None): def create_transit_gateway(self, description=None, options=None, tags=None):
transit_gateway = TransitGateway(self, description, options) transit_gateway = TransitGateway(self, description, options)

View File

@ -98,10 +98,9 @@ class TransitGatewayPeeringAttachment(TransitGatewayAttachment):
return get_account_id() return get_account_id()
class TransitGatewayAttachmentBackend(object): class TransitGatewayAttachmentBackend:
def __init__(self): def __init__(self):
self.transit_gateway_attachments = {} self.transit_gateway_attachments = {}
super().__init__()
def create_transit_gateway_vpn_attachment( def create_transit_gateway_vpn_attachment(
self, vpn_id, transit_gateway_id, tags=None self, vpn_id, transit_gateway_id, tags=None

View File

@ -37,10 +37,9 @@ class TransitGatewayRouteTable(TaggedEC2Resource):
return iso_8601_datetime_with_milliseconds(self._created_at) return iso_8601_datetime_with_milliseconds(self._created_at)
class TransitGatewayRouteTableBackend(object): class TransitGatewayRouteTableBackend:
def __init__(self): def __init__(self):
self.transit_gateways_route_tables = {} self.transit_gateways_route_tables = {}
super().__init__()
def create_transit_gateway_route_table( def create_transit_gateway_route_table(
self, self,
@ -286,11 +285,10 @@ class TransitGatewayRelations(object):
self.state = state self.state = state
class TransitGatewayRelationsBackend(object): class TransitGatewayRelationsBackend:
def __init__(self): def __init__(self):
self.transit_gateway_associations = {} self.transit_gateway_associations = {}
self.transit_gateway_propagations = {} self.transit_gateway_propagations = {}
super().__init__()
def associate_transit_gateway_route_table( def associate_transit_gateway_route_table(
self, transit_gateway_attachment_id=None, transit_gateway_route_table_id=None self, transit_gateway_attachment_id=None, transit_gateway_route_table_id=None

View File

@ -84,14 +84,13 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
return self.id return self.id
class VPCPeeringConnectionBackend(object): class VPCPeeringConnectionBackend:
# for cross region vpc reference # for cross region vpc reference
vpc_pcx_refs = defaultdict(set) vpc_pcx_refs = defaultdict(set)
def __init__(self): def __init__(self):
self.vpc_pcxs = {} self.vpc_pcxs = {}
self.vpc_pcx_refs[self.__class__].add(weakref.ref(self)) self.vpc_pcx_refs[self.__class__].add(weakref.ref(self))
super().__init__()
@classmethod @classmethod
def get_vpc_pcx_refs(cls): def get_vpc_pcx_refs(cls):

View File

@ -36,10 +36,9 @@ class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
class VPCServiceConfigurationBackend(object): class VPCServiceConfigurationBackend:
def __init__(self): def __init__(self):
self.configurations = {} self.configurations = {}
super().__init__()
@property @property
def elbv2_backend(self): def elbv2_backend(self):

View File

@ -324,14 +324,13 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
] ]
class VPCBackend(object): class VPCBackend:
vpc_refs = defaultdict(set) vpc_refs = defaultdict(set)
def __init__(self): def __init__(self):
self.vpcs = {} self.vpcs = {}
self.vpc_end_points = {} self.vpc_end_points = {}
self.vpc_refs[self.__class__].add(weakref.ref(self)) self.vpc_refs[self.__class__].add(weakref.ref(self))
super().__init__()
def create_vpc( def create_vpc(
self, self,

View File

@ -31,10 +31,9 @@ class VPNConnection(TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeVpnConnections") return super().get_filter_value(filter_name, "DescribeVpnConnections")
class VPNConnectionBackend(object): class VPNConnectionBackend:
def __init__(self): def __init__(self):
self.vpn_connections = {} self.vpn_connections = {}
super().__init__()
def create_vpn_connection( def create_vpn_connection(
self, self,

View File

@ -105,10 +105,9 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
return super().get_filter_value(filter_name, "DescribeVpnGateways") return super().get_filter_value(filter_name, "DescribeVpnGateways")
class VpnGatewayBackend(object): class VpnGatewayBackend:
def __init__(self): def __init__(self):
self.vpn_gateways = {} self.vpn_gateways = {}
super().__init__()
def create_vpn_gateway( def create_vpn_gateway(
self, self,

View File

@ -4,9 +4,6 @@ from moto.core.utils import BackendDict
class Ec2InstanceConnectBackend(BaseBackend): class Ec2InstanceConnectBackend(BaseBackend):
def __init__(self, region=None):
pass
def send_ssh_public_key(self): def send_ssh_public_key(self):
return json.dumps( return json.dumps(
{"RequestId": "example-2a47-4c91-9700-e37e85162cb6", "Success": True} {"RequestId": "example-2a47-4c91-9700-e37e85162cb6", "Success": True}

View File

@ -326,18 +326,13 @@ class Image(BaseObject):
class ECRBackend(BaseBackend): class ECRBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.registry_policy = None self.registry_policy = None
self.replication_config = {"rules": []} self.replication_config = {"rules": []}
self.repositories: Dict[str, Repository] = {} self.repositories: Dict[str, Repository] = {}
self.tagger = TaggingService(tag_name="tags") self.tagger = TaggingService(tag_name="tags")
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -748,8 +748,8 @@ class EC2ContainerServiceBackend(BaseBackend):
AWS reference: https://aws.amazon.com/blogs/compute/migrating-your-amazon-ecs-deployment-to-the-new-arn-and-resource-id-format-2/ AWS reference: https://aws.amazon.com/blogs/compute/migrating-your-amazon-ecs-deployment-to-the-new-arn-and-resource-id-format-2/
""" """
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.account_settings = dict() self.account_settings = dict()
self.capacity_providers = dict() self.capacity_providers = dict()
self.clusters = {} self.clusters = {}
@ -758,16 +758,10 @@ class EC2ContainerServiceBackend(BaseBackend):
self.services = {} self.services = {}
self.container_instances = {} self.container_instances = {}
self.task_sets = {} self.task_sets = {}
self.region_name = region_name
self.tagger = TaggingService( self.tagger = TaggingService(
tag_name="tags", key_name="key", value_name="value" tag_name="tags", key_name="key", value_name="value"
) )
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -362,9 +362,8 @@ class EFSBackend(BaseBackend):
such resources should always go through this class. such resources should always go through this class.
""" """
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.creation_tokens = set() self.creation_tokens = set()
self.access_points = dict() self.access_points = dict()
self.file_systems_by_id = {} self.file_systems_by_id = {}
@ -372,12 +371,6 @@ class EFSBackend(BaseBackend):
self.next_markers = {} self.next_markers = {}
self.tagging_service = TaggingService() self.tagging_service = TaggingService()
def reset(self):
# preserve region
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def _mark_description(self, corpus, max_items): def _mark_description(self, corpus, max_items):
if max_items < len(corpus): if max_items < len(corpus):
new_corpus = corpus[max_items:] new_corpus = corpus[max_items:]

View File

@ -311,18 +311,12 @@ class ManagedNodegroup:
class EKSBackend(BaseBackend): class EKSBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.clusters = dict() self.clusters = dict()
self.cluster_count = 0 self.cluster_count = 0
self.region_name = region_name
self.partition = get_partition(region_name) self.partition = get_partition(region_name)
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_cluster( def create_cluster(
self, self,
name, name,

View File

@ -34,8 +34,8 @@ class User(BaseModel):
class ElastiCacheBackend(BaseBackend): class ElastiCacheBackend(BaseBackend):
"""Implementation of ElastiCache APIs.""" """Implementation of ElastiCache APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.users = dict() self.users = dict()
self.users["default"] = User( self.users["default"] = User(
region=self.region_name, region=self.region_name,
@ -46,11 +46,6 @@ class ElastiCacheBackend(BaseBackend):
no_password_required=True, no_password_required=True,
) )
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_user( def create_user(
self, user_id, user_name, engine, passwords, access_string, no_password_required self, user_id, user_name, engine, passwords, access_string, no_password_required
): ):

View File

@ -55,7 +55,7 @@ class FakeApplication(BaseModel):
@property @property
def region(self): def region(self):
return self.backend.region return self.backend.region_name
@property @property
def arn(self): def arn(self):
@ -65,17 +65,10 @@ class FakeApplication(BaseModel):
class EBBackend(BaseBackend): class EBBackend(BaseBackend):
def __init__(self, region): def __init__(self, region_name, account_id):
self.region = region super().__init__(region_name, account_id)
self.applications = dict() self.applications = dict()
def reset(self):
# preserve region
region = self.region
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -18,7 +18,7 @@ class EBResponse(BaseResponse):
) )
template = self.response_template(EB_CREATE_APPLICATION) template = self.response_template(EB_CREATE_APPLICATION)
return template.render(region_name=self.backend.region, application=app) return template.render(region_name=self.backend.region_name, application=app)
def describe_applications(self): def describe_applications(self):
template = self.response_template(EB_DESCRIBE_APPLICATIONS) template = self.response_template(EB_DESCRIBE_APPLICATIONS)
@ -42,7 +42,7 @@ class EBResponse(BaseResponse):
) )
template = self.response_template(EB_CREATE_ENVIRONMENT) template = self.response_template(EB_CREATE_ENVIRONMENT)
return template.render(environment=env, region=self.backend.region) return template.render(environment=env, region=self.backend.region_name)
def describe_environments(self): def describe_environments(self):
envs = self.backend.describe_environments() envs = self.backend.describe_environments()

View File

@ -62,16 +62,10 @@ class Pipeline(BaseModel):
class ElasticTranscoderBackend(BaseBackend): class ElasticTranscoderBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.pipelines = {} self.pipelines = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_pipeline( def create_pipeline(
self, self,
name, name,

View File

@ -270,15 +270,10 @@ class FakeLoadBalancer(CloudFormationModel):
class ELBBackend(BaseBackend): class ELBBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.load_balancers = OrderedDict() self.load_balancers = OrderedDict()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_load_balancer( def create_load_balancer(
self, self,
name, name,

View File

@ -636,8 +636,8 @@ class FakeLoadBalancer(CloudFormationModel):
class ELBv2Backend(BaseBackend): class ELBv2Backend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.target_groups = OrderedDict() self.target_groups = OrderedDict()
self.load_balancers = OrderedDict() self.load_balancers = OrderedDict()
self.tagging_service = TaggingService() self.tagging_service = TaggingService()
@ -659,11 +659,6 @@ class ELBv2Backend(BaseBackend):
""" """
return ec2_backends[self.region_name] return ec2_backends[self.region_name]
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_load_balancer( def create_load_balancer(
self, self,
name, name,

View File

@ -390,18 +390,12 @@ class FakeSecurityConfiguration(BaseModel):
class ElasticMapReduceBackend(BaseBackend): class ElasticMapReduceBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.clusters = {} self.clusters = {}
self.instance_groups = {} self.instance_groups = {}
self.security_configurations = {} self.security_configurations = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -159,21 +159,14 @@ class FakeJob(BaseModel):
class EMRContainersBackend(BaseBackend): class EMRContainersBackend(BaseBackend):
"""Implementation of EMRContainers APIs.""" """Implementation of EMRContainers APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.virtual_clusters = dict() self.virtual_clusters = dict()
self.virtual_cluster_count = 0 self.virtual_cluster_count = 0
self.jobs = dict() self.jobs = dict()
self.job_count = 0 self.job_count = 0
self.region_name = region_name
self.partition = get_partition(region_name) self.partition = get_partition(region_name)
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_virtual_cluster(self, name, container_provider, client_token, tags=None): def create_virtual_cluster(self, name, container_provider, client_token, tags=None):
occupied_namespaces = [ occupied_namespaces = [
virtual_cluster.namespace virtual_cluster.namespace

View File

@ -76,16 +76,10 @@ class Domain(BaseModel):
class ElasticsearchServiceBackend(BaseBackend): class ElasticsearchServiceBackend(BaseBackend):
"""Implementation of ElasticsearchService APIs.""" """Implementation of ElasticsearchService APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.domains = dict() self.domains = dict()
def reset(self):
"""Re-initialize all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_elasticsearch_domain( def create_elasticsearch_domain(
self, self,
domain_name, domain_name,

View File

@ -937,10 +937,10 @@ class EventsBackend(BaseBackend):
_CRON_REGEX = re.compile(r"^cron\(.*\)") _CRON_REGEX = re.compile(r"^cron\(.*\)")
_RATE_REGEX = re.compile(r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)") _RATE_REGEX = re.compile(r"^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)")
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.rules = OrderedDict() self.rules = OrderedDict()
self.next_tokens = {} self.next_tokens = {}
self.region_name = region_name
self.event_buses = {} self.event_buses = {}
self.event_sources = {} self.event_sources = {}
self.archives = {} self.archives = {}
@ -951,11 +951,6 @@ class EventsBackend(BaseBackend):
self.connections = {} self.connections = {}
self.destinations = {} self.destinations = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -1,4 +1,5 @@
import json import json
from moto.core import get_account_id
_EVENT_S3_OBJECT_CREATED = { _EVENT_S3_OBJECT_CREATED = {
@ -35,7 +36,8 @@ def _send_safe_notification(source, event_name, region, resources, detail):
if event is None: if event is None:
return return
for backend in events_backends.values(): account = events_backends[get_account_id()]
for backend in account.values():
applicable_targets = [] applicable_targets = []
for rule in backend.rules.values(): for rule in backend.rules.values():
if rule.state != "ENABLED": if rule.state != "ENABLED":

View File

@ -37,7 +37,7 @@ from moto.firehose.exceptions import (
ResourceNotFoundException, ResourceNotFoundException,
ValidationException, ValidationException,
) )
from moto.s3 import s3_backend from moto.s3.models import s3_backend
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
MAX_TAGS_PER_DELIVERY_STREAM = 50 MAX_TAGS_PER_DELIVERY_STREAM = 50
@ -163,17 +163,11 @@ class DeliveryStream(
class FirehoseBackend(BaseBackend): class FirehoseBackend(BaseBackend):
"""Implementation of Firehose APIs.""" """Implementation of Firehose APIs."""
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.delivery_streams = {} self.delivery_streams = {}
self.tagger = TaggingService() self.tagger = TaggingService()
def reset(self):
"""Re-initializes all attributes for this instance."""
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -99,11 +99,10 @@ class DatasetGroup:
class ForecastBackend(BaseBackend): class ForecastBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.dataset_groups = {} self.dataset_groups = {}
self.datasets = {} self.datasets = {}
self.region_name = region_name
def create_dataset_group(self, dataset_group_name, domain, dataset_arns, tags): def create_dataset_group(self, dataset_group_name, domain, dataset_arns, tags):
dataset_group = DatasetGroup( dataset_group = DatasetGroup(
@ -159,10 +158,5 @@ class ForecastBackend(BaseBackend):
def list_dataset_groups(self): def list_dataset_groups(self):
return [v for (_, v) in self.dataset_groups.items()] return [v for (_, v) in self.dataset_groups.items()]
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
forecast_backends = BackendDict(ForecastBackend, "forecast") forecast_backends = BackendDict(ForecastBackend, "forecast")

View File

@ -188,14 +188,9 @@ class Vault(BaseModel):
class GlacierBackend(BaseBackend): class GlacierBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.vaults = {} self.vaults = {}
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def get_vault(self, vault_name): def get_vault(self, vault_name):
return self.vaults[vault_name] return self.vaults[vault_name]

View File

@ -1,5 +1,4 @@
from .models import glue_backend from .models import glue_backends
from ..core.models import base_decorator from ..core.models import base_decorator
glue_backends = {"global": glue_backend}
mock_glue = base_decorator(glue_backends) mock_glue = base_decorator(glue_backends)

View File

@ -4,6 +4,7 @@ from datetime import datetime
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.models import get_account_id from moto.core.models import get_account_id
from moto.core.utils import BackendDict
from moto.glue.exceptions import CrawlerRunningException, CrawlerNotRunningException from moto.glue.exceptions import CrawlerRunningException, CrawlerNotRunningException
from .exceptions import ( from .exceptions import (
JsonRESTError, JsonRESTError,
@ -40,7 +41,8 @@ class GlueBackend(BaseBackend):
}, },
} }
def __init__(self): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.databases = OrderedDict() self.databases = OrderedDict()
self.crawlers = OrderedDict() self.crawlers = OrderedDict()
self.jobs = OrderedDict() self.jobs = OrderedDict()
@ -624,4 +626,7 @@ class FakeJobRun:
} }
glue_backend = GlueBackend() glue_backends = BackendDict(
GlueBackend, "glue", use_boto3_regions=False, additional_regions=["global"]
)
glue_backend = glue_backends["global"]

View File

@ -53,9 +53,8 @@ class FakeCoreDefinitionVersion(BaseModel):
class GreengrassBackend(BaseBackend): class GreengrassBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.groups = OrderedDict() self.groups = OrderedDict()
self.group_versions = OrderedDict() self.group_versions = OrderedDict()
self.core_definitions = OrderedDict() self.core_definitions = OrderedDict()

View File

@ -7,17 +7,11 @@ from .exceptions import DetectorNotFoundException, FilterNotFoundException
class GuardDutyBackend(BaseBackend): class GuardDutyBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.admin_account_ids = [] self.admin_account_ids = []
self.detectors = {} self.detectors = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_detector(self, enable, finding_publishing_frequency, data_sources, tags): def create_detector(self, enable, finding_publishing_frequency, data_sources, tags):
if finding_publishing_frequency not in [ if finding_publishing_frequency not in [
"FIFTEEN_MINUTES", "FIFTEEN_MINUTES",

View File

@ -39,7 +39,7 @@ from moto.s3.exceptions import (
BucketSignatureDoesNotMatchError, BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError, S3SignatureDoesNotMatchError,
) )
from moto.sts import sts_backend from moto.sts.models import sts_backend
from .models import iam_backend, Policy from .models import iam_backend, Policy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -1567,7 +1567,7 @@ def filter_items_with_path_prefix(path_prefix, items):
class IAMBackend(BaseBackend): class IAMBackend(BaseBackend):
def __init__(self): def __init__(self, region_name, account_id=None):
self.instance_profiles = {} self.instance_profiles = {}
self.roles = {} self.roles = {}
self.certificates = {} self.certificates = {}
@ -1586,7 +1586,7 @@ class IAMBackend(BaseBackend):
self.access_keys = {} self.access_keys = {}
self.tagger = TaggingService() self.tagger = TaggingService()
super().__init__() super().__init__(region_name=region_name, account_id=account_id)
def _init_managed_policies(self): def _init_managed_policies(self):
return dict((p.arn, p) for p in aws_managed_policies) return dict((p.arn, p) for p in aws_managed_policies)
@ -2924,4 +2924,4 @@ class IAMBackend(BaseBackend):
return True return True
iam_backend = IAMBackend() iam_backend = IAMBackend("global")

View File

@ -5,4 +5,4 @@ class InstanceMetadataBackend(BaseBackend):
pass pass
instance_metadata_backend = InstanceMetadataBackend() instance_metadata_backend = InstanceMetadataBackend(region_name="global")

View File

@ -564,9 +564,8 @@ class FakeDomainConfiguration(BaseModel):
class IoTBackend(BaseBackend): class IoTBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.things = OrderedDict() self.things = OrderedDict()
self.jobs = OrderedDict() self.jobs = OrderedDict()
self.job_executions = OrderedDict() self.job_executions = OrderedDict()
@ -581,11 +580,6 @@ class IoTBackend(BaseBackend):
self.endpoint = None self.endpoint = None
self.domain_configurations = OrderedDict() self.domain_configurations = OrderedDict()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service.""" """Default VPC endpoint service."""

View File

@ -139,16 +139,10 @@ class FakeShadow(BaseModel):
class IoTDataPlaneBackend(BaseBackend): class IoTDataPlaneBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
super().__init__() super().__init__(region_name, account_id)
self.region_name = region_name
self.published_payloads = list() self.published_payloads = list()
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def update_thing_shadow(self, thing_name, payload): def update_thing_shadow(self, thing_name, payload):
""" """
spec of payload: spec of payload:

Some files were not shown because too many files have changed in this diff Show More