From a2c2c06243b49207797ab0798dbfa8c1f6cb6477 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Thu, 9 Jun 2022 17:40:22 +0000 Subject: [PATCH] Techdebt - Align models-responses integration for all services (#5207) --- moto/awslambda/models.py | 13 +- moto/cloudformation/parsing.py | 4 +- moto/cloudformation/responses.py | 4 +- moto/cloudfront/models.py | 1 - moto/cloudfront/responses.py | 16 +- moto/cloudtrail/models.py | 4 +- moto/cloudwatch/models.py | 1 - moto/cognitoidentity/responses.py | 24 +-- moto/cognitoidp/responses.py | 134 +++++------- moto/config/models.py | 4 - moto/core/responses.py | 11 +- moto/core/utils.py | 96 ++++----- moto/datapipeline/responses.py | 34 ++- moto/dynamodb_v20111205/models.py | 1 - moto/dynamodb_v20111205/responses.py | 38 ++-- moto/ec2/models/flow_logs.py | 4 +- moto/elbv2/models.py | 4 +- moto/firehose/models.py | 4 +- moto/iam/__init__.py | 3 +- moto/iam/access_control.py | 38 ++-- moto/iam/models.py | 58 ++--- moto/iam/responses.py | 304 +++++++++++++-------------- moto/logs/models.py | 4 +- moto/managedblockchain/responses.py | 165 ++++----------- moto/managedblockchain/urls.py | 28 +-- moto/managedblockchain/utils.py | 8 - moto/s3/models.py | 19 +- moto/s3/responses.py | 177 +++++++--------- moto/s3control/__init__.py | 3 +- moto/s3control/models.py | 18 +- moto/s3control/responses.py | 28 +-- moto/ses/models.py | 4 +- moto/ses/responses.py | 66 +++--- moto/sts/models.py | 4 +- moto/sts/responses.py | 22 +- tests/test_core/test_server.py | 2 +- tests/test_kms/test_kms_boto3.py | 5 +- tests/test_s3/test_server.py | 3 + 38 files changed, 609 insertions(+), 747 deletions(-) diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 92080d407..17f5a5a77 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -4,6 +4,7 @@ from collections import defaultdict import copy import datetime from gzip import GzipFile +from typing import Mapping from sys import platform import docker @@ -25,10 +26,10 @@ import requests.exceptions from moto.awslambda.policy import Policy from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import RESTError -from moto.iam.models import iam_backend +from moto.iam.models import iam_backends from moto.iam.exceptions import IAMNotFoundException from moto.core.utils import unix_time_millis, BackendDict -from moto.s3.models import s3_backend +from moto.s3.models import s3_backends from moto.logs.models import logs_backends from moto.s3.exceptions import MissingBucket, MissingKey from moto import settings @@ -182,7 +183,7 @@ def _validate_s3_bucket_and_key(data): key = None try: # FIXME: does not validate bucket region - key = s3_backend.get_object(data["S3Bucket"], data["S3Key"]) + key = s3_backends["global"].get_object(data["S3Bucket"], data["S3Key"]) except MissingBucket: if do_validate_s3(): raise InvalidParameterValueException( @@ -585,7 +586,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): key = None try: # FIXME: does not validate bucket region - key = s3_backend.get_object( + key = s3_backends["global"].get_object( updated_spec["S3Bucket"], updated_spec["S3Key"] ) except MissingBucket: @@ -1121,7 +1122,7 @@ class LambdaStorage(object): if account != get_account_id(): raise CrossAccountNotAllowed() try: - iam_backend.get_role_by_arn(fn.role) + iam_backends["global"].get_role_by_arn(fn.role) except IAMNotFoundException: raise InvalidParameterValueException( "The role defined for the function cannot be assumed by Lambda." @@ -1666,4 +1667,4 @@ def do_validate_s3(): return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"] -lambda_backends = BackendDict(LambdaBackend, "lambda") +lambda_backends: Mapping[str, LambdaBackend] = BackendDict(LambdaBackend, "lambda") diff --git a/moto/cloudformation/parsing.py b/moto/cloudformation/parsing.py index 1d55c7a41..8302c4df1 100644 --- a/moto/cloudformation/parsing.py +++ b/moto/cloudformation/parsing.py @@ -47,7 +47,7 @@ from moto.ssm import models # noqa # pylint: disable=all # End ugly list of imports from moto.core import get_account_id, CloudFormationModel -from moto.s3.models import s3_backend +from moto.s3.models import s3_backends from moto.s3.utils import bucket_and_name_from_url from moto.ssm import ssm_backends from .utils import random_suffix @@ -528,7 +528,7 @@ class ResourceMap(collections_abc.Mapping): if name == "AWS::Include": location = params["Location"] bucket_name, name = bucket_and_name_from_url(location) - key = s3_backend.get_object(bucket_name, name) + key = s3_backends["global"].get_object(bucket_name, name) self._parsed_resources.update(json.loads(key.value)) def parse_ssm_parameter(self, value, value_type): diff --git a/moto/cloudformation/responses.py b/moto/cloudformation/responses.py index 32d1c9f11..36da31b72 100644 --- a/moto/cloudformation/responses.py +++ b/moto/cloudformation/responses.py @@ -6,7 +6,7 @@ from yaml.scanner import ScannerError # pylint:disable=c-extension-no-member from moto.core.responses import BaseResponse from moto.core.utils import amzn_request_id -from moto.s3.models import s3_backend +from moto.s3.models import s3_backends from moto.s3.exceptions import S3ClientError from moto.core import get_account_id from .models import cloudformation_backends @@ -68,7 +68,7 @@ class CloudFormationResponse(BaseResponse): bucket_name = template_url_parts.netloc.split(".")[0] key_name = template_url_parts.path.lstrip("/") - key = s3_backend.get_object(bucket_name, key_name) + key = s3_backends["global"].get_object(bucket_name, key_name) return key.value.decode("utf-8") def _get_params_from_list(self, parameters_list): diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py index 2d3dbf00e..b947c0336 100644 --- a/moto/cloudfront/models.py +++ b/moto/cloudfront/models.py @@ -255,4 +255,3 @@ cloudfront_backends = BackendDict( use_boto3_regions=False, additional_regions=["global"], ) -cloudfront_backend = cloudfront_backends["global"] diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py index 4888e00ad..252d5b559 100644 --- a/moto/cloudfront/responses.py +++ b/moto/cloudfront/responses.py @@ -1,7 +1,7 @@ import xmltodict from moto.core.responses import BaseResponse -from .models import cloudfront_backend +from .models import cloudfront_backends XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" @@ -11,6 +11,10 @@ class CloudFrontResponse(BaseResponse): def _get_xml_body(self): return xmltodict.parse(self.body, dict_constructor=dict) + @property + def backend(self): + return cloudfront_backends["global"] + def distributions(self, request, full_url, headers): self.setup_class(request, full_url, headers) if request.method == "POST": @@ -21,7 +25,7 @@ class CloudFrontResponse(BaseResponse): def create_distribution(self): params = self._get_xml_body() distribution_config = params.get("DistributionConfig") - distribution, location, e_tag = cloudfront_backend.create_distribution( + distribution, location, e_tag = self.backend.create_distribution( distribution_config=distribution_config ) template = self.response_template(CREATE_DISTRIBUTION_TEMPLATE) @@ -30,7 +34,7 @@ class CloudFrontResponse(BaseResponse): return 200, headers, response def list_distributions(self): - distributions = cloudfront_backend.list_distributions() + distributions = self.backend.list_distributions() template = self.response_template(LIST_TEMPLATE) response = template.render(distributions=distributions) return 200, {}, response @@ -40,10 +44,10 @@ class CloudFrontResponse(BaseResponse): distribution_id = full_url.split("/")[-1] if request.method == "DELETE": if_match = self._get_param("If-Match") - cloudfront_backend.delete_distribution(distribution_id, if_match) + self.backend.delete_distribution(distribution_id, if_match) return 204, {}, "" if request.method == "GET": - dist, etag = cloudfront_backend.get_distribution(distribution_id) + dist, etag = self.backend.get_distribution(distribution_id) template = self.response_template(GET_DISTRIBUTION_TEMPLATE) response = template.render(distribution=dist, xmlns=XMLNS) return 200, {"ETag": etag}, response @@ -55,7 +59,7 @@ class CloudFrontResponse(BaseResponse): dist_id = full_url.split("/")[-2] if_match = headers["If-Match"] - dist, location, e_tag = cloudfront_backend.update_distribution( + dist, location, e_tag = self.backend.update_distribution( DistributionConfig=distribution_config, Id=dist_id, IfMatch=if_match, diff --git a/moto/cloudtrail/models.py b/moto/cloudtrail/models.py index 3c156fbe1..86db63ae1 100644 --- a/moto/cloudtrail/models.py +++ b/moto/cloudtrail/models.py @@ -130,10 +130,10 @@ class Trail(BaseModel): raise TrailNameInvalidChars() def check_bucket_exists(self): - from moto.s3.models import s3_backend + from moto.s3.models import s3_backends try: - s3_backend.get_bucket(self.bucket_name) + s3_backends["global"].get_bucket(self.bucket_name) except Exception: raise S3BucketDoesNotExistException( f"S3 bucket {self.bucket_name} does not exist!" diff --git a/moto/cloudwatch/models.py b/moto/cloudwatch/models.py index 3508b7d01..8a74db335 100644 --- a/moto/cloudwatch/models.py +++ b/moto/cloudwatch/models.py @@ -544,7 +544,6 @@ class CloudWatchBackend(BaseBackend): unit=None, ): period_delta = timedelta(seconds=period) - # TODO: Also filter by unit and dimensions filtered_data = [ md for md in self.get_all_metrics() diff --git a/moto/cognitoidentity/responses.py b/moto/cognitoidentity/responses.py index 51b16bfec..6ac72ee8f 100644 --- a/moto/cognitoidentity/responses.py +++ b/moto/cognitoidentity/responses.py @@ -4,6 +4,10 @@ from .utils import get_random_identity_id class CognitoIdentityResponse(BaseResponse): + @property + def backend(self): + return cognitoidentity_backends[self.region] + def create_identity_pool(self): identity_pool_name = self._get_param("IdentityPoolName") allow_unauthenticated_identities = self._get_param( @@ -16,7 +20,7 @@ class CognitoIdentityResponse(BaseResponse): saml_provider_arns = self._get_param("SamlProviderARNs") pool_tags = self._get_param("IdentityPoolTags") - return cognitoidentity_backends[self.region].create_identity_pool( + return self.backend.create_identity_pool( identity_pool_name=identity_pool_name, allow_unauthenticated_identities=allow_unauthenticated_identities, supported_login_providers=supported_login_providers, @@ -38,7 +42,7 @@ class CognitoIdentityResponse(BaseResponse): saml_providers = self._get_param("SamlProviderARNs") pool_tags = self._get_param("IdentityPoolTags") - return cognitoidentity_backends[self.region].update_identity_pool( + return self.backend.update_identity_pool( identity_pool_id=pool_id, identity_pool_name=pool_name, allow_unauthenticated=allow_unauthenticated, @@ -51,19 +55,13 @@ class CognitoIdentityResponse(BaseResponse): ) def get_id(self): - return cognitoidentity_backends[self.region].get_id( - identity_pool_id=self._get_param("IdentityPoolId") - ) + return self.backend.get_id(identity_pool_id=self._get_param("IdentityPoolId")) def describe_identity_pool(self): - return cognitoidentity_backends[self.region].describe_identity_pool( - self._get_param("IdentityPoolId") - ) + return self.backend.describe_identity_pool(self._get_param("IdentityPoolId")) def get_credentials_for_identity(self): - return cognitoidentity_backends[self.region].get_credentials_for_identity( - self._get_param("IdentityId") - ) + return self.backend.get_credentials_for_identity(self._get_param("IdentityId")) def get_open_id_token_for_developer_identity(self): return cognitoidentity_backends[ @@ -73,11 +71,11 @@ class CognitoIdentityResponse(BaseResponse): ) def get_open_id_token(self): - return cognitoidentity_backends[self.region].get_open_id_token( + return self.backend.get_open_id_token( self._get_param("IdentityId") or get_random_identity_id(self.region) ) def list_identities(self): - return cognitoidentity_backends[self.region].list_identities( + return self.backend.list_identities( self._get_param("IdentityPoolId") or get_random_identity_id(self.region) ) diff --git a/moto/cognitoidp/responses.py b/moto/cognitoidp/responses.py index d68e67c1c..4ce243c37 100644 --- a/moto/cognitoidp/responses.py +++ b/moto/cognitoidp/responses.py @@ -20,12 +20,14 @@ class CognitoIdpResponse(BaseResponse): def parameters(self): return json.loads(self.body) + @property + def backend(self): + return cognitoidp_backends[self.region] + # User pool def create_user_pool(self): name = self.parameters.pop("PoolName") - user_pool = cognitoidp_backends[self.region].create_user_pool( - name, self.parameters - ) + user_pool = self.backend.create_user_pool(name, self.parameters) return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def set_user_pool_mfa_config(self): @@ -50,22 +52,20 @@ class CognitoIdpResponse(BaseResponse): "[SmsConfiguration] is a required member of [SoftwareTokenMfaConfiguration]." ) - response = cognitoidp_backends[self.region].set_user_pool_mfa_config( + response = self.backend.set_user_pool_mfa_config( user_pool_id, sms_config, token_config, mfa_config ) return json.dumps(response) def get_user_pool_mfa_config(self): user_pool_id = self._get_param("UserPoolId") - response = cognitoidp_backends[self.region].get_user_pool_mfa_config( - user_pool_id - ) + response = self.backend.get_user_pool_mfa_config(user_pool_id) return json.dumps(response) def list_user_pools(self): max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") - user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( + user_pools, next_token = self.backend.list_user_pools( max_results=max_results, next_token=next_token ) response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]} @@ -75,16 +75,16 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool(self): user_pool_id = self._get_param("UserPoolId") - user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) + user_pool = self.backend.describe_user_pool(user_pool_id) return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def update_user_pool(self): user_pool_id = self._get_param("UserPoolId") - cognitoidp_backends[self.region].update_user_pool(user_pool_id, self.parameters) + self.backend.update_user_pool(user_pool_id, self.parameters) def delete_user_pool(self): user_pool_id = self._get_param("UserPoolId") - cognitoidp_backends[self.region].delete_user_pool(user_pool_id) + self.backend.delete_user_pool(user_pool_id) return "" # User pool domain @@ -92,7 +92,7 @@ class CognitoIdpResponse(BaseResponse): domain = self._get_param("Domain") user_pool_id = self._get_param("UserPoolId") custom_domain_config = self._get_param("CustomDomainConfig") - user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain( + user_pool_domain = self.backend.create_user_pool_domain( user_pool_id, domain, custom_domain_config ) domain_description = user_pool_domain.to_json(extended=False) @@ -102,9 +102,7 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_domain(self): domain = self._get_param("Domain") - user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain( - domain - ) + user_pool_domain = self.backend.describe_user_pool_domain(domain) domain_description = {} if user_pool_domain: domain_description = user_pool_domain.to_json() @@ -113,13 +111,13 @@ class CognitoIdpResponse(BaseResponse): def delete_user_pool_domain(self): domain = self._get_param("Domain") - cognitoidp_backends[self.region].delete_user_pool_domain(domain) + self.backend.delete_user_pool_domain(domain) return "" def update_user_pool_domain(self): domain = self._get_param("Domain") custom_domain_config = self._get_param("CustomDomainConfig") - user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain( + user_pool_domain = self.backend.update_user_pool_domain( domain, custom_domain_config ) domain_description = user_pool_domain.to_json(extended=False) @@ -131,7 +129,7 @@ class CognitoIdpResponse(BaseResponse): def create_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") generate_secret = self.parameters.pop("GenerateSecret", False) - user_pool_client = cognitoidp_backends[self.region].create_user_pool_client( + user_pool_client = self.backend.create_user_pool_client( user_pool_id, generate_secret, self.parameters ) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) @@ -157,7 +155,7 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client( + user_pool_client = self.backend.describe_user_pool_client( user_pool_id, client_id ) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) @@ -165,7 +163,7 @@ class CognitoIdpResponse(BaseResponse): def update_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") client_id = self.parameters.pop("ClientId") - user_pool_client = cognitoidp_backends[self.region].update_user_pool_client( + user_pool_client = self.backend.update_user_pool_client( user_pool_id, client_id, self.parameters ) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) @@ -173,16 +171,14 @@ class CognitoIdpResponse(BaseResponse): def delete_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - cognitoidp_backends[self.region].delete_user_pool_client( - user_pool_id, client_id - ) + self.backend.delete_user_pool_client(user_pool_id, client_id) return "" # Identity provider def create_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self.parameters.pop("ProviderName") - identity_provider = cognitoidp_backends[self.region].create_identity_provider( + identity_provider = self.backend.create_identity_provider( user_pool_id, name, self.parameters ) return json.dumps( @@ -210,9 +206,7 @@ class CognitoIdpResponse(BaseResponse): def describe_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].describe_identity_provider( - user_pool_id, name - ) + identity_provider = self.backend.describe_identity_provider(user_pool_id, name) return json.dumps( {"IdentityProvider": identity_provider.to_json(extended=True)} ) @@ -220,7 +214,7 @@ class CognitoIdpResponse(BaseResponse): def update_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].update_identity_provider( + identity_provider = self.backend.update_identity_provider( user_pool_id, name, self.parameters ) return json.dumps( @@ -230,7 +224,7 @@ class CognitoIdpResponse(BaseResponse): def delete_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - cognitoidp_backends[self.region].delete_identity_provider(user_pool_id, name) + self.backend.delete_identity_provider(user_pool_id, name) return "" # Group @@ -241,7 +235,7 @@ class CognitoIdpResponse(BaseResponse): role_arn = self._get_param("RoleArn") precedence = self._get_param("Precedence") - group = cognitoidp_backends[self.region].create_group( + group = self.backend.create_group( user_pool_id, group_name, description, role_arn, precedence ) @@ -250,18 +244,18 @@ class CognitoIdpResponse(BaseResponse): def get_group(self): group_name = self._get_param("GroupName") user_pool_id = self._get_param("UserPoolId") - group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) + group = self.backend.get_group(user_pool_id, group_name) return json.dumps({"Group": group.to_json()}) def list_groups(self): user_pool_id = self._get_param("UserPoolId") - groups = cognitoidp_backends[self.region].list_groups(user_pool_id) + groups = self.backend.list_groups(user_pool_id) return json.dumps({"Groups": [group.to_json() for group in groups]}) def delete_group(self): group_name = self._get_param("GroupName") user_pool_id = self._get_param("UserPoolId") - cognitoidp_backends[self.region].delete_group(user_pool_id, group_name) + self.backend.delete_group(user_pool_id, group_name) return "" def update_group(self): @@ -271,7 +265,7 @@ class CognitoIdpResponse(BaseResponse): role_arn = self._get_param("RoleArn") precedence = self._get_param("Precedence") - group = cognitoidp_backends[self.region].update_group( + group = self.backend.update_group( user_pool_id, group_name, description, role_arn, precedence ) @@ -282,26 +276,20 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") group_name = self._get_param("GroupName") - cognitoidp_backends[self.region].admin_add_user_to_group( - user_pool_id, group_name, username - ) + self.backend.admin_add_user_to_group(user_pool_id, group_name, username) return "" def list_users_in_group(self): user_pool_id = self._get_param("UserPoolId") group_name = self._get_param("GroupName") - users = cognitoidp_backends[self.region].list_users_in_group( - user_pool_id, group_name - ) + users = self.backend.list_users_in_group(user_pool_id, group_name) return json.dumps({"Users": [user.to_json(extended=True) for user in users]}) def admin_list_groups_for_user(self): username = self._get_param("Username") user_pool_id = self._get_param("UserPoolId") - groups = cognitoidp_backends[self.region].admin_list_groups_for_user( - user_pool_id, username - ) + groups = self.backend.admin_list_groups_for_user(user_pool_id, username) return json.dumps({"Groups": [group.to_json() for group in groups]}) def admin_remove_user_from_group(self): @@ -309,18 +297,14 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") group_name = self._get_param("GroupName") - cognitoidp_backends[self.region].admin_remove_user_from_group( - user_pool_id, group_name, username - ) + self.backend.admin_remove_user_from_group(user_pool_id, group_name, username) return "" def admin_reset_user_password(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - cognitoidp_backends[self.region].admin_reset_user_password( - user_pool_id, username - ) + self.backend.admin_reset_user_password(user_pool_id, username) return "" # User @@ -329,7 +313,7 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") message_action = self._get_param("MessageAction") temporary_password = self._get_param("TemporaryPassword") - user = cognitoidp_backends[self.region].admin_create_user( + user = self.backend.admin_create_user( user_pool_id, username, message_action, @@ -342,14 +326,12 @@ class CognitoIdpResponse(BaseResponse): def admin_confirm_sign_up(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - return cognitoidp_backends[self.region].admin_confirm_sign_up( - user_pool_id, username - ) + return self.backend.admin_confirm_sign_up(user_pool_id, username) def admin_get_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) + user = self.backend.admin_get_user(user_pool_id, username) return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes")) def get_user(self): @@ -363,7 +345,7 @@ class CognitoIdpResponse(BaseResponse): token = self._get_param("PaginationToken") filt = self._get_param("Filter") attributes_to_get = self._get_param("AttributesToGet") - users, token = cognitoidp_backends[self.region].list_users( + users, token = self.backend.list_users( user_pool_id, limit=limit, pagination_token=token ) if filt: @@ -420,19 +402,19 @@ class CognitoIdpResponse(BaseResponse): def admin_disable_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - cognitoidp_backends[self.region].admin_disable_user(user_pool_id, username) + self.backend.admin_disable_user(user_pool_id, username) return "" def admin_enable_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - cognitoidp_backends[self.region].admin_enable_user(user_pool_id, username) + self.backend.admin_enable_user(user_pool_id, username) return "" def admin_delete_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - cognitoidp_backends[self.region].admin_delete_user(user_pool_id, username) + self.backend.admin_delete_user(user_pool_id, username) return "" def admin_initiate_auth(self): @@ -441,7 +423,7 @@ class CognitoIdpResponse(BaseResponse): auth_flow = self._get_param("AuthFlow") auth_parameters = self._get_param("AuthParameters") - auth_result = cognitoidp_backends[self.region].admin_initiate_auth( + auth_result = self.backend.admin_initiate_auth( user_pool_id, client_id, auth_flow, auth_parameters ) @@ -501,31 +483,25 @@ class CognitoIdpResponse(BaseResponse): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") attributes = self._get_param("UserAttributes") - cognitoidp_backends[self.region].admin_update_user_attributes( - user_pool_id, username, attributes - ) + self.backend.admin_update_user_attributes(user_pool_id, username, attributes) return "" def admin_delete_user_attributes(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") attributes = self._get_param("UserAttributeNames") - cognitoidp_backends[self.region].admin_delete_user_attributes( - user_pool_id, username, attributes - ) + self.backend.admin_delete_user_attributes(user_pool_id, username, attributes) return "" def admin_user_global_sign_out(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") - cognitoidp_backends[self.region].admin_user_global_sign_out( - user_pool_id, username - ) + self.backend.admin_user_global_sign_out(user_pool_id, username) return "" def global_sign_out(self): access_token = self._get_param("AccessToken") - cognitoidp_backends[self.region].global_sign_out(access_token) + self.backend.global_sign_out(access_token) return "" # Resource Server @@ -534,7 +510,7 @@ class CognitoIdpResponse(BaseResponse): identifier = self._get_param("Identifier") name = self._get_param("Name") scopes = self._get_param("Scopes") - resource_server = cognitoidp_backends[self.region].create_resource_server( + resource_server = self.backend.create_resource_server( user_pool_id, identifier, name, scopes ) return json.dumps({"ResourceServer": resource_server.to_json()}) @@ -575,19 +551,19 @@ class CognitoIdpResponse(BaseResponse): def associate_software_token(self): access_token = self._get_param("AccessToken") - result = cognitoidp_backends[self.region].associate_software_token(access_token) + result = self.backend.associate_software_token(access_token) return json.dumps(result) def verify_software_token(self): access_token = self._get_param("AccessToken") - result = cognitoidp_backends[self.region].verify_software_token(access_token) + result = self.backend.verify_software_token(access_token) return json.dumps(result) def set_user_mfa_preference(self): access_token = self._get_param("AccessToken") software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings") sms_mfa_settings = self._get_param("SMSMfaSettings") - cognitoidp_backends[self.region].set_user_mfa_preference( + self.backend.set_user_mfa_preference( access_token, software_token_mfa_settings, sms_mfa_settings ) return "" @@ -597,7 +573,7 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings") sms_mfa_settings = self._get_param("SMSMfaSettings") - cognitoidp_backends[self.region].admin_set_user_mfa_preference( + self.backend.admin_set_user_mfa_preference( user_pool_id, username, software_token_mfa_settings, sms_mfa_settings ) return "" @@ -607,7 +583,7 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") password = self._get_param("Password") permanent = self._get_param("Permanent") - cognitoidp_backends[self.region].admin_set_user_password( + self.backend.admin_set_user_password( user_pool_id, username, password, permanent ) return "" @@ -615,17 +591,13 @@ class CognitoIdpResponse(BaseResponse): def add_custom_attributes(self): user_pool_id = self._get_param("UserPoolId") custom_attributes = self._get_param("CustomAttributes") - cognitoidp_backends[self.region].add_custom_attributes( - user_pool_id, custom_attributes - ) + self.backend.add_custom_attributes(user_pool_id, custom_attributes) return "" def update_user_attributes(self): access_token = self._get_param("AccessToken") attributes = self._get_param("UserAttributes") - cognitoidp_backends[self.region].update_user_attributes( - access_token, attributes - ) + self.backend.update_user_attributes(access_token, attributes) return json.dumps({}) diff --git a/moto/config/models.py b/moto/config/models.py index b82ebbfeb..50c3e3a93 100644 --- a/moto/config/models.py +++ b/moto/config/models.py @@ -1468,14 +1468,11 @@ class ConfigBackend(BaseBackend): backend_query_region = ( backend_region # Always provide the backend this request arrived from. ) - print(RESOURCE_MAP[resource_type].backends) if RESOURCE_MAP[resource_type].backends.get("global"): - print("yes, its global") backend_region = "global" # If the backend region isn't implemented then we won't find the item: if not RESOURCE_MAP[resource_type].backends.get(backend_region): - print(f"cant find {backend_region} for {resource_type}") raise ResourceNotDiscoveredException(resource_type, resource_id) # Get the item: @@ -1483,7 +1480,6 @@ class ConfigBackend(BaseBackend): resource_id, backend_region=backend_query_region ) if not item: - print("item not found") raise ResourceNotDiscoveredException(resource_type, resource_id) item["accountId"] = get_account_id() diff --git a/moto/core/responses.py b/moto/core/responses.py index 57d3172eb..da3eefe74 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -204,7 +204,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def dispatch(cls, *args, **kwargs): return cls()._dispatch(*args, **kwargs) - def setup_class(self, request, full_url, headers): + def setup_class(self, request, full_url, headers, use_raw_body=False): + """ + use_raw_body: Use incoming bytes if True, encode to string otherwise + """ querystring = OrderedDict() if hasattr(request, "body"): # Boto @@ -222,7 +225,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): querystring[key] = [value] raw_body = self.body - if isinstance(self.body, bytes): + if isinstance(self.body, bytes) and not use_raw_body: self.body = self.body.decode("utf-8") if not querystring: @@ -244,7 +247,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): flat = flatten_json_request_body("", decoded, input_spec) for key, value in flat.items(): querystring[key] = [value] - elif self.body: + elif self.body and not use_raw_body: try: querystring.update( OrderedDict( @@ -254,7 +257,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) ) ) - except (UnicodeEncodeError, UnicodeDecodeError): + except (UnicodeEncodeError, UnicodeDecodeError, AttributeError): pass # ignore encoding errors, as the body may not contain a legitimate querystring if not querystring: querystring.update(headers) diff --git a/moto/core/utils.py b/moto/core/utils.py index c1c3c8ec3..e51e3119a 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -412,6 +412,54 @@ def extract_region_from_aws_authorization(string): backend_lock = RLock() +class AccountSpecificBackend(dict): + """ + Dictionary storing the data for a service in a specific account. + Data access pattern: + account_specific_backend[region: str] = backend: BaseBackend + """ + + def __init__( + self, service_name, account_id, backend, use_boto3_regions, additional_regions + ): + self.service_name = service_name + self.account_id = account_id + self.backend = backend + self.regions = [] + if use_boto3_regions: + sess = Session() + self.regions.extend(sess.get_available_regions(service_name)) + self.regions.extend( + sess.get_available_regions(service_name, partition_name="aws-us-gov") + ) + self.regions.extend( + sess.get_available_regions(service_name, partition_name="aws-cn") + ) + self.regions.extend(additional_regions or []) + + def reset(self): + for region_specific_backend in self.values(): + region_specific_backend.reset() + + def __contains__(self, region): + return region in self.regions or region in self.keys() + + def __getitem__(self, region_name): + if region_name in self.keys(): + return super().__getitem__(region_name) + # Create the backend for a specific region + with backend_lock: + if region_name in self.regions and region_name not in self.keys(): + super().__setitem__( + region_name, self.backend(region_name, account_id=self.account_id) + ) + if region_name not in self.regions and allow_unknown_region(): + super().__setitem__( + region_name, self.backend(region_name, account_id=self.account_id) + ) + return super().__getitem__(region_name) + + class BackendDict(dict): """ Data Structure to store everything related to a specific service. @@ -484,51 +532,3 @@ class BackendDict(dict): use_boto3_regions=self._use_boto3_regions, additional_regions=self._additional_regions, ) - - -class AccountSpecificBackend(dict): - """ - Dictionary storing the data for a service in a specific account. - Data access pattern: - account_specific_backend[region: str] = backend: BaseBackend - """ - - def __init__( - self, service_name, account_id, backend, use_boto3_regions, additional_regions - ): - self.service_name = service_name - self.account_id = account_id - self.backend = backend - self.regions = [] - if use_boto3_regions: - sess = Session() - self.regions.extend(sess.get_available_regions(service_name)) - self.regions.extend( - sess.get_available_regions(service_name, partition_name="aws-us-gov") - ) - self.regions.extend( - sess.get_available_regions(service_name, partition_name="aws-cn") - ) - self.regions.extend(additional_regions or []) - - def reset(self): - for region_specific_backend in self.values(): - region_specific_backend.reset() - - def __contains__(self, region): - return region in self.regions or region in self.keys() - - def __getitem__(self, region_name): - if region_name in self.keys(): - return super().__getitem__(region_name) - # Create the backend for a specific region - with backend_lock: - if region_name in self.regions and region_name not in self.keys(): - super().__setitem__( - region_name, self.backend(region_name, account_id=self.account_id) - ) - if region_name not in self.regions and allow_unknown_region(): - super().__setitem__( - region_name, self.backend(region_name, account_id=self.account_id) - ) - return super().__getitem__(region_name) diff --git a/moto/datapipeline/responses.py b/moto/datapipeline/responses.py index cba52609b..60cc294b8 100644 --- a/moto/datapipeline/responses.py +++ b/moto/datapipeline/responses.py @@ -5,23 +5,15 @@ from .models import datapipeline_backends class DataPipelineResponse(BaseResponse): - @property - def parameters(self): - # TODO this should really be moved to core/responses.py - if self.body: - return json.loads(self.body) - else: - return self.querystring - @property def datapipeline_backend(self): return datapipeline_backends[self.region] def create_pipeline(self): - name = self.parameters.get("name") - unique_id = self.parameters.get("uniqueId") - description = self.parameters.get("description", "") - tags = self.parameters.get("tags", []) + name = self._get_param("name") + unique_id = self._get_param("uniqueId") + description = self._get_param("description", "") + tags = self._get_param("tags", []) pipeline = self.datapipeline_backend.create_pipeline( name, unique_id, description=description, tags=tags ) @@ -31,7 +23,7 @@ class DataPipelineResponse(BaseResponse): pipelines = list(self.datapipeline_backend.list_pipelines()) pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] max_pipelines = 50 - marker = self.parameters.get("marker") + marker = self._get_param("marker") if marker: start = pipeline_ids.index(marker) + 1 else: @@ -53,7 +45,7 @@ class DataPipelineResponse(BaseResponse): ) def describe_pipelines(self): - pipeline_ids = self.parameters["pipelineIds"] + pipeline_ids = self._get_param("pipelineIds") pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) return json.dumps( @@ -61,19 +53,19 @@ class DataPipelineResponse(BaseResponse): ) def delete_pipeline(self): - pipeline_id = self.parameters["pipelineId"] + pipeline_id = self._get_param("pipelineId") self.datapipeline_backend.delete_pipeline(pipeline_id) return json.dumps({}) def put_pipeline_definition(self): - pipeline_id = self.parameters["pipelineId"] - pipeline_objects = self.parameters["pipelineObjects"] + pipeline_id = self._get_param("pipelineId") + pipeline_objects = self._get_param("pipelineObjects") self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects) return json.dumps({"errored": False}) def get_pipeline_definition(self): - pipeline_id = self.parameters["pipelineId"] + pipeline_id = self._get_param("pipelineId") pipeline_definition = self.datapipeline_backend.get_pipeline_definition( pipeline_id ) @@ -86,8 +78,8 @@ class DataPipelineResponse(BaseResponse): ) def describe_objects(self): - pipeline_id = self.parameters["pipelineId"] - object_ids = self.parameters["objectIds"] + pipeline_id = self._get_param("pipelineId") + object_ids = self._get_param("objectIds") pipeline_objects = self.datapipeline_backend.describe_objects( object_ids, pipeline_id @@ -103,6 +95,6 @@ class DataPipelineResponse(BaseResponse): ) def activate_pipeline(self): - pipeline_id = self.parameters["pipelineId"] + pipeline_id = self._get_param("pipelineId") self.datapipeline_backend.activate_pipeline(pipeline_id) return json.dumps({}) diff --git a/moto/dynamodb_v20111205/models.py b/moto/dynamodb_v20111205/models.py index dc3e3af29..a2a811a72 100644 --- a/moto/dynamodb_v20111205/models.py +++ b/moto/dynamodb_v20111205/models.py @@ -397,4 +397,3 @@ dynamodb_backends = BackendDict( use_boto3_regions=False, additional_regions=["global"], ) -dynamodb_backend = dynamodb_backends["global"] diff --git a/moto/dynamodb_v20111205/responses.py b/moto/dynamodb_v20111205/responses.py index 4584ae29a..1e2cfdc4d 100644 --- a/moto/dynamodb_v20111205/responses.py +++ b/moto/dynamodb_v20111205/responses.py @@ -2,7 +2,7 @@ import json from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from .models import dynamodb_backend, dynamo_json_dump +from .models import dynamodb_backends, dynamo_json_dump class DynamoHandler(BaseResponse): @@ -36,15 +36,19 @@ class DynamoHandler(BaseResponse): else: return 404, self.response_headers, "" + @property + def backend(self): + return dynamodb_backends["global"] + def list_tables(self): body = self.body limit = body.get("Limit") if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") - start = list(dynamodb_backend.tables.keys()).index(last) + 1 + start = list(self.backend.tables.keys()).index(last) + 1 else: start = 0 - all_tables = list(dynamodb_backend.tables.keys()) + all_tables = list(self.backend.tables.keys()) if limit: tables = all_tables[start : start + limit] else: @@ -71,7 +75,7 @@ class DynamoHandler(BaseResponse): read_units = throughput["ReadCapacityUnits"] write_units = throughput["WriteCapacityUnits"] - table = dynamodb_backend.create_table( + table = self.backend.create_table( name, hash_key_attr=hash_key_attr, hash_key_type=hash_key_type, @@ -84,7 +88,7 @@ class DynamoHandler(BaseResponse): def delete_table(self): name = self.body["TableName"] - table = dynamodb_backend.delete_table(name) + table = self.backend.delete_table(name) if table: return dynamo_json_dump(table.describe) else: @@ -96,7 +100,7 @@ class DynamoHandler(BaseResponse): throughput = self.body["ProvisionedThroughput"] new_read_units = throughput["ReadCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"] - table = dynamodb_backend.update_table_throughput( + table = self.backend.update_table_throughput( name, new_read_units, new_write_units ) return dynamo_json_dump(table.describe) @@ -104,7 +108,7 @@ class DynamoHandler(BaseResponse): def describe_table(self): name = self.body["TableName"] try: - table = dynamodb_backend.tables[name] + table = self.backend.tables[name] except KeyError: er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) @@ -113,7 +117,7 @@ class DynamoHandler(BaseResponse): def put_item(self): name = self.body["TableName"] item = self.body["Item"] - result = dynamodb_backend.put_item(name, item) + result = self.backend.put_item(name, item) if result: item_dict = result.to_json() item_dict["ConsumedCapacityUnits"] = 1 @@ -132,12 +136,12 @@ class DynamoHandler(BaseResponse): if request_type == "PutRequest": item = request["Item"] - dynamodb_backend.put_item(table_name, item) + self.backend.put_item(table_name, item) elif request_type == "DeleteRequest": key = request["Key"] hash_key = key["HashKeyElement"] range_key = key.get("RangeKeyElement") - item = dynamodb_backend.delete_item(table_name, hash_key, range_key) + self.backend.delete_item(table_name, hash_key, range_key) response = { "Responses": { @@ -156,7 +160,7 @@ class DynamoHandler(BaseResponse): range_key = key.get("RangeKeyElement") attrs_to_get = self.body.get("AttributesToGet") try: - item = dynamodb_backend.get_item(name, hash_key, range_key) + item = self.backend.get_item(name, hash_key, range_key) except ValueError: er = "com.amazon.coral.validate#ValidationException" return self.error(er, status=400) @@ -181,7 +185,7 @@ class DynamoHandler(BaseResponse): for key in keys: hash_key = key["HashKeyElement"] range_key = key.get("RangeKeyElement") - item = dynamodb_backend.get_item(table_name, hash_key, range_key) + item = self.backend.get_item(table_name, hash_key, range_key) if item: item_describe = item.describe_attrs(attributes_to_get) items.append(item_describe) @@ -202,9 +206,7 @@ class DynamoHandler(BaseResponse): range_comparison = None range_values = [] - items, _ = dynamodb_backend.query( - name, hash_key, range_comparison, range_values - ) + items, _ = self.backend.query(name, hash_key, range_comparison, range_values) if items is None: er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" @@ -236,7 +238,7 @@ class DynamoHandler(BaseResponse): comparison_values = scan_filter.get("AttributeValueList", []) filters[attribute_name] = (comparison_operator, comparison_values) - items, scanned_count, _ = dynamodb_backend.scan(name, filters) + items, scanned_count, _ = self.backend.scan(name, filters) if items is None: er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" @@ -263,7 +265,7 @@ class DynamoHandler(BaseResponse): hash_key = key["HashKeyElement"] range_key = key.get("RangeKeyElement") return_values = self.body.get("ReturnValues", "") - item = dynamodb_backend.delete_item(name, hash_key, range_key) + item = self.backend.delete_item(name, hash_key, range_key) if item: if return_values == "ALL_OLD": item_dict = item.to_json() @@ -282,7 +284,7 @@ class DynamoHandler(BaseResponse): range_key = key.get("RangeKeyElement") updates = self.body["AttributeUpdates"] - item = dynamodb_backend.update_item(name, hash_key, range_key, updates) + item = self.backend.update_item(name, hash_key, range_key, updates) if item: item_dict = item.to_json() diff --git a/moto/ec2/models/flow_logs.py b/moto/ec2/models/flow_logs.py index 3a2044a71..38491780f 100644 --- a/moto/ec2/models/flow_logs.py +++ b/moto/ec2/models/flow_logs.py @@ -214,12 +214,12 @@ class FlowLogsBackend: self.get_network_interface(resource_id) if log_destination_type == "s3": - from moto.s3.models import s3_backend + from moto.s3.models import s3_backends from moto.s3.exceptions import MissingBucket arn = log_destination.split(":", 5)[5] try: - s3_backend.get_bucket(arn) + s3_backends["global"].get_bucket(arn) except MissingBucket: # Instead of creating FlowLog report # the unsuccessful status for the diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 69eba6f0e..7434e9887 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -1547,9 +1547,9 @@ Member must satisfy regular expression pattern: {}".format( except AWSResourceNotFoundException: pass - from moto.iam import iam_backend + from moto.iam import iam_backends - cert = iam_backend.get_certificate_by_arn(certificate_arn) + cert = iam_backends["global"].get_certificate_by_arn(certificate_arn) if cert is not None: return True diff --git a/moto/firehose/models.py b/moto/firehose/models.py index 86bfff24f..a8d43f3ac 100644 --- a/moto/firehose/models.py +++ b/moto/firehose/models.py @@ -37,7 +37,7 @@ from moto.firehose.exceptions import ( ResourceNotFoundException, ValidationException, ) -from moto.s3.models import s3_backend +from moto.s3.models import s3_backends from moto.utilities.tagging_service import TaggingService MAX_TAGS_PER_DELIVERY_STREAM = 50 @@ -447,7 +447,7 @@ class FirehoseBackend(BaseBackend): batched_data = b"".join([b64decode(r["Data"]) for r in records]) try: - s3_backend.put_object(bucket_name, object_path, batched_data) + s3_backends["global"].put_object(bucket_name, object_path, batched_data) except Exception as exc: # This could be better ... raise RuntimeError( diff --git a/moto/iam/__init__.py b/moto/iam/__init__.py index 7131d1370..401c426b3 100644 --- a/moto/iam/__init__.py +++ b/moto/iam/__init__.py @@ -1,5 +1,4 @@ -from .models import iam_backend +from .models import iam_backends from ..core.models import base_decorator -iam_backends = {"global": iam_backend} mock_iam = base_decorator(iam_backends) diff --git a/moto/iam/access_control.py b/moto/iam/access_control.py index 58d87303c..5c2d24bdf 100644 --- a/moto/iam/access_control.py +++ b/moto/iam/access_control.py @@ -39,8 +39,8 @@ from moto.s3.exceptions import ( BucketSignatureDoesNotMatchError, S3SignatureDoesNotMatchError, ) -from moto.sts.models import sts_backend -from .models import iam_backend, Policy +from moto.sts.models import sts_backends +from .models import iam_backends, Policy log = logging.getLogger(__name__) @@ -53,8 +53,12 @@ def create_access_key(access_key_id, headers): class IAMUserAccessKey(object): + @property + def backend(self): + return iam_backends["global"] + def __init__(self, access_key_id, headers): - iam_users = iam_backend.list_users("/", None, None) + iam_users = self.backend.list_users("/", None, None) for iam_user in iam_users: for access_key in iam_user.access_keys: if access_key.access_key_id == access_key_id: @@ -78,28 +82,30 @@ class IAMUserAccessKey(object): def collect_policies(self): user_policies = [] - inline_policy_names = iam_backend.list_user_policies(self._owner_user_name) + inline_policy_names = self.backend.list_user_policies(self._owner_user_name) for inline_policy_name in inline_policy_names: - inline_policy = iam_backend.get_user_policy( + inline_policy = self.backend.get_user_policy( self._owner_user_name, inline_policy_name ) user_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_user_policies( + attached_policies, _ = self.backend.list_attached_user_policies( self._owner_user_name ) user_policies += attached_policies - user_groups = iam_backend.get_groups_for_user(self._owner_user_name) + user_groups = self.backend.get_groups_for_user(self._owner_user_name) for user_group in user_groups: - inline_group_policy_names = iam_backend.list_group_policies(user_group.name) + inline_group_policy_names = self.backend.list_group_policies( + user_group.name + ) for inline_group_policy_name in inline_group_policy_names: - inline_user_group_policy = iam_backend.get_group_policy( + inline_user_group_policy = self.backend.get_group_policy( user_group.name, inline_group_policy_name ) user_policies.append(inline_user_group_policy) - attached_group_policies, _ = iam_backend.list_attached_group_policies( + attached_group_policies, _ = self.backend.list_attached_group_policies( user_group.name ) user_policies += attached_group_policies @@ -108,8 +114,12 @@ class IAMUserAccessKey(object): class AssumedRoleAccessKey(object): + @property + def backend(self): + return iam_backends["global"] + def __init__(self, access_key_id, headers): - for assumed_role in sts_backend.assumed_roles: + for assumed_role in sts_backends["global"].assumed_roles: if assumed_role.access_key_id == access_key_id: self._access_key_id = access_key_id self._secret_access_key = assumed_role.secret_access_key @@ -139,14 +149,14 @@ class AssumedRoleAccessKey(object): def collect_policies(self): role_policies = [] - inline_policy_names = iam_backend.list_role_policies(self._owner_role_name) + inline_policy_names = self.backend.list_role_policies(self._owner_role_name) for inline_policy_name in inline_policy_names: - _, inline_policy = iam_backend.get_role_policy( + _, inline_policy = self.backend.get_role_policy( self._owner_role_name, inline_policy_name ) role_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_role_policies( + attached_policies, _ = self.backend.list_attached_role_policies( self._owner_role_name ) role_policies += attached_policies diff --git a/moto/iam/models.py b/moto/iam/models.py index 5f12238e3..bc71df6d9 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -19,6 +19,7 @@ from moto.core import BaseBackend, BaseModel, get_account_id, CloudFormationMode from moto.core.utils import ( iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds, + BackendDict, ) from moto.iam.policy_validation import IAMPolicyDocumentValidator from moto.utilities.utils import md5_hash @@ -362,7 +363,7 @@ class ManagedPolicy(Policy, CloudFormationModel): role_names = properties.get("Roles", []) tags = properties.get("Tags", {}) - policy = iam_backend.create_policy( + policy = iam_backends["global"].create_policy( description=description, path=path, policy_document=policy_document, @@ -370,13 +371,17 @@ class ManagedPolicy(Policy, CloudFormationModel): tags=tags, ) for group_name in group_names: - iam_backend.attach_group_policy( + iam_backends["global"].attach_group_policy( group_name=group_name, policy_arn=policy.arn ) for user_name in user_names: - iam_backend.attach_user_policy(user_name=user_name, policy_arn=policy.arn) + iam_backends["global"].attach_user_policy( + user_name=user_name, policy_arn=policy.arn + ) for role_name in role_names: - iam_backend.attach_role_policy(role_name=role_name, policy_arn=policy.arn) + iam_backends["global"].attach_role_policy( + role_name=role_name, policy_arn=policy.arn + ) return policy @property @@ -466,7 +471,7 @@ class InlinePolicy(CloudFormationModel): role_names = properties.get("Roles") group_names = properties.get("Groups") - return iam_backend.create_inline_policy( + return iam_backends["global"].create_inline_policy( resource_name, policy_name, policy_document, @@ -502,7 +507,7 @@ class InlinePolicy(CloudFormationModel): role_names = properties.get("Roles") group_names = properties.get("Groups") - return iam_backend.update_inline_policy( + return iam_backends["global"].update_inline_policy( original_resource.name, policy_name, policy_document, @@ -515,7 +520,7 @@ class InlinePolicy(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - iam_backend.delete_inline_policy(resource_name) + iam_backends["global"].delete_inline_policy(resource_name) @staticmethod def is_replacement_update(properties): @@ -606,7 +611,7 @@ class Role(CloudFormationModel): properties = cloudformation_json["Properties"] role_name = properties.get("RoleName", resource_name) - role = iam_backend.create_role( + role = iam_backends["global"].create_role( role_name=role_name, assume_role_policy_document=properties["AssumeRolePolicyDocument"], path=properties.get("Path", "/"), @@ -628,14 +633,14 @@ class Role(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - for profile in iam_backend.instance_profiles.values(): + for profile in iam_backends["global"].instance_profiles.values(): profile.delete_role(role_name=resource_name) - for role in iam_backend.roles.values(): + for role in iam_backends["global"].roles.values(): if role.name == resource_name: for arn in role.policies.keys(): role.delete_policy(arn) - iam_backend.delete_role(resource_name) + iam_backends["global"].delete_role(resource_name) @property def arn(self): @@ -649,7 +654,10 @@ class Role(CloudFormationModel): _managed_policies = [] for key in self.managed_policies.keys(): _managed_policies.append( - {"policyArn": key, "policyName": iam_backend.managed_policies[key].name} + { + "policyArn": key, + "policyName": iam_backends["global"].managed_policies[key].name, + } ) _role_policy_list = [] @@ -659,7 +667,7 @@ class Role(CloudFormationModel): ) _instance_profiles = [] - for key, instance_profile in iam_backend.instance_profiles.items(): + for key, instance_profile in iam_backends["global"].instance_profiles.items(): for _ in instance_profile.roles: _instance_profiles.append(instance_profile.to_embedded_config_dict()) break @@ -808,7 +816,7 @@ class InstanceProfile(CloudFormationModel): properties = cloudformation_json["Properties"] role_names = properties["Roles"] - return iam_backend.create_instance_profile( + return iam_backends["global"].create_instance_profile( name=resource_name, path=properties.get("Path", "/"), role_names=role_names, @@ -818,7 +826,7 @@ class InstanceProfile(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - iam_backend.delete_instance_profile(resource_name) + iam_backends["global"].delete_instance_profile(resource_name) def delete_role(self, role_name): self.roles = [role for role in self.roles if role.name != role_name] @@ -964,7 +972,7 @@ class AccessKey(CloudFormationModel): user_name = properties.get("UserName") status = properties.get("Status", "Active") - return iam_backend.create_access_key(user_name, status=status) + return iam_backends["global"].create_access_key(user_name, status=status) @classmethod def update_from_cloudformation_json( @@ -984,7 +992,7 @@ class AccessKey(CloudFormationModel): else: # No Interruption properties = cloudformation_json.get("Properties", {}) status = properties.get("Status") - return iam_backend.update_access_key( + return iam_backends["global"].update_access_key( original_resource.user_name, original_resource.access_key_id, status ) @@ -992,7 +1000,7 @@ class AccessKey(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - iam_backend.delete_access_key_by_name(resource_name) + iam_backends["global"].delete_access_key_by_name(resource_name) @staticmethod def is_replacement_update(properties): @@ -1303,7 +1311,7 @@ class User(CloudFormationModel): ): properties = cloudformation_json.get("Properties", {}) path = properties.get("Path") - user, _ = iam_backend.create_user(resource_name, path) + user, _ = iam_backends["global"].create_user(resource_name, path) return user @classmethod @@ -1334,7 +1342,7 @@ class User(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - iam_backend.delete_user(resource_name) + iam_backends["global"].delete_user(resource_name) @staticmethod def is_replacement_update(properties): @@ -2043,7 +2051,7 @@ class IAMBackend(BaseBackend): instance_profile_id = random_resource_id() - roles = [iam_backend.get_role(role_name) for role_name in role_names] + roles = [iam_backends["global"].get_role(role_name) for role_name in role_names] instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags) self.instance_profiles[name] = instance_profile return instance_profile @@ -2838,12 +2846,10 @@ class IAMBackend(BaseBackend): return inline_policy def get_inline_policy(self, policy_id): - inline_policy = None try: - inline_policy = self.inline_policies[policy_id] + return self.inline_policies[policy_id] except KeyError: raise IAMNotFoundException("Inline policy {0} not found".format(policy_id)) - return inline_policy def update_inline_policy( self, @@ -2924,4 +2930,6 @@ class IAMBackend(BaseBackend): return True -iam_backend = IAMBackend("global") +iam_backends = BackendDict( + IAMBackend, "iam", use_boto3_regions=False, additional_regions=["global"] +) diff --git a/moto/iam/responses.py b/moto/iam/responses.py index c203921c7..9a149b65f 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -1,48 +1,52 @@ from moto.core.responses import BaseResponse -from .models import iam_backend, User +from .models import iam_backends, User class IamResponse(BaseResponse): + @property + def backend(self): + return iam_backends["global"] + def attach_role_policy(self): policy_arn = self._get_param("PolicyArn") role_name = self._get_param("RoleName") - iam_backend.attach_role_policy(policy_arn, role_name) + self.backend.attach_role_policy(policy_arn, role_name) template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE) return template.render() def detach_role_policy(self): role_name = self._get_param("RoleName") policy_arn = self._get_param("PolicyArn") - iam_backend.detach_role_policy(policy_arn, role_name) + self.backend.detach_role_policy(policy_arn, role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DetachRolePolicy") def attach_group_policy(self): policy_arn = self._get_param("PolicyArn") group_name = self._get_param("GroupName") - iam_backend.attach_group_policy(policy_arn, group_name) + self.backend.attach_group_policy(policy_arn, group_name) template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE) return template.render() def detach_group_policy(self): policy_arn = self._get_param("PolicyArn") group_name = self._get_param("GroupName") - iam_backend.detach_group_policy(policy_arn, group_name) + self.backend.detach_group_policy(policy_arn, group_name) template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE) return template.render() def attach_user_policy(self): policy_arn = self._get_param("PolicyArn") user_name = self._get_param("UserName") - iam_backend.attach_user_policy(policy_arn, user_name) + self.backend.attach_user_policy(policy_arn, user_name) template = self.response_template(ATTACH_USER_POLICY_TEMPLATE) return template.render() def detach_user_policy(self): policy_arn = self._get_param("PolicyArn") user_name = self._get_param("UserName") - iam_backend.detach_user_policy(policy_arn, user_name) + self.backend.detach_user_policy(policy_arn, user_name) template = self.response_template(DETACH_USER_POLICY_TEMPLATE) return template.render() @@ -52,7 +56,7 @@ class IamResponse(BaseResponse): policy_document = self._get_param("PolicyDocument") policy_name = self._get_param("PolicyName") tags = self._get_multi_param("Tags.member") - policy = iam_backend.create_policy( + policy = self.backend.create_policy( description, path, policy_document, policy_name, tags ) template = self.response_template(CREATE_POLICY_TEMPLATE) @@ -60,7 +64,7 @@ class IamResponse(BaseResponse): def get_policy(self): policy_arn = self._get_param("PolicyArn") - policy = iam_backend.get_policy(policy_arn) + policy = self.backend.get_policy(policy_arn) template = self.response_template(GET_POLICY_TEMPLATE) return template.render(policy=policy) @@ -69,7 +73,7 @@ class IamResponse(BaseResponse): max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") role_name = self._get_param("RoleName") - policies, marker = iam_backend.list_attached_role_policies( + policies, marker = self.backend.list_attached_role_policies( role_name, marker=marker, max_items=max_items, path_prefix=path_prefix ) template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) @@ -80,7 +84,7 @@ class IamResponse(BaseResponse): max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") group_name = self._get_param("GroupName") - policies, marker = iam_backend.list_attached_group_policies( + policies, marker = self.backend.list_attached_group_policies( group_name, marker=marker, max_items=max_items, path_prefix=path_prefix ) template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE) @@ -91,7 +95,7 @@ class IamResponse(BaseResponse): max_items = self._get_int_param("MaxItems", 100) path_prefix = self._get_param("PathPrefix", "/") user_name = self._get_param("UserName") - policies, marker = iam_backend.list_attached_user_policies( + policies, marker = self.backend.list_attached_user_policies( user_name, marker=marker, max_items=max_items, path_prefix=path_prefix ) template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE) @@ -103,7 +107,7 @@ class IamResponse(BaseResponse): only_attached = self._get_bool_param("OnlyAttached", False) path_prefix = self._get_param("PathPrefix", "/") scope = self._get_param("Scope", "All") - policies, marker = iam_backend.list_policies( + policies, marker = self.backend.list_policies( marker, max_items, only_attached, path_prefix, scope ) template = self.response_template(LIST_POLICIES_TEMPLATE) @@ -124,7 +128,7 @@ class IamResponse(BaseResponse): entity_users = [] if not entity or entity == "User": - users = iam_backend.list_users(path_prefix, marker, max_items) + users = self.backend.list_users(path_prefix, marker, max_items) if users: for user in users: for p in user.managed_policies: @@ -132,7 +136,7 @@ class IamResponse(BaseResponse): entity_users.append({"name": user.name, "id": user.id}) if not entity or entity == "Role": - roles, _ = iam_backend.list_roles(path_prefix, marker, max_items) + roles, _ = self.backend.list_roles(path_prefix, marker, max_items) if roles: for role in roles: for p in role.managed_policies: @@ -140,7 +144,7 @@ class IamResponse(BaseResponse): entity_roles.append({"name": role.name, "id": role.id}) if not entity or entity == "Group": - groups = iam_backend.list_groups() + groups = self.backend.list_groups() if groups: for group in groups: for p in group.managed_policies: @@ -148,21 +152,21 @@ class IamResponse(BaseResponse): entity_groups.append({"name": group.name, "id": group.id}) if entity == "LocalManagedPolicy" or entity == "AWSManagedPolicy": - users = iam_backend.list_users(path_prefix, marker, max_items) + users = self.backend.list_users(path_prefix, marker, max_items) if users: for user in users: for p in user.managed_policies: if p == policy_arn: entity_users.append({"name": user.name, "id": user.id}) - roles, _ = iam_backend.list_roles(path_prefix, marker, max_items) + roles, _ = self.backend.list_roles(path_prefix, marker, max_items) if roles: for role in roles: for p in role.managed_policies: if p == policy_arn: entity_roles.append({"name": role.name, "id": role.id}) - groups = iam_backend.list_groups() + groups = self.backend.list_groups() if groups: for group in groups: for p in group.managed_policies: @@ -177,7 +181,7 @@ class IamResponse(BaseResponse): def set_default_policy_version(self): policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") - iam_backend.set_default_policy_version(policy_arn, version_id) + self.backend.set_default_policy_version(policy_arn, version_id) template = self.response_template(SET_DEFAULT_POLICY_VERSION_TEMPLATE) return template.render() @@ -190,7 +194,7 @@ class IamResponse(BaseResponse): tags = self._get_multi_param("Tags.member") max_session_duration = self._get_param("MaxSessionDuration", 3600) - role = iam_backend.create_role( + role = self.backend.create_role( role_name, assume_role_policy_document, path, @@ -204,20 +208,20 @@ class IamResponse(BaseResponse): def get_role(self): role_name = self._get_param("RoleName") - role = iam_backend.get_role(role_name) + role = self.backend.get_role(role_name) template = self.response_template(GET_ROLE_TEMPLATE) return template.render(role=role) def delete_role(self): role_name = self._get_param("RoleName") - iam_backend.delete_role(role_name) + self.backend.delete_role(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRole") def list_role_policies(self): role_name = self._get_param("RoleName") - role_policies_names = iam_backend.list_role_policies(role_name) + role_policies_names = self.backend.list_role_policies(role_name) template = self.response_template(LIST_ROLE_POLICIES) return template.render(role_policies=role_policies_names) @@ -225,21 +229,21 @@ class IamResponse(BaseResponse): role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") - iam_backend.put_role_policy(role_name, policy_name, policy_document) + self.backend.put_role_policy(role_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePolicy") def delete_role_policy(self): role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") - iam_backend.delete_role_policy(role_name, policy_name) + self.backend.delete_role_policy(role_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePolicy") def get_role_policy(self): role_name = self._get_param("RoleName") policy_name = self._get_param("PolicyName") - policy_name, policy_document = iam_backend.get_role_policy( + policy_name, policy_document = self.backend.get_role_policy( role_name, policy_name ) template = self.response_template(GET_ROLE_POLICY_TEMPLATE) @@ -251,7 +255,7 @@ class IamResponse(BaseResponse): def update_assume_role_policy(self): role_name = self._get_param("RoleName") - role = iam_backend.get_role(role_name) + role = self.backend.get_role(role_name) role.assume_role_policy_document = self._get_param("PolicyDocument") template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAssumeRolePolicy") @@ -259,7 +263,7 @@ class IamResponse(BaseResponse): def update_role_description(self): role_name = self._get_param("RoleName") description = self._get_param("Description") - role = iam_backend.update_role_description(role_name, description) + role = self.backend.update_role_description(role_name, description) template = self.response_template(UPDATE_ROLE_DESCRIPTION_TEMPLATE) return template.render(role=role) @@ -267,20 +271,20 @@ class IamResponse(BaseResponse): role_name = self._get_param("RoleName") description = self._get_param("Description") max_session_duration = self._get_param("MaxSessionDuration", 3600) - role = iam_backend.update_role(role_name, description, max_session_duration) + role = self.backend.update_role(role_name, description, max_session_duration) template = self.response_template(UPDATE_ROLE_TEMPLATE) return template.render(role=role) def put_role_permissions_boundary(self): permissions_boundary = self._get_param("PermissionsBoundary") role_name = self._get_param("RoleName") - iam_backend.put_role_permissions_boundary(role_name, permissions_boundary) + self.backend.put_role_permissions_boundary(role_name, permissions_boundary) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePermissionsBoundary") def delete_role_permissions_boundary(self): role_name = self._get_param("RoleName") - iam_backend.delete_role_permissions_boundary(role_name) + self.backend.delete_role_permissions_boundary(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePermissionsBoundary") @@ -288,7 +292,7 @@ class IamResponse(BaseResponse): policy_arn = self._get_param("PolicyArn") policy_document = self._get_param("PolicyDocument") set_as_default = self._get_param("SetAsDefault") - policy_version = iam_backend.create_policy_version( + policy_version = self.backend.create_policy_version( policy_arn, policy_document, set_as_default ) template = self.response_template(CREATE_POLICY_VERSION_TEMPLATE) @@ -297,13 +301,13 @@ class IamResponse(BaseResponse): def get_policy_version(self): policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") - policy_version = iam_backend.get_policy_version(policy_arn, version_id) + policy_version = self.backend.get_policy_version(policy_arn, version_id) template = self.response_template(GET_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) def list_policy_versions(self): policy_arn = self._get_param("PolicyArn") - policy_versions = iam_backend.list_policy_versions(policy_arn) + policy_versions = self.backend.list_policy_versions(policy_arn) template = self.response_template(LIST_POLICY_VERSIONS_TEMPLATE) return template.render(policy_versions=policy_versions) @@ -313,7 +317,7 @@ class IamResponse(BaseResponse): marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) - tags, marker = iam_backend.list_policy_tags(policy_arn, marker, max_items) + tags, marker = self.backend.list_policy_tags(policy_arn, marker, max_items) template = self.response_template(LIST_POLICY_TAG_TEMPLATE) return template.render(tags=tags, marker=marker) @@ -322,7 +326,7 @@ class IamResponse(BaseResponse): policy_arn = self._get_param("PolicyArn") tags = self._get_multi_param("Tags.member") - iam_backend.tag_policy(policy_arn, tags) + self.backend.tag_policy(policy_arn, tags) template = self.response_template(TAG_POLICY_TEMPLATE) return template.render() @@ -331,7 +335,7 @@ class IamResponse(BaseResponse): policy_arn = self._get_param("PolicyArn") tag_keys = self._get_multi_param("TagKeys.member") - iam_backend.untag_policy(policy_arn, tag_keys) + self.backend.untag_policy(policy_arn, tag_keys) template = self.response_template(UNTAG_POLICY_TEMPLATE) return template.render() @@ -340,7 +344,7 @@ class IamResponse(BaseResponse): policy_arn = self._get_param("PolicyArn") version_id = self._get_param("VersionId") - iam_backend.delete_policy_version(policy_arn, version_id) + self.backend.delete_policy_version(policy_arn, version_id) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeletePolicyVersion") @@ -349,7 +353,7 @@ class IamResponse(BaseResponse): path = self._get_param("Path", "/") tags = self._get_multi_param("Tags.member") - profile = iam_backend.create_instance_profile( + profile = self.backend.create_instance_profile( profile_name, path, role_names=[], tags=tags ) template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) @@ -358,13 +362,13 @@ class IamResponse(BaseResponse): def delete_instance_profile(self): profile_name = self._get_param("InstanceProfileName") - profile = iam_backend.delete_instance_profile(profile_name) + profile = self.backend.delete_instance_profile(profile_name) template = self.response_template(DELETE_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) def get_instance_profile(self): profile_name = self._get_param("InstanceProfileName") - profile = iam_backend.get_instance_profile(profile_name) + profile = self.backend.get_instance_profile(profile_name) template = self.response_template(GET_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) @@ -373,7 +377,7 @@ class IamResponse(BaseResponse): profile_name = self._get_param("InstanceProfileName") role_name = self._get_param("RoleName") - iam_backend.add_role_to_instance_profile(profile_name, role_name) + self.backend.add_role_to_instance_profile(profile_name, role_name) template = self.response_template(ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) return template.render() @@ -381,7 +385,7 @@ class IamResponse(BaseResponse): profile_name = self._get_param("InstanceProfileName") role_name = self._get_param("RoleName") - iam_backend.remove_role_from_instance_profile(profile_name, role_name) + self.backend.remove_role_from_instance_profile(profile_name, role_name) template = self.response_template(REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) return template.render() @@ -390,19 +394,19 @@ class IamResponse(BaseResponse): marker = self._get_param("Marker", "0") max_items = self._get_param("MaxItems", 100) - roles, marker = iam_backend.list_roles(path_prefix, marker, max_items) + roles, marker = self.backend.list_roles(path_prefix, marker, max_items) template = self.response_template(LIST_ROLES_TEMPLATE) return template.render(roles=roles, marker=marker) def list_instance_profiles(self): - profiles = iam_backend.get_instance_profiles() + profiles = self.backend.get_instance_profiles() template = self.response_template(LIST_INSTANCE_PROFILES_TEMPLATE) return template.render(instance_profiles=profiles) def list_instance_profiles_for_role(self): role_name = self._get_param("RoleName") - profiles = iam_backend.get_instance_profiles_for_role(role_name=role_name) + profiles = self.backend.get_instance_profiles_for_role(role_name=role_name) template = self.response_template(LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) return template.render(instance_profiles=profiles) @@ -414,26 +418,26 @@ class IamResponse(BaseResponse): private_key = self._get_param("PrivateKey") cert_chain = self._get_param("CertificateName") - cert = iam_backend.upload_server_certificate( + cert = self.backend.upload_server_certificate( cert_name, cert_body, private_key, cert_chain=cert_chain, path=path ) template = self.response_template(UPLOAD_CERT_TEMPLATE) return template.render(certificate=cert) def list_server_certificates(self): - certs = iam_backend.list_server_certificates() + certs = self.backend.list_server_certificates() template = self.response_template(LIST_SERVER_CERTIFICATES_TEMPLATE) return template.render(server_certificates=certs) def get_server_certificate(self): cert_name = self._get_param("ServerCertificateName") - cert = iam_backend.get_server_certificate(cert_name) + cert = self.backend.get_server_certificate(cert_name) template = self.response_template(GET_SERVER_CERTIFICATE_TEMPLATE) return template.render(certificate=cert) def delete_server_certificate(self): cert_name = self._get_param("ServerCertificateName") - iam_backend.delete_server_certificate(cert_name) + self.backend.delete_server_certificate(cert_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteServerCertificate") @@ -441,26 +445,26 @@ class IamResponse(BaseResponse): group_name = self._get_param("GroupName") path = self._get_param("Path", "/") - group = iam_backend.create_group(group_name, path) + group = self.backend.create_group(group_name, path) template = self.response_template(CREATE_GROUP_TEMPLATE) return template.render(group=group) def get_group(self): group_name = self._get_param("GroupName") - group = iam_backend.get_group(group_name) + group = self.backend.get_group(group_name) template = self.response_template(GET_GROUP_TEMPLATE) return template.render(group=group) def list_groups(self): - groups = iam_backend.list_groups() + groups = self.backend.list_groups() template = self.response_template(LIST_GROUPS_TEMPLATE) return template.render(groups=groups) def list_groups_for_user(self): user_name = self._get_param("UserName") - groups = iam_backend.get_groups_for_user(user_name) + groups = self.backend.get_groups_for_user(user_name) template = self.response_template(LIST_GROUPS_FOR_USER_TEMPLATE) return template.render(groups=groups) @@ -468,14 +472,14 @@ class IamResponse(BaseResponse): group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") - iam_backend.put_group_policy(group_name, policy_name, policy_document) + self.backend.put_group_policy(group_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutGroupPolicy") def list_group_policies(self): group_name = self._get_param("GroupName") marker = self._get_param("Marker") - policies = iam_backend.list_group_policies(group_name) + policies = self.backend.list_group_policies(group_name) template = self.response_template(LIST_GROUP_POLICIES_TEMPLATE) return template.render( name="ListGroupPoliciesResponse", policies=policies, marker=marker @@ -484,20 +488,20 @@ class IamResponse(BaseResponse): def get_group_policy(self): group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") - policy_result = iam_backend.get_group_policy(group_name, policy_name) + policy_result = self.backend.get_group_policy(group_name, policy_name) template = self.response_template(GET_GROUP_POLICY_TEMPLATE) return template.render(name="GetGroupPolicyResponse", **policy_result) def delete_group_policy(self): group_name = self._get_param("GroupName") policy_name = self._get_param("PolicyName") - iam_backend.delete_group_policy(group_name, policy_name) + self.backend.delete_group_policy(group_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteGroupPolicy") def delete_group(self): group_name = self._get_param("GroupName") - iam_backend.delete_group(group_name) + self.backend.delete_group(group_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteGroup") @@ -505,7 +509,7 @@ class IamResponse(BaseResponse): group_name = self._get_param("GroupName") new_group_name = self._get_param("NewGroupName") new_path = self._get_param("NewPath") - iam_backend.update_group(group_name, new_group_name, new_path) + self.backend.update_group(group_name, new_group_name, new_path) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateGroup") @@ -513,7 +517,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") path = self._get_param("Path") tags = self._get_multi_param("Tags.member") - user, user_tags = iam_backend.create_user(user_name, path, tags) + user, user_tags = self.backend.create_user(user_name, path, tags) template = self.response_template(USER_TEMPLATE) return template.render(action="Create", user=user, tags=user_tags["Tags"]) @@ -521,12 +525,12 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_current_user() - user = iam_backend.get_user_from_access_key_id(access_key_id) + user = self.backend.get_user_from_access_key_id(access_key_id) if user is None: user = User("default_user") else: - user = iam_backend.get_user(user_name) - tags = iam_backend.tagger.list_tags_for_resource(user.arn).get("Tags", []) + user = self.backend.get_user(user_name) + tags = self.backend.tagger.list_tags_for_resource(user.arn).get("Tags", []) template = self.response_template(USER_TEMPLATE) return template.render(action="Get", user=user, tags=tags) @@ -534,7 +538,7 @@ class IamResponse(BaseResponse): path_prefix = self._get_param("PathPrefix") marker = self._get_param("Marker") max_items = self._get_param("MaxItems") - users = iam_backend.list_users(path_prefix, marker, max_items) + users = self.backend.list_users(path_prefix, marker, max_items) template = self.response_template(LIST_USERS_TEMPLATE) return template.render(action="List", users=users, isTruncated=False) @@ -542,25 +546,25 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") new_path = self._get_param("NewPath") new_user_name = self._get_param("NewUserName") - iam_backend.update_user(user_name, new_path, new_user_name) + self.backend.update_user(user_name, new_path, new_user_name) if new_user_name: - user = iam_backend.get_user(new_user_name) + user = self.backend.get_user(new_user_name) else: - user = iam_backend.get_user(user_name) + user = self.backend.get_user(user_name) template = self.response_template(USER_TEMPLATE) return template.render(action="Update", user=user) def create_login_profile(self): user_name = self._get_param("UserName") password = self._get_param("Password") - user = iam_backend.create_login_profile(user_name, password) + user = self.backend.create_login_profile(user_name, password) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def get_login_profile(self): user_name = self._get_param("UserName") - user = iam_backend.get_login_profile(user_name) + user = self.backend.get_login_profile(user_name) template = self.response_template(GET_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) @@ -569,7 +573,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") password = self._get_param("Password") password_reset_required = self._get_param("PasswordResetRequired") - user = iam_backend.update_login_profile( + user = self.backend.update_login_profile( user_name, password, password_reset_required ) @@ -580,7 +584,7 @@ class IamResponse(BaseResponse): group_name = self._get_param("GroupName") user_name = self._get_param("UserName") - iam_backend.add_user_to_group(group_name, user_name) + self.backend.add_user_to_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="AddUserToGroup") @@ -588,7 +592,7 @@ class IamResponse(BaseResponse): group_name = self._get_param("GroupName") user_name = self._get_param("UserName") - iam_backend.remove_user_from_group(group_name, user_name) + self.backend.remove_user_from_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="RemoveUserFromGroup") @@ -596,7 +600,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") policy_name = self._get_param("PolicyName") - policy_document = iam_backend.get_user_policy(user_name, policy_name) + policy_document = self.backend.get_user_policy(user_name, policy_name) template = self.response_template(GET_USER_POLICY_TEMPLATE) return template.render( user_name=user_name, @@ -606,13 +610,13 @@ class IamResponse(BaseResponse): def list_user_policies(self): user_name = self._get_param("UserName") - policies = iam_backend.list_user_policies(user_name) + policies = self.backend.list_user_policies(user_name) template = self.response_template(LIST_USER_POLICIES_TEMPLATE) return template.render(policies=policies) def list_user_tags(self): user_name = self._get_param("UserName") - tags = iam_backend.list_user_tags(user_name) + tags = self.backend.list_user_tags(user_name) template = self.response_template(LIST_USER_TAGS_TEMPLATE) return template.render(user_tags=tags["Tags"]) @@ -621,7 +625,7 @@ class IamResponse(BaseResponse): policy_name = self._get_param("PolicyName") policy_document = self._get_param("PolicyDocument") - iam_backend.put_user_policy(user_name, policy_name, policy_document) + self.backend.put_user_policy(user_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutUserPolicy") @@ -629,7 +633,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") policy_name = self._get_param("PolicyName") - iam_backend.delete_user_policy(user_name, policy_name) + self.backend.delete_user_policy(user_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteUserPolicy") @@ -637,10 +641,10 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_current_user() - access_key = iam_backend.get_access_key_last_used(access_key_id) + access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] - key = iam_backend.create_access_key(user_name) + key = self.backend.create_access_key(user_name) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) return template.render(key=key) @@ -649,16 +653,16 @@ class IamResponse(BaseResponse): access_key_id = self._get_param("AccessKeyId") status = self._get_param("Status") if not user_name: - access_key = iam_backend.get_access_key_last_used(access_key_id) + access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] - iam_backend.update_access_key(user_name, access_key_id, status) + self.backend.update_access_key(user_name, access_key_id, status) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAccessKey") def get_access_key_last_used(self): access_key_id = self._get_param("AccessKeyId") - last_used_response = iam_backend.get_access_key_last_used(access_key_id) + last_used_response = self.backend.get_access_key_last_used(access_key_id) template = self.response_template(GET_ACCESS_KEY_LAST_USED_TEMPLATE) return template.render( user_name=last_used_response["user_name"], @@ -669,10 +673,10 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_current_user() - access_key = iam_backend.get_access_key_last_used(access_key_id) + access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] - keys = iam_backend.list_access_keys(user_name) + keys = self.backend.list_access_keys(user_name) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) return template.render(user_name=user_name, keys=keys) @@ -680,10 +684,10 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") access_key_id = self._get_param("AccessKeyId") if not user_name: - access_key = iam_backend.get_access_key_last_used(access_key_id) + access_key = self.backend.get_access_key_last_used(access_key_id) user_name = access_key["user_name"] - iam_backend.delete_access_key(access_key_id, user_name) + self.backend.delete_access_key(access_key_id, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteAccessKey") @@ -691,7 +695,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") ssh_public_key_body = self._get_param("SSHPublicKeyBody") - key = iam_backend.upload_ssh_public_key(user_name, ssh_public_key_body) + key = self.backend.upload_ssh_public_key(user_name, ssh_public_key_body) template = self.response_template(UPLOAD_SSH_PUBLIC_KEY_TEMPLATE) return template.render(key=key) @@ -699,14 +703,14 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") ssh_public_key_id = self._get_param("SSHPublicKeyId") - key = iam_backend.get_ssh_public_key(user_name, ssh_public_key_id) + key = self.backend.get_ssh_public_key(user_name, ssh_public_key_id) template = self.response_template(GET_SSH_PUBLIC_KEY_TEMPLATE) return template.render(key=key) def list_ssh_public_keys(self): user_name = self._get_param("UserName") - keys = iam_backend.get_all_ssh_public_keys(user_name) + keys = self.backend.get_all_ssh_public_keys(user_name) template = self.response_template(LIST_SSH_PUBLIC_KEYS_TEMPLATE) return template.render(keys=keys) @@ -715,7 +719,7 @@ class IamResponse(BaseResponse): ssh_public_key_id = self._get_param("SSHPublicKeyId") status = self._get_param("Status") - iam_backend.update_ssh_public_key(user_name, ssh_public_key_id, status) + self.backend.update_ssh_public_key(user_name, ssh_public_key_id, status) template = self.response_template(UPDATE_SSH_PUBLIC_KEY_TEMPLATE) return template.render() @@ -723,7 +727,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") ssh_public_key_id = self._get_param("SSHPublicKeyId") - iam_backend.delete_ssh_public_key(user_name, ssh_public_key_id) + self.backend.delete_ssh_public_key(user_name, ssh_public_key_id) template = self.response_template(DELETE_SSH_PUBLIC_KEY_TEMPLATE) return template.render() @@ -731,7 +735,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") serial_number = self._get_param("SerialNumber") - iam_backend.deactivate_mfa_device(user_name, serial_number) + self.backend.deactivate_mfa_device(user_name, serial_number) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeactivateMFADevice") @@ -741,7 +745,7 @@ class IamResponse(BaseResponse): authentication_code_1 = self._get_param("AuthenticationCode1") authentication_code_2 = self._get_param("AuthenticationCode2") - iam_backend.enable_mfa_device( + self.backend.enable_mfa_device( user_name, serial_number, authentication_code_1, authentication_code_2 ) template = self.response_template(GENERIC_EMPTY_TEMPLATE) @@ -749,7 +753,7 @@ class IamResponse(BaseResponse): def list_mfa_devices(self): user_name = self._get_param("UserName") - devices = iam_backend.list_mfa_devices(user_name) + devices = self.backend.list_mfa_devices(user_name) template = self.response_template(LIST_MFA_DEVICES_TEMPLATE) return template.render(user_name=user_name, devices=devices) @@ -757,7 +761,7 @@ class IamResponse(BaseResponse): path = self._get_param("Path") virtual_mfa_device_name = self._get_param("VirtualMFADeviceName") - virtual_mfa_device = iam_backend.create_virtual_mfa_device( + virtual_mfa_device = self.backend.create_virtual_mfa_device( virtual_mfa_device_name, path ) @@ -767,7 +771,7 @@ class IamResponse(BaseResponse): def delete_virtual_mfa_device(self): serial_number = self._get_param("SerialNumber") - iam_backend.delete_virtual_mfa_device(serial_number) + self.backend.delete_virtual_mfa_device(serial_number) template = self.response_template(DELETE_VIRTUAL_MFA_DEVICE_TEMPLATE) return template.render() @@ -777,7 +781,7 @@ class IamResponse(BaseResponse): marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) - devices, marker = iam_backend.list_virtual_mfa_devices( + devices, marker = self.backend.list_virtual_mfa_devices( assignment_status, marker, max_items ) @@ -786,54 +790,54 @@ class IamResponse(BaseResponse): def delete_user(self): user_name = self._get_param("UserName") - iam_backend.delete_user(user_name) + self.backend.delete_user(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteUser") def delete_policy(self): policy_arn = self._get_param("PolicyArn") - iam_backend.delete_policy(policy_arn) + self.backend.delete_policy(policy_arn) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeletePolicy") def delete_login_profile(self): user_name = self._get_param("UserName") - iam_backend.delete_login_profile(user_name) + self.backend.delete_login_profile(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteLoginProfile") def generate_credential_report(self): - if iam_backend.report_generated(): + if self.backend.report_generated(): template = self.response_template(CREDENTIAL_REPORT_GENERATED) else: template = self.response_template(CREDENTIAL_REPORT_GENERATING) - iam_backend.generate_report() + self.backend.generate_report() return template.render() def get_credential_report(self): - report = iam_backend.get_credential_report() + report = self.backend.get_credential_report() template = self.response_template(CREDENTIAL_REPORT) return template.render(report=report) def list_account_aliases(self): - aliases = iam_backend.list_account_aliases() + aliases = self.backend.list_account_aliases() template = self.response_template(LIST_ACCOUNT_ALIASES_TEMPLATE) return template.render(aliases=aliases) def create_account_alias(self): alias = self._get_param("AccountAlias") - iam_backend.create_account_alias(alias) + self.backend.create_account_alias(alias) template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def delete_account_alias(self): - iam_backend.delete_account_alias() + self.backend.delete_account_alias() template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def get_account_authorization_details(self): filter_param = self._get_multi_param("Filter.member") - account_details = iam_backend.get_account_authorization_details(filter_param) + account_details = self.backend.get_account_authorization_details(filter_param) template = self.response_template(GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE) return template.render( instance_profiles=account_details["instance_profiles"], @@ -841,13 +845,13 @@ class IamResponse(BaseResponse): users=account_details["users"], groups=account_details["groups"], roles=account_details["roles"], - get_groups_for_user=iam_backend.get_groups_for_user, + get_groups_for_user=self.backend.get_groups_for_user, ) def create_saml_provider(self): saml_provider_name = self._get_param("Name") saml_metadata_document = self._get_param("SAMLMetadataDocument") - saml_provider = iam_backend.create_saml_provider( + saml_provider = self.backend.create_saml_provider( saml_provider_name, saml_metadata_document ) @@ -857,7 +861,7 @@ class IamResponse(BaseResponse): def update_saml_provider(self): saml_provider_arn = self._get_param("SAMLProviderArn") saml_metadata_document = self._get_param("SAMLMetadataDocument") - saml_provider = iam_backend.update_saml_provider( + saml_provider = self.backend.update_saml_provider( saml_provider_arn, saml_metadata_document ) @@ -866,20 +870,20 @@ class IamResponse(BaseResponse): def delete_saml_provider(self): saml_provider_arn = self._get_param("SAMLProviderArn") - iam_backend.delete_saml_provider(saml_provider_arn) + self.backend.delete_saml_provider(saml_provider_arn) template = self.response_template(DELETE_SAML_PROVIDER_TEMPLATE) return template.render() def list_saml_providers(self): - saml_providers = iam_backend.list_saml_providers() + saml_providers = self.backend.list_saml_providers() template = self.response_template(LIST_SAML_PROVIDERS_TEMPLATE) return template.render(saml_providers=saml_providers) def get_saml_provider(self): saml_provider_arn = self._get_param("SAMLProviderArn") - saml_provider = iam_backend.get_saml_provider(saml_provider_arn) + saml_provider = self.backend.get_saml_provider(saml_provider_arn) template = self.response_template(GET_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) @@ -888,7 +892,7 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") cert_body = self._get_param("CertificateBody") - cert = iam_backend.upload_signing_certificate(user_name, cert_body) + cert = self.backend.upload_signing_certificate(user_name, cert_body) template = self.response_template(UPLOAD_SIGNING_CERTIFICATE_TEMPLATE) return template.render(cert=cert) @@ -897,7 +901,7 @@ class IamResponse(BaseResponse): cert_id = self._get_param("CertificateId") status = self._get_param("Status") - iam_backend.update_signing_certificate(user_name, cert_id, status) + self.backend.update_signing_certificate(user_name, cert_id, status) template = self.response_template(UPDATE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() @@ -905,14 +909,14 @@ class IamResponse(BaseResponse): user_name = self._get_param("UserName") cert_id = self._get_param("CertificateId") - iam_backend.delete_signing_certificate(user_name, cert_id) + self.backend.delete_signing_certificate(user_name, cert_id) template = self.response_template(DELETE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() def list_signing_certificates(self): user_name = self._get_param("UserName") - certs = iam_backend.list_signing_certificates(user_name) + certs = self.backend.list_signing_certificates(user_name) template = self.response_template(LIST_SIGNING_CERTIFICATES_TEMPLATE) return template.render(user_name=user_name, certificates=certs) @@ -921,7 +925,7 @@ class IamResponse(BaseResponse): marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) - tags, marker = iam_backend.list_role_tags(role_name, marker, max_items) + tags, marker = self.backend.list_role_tags(role_name, marker, max_items) template = self.response_template(LIST_ROLE_TAG_TEMPLATE) return template.render(tags=tags, marker=marker) @@ -930,7 +934,7 @@ class IamResponse(BaseResponse): role_name = self._get_param("RoleName") tags = self._get_multi_param("Tags.member") - iam_backend.tag_role(role_name, tags) + self.backend.tag_role(role_name, tags) template = self.response_template(TAG_ROLE_TEMPLATE) return template.render() @@ -939,7 +943,7 @@ class IamResponse(BaseResponse): role_name = self._get_param("RoleName") tag_keys = self._get_multi_param("TagKeys.member") - iam_backend.untag_role(role_name, tag_keys) + self.backend.untag_role(role_name, tag_keys) template = self.response_template(UNTAG_ROLE_TEMPLATE) return template.render() @@ -950,7 +954,7 @@ class IamResponse(BaseResponse): client_id_list = self._get_multi_param("ClientIDList.member") tags = self._get_multi_param("Tags.member") - open_id_provider = iam_backend.create_open_id_connect_provider( + open_id_provider = self.backend.create_open_id_connect_provider( open_id_provider_url, thumbprint_list, client_id_list, tags ) @@ -961,7 +965,7 @@ class IamResponse(BaseResponse): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") thumbprint_list = self._get_multi_param("ThumbprintList.member") - iam_backend.update_open_id_connect_provider_thumbprint( + self.backend.update_open_id_connect_provider_thumbprint( open_id_provider_arn, thumbprint_list ) @@ -972,7 +976,7 @@ class IamResponse(BaseResponse): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") tags = self._get_multi_param("Tags.member") - iam_backend.tag_open_id_connect_provider(open_id_provider_arn, tags) + self.backend.tag_open_id_connect_provider(open_id_provider_arn, tags) template = self.response_template(TAG_OPEN_ID_CONNECT_PROVIDER) return template.render() @@ -981,7 +985,7 @@ class IamResponse(BaseResponse): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") tag_keys = self._get_multi_param("TagKeys.member") - iam_backend.untag_open_id_connect_provider(open_id_provider_arn, tag_keys) + self.backend.untag_open_id_connect_provider(open_id_provider_arn, tag_keys) template = self.response_template(UNTAG_OPEN_ID_CONNECT_PROVIDER) return template.render() @@ -990,7 +994,7 @@ class IamResponse(BaseResponse): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") marker = self._get_param("Marker") max_items = self._get_param("MaxItems", 100) - tags, marker = iam_backend.list_open_id_connect_provider_tags( + tags, marker = self.backend.list_open_id_connect_provider_tags( open_id_provider_arn, marker, max_items ) template = self.response_template(LIST_OPEN_ID_CONNECT_PROVIDER_TAGS) @@ -999,7 +1003,7 @@ class IamResponse(BaseResponse): def delete_open_id_connect_provider(self): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") - iam_backend.delete_open_id_connect_provider(open_id_provider_arn) + self.backend.delete_open_id_connect_provider(open_id_provider_arn) template = self.response_template(DELETE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render() @@ -1007,7 +1011,7 @@ class IamResponse(BaseResponse): def get_open_id_connect_provider(self): open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") - open_id_provider = iam_backend.get_open_id_connect_provider( + open_id_provider = self.backend.get_open_id_connect_provider( open_id_provider_arn ) @@ -1015,7 +1019,7 @@ class IamResponse(BaseResponse): return template.render(open_id_provider=open_id_provider) def list_open_id_connect_providers(self): - open_id_provider_arns = iam_backend.list_open_id_connect_providers() + open_id_provider_arns = self.backend.list_open_id_connect_providers() template = self.response_template(LIST_OPEN_ID_CONNECT_PROVIDERS_TEMPLATE) return template.render(open_id_provider_arns=open_id_provider_arns) @@ -1037,7 +1041,7 @@ class IamResponse(BaseResponse): "RequireUppercaseCharacters", False ) - iam_backend.update_account_password_policy( + self.backend.update_account_password_policy( allow_change_password, hard_expiry, max_password_age, @@ -1053,19 +1057,19 @@ class IamResponse(BaseResponse): return template.render() def get_account_password_policy(self): - account_password_policy = iam_backend.get_account_password_policy() + account_password_policy = self.backend.get_account_password_policy() template = self.response_template(GET_ACCOUNT_PASSWORD_POLICY_TEMPLATE) return template.render(password_policy=account_password_policy) def delete_account_password_policy(self): - iam_backend.delete_account_password_policy() + self.backend.delete_account_password_policy() template = self.response_template(DELETE_ACCOUNT_PASSWORD_POLICY_TEMPLATE) return template.render() def get_account_summary(self): - account_summary = iam_backend.get_account_summary() + account_summary = self.backend.get_account_summary() template = self.response_template(GET_ACCOUNT_SUMMARY_TEMPLATE) return template.render(summary_map=account_summary.summary_map) @@ -1074,7 +1078,7 @@ class IamResponse(BaseResponse): name = self._get_param("UserName") tags = self._get_multi_param("Tags.member") - iam_backend.tag_user(name, tags) + self.backend.tag_user(name, tags) template = self.response_template(TAG_USER_TEMPLATE) return template.render() @@ -1083,7 +1087,7 @@ class IamResponse(BaseResponse): name = self._get_param("UserName") tag_keys = self._get_multi_param("TagKeys.member") - iam_backend.untag_user(name, tag_keys) + self.backend.untag_user(name, tag_keys) template = self.response_template(UNTAG_USER_TEMPLATE) return template.render() @@ -1093,7 +1097,9 @@ class IamResponse(BaseResponse): description = self._get_param("Description") suffix = self._get_param("CustomSuffix") - role = iam_backend.create_service_linked_role(service_name, description, suffix) + role = self.backend.create_service_linked_role( + service_name, description, suffix + ) template = self.response_template(CREATE_SERVICE_LINKED_ROLE_TEMPLATE) return template.render(role=role) @@ -1101,13 +1107,13 @@ class IamResponse(BaseResponse): def delete_service_linked_role(self): role_name = self._get_param("RoleName") - deletion_task_id = iam_backend.delete_service_linked_role(role_name) + deletion_task_id = self.backend.delete_service_linked_role(role_name) template = self.response_template(DELETE_SERVICE_LINKED_ROLE_TEMPLATE) return template.render(deletion_task_id=deletion_task_id) def get_service_linked_role_deletion_status(self): - iam_backend.get_service_linked_role_deletion_status() + self.backend.get_service_linked_role_deletion_status() template = self.response_template( GET_SERVICE_LINKED_ROLE_DELETION_STATUS_TEMPLATE @@ -1757,26 +1763,6 @@ LIST_GROUPS_TEMPLATE = """ """ -LIST_GROUPS_FOR_USER_TEMPLATE = """ - - - {% for group in groups %} - - {{ group.path }} - {{ group.name }} - {{ group.id }} - {{ group.arn }} - {{ group.created_iso_8601 }} - - {% endfor %} - - false - - - 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE - -""" - LIST_GROUP_POLICIES_TEMPLATE = """ {% if marker is none %} diff --git a/moto/logs/models.py b/moto/logs/models.py index 2436496ee..4cf5cc33e 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -13,7 +13,7 @@ from moto.logs.exceptions import ( InvalidParameterException, LimitExceededException, ) -from moto.s3.models import s3_backend +from moto.s3.models import s3_backends from .utils import PAGINATION_MODEL MAX_RESOURCE_POLICIES_PER_REGION = 10 @@ -940,7 +940,7 @@ class LogsBackend(BaseBackend): return query_id def create_export_task(self, log_group_name, destination): - s3_backend.get_bucket(destination) + s3_backends["global"].get_bucket(destination) if log_group_name not in self.groups: raise ResourceNotFoundException() task_id = uuid.uuid4() diff --git a/moto/managedblockchain/responses.py b/moto/managedblockchain/responses.py index f23112bf2..3c0f28908 100644 --- a/moto/managedblockchain/responses.py +++ b/moto/managedblockchain/responses.py @@ -5,7 +5,6 @@ from moto.core.responses import BaseResponse from .exceptions import exception_handler from .models import managedblockchain_backends from .utils import ( - region_from_managedblckchain_url, networkid_from_managedblockchain_url, proposalid_from_managedblockchain_url, invitationid_from_managedblockchain_url, @@ -15,29 +14,21 @@ from .utils import ( class ManagedBlockchainResponse(BaseResponse): - def __init__(self, backend): - super().__init__() - self.backend = backend + @property + def backend(self): + return managedblockchain_backends[self.region] - @classmethod @exception_handler - def network_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._network_response(request, headers) + def network_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._network_response(request, headers) def _network_response(self, request, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data if method == "GET": return self._all_networks_response(headers) elif method == "POST": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._network_response_post(json_body, headers) def _all_networks_response(self, headers): @@ -70,14 +61,10 @@ class ManagedBlockchainResponse(BaseResponse): ) return 200, headers, json.dumps(response) - @classmethod @exception_handler - def networkid_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._networkid_response(request, full_url, headers) + def networkid_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._networkid_response(request, full_url, headers) def _networkid_response(self, request, full_url, headers): method = request.method @@ -92,26 +79,18 @@ class ManagedBlockchainResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, response - @classmethod @exception_handler - def proposal_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._proposal_response(request, full_url, headers) + def proposal_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._proposal_response(request, full_url, headers) def _proposal_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data network_id = networkid_from_managedblockchain_url(full_url) if method == "GET": return self._all_proposals_response(network_id, headers) elif method == "POST": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._proposal_response_post(network_id, json_body, headers) def _all_proposals_response(self, network_id, headers): @@ -134,14 +113,10 @@ class ManagedBlockchainResponse(BaseResponse): ) return 200, headers, json.dumps(response) - @classmethod @exception_handler - def proposalid_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._proposalid_response(request, full_url, headers) + def proposalid_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._proposalid_response(request, full_url, headers) def _proposalid_response(self, request, full_url, headers): method = request.method @@ -156,27 +131,19 @@ class ManagedBlockchainResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, response - @classmethod @exception_handler - def proposal_votes_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._proposal_votes_response(request, full_url, headers) + def proposal_votes_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._proposal_votes_response(request, full_url, headers) def _proposal_votes_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data network_id = networkid_from_managedblockchain_url(full_url) proposal_id = proposalid_from_managedblockchain_url(full_url) if method == "GET": return self._all_proposal_votes_response(network_id, proposal_id, headers) elif method == "POST": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._proposal_votes_response_post( network_id, proposal_id, json_body, headers ) @@ -196,14 +163,10 @@ class ManagedBlockchainResponse(BaseResponse): self.backend.vote_on_proposal(network_id, proposal_id, votermemberid, vote) return 200, headers, "" - @classmethod @exception_handler - def invitation_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._invitation_response(request, headers) + def invitation_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._invitation_response(request, headers) def _invitation_response(self, request, headers): method = request.method @@ -218,14 +181,10 @@ class ManagedBlockchainResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, response - @classmethod @exception_handler - def invitationid_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._invitationid_response(request, full_url, headers) + def invitationid_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._invitationid_response(request, full_url, headers) def _invitationid_response(self, request, full_url, headers): method = request.method @@ -238,26 +197,18 @@ class ManagedBlockchainResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, "" - @classmethod @exception_handler - def member_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._member_response(request, full_url, headers) + def member_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._member_response(request, full_url, headers) def _member_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data network_id = networkid_from_managedblockchain_url(full_url) if method == "GET": return self._all_members_response(network_id, headers) elif method == "POST": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._member_response_post(network_id, json_body, headers) def _all_members_response(self, network_id, headers): @@ -275,27 +226,19 @@ class ManagedBlockchainResponse(BaseResponse): ) return 200, headers, json.dumps(response) - @classmethod @exception_handler - def memberid_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._memberid_response(request, full_url, headers) + def memberid_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._memberid_response(request, full_url, headers) def _memberid_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data network_id = networkid_from_managedblockchain_url(full_url) - member_id = memberid_from_managedblockchain_request(full_url, body) + member_id = memberid_from_managedblockchain_request(full_url, self.body) if method == "GET": return self._memberid_response_get(network_id, member_id, headers) elif method == "PATCH": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._memberid_response_patch( network_id, member_id, json_body, headers ) @@ -318,32 +261,24 @@ class ManagedBlockchainResponse(BaseResponse): headers["content-type"] = "application/json" return 200, headers, "" - @classmethod @exception_handler - def node_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._node_response(request, full_url, headers) + def node_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._node_response(request, full_url, headers) def _node_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data parsed_url = urlparse(full_url) querystring = parse_qs(parsed_url.query, keep_blank_values=True) network_id = networkid_from_managedblockchain_url(full_url) - member_id = memberid_from_managedblockchain_request(full_url, body) + member_id = memberid_from_managedblockchain_request(full_url, self.body) if method == "GET": status = None if "status" in querystring: status = querystring["status"][0] return self._all_nodes_response(network_id, member_id, status, headers) elif method == "POST": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._node_response_post(network_id, member_id, json_body, headers) def _all_nodes_response(self, network_id, member_id, status, headers): @@ -368,28 +303,20 @@ class ManagedBlockchainResponse(BaseResponse): ) return 200, headers, json.dumps(response) - @classmethod @exception_handler - def nodeid_response(clazz, request, full_url, headers): - region_name = region_from_managedblckchain_url(full_url) - response_instance = ManagedBlockchainResponse( - managedblockchain_backends[region_name] - ) - return response_instance._nodeid_response(request, full_url, headers) + def nodeid_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + return self._nodeid_response(request, full_url, headers) def _nodeid_response(self, request, full_url, headers): method = request.method - if hasattr(request, "body"): - body = request.body - else: - body = request.data network_id = networkid_from_managedblockchain_url(full_url) - member_id = memberid_from_managedblockchain_request(full_url, body) + member_id = memberid_from_managedblockchain_request(full_url, self.body) node_id = nodeid_from_managedblockchain_url(full_url) if method == "GET": return self._nodeid_response_get(network_id, member_id, node_id, headers) elif method == "PATCH": - json_body = json.loads(body.decode("utf-8")) + json_body = json.loads(self.body) return self._nodeid_response_patch( network_id, member_id, node_id, json_body, headers ) diff --git a/moto/managedblockchain/urls.py b/moto/managedblockchain/urls.py index 6fa1c1109..685ab0de0 100644 --- a/moto/managedblockchain/urls.py +++ b/moto/managedblockchain/urls.py @@ -3,19 +3,19 @@ from .responses import ManagedBlockchainResponse url_bases = [r"https?://managedblockchain\.(.+)\.amazonaws.com"] url_paths = { - "{0}/networks$": ManagedBlockchainResponse.network_response, - "{0}/networks/(?P[^/.]+)$": ManagedBlockchainResponse.networkid_response, - "{0}/networks/(?P[^/.]+)/proposals$": ManagedBlockchainResponse.proposal_response, - "{0}/networks/(?P[^/.]+)/proposals/(?P[^/.]+)$": ManagedBlockchainResponse.proposalid_response, - "{0}/networks/(?P[^/.]+)/proposals/(?P[^/.]+)/votes$": ManagedBlockchainResponse.proposal_votes_response, - "{0}/invitations$": ManagedBlockchainResponse.invitation_response, - "{0}/invitations/(?P[^/.]+)$": ManagedBlockchainResponse.invitationid_response, - "{0}/networks/(?P[^/.]+)/members$": ManagedBlockchainResponse.member_response, - "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)$": ManagedBlockchainResponse.memberid_response, - "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes$": ManagedBlockchainResponse.node_response, - "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes?(?P[^/.]+)$": ManagedBlockchainResponse.node_response, - "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes/(?P[^/.]+)$": ManagedBlockchainResponse.nodeid_response, + "{0}/networks$": ManagedBlockchainResponse().network_response, + "{0}/networks/(?P[^/.]+)$": ManagedBlockchainResponse().networkid_response, + "{0}/networks/(?P[^/.]+)/proposals$": ManagedBlockchainResponse().proposal_response, + "{0}/networks/(?P[^/.]+)/proposals/(?P[^/.]+)$": ManagedBlockchainResponse().proposalid_response, + "{0}/networks/(?P[^/.]+)/proposals/(?P[^/.]+)/votes$": ManagedBlockchainResponse().proposal_votes_response, + "{0}/invitations$": ManagedBlockchainResponse().invitation_response, + "{0}/invitations/(?P[^/.]+)$": ManagedBlockchainResponse().invitationid_response, + "{0}/networks/(?P[^/.]+)/members$": ManagedBlockchainResponse().member_response, + "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)$": ManagedBlockchainResponse().memberid_response, + "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes$": ManagedBlockchainResponse().node_response, + "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes?(?P[^/.]+)$": ManagedBlockchainResponse().node_response, + "{0}/networks/(?P[^/.]+)/members/(?P[^/.]+)/nodes/(?P[^/.]+)$": ManagedBlockchainResponse().nodeid_response, # >= botocore 1.19.41 (API change - memberId is now part of query-string or body) - "{0}/networks/(?P[^/.]+)/nodes$": ManagedBlockchainResponse.node_response, - "{0}/networks/(?P[^/.]+)/nodes/(?P[^/.]+)$": ManagedBlockchainResponse.nodeid_response, + "{0}/networks/(?P[^/.]+)/nodes$": ManagedBlockchainResponse().node_response, + "{0}/networks/(?P[^/.]+)/nodes/(?P[^/.]+)$": ManagedBlockchainResponse().nodeid_response, } diff --git a/moto/managedblockchain/utils.py b/moto/managedblockchain/utils.py index 96214031f..280f108fc 100644 --- a/moto/managedblockchain/utils.py +++ b/moto/managedblockchain/utils.py @@ -6,14 +6,6 @@ import string from urllib.parse import parse_qs, urlparse -def region_from_managedblckchain_url(url): - domain = urlparse(url).netloc - region = "us-east-1" - if "." in domain: - region = domain.split(".")[1] - return region - - def networkid_from_managedblockchain_url(full_url): id_search = re.search(r"\/n-[A-Z0-9]{26}", full_url, re.IGNORECASE) return_id = None diff --git a/moto/s3/models.py b/moto/s3/models.py index a0b862067..c0f11809b 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -121,7 +121,6 @@ class FakeKey(BaseModel, ManagedState): lock_mode=None, lock_legal_status=None, lock_until=None, - s3_backend=None, ): ManagedState.__init__( self, @@ -162,8 +161,6 @@ class FakeKey(BaseModel, ManagedState): # Default metadata values self._metadata["Content-Type"] = "binary/octet-stream" - self.s3_backend = s3_backend - def safe_name(self, encoding_type=None): if encoding_type == "url": return urllib.parse.quote(self.name, safe="") @@ -292,7 +289,7 @@ class FakeKey(BaseModel, ManagedState): res["x-amz-object-lock-retain-until-date"] = self.lock_until if self.lock_mode: res["x-amz-object-lock-mode"] = self.lock_mode - tags = s3_backend.tagger.get_tag_dict_for_resource(self.arn) + tags = s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn) if tags: res["x-amz-tagging-count"] = len(tags.keys()) @@ -1228,13 +1225,13 @@ class FakeBucket(CloudFormationModel): def create_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name, **kwargs ): - bucket = s3_backend.create_bucket(resource_name, region_name) + bucket = s3_backends["global"].create_bucket(resource_name, region_name) properties = cloudformation_json.get("Properties", {}) if "BucketEncryption" in properties: bucket_encryption = cfn_to_api_encryption(properties["BucketEncryption"]) - s3_backend.put_bucket_encryption( + s3_backends["global"].put_bucket_encryption( bucket_name=resource_name, encryption=bucket_encryption ) @@ -1264,7 +1261,7 @@ class FakeBucket(CloudFormationModel): bucket_encryption = cfn_to_api_encryption( properties["BucketEncryption"] ) - s3_backend.put_bucket_encryption( + s3_backends["global"].put_bucket_encryption( bucket_name=original_resource.name, encryption=bucket_encryption ) return original_resource @@ -1273,7 +1270,7 @@ class FakeBucket(CloudFormationModel): def delete_from_cloudformation_json( cls, resource_name, cloudformation_json, region_name ): - s3_backend.delete_bucket(resource_name) + s3_backends["global"].delete_bucket(resource_name) def to_config_dict(self): """Return the AWS Config JSON format of this S3 bucket. @@ -1298,7 +1295,7 @@ class FakeBucket(CloudFormationModel): "resourceCreationTime": str(self.creation_date), "relatedEvents": [], "relationships": [], - "tags": s3_backend.tagger.get_tag_dict_for_resource(self.arn), + "tags": s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn), "configuration": { "name": self.name, "owner": {"id": OWNER}, @@ -1449,7 +1446,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): @classmethod def get_cloudwatch_metrics(cls): metrics = [] - for name, bucket in s3_backend.buckets.items(): + for name, bucket in s3_backends["global"].buckets.items(): metrics.append( MetricDatum( namespace="AWS/S3", @@ -1700,7 +1697,6 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): lock_mode=lock_mode, lock_legal_status=lock_legal_status, lock_until=lock_until, - s3_backend=s3_backend, ) keys = [ @@ -2173,4 +2169,3 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): s3_backends = BackendDict( S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"] ) -s3_backend = s3_backends["global"] diff --git a/moto/s3/responses.py b/moto/s3/responses.py index d6eaa8105..b0a9ddbea 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse import xmltodict -from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin +from moto.core.responses import BaseResponse from moto.core.utils import path_url from moto.core import get_account_id @@ -51,7 +51,8 @@ from .exceptions import ( InvalidRange, LockNotEnabled, ) -from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey +from .models import s3_backends +from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey from .utils import bucket_name_from_url, metadata_from_headers, parse_region_from_url from xml.dom import minidom @@ -151,14 +152,10 @@ def is_delete_keys(request, path): ) -class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): - def __init__(self, backend): - super().__init__() - self.backend = backend - self.method = "" - self.path = "" - self.data = {} - self.headers = {} +class S3Response(BaseResponse): + @property + def backend(self): + return s3_backends["global"] @property def should_autoescape(self): @@ -253,15 +250,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return self.bucket_response(request, full_url, headers) @amzn_request_id - def bucket_response( - self, request, full_url, headers - ): # pylint: disable=unused-argument - self.method = request.method - self.path = self._get_path(request) - # Make a copy of request.headers because it's immutable - self.headers = dict(request.headers) - if "Host" not in self.headers: - self.headers["Host"] = urlparse(full_url).netloc + def bucket_response(self, request, full_url, headers): + self.setup_class(request, full_url, headers, use_raw_body=True) try: response = self._bucket_response(request, full_url) except S3ClientError as s3error: @@ -297,30 +287,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.data["BucketName"] = bucket_name - if hasattr(request, "body"): - # Boto - body = request.body - else: - # Flask server - body = request.data - if body is None: - body = b"" - if isinstance(body, bytes): - body = body.decode("utf-8") - body = "{0}".format(body).encode("utf-8") - if method == "HEAD": return self._bucket_response_head(bucket_name, querystring) elif method == "GET": return self._bucket_response_get(bucket_name, querystring) elif method == "PUT": return self._bucket_response_put( - request, body, region_name, bucket_name, querystring + request, region_name, bucket_name, querystring ) elif method == "DELETE": return self._bucket_response_delete(bucket_name, querystring) elif method == "POST": - return self._bucket_response_post(request, body, bucket_name) + return self._bucket_response_post(request, bucket_name) elif method == "OPTIONS": return self._response_options(bucket_name) else: @@ -379,28 +357,26 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): for cors_rule in bucket.cors: if cors_rule.allowed_methods is not None: - self.headers["Access-Control-Allow-Methods"] = _to_string( + self.response_headers["Access-Control-Allow-Methods"] = _to_string( cors_rule.allowed_methods ) if cors_rule.allowed_origins is not None: - self.headers["Access-Control-Allow-Origin"] = _to_string( + self.response_headers["Access-Control-Allow-Origin"] = _to_string( cors_rule.allowed_origins ) if cors_rule.allowed_headers is not None: - self.headers["Access-Control-Allow-Headers"] = _to_string( + self.response_headers["Access-Control-Allow-Headers"] = _to_string( cors_rule.allowed_headers ) if cors_rule.exposed_headers is not None: - self.headers["Access-Control-Expose-Headers"] = _to_string( + self.response_headers["Access-Control-Expose-Headers"] = _to_string( cors_rule.exposed_headers ) if cors_rule.max_age_seconds is not None: - self.headers["Access-Control-Max-Age"] = _to_string( + self.response_headers["Access-Control-Max-Age"] = _to_string( cors_rule.max_age_seconds ) - return self.headers - def _response_options(self, bucket_name): # Return 200 with the headers from the bucket CORS configuration self._authenticate_and_authorize_s3_action() @@ -415,7 +391,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._set_cors_headers(bucket) - return 200, self.headers, "" + return 200, self.response_headers, "" def _bucket_response_get(self, bucket_name, querystring): self._set_action("BUCKET", "GET", querystring) @@ -728,15 +704,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): pass return False - def _parse_pab_config(self, body): - parsed_xml = xmltodict.parse(body) + def _parse_pab_config(self): + parsed_xml = xmltodict.parse(self.body) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) return parsed_xml - def _bucket_response_put( - self, request, body, region_name, bucket_name, querystring - ): + def _bucket_response_put(self, request, region_name, bucket_name, querystring): if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" @@ -744,8 +718,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._authenticate_and_authorize_s3_action() if "object-lock" in querystring: - body_decoded = body.decode() - config = self._lock_config_from_xml(body_decoded) + config = self._lock_config_from_body() if not self.backend.get_bucket(bucket_name).object_lock_enabled: raise BucketMustHaveLockeEnabled @@ -760,7 +733,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 200, {}, "" if "versioning" in querystring: - ver = re.search("([A-Za-z]+)", body.decode()) + body = self.body.decode("utf-8") + ver = re.search(r"([A-Za-z]+)", body) if ver: self.backend.set_bucket_versioning(bucket_name, ver.group(1)) template = self.response_template(S3_BUCKET_VERSIONING) @@ -768,47 +742,45 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: return 404, {}, "" elif "lifecycle" in querystring: - rules = xmltodict.parse(body)["LifecycleConfiguration"]["Rule"] + rules = xmltodict.parse(self.body)["LifecycleConfiguration"]["Rule"] if not isinstance(rules, list): # If there is only one rule, xmldict returns just the item rules = [rules] self.backend.put_bucket_lifecycle(bucket_name, rules) return "" elif "policy" in querystring: - self.backend.put_bucket_policy(bucket_name, body) + self.backend.put_bucket_policy(bucket_name, self.body) return "True" elif "acl" in querystring: # Headers are first. If not set, then look at the body (consistent with the documentation): acls = self._acl_from_headers(request.headers) if not acls: - acls = self._acl_from_xml(body) + acls = self._acl_from_body() self.backend.put_bucket_acl(bucket_name, acls) return "" elif "tagging" in querystring: - tagging = self._bucket_tagging_from_xml(body) + tagging = self._bucket_tagging_from_body() self.backend.put_bucket_tagging(bucket_name, tagging) return "" elif "website" in querystring: - self.backend.set_bucket_website_configuration(bucket_name, body) + self.backend.set_bucket_website_configuration(bucket_name, self.body) return "" elif "cors" in querystring: try: - self.backend.put_bucket_cors(bucket_name, self._cors_from_xml(body)) + self.backend.put_bucket_cors(bucket_name, self._cors_from_body()) return "" except KeyError: raise MalformedXML() elif "logging" in querystring: try: - self.backend.put_bucket_logging( - bucket_name, self._logging_from_xml(body) - ) + self.backend.put_bucket_logging(bucket_name, self._logging_from_body()) return "" except KeyError: raise MalformedXML() elif "notification" in querystring: try: self.backend.put_bucket_notification_configuration( - bucket_name, self._notification_config_from_xml(body) + bucket_name, self._notification_config_from_body() ) return "" except KeyError: @@ -817,7 +789,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise e elif "accelerate" in querystring: try: - accelerate_status = self._accelerate_config_from_xml(body) + accelerate_status = self._accelerate_config_from_body() self.backend.put_bucket_accelerate_configuration( bucket_name, accelerate_status ) @@ -828,7 +800,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise e elif "publicAccessBlock" in querystring: - pab_config = self._parse_pab_config(body) + pab_config = self._parse_pab_config() self.backend.put_bucket_public_access_block( bucket_name, pab_config["PublicAccessBlockConfiguration"] ) @@ -836,7 +808,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): elif "encryption" in querystring: try: self.backend.put_bucket_encryption( - bucket_name, self._encryption_config_from_xml(body) + bucket_name, self._encryption_config_from_body() ) return "" except KeyError: @@ -848,7 +820,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if not bucket.is_versioned: template = self.response_template(S3_NO_VERSIONING_ENABLED) return 400, {}, template.render(bucket_name=bucket_name) - replication_config = self._replication_config_from_xml(body) + replication_config = self._replication_config_from_xml(self.body) self.backend.put_bucket_replication(bucket_name, replication_config) return "" else: @@ -858,17 +830,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # - LocationConstraint has to be specified if outside us-east-1 if ( region_name != DEFAULT_REGION_NAME - and not self._body_contains_location_constraint(body) + and not self._body_contains_location_constraint(self.body) ): raise IllegalLocationConstraintException() - if body: - if self._create_bucket_configuration_is_empty(body): + if self.body: + if self._create_bucket_configuration_is_empty(self.body): raise MalformedXML() try: - forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ - "LocationConstraint" - ] + forced_region = xmltodict.parse(self.body)[ + "CreateBucketConfiguration" + ]["LocationConstraint"] if forced_region == DEFAULT_REGION_NAME: raise S3ClientError( @@ -950,21 +922,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) return 409, {}, template.render(bucket=removed_bucket) - def _bucket_response_post(self, request, body, bucket_name): + def _bucket_response_post(self, request, bucket_name): response_headers = {} if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" - path = self._get_path(request) + self.path = self._get_path(request) - if self.is_delete_keys(request, path, bucket_name): + if self.is_delete_keys(request, self.path, bucket_name): self.data["Action"] = "DeleteObject" try: self._authenticate_and_authorize_s3_action() - return self._bucket_response_delete_keys(body, bucket_name) + return self._bucket_response_delete_keys(bucket_name) except BucketAccessDeniedError: return self._bucket_response_delete_keys( - body, bucket_name, authenticated=False + bucket_name, authenticated=False ) self.data["Action"] = "PutObject" @@ -1027,9 +999,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else path_url(request.url) ) - def _bucket_response_delete_keys(self, body, bucket_name, authenticated=True): + def _bucket_response_delete_keys(self, bucket_name, authenticated=True): template = self.response_template(S3_DELETE_KEYS_RESPONSE) - body_dict = xmltodict.parse(body, strip_whitespace=False) + body_dict = xmltodict.parse(self.body, strip_whitespace=False) objects = body_dict["Delete"].get("Object", []) if not isinstance(objects, list): @@ -1098,16 +1070,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return bytes(new_body) @amzn_request_id - def key_response( - self, request, full_url, headers - ): # pylint: disable=unused-argument + def key_response(self, request, full_url, headers): # Key and Control are lumped in because splitting out the regex is too much of a pain :/ - self.method = request.method - self.path = self._get_path(request) - # Make a copy of request.headers because it's immutable - self.headers = dict(request.headers) - if "Host" not in self.headers: - self.headers["Host"] = urlparse(full_url).netloc + self.setup_class(request, full_url, headers, use_raw_body=True) response_headers = {} try: @@ -1300,7 +1265,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 304, response_headers, "Not Modified" if "acl" in query: - acl = s3_backend.get_object_acl(key) + acl = self.backend.get_object_acl(key) template = self.response_template(S3_OBJECT_ACL_RESPONSE) return 200, response_headers, template.render(acl=acl) if "tagging" in query: @@ -1411,7 +1376,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if not lock_enabled: raise LockNotEnabled version_id = query.get("VersionId") - retention = self._mode_until_from_xml(body) + retention = self._mode_until_from_body() self.backend.put_object_retention( bucket_name, key_name, version_id=version_id, retention=retention ) @@ -1573,9 +1538,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: return 404, response_headers, "" - def _lock_config_from_xml(self, xml): + def _lock_config_from_body(self): response_dict = {"enabled": False, "mode": None, "days": None, "years": None} - parsed_xml = xmltodict.parse(xml) + parsed_xml = xmltodict.parse(self.body) enabled = ( parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled" ) @@ -1596,8 +1561,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return response_dict - def _acl_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _acl_from_body(self): + parsed_xml = xmltodict.parse(self.body) if not parsed_xml.get("AccessControlPolicy"): raise MalformedACLError() @@ -1713,8 +1678,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return tags - def _bucket_tagging_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _bucket_tagging_from_body(self): + parsed_xml = xmltodict.parse(self.body) tags = {} # Optional if no tags are being sent: @@ -1737,16 +1702,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return tags - def _cors_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _cors_from_body(self): + parsed_xml = xmltodict.parse(self.body) if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list): return [cors for cors in parsed_xml["CORSConfiguration"]["CORSRule"]] return [parsed_xml["CORSConfiguration"]["CORSRule"]] - def _mode_until_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _mode_until_from_body(self): + parsed_xml = xmltodict.parse(self.body) return ( parsed_xml.get("Retention", None).get("Mode", None), parsed_xml.get("Retention", None).get("RetainUntilDate", None), @@ -1756,8 +1721,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): parsed_xml = xmltodict.parse(xml) return parsed_xml["LegalHold"]["Status"] - def _encryption_config_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _encryption_config_from_body(self): + parsed_xml = xmltodict.parse(self.body) if ( not parsed_xml["ServerSideEncryptionConfiguration"].get("Rule") @@ -1772,8 +1737,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return parsed_xml["ServerSideEncryptionConfiguration"] - def _logging_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _logging_from_body(self): + parsed_xml = xmltodict.parse(self.body) if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"): return {} @@ -1817,8 +1782,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] - def _notification_config_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _notification_config_from_body(self): + parsed_xml = xmltodict.parse(self.body) if not len(parsed_xml["NotificationConfiguration"]): return {} @@ -1892,8 +1857,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return parsed_xml["NotificationConfiguration"] - def _accelerate_config_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml) + def _accelerate_config_from_body(self): + parsed_xml = xmltodict.parse(self.body) config = parsed_xml["AccelerateConfiguration"] return config["Status"] @@ -2028,7 +1993,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return False -S3ResponseInstance = ResponseObject(s3_backend) +S3ResponseInstance = S3Response() S3_ALL_BUCKETS = """ diff --git a/moto/s3control/__init__.py b/moto/s3control/__init__.py index 2c34680ea..1ef58b939 100644 --- a/moto/s3control/__init__.py +++ b/moto/s3control/__init__.py @@ -1,6 +1,5 @@ """s3control module initialization; sets value for base decorator.""" -from .models import s3control_backend +from .models import s3control_backends from ..core.models import base_decorator -s3control_backends = {"global": s3control_backend} mock_s3control = base_decorator(s3control_backends) diff --git a/moto/s3control/models.py b/moto/s3control/models.py index e0c935ec1..2b3719d5d 100644 --- a/moto/s3control/models.py +++ b/moto/s3control/models.py @@ -1,7 +1,7 @@ from collections import defaultdict from datetime import datetime from moto.core import get_account_id, BaseBackend, BaseModel -from moto.core.utils import get_random_hex +from moto.core.utils import get_random_hex, BackendDict from moto.s3.exceptions import ( WrongPublicAccessBlockAccountIdError, NoSuchPublicAccessBlockConfiguration, @@ -43,16 +43,11 @@ class AccessPoint(BaseModel): class S3ControlBackend(BaseBackend): - def __init__(self, region_name=None): - self.region_name = region_name + def __init__(self, region_name, account_id): + super().__init__(region_name, account_id) self.public_access_block = None self.access_points = defaultdict(dict) - def reset(self): - region_name = self.region_name - self.__dict__ = {} - self.__init__(region_name) - def get_public_access_block(self, account_id): # The account ID should equal the account id that is set for Moto: if account_id != get_account_id(): @@ -129,4 +124,9 @@ class S3ControlBackend(BaseBackend): return True -s3control_backend = S3ControlBackend() +s3control_backends = BackendDict( + S3ControlBackend, + "s3control", + use_boto3_regions=False, + additional_regions=["global"], +) diff --git a/moto/s3control/responses.py b/moto/s3control/responses.py index 103335d30..07cc33ee5 100644 --- a/moto/s3control/responses.py +++ b/moto/s3control/responses.py @@ -5,10 +5,14 @@ from moto.core.responses import BaseResponse from moto.core.utils import amzn_request_id from moto.s3.exceptions import S3ClientError from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION -from .models import s3control_backend +from .models import s3control_backends class S3ControlResponse(BaseResponse): + @property + def backend(self): + return s3control_backends["global"] + @amzn_request_id def public_access_block( self, request, full_url, headers @@ -25,7 +29,7 @@ class S3ControlResponse(BaseResponse): def get_public_access_block(self, request): account_id = request.headers.get("x-amz-account-id") - public_block_config = s3control_backend.get_public_access_block( + public_block_config = self.backend.get_public_access_block( account_id=account_id ) template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION) @@ -35,14 +39,14 @@ class S3ControlResponse(BaseResponse): account_id = request.headers.get("x-amz-account-id") data = request.body if hasattr(request, "body") else request.data pab_config = self._parse_pab_config(data) - s3control_backend.put_public_access_block( + self.backend.put_public_access_block( account_id, pab_config["PublicAccessBlockConfiguration"] ) return 201, {}, json.dumps({}) def delete_public_access_block(self, request): account_id = request.headers.get("x-amz-account-id") - s3control_backend.delete_public_access_block(account_id=account_id) + self.backend.delete_public_access_block(account_id=account_id) return 204, {}, json.dumps({}) def _parse_pab_config(self, body): @@ -82,7 +86,7 @@ class S3ControlResponse(BaseResponse): bucket = params["Bucket"] vpc_configuration = params.get("VpcConfiguration") public_access_block_configuration = params.get("PublicAccessBlockConfiguration") - access_point = s3control_backend.create_access_point( + access_point = self.backend.create_access_point( account_id=account_id, name=name, bucket=bucket, @@ -95,38 +99,36 @@ class S3ControlResponse(BaseResponse): def get_access_point(self, full_url): account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) - access_point = s3control_backend.get_access_point( - account_id=account_id, name=name - ) + access_point = self.backend.get_access_point(account_id=account_id, name=name) template = self.response_template(GET_ACCESS_POINT_TEMPLATE) return 200, {}, template.render(access_point=access_point) def delete_access_point(self, full_url): account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) - s3control_backend.delete_access_point(account_id=account_id, name=name) + self.backend.delete_access_point(account_id=account_id, name=name) return 204, {}, "" def create_access_point_policy(self, full_url): account_id, name = self._get_accountid_and_name_from_policy(full_url) params = xmltodict.parse(self.body) policy = params["PutAccessPointPolicyRequest"]["Policy"] - s3control_backend.create_access_point_policy(account_id, name, policy) + self.backend.create_access_point_policy(account_id, name, policy) return 200, {}, "" def get_access_point_policy(self, full_url): account_id, name = self._get_accountid_and_name_from_policy(full_url) - policy = s3control_backend.get_access_point_policy(account_id, name) + policy = self.backend.get_access_point_policy(account_id, name) template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE) return 200, {}, template.render(policy=policy) def delete_access_point_policy(self, full_url): account_id, name = self._get_accountid_and_name_from_policy(full_url) - s3control_backend.delete_access_point_policy(account_id=account_id, name=name) + self.backend.delete_access_point_policy(account_id=account_id, name=name) return 204, {}, "" def get_access_point_policy_status(self, full_url): account_id, name = self._get_accountid_and_name_from_policy(full_url) - s3control_backend.get_access_point_policy_status(account_id, name) + self.backend.get_access_point_policy_status(account_id, name) template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE) return 200, {}, template.render() diff --git a/moto/ses/models.py b/moto/ses/models.py index 964e9b23c..feaafbae3 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -6,6 +6,7 @@ from email.mime.base import MIMEBase from email.utils import parseaddr from email.mime.multipart import MIMEMultipart from email.encoders import encode_7or8bit +from typing import Mapping from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict @@ -537,7 +538,6 @@ class SESBackend(BaseBackend): return attributes_by_identity -ses_backends = BackendDict( +ses_backends: Mapping[str, SESBackend] = BackendDict( SESBackend, "ses", use_boto3_regions=False, additional_regions=["global"] ) -ses_backend = ses_backends["global"] diff --git a/moto/ses/responses.py b/moto/ses/responses.py index b49a6c5c1..3fb5f8cc0 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -1,48 +1,52 @@ import base64 from moto.core.responses import BaseResponse -from .models import ses_backend +from .models import ses_backends from datetime import datetime class EmailResponse(BaseResponse): + @property + def backend(self): + return ses_backends["global"] + def verify_email_identity(self): address = self.querystring.get("EmailAddress")[0] - ses_backend.verify_email_identity(address) + self.backend.verify_email_identity(address) template = self.response_template(VERIFY_EMAIL_IDENTITY) return template.render() def verify_email_address(self): address = self.querystring.get("EmailAddress")[0] - ses_backend.verify_email_address(address) + self.backend.verify_email_address(address) template = self.response_template(VERIFY_EMAIL_ADDRESS) return template.render() def list_identities(self): - identities = ses_backend.list_identities() + identities = self.backend.list_identities() template = self.response_template(LIST_IDENTITIES_RESPONSE) return template.render(identities=identities) def list_verified_email_addresses(self): - email_addresses = ses_backend.list_verified_email_addresses() + email_addresses = self.backend.list_verified_email_addresses() template = self.response_template(LIST_VERIFIED_EMAIL_RESPONSE) return template.render(email_addresses=email_addresses) def verify_domain_dkim(self): domain = self.querystring.get("Domain")[0] - ses_backend.verify_domain(domain) + self.backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE) return template.render() def verify_domain_identity(self): domain = self.querystring.get("Domain")[0] - ses_backend.verify_domain(domain) + self.backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE) return template.render() def delete_identity(self): domain = self.querystring.get("Identity")[0] - ses_backend.delete_identity(domain) + self.backend.delete_identity(domain) template = self.response_template(DELETE_IDENTITY_RESPONSE) return template.render() @@ -63,7 +67,7 @@ class EmailResponse(BaseResponse): break destinations[dest_type].append(address[0]) - message = ses_backend.send_email( + message = self.backend.send_email( source, subject, body, destinations, self.region ) template = self.response_template(SEND_EMAIL_RESPONSE) @@ -84,7 +88,7 @@ class EmailResponse(BaseResponse): break destinations[dest_type].append(address[0]) - message = ses_backend.send_templated_email( + message = self.backend.send_templated_email( source, template, template_data, destinations, self.region ) template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE) @@ -107,27 +111,27 @@ class EmailResponse(BaseResponse): break destinations.append(address[0]) - message = ses_backend.send_raw_email( + message = self.backend.send_raw_email( source, destinations, raw_data, self.region ) template = self.response_template(SEND_RAW_EMAIL_RESPONSE) return template.render(message=message) def get_send_quota(self): - quota = ses_backend.get_send_quota() + quota = self.backend.get_send_quota() template = self.response_template(GET_SEND_QUOTA_RESPONSE) return template.render(quota=quota) def get_identity_notification_attributes(self): identities = self._get_params()["Identities"] - identities = ses_backend.get_identity_notification_attributes(identities) + identities = self.backend.get_identity_notification_attributes(identities) template = self.response_template(GET_IDENTITY_NOTIFICATION_ATTRIBUTES) return template.render(identities=identities) def set_identity_feedback_forwarding_enabled(self): identity = self._get_param("Identity") enabled = self._get_bool_param("ForwardingEnabled") - ses_backend.set_identity_feedback_forwarding_enabled(identity, enabled) + self.backend.set_identity_feedback_forwarding_enabled(identity, enabled) template = self.response_template(SET_IDENTITY_FORWARDING_ENABLED_RESPONSE) return template.render() @@ -139,18 +143,18 @@ class EmailResponse(BaseResponse): if sns_topic: sns_topic = sns_topic[0] - ses_backend.set_identity_notification_topic(identity, not_type, sns_topic) + self.backend.set_identity_notification_topic(identity, not_type, sns_topic) template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE) return template.render() def get_send_statistics(self): - statistics = ses_backend.get_send_statistics() + statistics = self.backend.get_send_statistics() template = self.response_template(GET_SEND_STATISTICS) return template.render(all_statistics=[statistics]) def create_configuration_set(self): configuration_set_name = self.querystring.get("ConfigurationSet.Name")[0] - ses_backend.create_configuration_set( + self.backend.create_configuration_set( configuration_set_name=configuration_set_name ) template = self.response_template(CREATE_CONFIGURATION_SET) @@ -177,7 +181,7 @@ class EmailResponse(BaseResponse): "SNSDestination": event_topic_arn, } - ses_backend.create_configuration_set_event_destination( + self.backend.create_configuration_set_event_destination( configuration_set_name=configuration_set_name, event_destination=event_destination, ) @@ -193,7 +197,7 @@ class EmailResponse(BaseResponse): template_info["template_name"] = template_data.get("._name", "") template_info["subject_part"] = template_data.get("._subject_part", "") template_info["Timestamp"] = datetime.utcnow() - ses_backend.add_template(template_info=template_info) + self.backend.add_template(template_info=template_info) template = self.response_template(CREATE_TEMPLATE) return template.render() @@ -205,44 +209,44 @@ class EmailResponse(BaseResponse): template_info["template_name"] = template_data.get("._name", "") template_info["subject_part"] = template_data.get("._subject_part", "") template_info["Timestamp"] = datetime.utcnow() - ses_backend.update_template(template_info=template_info) + self.backend.update_template(template_info=template_info) template = self.response_template(UPDATE_TEMPLATE) return template.render() def get_template(self): template_name = self._get_param("TemplateName") - template_data = ses_backend.get_template(template_name) + template_data = self.backend.get_template(template_name) template = self.response_template(GET_TEMPLATE) return template.render(template_data=template_data) def list_templates(self): - email_templates = ses_backend.list_templates() + email_templates = self.backend.list_templates() template = self.response_template(LIST_TEMPLATES) return template.render(templates=email_templates) def test_render_template(self): render_info = self._get_dict_param("Template") - rendered_template = ses_backend.render_template(render_info) + rendered_template = self.backend.render_template(render_info) template = self.response_template(RENDER_TEMPLATE) return template.render(template=rendered_template) def create_receipt_rule_set(self): rule_set_name = self._get_param("RuleSetName") - ses_backend.create_receipt_rule_set(rule_set_name) + self.backend.create_receipt_rule_set(rule_set_name) template = self.response_template(CREATE_RECEIPT_RULE_SET) return template.render() def create_receipt_rule(self): rule_set_name = self._get_param("RuleSetName") rule = self._get_dict_param("Rule.") - ses_backend.create_receipt_rule(rule_set_name, rule) + self.backend.create_receipt_rule(rule_set_name, rule) template = self.response_template(CREATE_RECEIPT_RULE) return template.render() def describe_receipt_rule_set(self): rule_set_name = self._get_param("RuleSetName") - rule_set = ses_backend.describe_receipt_rule_set(rule_set_name) + rule_set = self.backend.describe_receipt_rule_set(rule_set_name) for i, rule in enumerate(rule_set): formatted_rule = {} @@ -260,7 +264,7 @@ class EmailResponse(BaseResponse): rule_set_name = self._get_param("RuleSetName") rule_name = self._get_param("RuleName") - receipt_rule = ses_backend.describe_receipt_rule(rule_set_name, rule_name) + receipt_rule = self.backend.describe_receipt_rule(rule_set_name, rule_name) rule = {} @@ -274,7 +278,7 @@ class EmailResponse(BaseResponse): rule_set_name = self._get_param("RuleSetName") rule = self._get_dict_param("Rule.") - ses_backend.update_receipt_rule(rule_set_name, rule) + self.backend.update_receipt_rule(rule_set_name, rule) template = self.response_template(UPDATE_RECEIPT_RULE) return template.render() @@ -284,7 +288,7 @@ class EmailResponse(BaseResponse): mail_from_domain = self._get_param("MailFromDomain") behavior_on_mx_failure = self._get_param("BehaviorOnMXFailure") - ses_backend.set_identity_mail_from_domain( + self.backend.set_identity_mail_from_domain( identity, mail_from_domain, behavior_on_mx_failure ) @@ -293,7 +297,7 @@ class EmailResponse(BaseResponse): def get_identity_mail_from_domain_attributes(self): identities = self._get_multi_param("Identities.member.") - identities = ses_backend.get_identity_mail_from_domain_attributes(identities) + identities = self.backend.get_identity_mail_from_domain_attributes(identities) template = self.response_template(GET_IDENTITY_MAIL_FROM_DOMAIN_ATTRIBUTES) return template.render(identities=identities) @@ -301,7 +305,7 @@ class EmailResponse(BaseResponse): def get_identity_verification_attributes(self): params = self._get_params() identities = params.get("Identities") - verification_attributes = ses_backend.get_identity_verification_attributes( + verification_attributes = self.backend.get_identity_verification_attributes( identities=identities, ) diff --git a/moto/sts/models.py b/moto/sts/models.py index 677bc0ffe..10109897c 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -11,6 +11,7 @@ from moto.sts.utils import ( random_assumed_role_id, DEFAULT_STS_SESSION_DURATION, ) +from typing import Mapping class Token(BaseModel): @@ -138,7 +139,6 @@ class STSBackend(BaseBackend): pass -sts_backends = BackendDict( +sts_backends: Mapping[str, STSBackend] = BackendDict( STSBackend, "sts", use_boto3_regions=False, additional_regions=["global"] ) -sts_backend = sts_backends["global"] diff --git a/moto/sts/responses.py b/moto/sts/responses.py index 291b22450..0622ac82c 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -1,16 +1,20 @@ from moto.core.responses import BaseResponse from moto.core import get_account_id -from moto.iam import iam_backend +from moto.iam import iam_backends from .exceptions import STSValidationError -from .models import sts_backend +from .models import sts_backends MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 class TokenResponse(BaseResponse): + @property + def backend(self): + return sts_backends["global"] + def get_session_token(self): duration = int(self.querystring.get("DurationSeconds", [43200])[0]) - token = sts_backend.get_session_token(duration=duration) + token = self.backend.get_session_token(duration=duration) template = self.response_template(GET_SESSION_TOKEN_RESPONSE) return template.render(token=token) @@ -27,7 +31,7 @@ class TokenResponse(BaseResponse): ) name = self.querystring.get("Name")[0] - token = sts_backend.get_federation_token(duration=duration, name=name) + token = self.backend.get_federation_token(duration=duration, name=name) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) return template.render(token=token, account_id=get_account_id()) @@ -39,7 +43,7 @@ class TokenResponse(BaseResponse): duration = int(self.querystring.get("DurationSeconds", [3600])[0]) external_id = self.querystring.get("ExternalId", [None])[0] - role = sts_backend.assume_role( + role = self.backend.assume_role( role_session_name=role_session_name, role_arn=role_arn, policy=policy, @@ -57,7 +61,7 @@ class TokenResponse(BaseResponse): duration = int(self.querystring.get("DurationSeconds", [3600])[0]) external_id = self.querystring.get("ExternalId", [None])[0] - role = sts_backend.assume_role_with_web_identity( + role = self.backend.assume_role_with_web_identity( role_session_name=role_session_name, role_arn=role_arn, policy=policy, @@ -72,7 +76,7 @@ class TokenResponse(BaseResponse): principal_arn = self.querystring.get("PrincipalArn")[0] saml_assertion = self.querystring.get("SAMLAssertion")[0] - role = sts_backend.assume_role_with_saml( + role = self.backend.assume_role_with_saml( role_arn=role_arn, principal_arn=principal_arn, saml_assertion=saml_assertion, @@ -88,12 +92,12 @@ class TokenResponse(BaseResponse): arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=get_account_id()) access_key_id = self.get_current_user() - assumed_role = sts_backend.get_assumed_role_from_access_key(access_key_id) + assumed_role = self.backend.get_assumed_role_from_access_key(access_key_id) if assumed_role: user_id = assumed_role.user_id arn = assumed_role.arn - user = iam_backend.get_user_from_access_key_id(access_key_id) + user = iam_backends["global"].get_user_from_access_key_id(access_key_id) if user: user_id = user.id arn = user.arn diff --git a/tests/test_core/test_server.py b/tests/test_core/test_server.py index be37ba69c..a2a3ffb3a 100644 --- a/tests/test_core/test_server.py +++ b/tests/test_core/test_server.py @@ -45,4 +45,4 @@ def test_domain_dispatched_with_service(): dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) keys = set(backend_app.view_functions.keys()) - keys.should.contain("ResponseObject.key_response") + keys.should.contain("S3Response.key_response") diff --git a/tests/test_kms/test_kms_boto3.py b/tests/test_kms/test_kms_boto3.py index ff1006078..57fd9d4ed 100644 --- a/tests/test_kms/test_kms_boto3.py +++ b/tests/test_kms/test_kms_boto3.py @@ -647,7 +647,10 @@ def test_generate_data_key_all_valid_key_ids(prefix, append_key_id): if append_key_id: target_id += key_id - client.generate_data_key(KeyId=target_id, NumberOfBytes=32) + resp = client.generate_data_key(KeyId=target_id, NumberOfBytes=32) + resp.should.have.key("KeyId").equals( + f"arn:aws:kms:us-east-1:123456789012:key/{key_id}" + ) @mock_kms diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index ed0160258..09fbd09cc 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -75,6 +75,9 @@ def test_s3_server_ignore_subdomain_for_bucketnames(): def test_s3_server_bucket_versioning(): test_client = authenticated_client() + res = test_client.put("/", "http://foobaz.localhost:5000/") + res.status_code.should.equal(200) + # Just enough XML to enable versioning body = "Enabled" res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body)