Techdebt - Align models-responses integration for all services (#5207)

This commit is contained in:
Bert Blommers 2022-06-09 17:40:22 +00:00 committed by GitHub
parent 47fe052c6f
commit a2c2c06243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 609 additions and 747 deletions

View File

@ -4,6 +4,7 @@ from collections import defaultdict
import copy import copy
import datetime import datetime
from gzip import GzipFile from gzip import GzipFile
from typing import Mapping
from sys import platform from sys import platform
import docker import docker
@ -25,10 +26,10 @@ import requests.exceptions
from moto.awslambda.policy import Policy from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backends
from moto.iam.exceptions import IAMNotFoundException from moto.iam.exceptions import IAMNotFoundException
from moto.core.utils import unix_time_millis, BackendDict from moto.core.utils import unix_time_millis, BackendDict
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from moto.logs.models import logs_backends from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings from moto import settings
@ -182,7 +183,7 @@ def _validate_s3_bucket_and_key(data):
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_object(data["S3Bucket"], data["S3Key"]) key = s3_backends["global"].get_object(data["S3Bucket"], data["S3Key"])
except MissingBucket: except MissingBucket:
if do_validate_s3(): if do_validate_s3():
raise InvalidParameterValueException( raise InvalidParameterValueException(
@ -585,7 +586,7 @@ class LambdaFunction(CloudFormationModel, DockerModel):
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_object( key = s3_backends["global"].get_object(
updated_spec["S3Bucket"], updated_spec["S3Key"] updated_spec["S3Bucket"], updated_spec["S3Key"]
) )
except MissingBucket: except MissingBucket:
@ -1121,7 +1122,7 @@ class LambdaStorage(object):
if account != get_account_id(): if account != get_account_id():
raise CrossAccountNotAllowed() raise CrossAccountNotAllowed()
try: try:
iam_backend.get_role_by_arn(fn.role) iam_backends["global"].get_role_by_arn(fn.role)
except IAMNotFoundException: except IAMNotFoundException:
raise InvalidParameterValueException( raise InvalidParameterValueException(
"The role defined for the function cannot be assumed by Lambda." "The role defined for the function cannot be assumed by Lambda."
@ -1666,4 +1667,4 @@ def do_validate_s3():
return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"] return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"]
lambda_backends = BackendDict(LambdaBackend, "lambda") lambda_backends: Mapping[str, LambdaBackend] = BackendDict(LambdaBackend, "lambda")

View File

@ -47,7 +47,7 @@ from moto.ssm import models # noqa # pylint: disable=all
# End ugly list of imports # End ugly list of imports
from moto.core import get_account_id, CloudFormationModel from moto.core import get_account_id, CloudFormationModel
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from moto.s3.utils import bucket_and_name_from_url from moto.s3.utils import bucket_and_name_from_url
from moto.ssm import ssm_backends from moto.ssm import ssm_backends
from .utils import random_suffix from .utils import random_suffix
@ -528,7 +528,7 @@ class ResourceMap(collections_abc.Mapping):
if name == "AWS::Include": if name == "AWS::Include":
location = params["Location"] location = params["Location"]
bucket_name, name = bucket_and_name_from_url(location) bucket_name, name = bucket_and_name_from_url(location)
key = s3_backend.get_object(bucket_name, name) key = s3_backends["global"].get_object(bucket_name, name)
self._parsed_resources.update(json.loads(key.value)) self._parsed_resources.update(json.loads(key.value))
def parse_ssm_parameter(self, value, value_type): def parse_ssm_parameter(self, value, value_type):

View File

@ -6,7 +6,7 @@ from yaml.scanner import ScannerError # pylint:disable=c-extension-no-member
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from moto.s3.exceptions import S3ClientError from moto.s3.exceptions import S3ClientError
from moto.core import get_account_id from moto.core import get_account_id
from .models import cloudformation_backends from .models import cloudformation_backends
@ -68,7 +68,7 @@ class CloudFormationResponse(BaseResponse):
bucket_name = template_url_parts.netloc.split(".")[0] bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/") key_name = template_url_parts.path.lstrip("/")
key = s3_backend.get_object(bucket_name, key_name) key = s3_backends["global"].get_object(bucket_name, key_name)
return key.value.decode("utf-8") return key.value.decode("utf-8")
def _get_params_from_list(self, parameters_list): def _get_params_from_list(self, parameters_list):

View File

@ -255,4 +255,3 @@ cloudfront_backends = BackendDict(
use_boto3_regions=False, use_boto3_regions=False,
additional_regions=["global"], additional_regions=["global"],
) )
cloudfront_backend = cloudfront_backends["global"]

View File

@ -1,7 +1,7 @@
import xmltodict import xmltodict
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cloudfront_backend from .models import cloudfront_backends
XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/"
@ -11,6 +11,10 @@ class CloudFrontResponse(BaseResponse):
def _get_xml_body(self): def _get_xml_body(self):
return xmltodict.parse(self.body, dict_constructor=dict) return xmltodict.parse(self.body, dict_constructor=dict)
@property
def backend(self):
return cloudfront_backends["global"]
def distributions(self, request, full_url, headers): def distributions(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
@ -21,7 +25,7 @@ class CloudFrontResponse(BaseResponse):
def create_distribution(self): def create_distribution(self):
params = self._get_xml_body() params = self._get_xml_body()
distribution_config = params.get("DistributionConfig") distribution_config = params.get("DistributionConfig")
distribution, location, e_tag = cloudfront_backend.create_distribution( distribution, location, e_tag = self.backend.create_distribution(
distribution_config=distribution_config distribution_config=distribution_config
) )
template = self.response_template(CREATE_DISTRIBUTION_TEMPLATE) template = self.response_template(CREATE_DISTRIBUTION_TEMPLATE)
@ -30,7 +34,7 @@ class CloudFrontResponse(BaseResponse):
return 200, headers, response return 200, headers, response
def list_distributions(self): def list_distributions(self):
distributions = cloudfront_backend.list_distributions() distributions = self.backend.list_distributions()
template = self.response_template(LIST_TEMPLATE) template = self.response_template(LIST_TEMPLATE)
response = template.render(distributions=distributions) response = template.render(distributions=distributions)
return 200, {}, response return 200, {}, response
@ -40,10 +44,10 @@ class CloudFrontResponse(BaseResponse):
distribution_id = full_url.split("/")[-1] distribution_id = full_url.split("/")[-1]
if request.method == "DELETE": if request.method == "DELETE":
if_match = self._get_param("If-Match") if_match = self._get_param("If-Match")
cloudfront_backend.delete_distribution(distribution_id, if_match) self.backend.delete_distribution(distribution_id, if_match)
return 204, {}, "" return 204, {}, ""
if request.method == "GET": if request.method == "GET":
dist, etag = cloudfront_backend.get_distribution(distribution_id) dist, etag = self.backend.get_distribution(distribution_id)
template = self.response_template(GET_DISTRIBUTION_TEMPLATE) template = self.response_template(GET_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS) response = template.render(distribution=dist, xmlns=XMLNS)
return 200, {"ETag": etag}, response return 200, {"ETag": etag}, response
@ -55,7 +59,7 @@ class CloudFrontResponse(BaseResponse):
dist_id = full_url.split("/")[-2] dist_id = full_url.split("/")[-2]
if_match = headers["If-Match"] if_match = headers["If-Match"]
dist, location, e_tag = cloudfront_backend.update_distribution( dist, location, e_tag = self.backend.update_distribution(
DistributionConfig=distribution_config, DistributionConfig=distribution_config,
Id=dist_id, Id=dist_id,
IfMatch=if_match, IfMatch=if_match,

View File

@ -130,10 +130,10 @@ class Trail(BaseModel):
raise TrailNameInvalidChars() raise TrailNameInvalidChars()
def check_bucket_exists(self): def check_bucket_exists(self):
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
try: try:
s3_backend.get_bucket(self.bucket_name) s3_backends["global"].get_bucket(self.bucket_name)
except Exception: except Exception:
raise S3BucketDoesNotExistException( raise S3BucketDoesNotExistException(
f"S3 bucket {self.bucket_name} does not exist!" f"S3 bucket {self.bucket_name} does not exist!"

View File

@ -544,7 +544,6 @@ class CloudWatchBackend(BaseBackend):
unit=None, unit=None,
): ):
period_delta = timedelta(seconds=period) period_delta = timedelta(seconds=period)
# TODO: Also filter by unit and dimensions
filtered_data = [ filtered_data = [
md md
for md in self.get_all_metrics() for md in self.get_all_metrics()

View File

@ -4,6 +4,10 @@ from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse): class CognitoIdentityResponse(BaseResponse):
@property
def backend(self):
return cognitoidentity_backends[self.region]
def create_identity_pool(self): def create_identity_pool(self):
identity_pool_name = self._get_param("IdentityPoolName") identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param( allow_unauthenticated_identities = self._get_param(
@ -16,7 +20,7 @@ class CognitoIdentityResponse(BaseResponse):
saml_provider_arns = self._get_param("SamlProviderARNs") saml_provider_arns = self._get_param("SamlProviderARNs")
pool_tags = self._get_param("IdentityPoolTags") pool_tags = self._get_param("IdentityPoolTags")
return cognitoidentity_backends[self.region].create_identity_pool( return self.backend.create_identity_pool(
identity_pool_name=identity_pool_name, identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers, supported_login_providers=supported_login_providers,
@ -38,7 +42,7 @@ class CognitoIdentityResponse(BaseResponse):
saml_providers = self._get_param("SamlProviderARNs") saml_providers = self._get_param("SamlProviderARNs")
pool_tags = self._get_param("IdentityPoolTags") pool_tags = self._get_param("IdentityPoolTags")
return cognitoidentity_backends[self.region].update_identity_pool( return self.backend.update_identity_pool(
identity_pool_id=pool_id, identity_pool_id=pool_id,
identity_pool_name=pool_name, identity_pool_name=pool_name,
allow_unauthenticated=allow_unauthenticated, allow_unauthenticated=allow_unauthenticated,
@ -51,19 +55,13 @@ class CognitoIdentityResponse(BaseResponse):
) )
def get_id(self): def get_id(self):
return cognitoidentity_backends[self.region].get_id( return self.backend.get_id(identity_pool_id=self._get_param("IdentityPoolId"))
identity_pool_id=self._get_param("IdentityPoolId")
)
def describe_identity_pool(self): def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool( return self.backend.describe_identity_pool(self._get_param("IdentityPoolId"))
self._get_param("IdentityPoolId")
)
def get_credentials_for_identity(self): def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity( return self.backend.get_credentials_for_identity(self._get_param("IdentityId"))
self._get_param("IdentityId")
)
def get_open_id_token_for_developer_identity(self): def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[ return cognitoidentity_backends[
@ -73,11 +71,11 @@ class CognitoIdentityResponse(BaseResponse):
) )
def get_open_id_token(self): def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token( return self.backend.get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region) self._get_param("IdentityId") or get_random_identity_id(self.region)
) )
def list_identities(self): def list_identities(self):
return cognitoidentity_backends[self.region].list_identities( return self.backend.list_identities(
self._get_param("IdentityPoolId") or get_random_identity_id(self.region) self._get_param("IdentityPoolId") or get_random_identity_id(self.region)
) )

View File

@ -20,12 +20,14 @@ class CognitoIdpResponse(BaseResponse):
def parameters(self): def parameters(self):
return json.loads(self.body) return json.loads(self.body)
@property
def backend(self):
return cognitoidp_backends[self.region]
# User pool # User pool
def create_user_pool(self): def create_user_pool(self):
name = self.parameters.pop("PoolName") name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool( user_pool = self.backend.create_user_pool(name, self.parameters)
name, self.parameters
)
return json.dumps({"UserPool": user_pool.to_json(extended=True)}) return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def set_user_pool_mfa_config(self): def set_user_pool_mfa_config(self):
@ -50,22 +52,20 @@ class CognitoIdpResponse(BaseResponse):
"[SmsConfiguration] is a required member of [SoftwareTokenMfaConfiguration]." "[SmsConfiguration] is a required member of [SoftwareTokenMfaConfiguration]."
) )
response = cognitoidp_backends[self.region].set_user_pool_mfa_config( response = self.backend.set_user_pool_mfa_config(
user_pool_id, sms_config, token_config, mfa_config user_pool_id, sms_config, token_config, mfa_config
) )
return json.dumps(response) return json.dumps(response)
def get_user_pool_mfa_config(self): def get_user_pool_mfa_config(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
response = cognitoidp_backends[self.region].get_user_pool_mfa_config( response = self.backend.get_user_pool_mfa_config(user_pool_id)
user_pool_id
)
return json.dumps(response) return json.dumps(response)
def list_user_pools(self): def list_user_pools(self):
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( user_pools, next_token = self.backend.list_user_pools(
max_results=max_results, next_token=next_token max_results=max_results, next_token=next_token
) )
response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]} response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
@ -75,16 +75,16 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self): def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) user_pool = self.backend.describe_user_pool(user_pool_id)
return json.dumps({"UserPool": user_pool.to_json(extended=True)}) return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def update_user_pool(self): def update_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].update_user_pool(user_pool_id, self.parameters) self.backend.update_user_pool(user_pool_id, self.parameters)
def delete_user_pool(self): def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].delete_user_pool(user_pool_id) self.backend.delete_user_pool(user_pool_id)
return "" return ""
# User pool domain # User pool domain
@ -92,7 +92,7 @@ class CognitoIdpResponse(BaseResponse):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
custom_domain_config = self._get_param("CustomDomainConfig") custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain( user_pool_domain = self.backend.create_user_pool_domain(
user_pool_id, domain, custom_domain_config user_pool_id, domain, custom_domain_config
) )
domain_description = user_pool_domain.to_json(extended=False) domain_description = user_pool_domain.to_json(extended=False)
@ -102,9 +102,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_domain(self): def describe_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain( user_pool_domain = self.backend.describe_user_pool_domain(domain)
domain
)
domain_description = {} domain_description = {}
if user_pool_domain: if user_pool_domain:
domain_description = user_pool_domain.to_json() domain_description = user_pool_domain.to_json()
@ -113,13 +111,13 @@ class CognitoIdpResponse(BaseResponse):
def delete_user_pool_domain(self): def delete_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
cognitoidp_backends[self.region].delete_user_pool_domain(domain) self.backend.delete_user_pool_domain(domain)
return "" return ""
def update_user_pool_domain(self): def update_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig") custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain( user_pool_domain = self.backend.update_user_pool_domain(
domain, custom_domain_config domain, custom_domain_config
) )
domain_description = user_pool_domain.to_json(extended=False) domain_description = user_pool_domain.to_json(extended=False)
@ -131,7 +129,7 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_client(self): def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
generate_secret = self.parameters.pop("GenerateSecret", False) generate_secret = self.parameters.pop("GenerateSecret", False)
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client( user_pool_client = self.backend.create_user_pool_client(
user_pool_id, generate_secret, self.parameters user_pool_id, generate_secret, self.parameters
) )
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -157,7 +155,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self): def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client( user_pool_client = self.backend.describe_user_pool_client(
user_pool_id, client_id user_pool_id, client_id
) )
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -165,7 +163,7 @@ class CognitoIdpResponse(BaseResponse):
def update_user_pool_client(self): def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId") client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client( user_pool_client = self.backend.update_user_pool_client(
user_pool_id, client_id, self.parameters user_pool_id, client_id, self.parameters
) )
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -173,16 +171,14 @@ class CognitoIdpResponse(BaseResponse):
def delete_user_pool_client(self): def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client( self.backend.delete_user_pool_client(user_pool_id, client_id)
user_pool_id, client_id
)
return "" return ""
# Identity provider # Identity provider
def create_identity_provider(self): def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName") name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider( identity_provider = self.backend.create_identity_provider(
user_pool_id, name, self.parameters user_pool_id, name, self.parameters
) )
return json.dumps( return json.dumps(
@ -210,9 +206,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self): def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider( identity_provider = self.backend.describe_identity_provider(user_pool_id, name)
user_pool_id, name
)
return json.dumps( return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)} {"IdentityProvider": identity_provider.to_json(extended=True)}
) )
@ -220,7 +214,7 @@ class CognitoIdpResponse(BaseResponse):
def update_identity_provider(self): def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider( identity_provider = self.backend.update_identity_provider(
user_pool_id, name, self.parameters user_pool_id, name, self.parameters
) )
return json.dumps( return json.dumps(
@ -230,7 +224,7 @@ class CognitoIdpResponse(BaseResponse):
def delete_identity_provider(self): def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
cognitoidp_backends[self.region].delete_identity_provider(user_pool_id, name) self.backend.delete_identity_provider(user_pool_id, name)
return "" return ""
# Group # Group
@ -241,7 +235,7 @@ class CognitoIdpResponse(BaseResponse):
role_arn = self._get_param("RoleArn") role_arn = self._get_param("RoleArn")
precedence = self._get_param("Precedence") precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group( group = self.backend.create_group(
user_pool_id, group_name, description, role_arn, precedence user_pool_id, group_name, description, role_arn, precedence
) )
@ -250,18 +244,18 @@ class CognitoIdpResponse(BaseResponse):
def get_group(self): def get_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) group = self.backend.get_group(user_pool_id, group_name)
return json.dumps({"Group": group.to_json()}) return json.dumps({"Group": group.to_json()})
def list_groups(self): def list_groups(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id) groups = self.backend.list_groups(user_pool_id)
return json.dumps({"Groups": [group.to_json() for group in groups]}) return json.dumps({"Groups": [group.to_json() for group in groups]})
def delete_group(self): def delete_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].delete_group(user_pool_id, group_name) self.backend.delete_group(user_pool_id, group_name)
return "" return ""
def update_group(self): def update_group(self):
@ -271,7 +265,7 @@ class CognitoIdpResponse(BaseResponse):
role_arn = self._get_param("RoleArn") role_arn = self._get_param("RoleArn")
precedence = self._get_param("Precedence") precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].update_group( group = self.backend.update_group(
user_pool_id, group_name, description, role_arn, precedence user_pool_id, group_name, description, role_arn, precedence
) )
@ -282,26 +276,20 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group( self.backend.admin_add_user_to_group(user_pool_id, group_name, username)
user_pool_id, group_name, username
)
return "" return ""
def list_users_in_group(self): def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group( users = self.backend.list_users_in_group(user_pool_id, group_name)
user_pool_id, group_name
)
return json.dumps({"Users": [user.to_json(extended=True) for user in users]}) return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self): def admin_list_groups_for_user(self):
username = self._get_param("Username") username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user( groups = self.backend.admin_list_groups_for_user(user_pool_id, username)
user_pool_id, username
)
return json.dumps({"Groups": [group.to_json() for group in groups]}) return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self): def admin_remove_user_from_group(self):
@ -309,18 +297,14 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group( self.backend.admin_remove_user_from_group(user_pool_id, group_name, username)
user_pool_id, group_name, username
)
return "" return ""
def admin_reset_user_password(self): def admin_reset_user_password(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends[self.region].admin_reset_user_password( self.backend.admin_reset_user_password(user_pool_id, username)
user_pool_id, username
)
return "" return ""
# User # User
@ -329,7 +313,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
message_action = self._get_param("MessageAction") message_action = self._get_param("MessageAction")
temporary_password = self._get_param("TemporaryPassword") temporary_password = self._get_param("TemporaryPassword")
user = cognitoidp_backends[self.region].admin_create_user( user = self.backend.admin_create_user(
user_pool_id, user_pool_id,
username, username,
message_action, message_action,
@ -342,14 +326,12 @@ class CognitoIdpResponse(BaseResponse):
def admin_confirm_sign_up(self): def admin_confirm_sign_up(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
return cognitoidp_backends[self.region].admin_confirm_sign_up( return self.backend.admin_confirm_sign_up(user_pool_id, username)
user_pool_id, username
)
def admin_get_user(self): def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) user = self.backend.admin_get_user(user_pool_id, username)
return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes")) return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
def get_user(self): def get_user(self):
@ -363,7 +345,7 @@ class CognitoIdpResponse(BaseResponse):
token = self._get_param("PaginationToken") token = self._get_param("PaginationToken")
filt = self._get_param("Filter") filt = self._get_param("Filter")
attributes_to_get = self._get_param("AttributesToGet") attributes_to_get = self._get_param("AttributesToGet")
users, token = cognitoidp_backends[self.region].list_users( users, token = self.backend.list_users(
user_pool_id, limit=limit, pagination_token=token user_pool_id, limit=limit, pagination_token=token
) )
if filt: if filt:
@ -420,19 +402,19 @@ class CognitoIdpResponse(BaseResponse):
def admin_disable_user(self): def admin_disable_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends[self.region].admin_disable_user(user_pool_id, username) self.backend.admin_disable_user(user_pool_id, username)
return "" return ""
def admin_enable_user(self): def admin_enable_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends[self.region].admin_enable_user(user_pool_id, username) self.backend.admin_enable_user(user_pool_id, username)
return "" return ""
def admin_delete_user(self): def admin_delete_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends[self.region].admin_delete_user(user_pool_id, username) self.backend.admin_delete_user(user_pool_id, username)
return "" return ""
def admin_initiate_auth(self): def admin_initiate_auth(self):
@ -441,7 +423,7 @@ class CognitoIdpResponse(BaseResponse):
auth_flow = self._get_param("AuthFlow") auth_flow = self._get_param("AuthFlow")
auth_parameters = self._get_param("AuthParameters") auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth( auth_result = self.backend.admin_initiate_auth(
user_pool_id, client_id, auth_flow, auth_parameters user_pool_id, client_id, auth_flow, auth_parameters
) )
@ -501,31 +483,25 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
attributes = self._get_param("UserAttributes") attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes( self.backend.admin_update_user_attributes(user_pool_id, username, attributes)
user_pool_id, username, attributes
)
return "" return ""
def admin_delete_user_attributes(self): def admin_delete_user_attributes(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
attributes = self._get_param("UserAttributeNames") attributes = self._get_param("UserAttributeNames")
cognitoidp_backends[self.region].admin_delete_user_attributes( self.backend.admin_delete_user_attributes(user_pool_id, username, attributes)
user_pool_id, username, attributes
)
return "" return ""
def admin_user_global_sign_out(self): def admin_user_global_sign_out(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
cognitoidp_backends[self.region].admin_user_global_sign_out( self.backend.admin_user_global_sign_out(user_pool_id, username)
user_pool_id, username
)
return "" return ""
def global_sign_out(self): def global_sign_out(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
cognitoidp_backends[self.region].global_sign_out(access_token) self.backend.global_sign_out(access_token)
return "" return ""
# Resource Server # Resource Server
@ -534,7 +510,7 @@ class CognitoIdpResponse(BaseResponse):
identifier = self._get_param("Identifier") identifier = self._get_param("Identifier")
name = self._get_param("Name") name = self._get_param("Name")
scopes = self._get_param("Scopes") scopes = self._get_param("Scopes")
resource_server = cognitoidp_backends[self.region].create_resource_server( resource_server = self.backend.create_resource_server(
user_pool_id, identifier, name, scopes user_pool_id, identifier, name, scopes
) )
return json.dumps({"ResourceServer": resource_server.to_json()}) return json.dumps({"ResourceServer": resource_server.to_json()})
@ -575,19 +551,19 @@ class CognitoIdpResponse(BaseResponse):
def associate_software_token(self): def associate_software_token(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
result = cognitoidp_backends[self.region].associate_software_token(access_token) result = self.backend.associate_software_token(access_token)
return json.dumps(result) return json.dumps(result)
def verify_software_token(self): def verify_software_token(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
result = cognitoidp_backends[self.region].verify_software_token(access_token) result = self.backend.verify_software_token(access_token)
return json.dumps(result) return json.dumps(result)
def set_user_mfa_preference(self): def set_user_mfa_preference(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings") software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings") sms_mfa_settings = self._get_param("SMSMfaSettings")
cognitoidp_backends[self.region].set_user_mfa_preference( self.backend.set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings access_token, software_token_mfa_settings, sms_mfa_settings
) )
return "" return ""
@ -597,7 +573,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings") software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings") sms_mfa_settings = self._get_param("SMSMfaSettings")
cognitoidp_backends[self.region].admin_set_user_mfa_preference( self.backend.admin_set_user_mfa_preference(
user_pool_id, username, software_token_mfa_settings, sms_mfa_settings user_pool_id, username, software_token_mfa_settings, sms_mfa_settings
) )
return "" return ""
@ -607,7 +583,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
password = self._get_param("Password") password = self._get_param("Password")
permanent = self._get_param("Permanent") permanent = self._get_param("Permanent")
cognitoidp_backends[self.region].admin_set_user_password( self.backend.admin_set_user_password(
user_pool_id, username, password, permanent user_pool_id, username, password, permanent
) )
return "" return ""
@ -615,17 +591,13 @@ class CognitoIdpResponse(BaseResponse):
def add_custom_attributes(self): def add_custom_attributes(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
custom_attributes = self._get_param("CustomAttributes") custom_attributes = self._get_param("CustomAttributes")
cognitoidp_backends[self.region].add_custom_attributes( self.backend.add_custom_attributes(user_pool_id, custom_attributes)
user_pool_id, custom_attributes
)
return "" return ""
def update_user_attributes(self): def update_user_attributes(self):
access_token = self._get_param("AccessToken") access_token = self._get_param("AccessToken")
attributes = self._get_param("UserAttributes") attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].update_user_attributes( self.backend.update_user_attributes(access_token, attributes)
access_token, attributes
)
return json.dumps({}) return json.dumps({})

View File

@ -1468,14 +1468,11 @@ class ConfigBackend(BaseBackend):
backend_query_region = ( backend_query_region = (
backend_region # Always provide the backend this request arrived from. backend_region # Always provide the backend this request arrived from.
) )
print(RESOURCE_MAP[resource_type].backends)
if RESOURCE_MAP[resource_type].backends.get("global"): if RESOURCE_MAP[resource_type].backends.get("global"):
print("yes, its global")
backend_region = "global" backend_region = "global"
# If the backend region isn't implemented then we won't find the item: # If the backend region isn't implemented then we won't find the item:
if not RESOURCE_MAP[resource_type].backends.get(backend_region): if not RESOURCE_MAP[resource_type].backends.get(backend_region):
print(f"cant find {backend_region} for {resource_type}")
raise ResourceNotDiscoveredException(resource_type, resource_id) raise ResourceNotDiscoveredException(resource_type, resource_id)
# Get the item: # Get the item:
@ -1483,7 +1480,6 @@ class ConfigBackend(BaseBackend):
resource_id, backend_region=backend_query_region resource_id, backend_region=backend_query_region
) )
if not item: if not item:
print("item not found")
raise ResourceNotDiscoveredException(resource_type, resource_id) raise ResourceNotDiscoveredException(resource_type, resource_id)
item["accountId"] = get_account_id() item["accountId"] = get_account_id()

View File

@ -204,7 +204,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def dispatch(cls, *args, **kwargs): def dispatch(cls, *args, **kwargs):
return cls()._dispatch(*args, **kwargs) return cls()._dispatch(*args, **kwargs)
def setup_class(self, request, full_url, headers): def setup_class(self, request, full_url, headers, use_raw_body=False):
"""
use_raw_body: Use incoming bytes if True, encode to string otherwise
"""
querystring = OrderedDict() querystring = OrderedDict()
if hasattr(request, "body"): if hasattr(request, "body"):
# Boto # Boto
@ -222,7 +225,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring[key] = [value] querystring[key] = [value]
raw_body = self.body raw_body = self.body
if isinstance(self.body, bytes): if isinstance(self.body, bytes) and not use_raw_body:
self.body = self.body.decode("utf-8") self.body = self.body.decode("utf-8")
if not querystring: if not querystring:
@ -244,7 +247,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
flat = flatten_json_request_body("", decoded, input_spec) flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items(): for key, value in flat.items():
querystring[key] = [value] querystring[key] = [value]
elif self.body: elif self.body and not use_raw_body:
try: try:
querystring.update( querystring.update(
OrderedDict( OrderedDict(
@ -254,7 +257,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
) )
) )
except (UnicodeEncodeError, UnicodeDecodeError): except (UnicodeEncodeError, UnicodeDecodeError, AttributeError):
pass # ignore encoding errors, as the body may not contain a legitimate querystring pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring: if not querystring:
querystring.update(headers) querystring.update(headers)

View File

@ -412,6 +412,54 @@ def extract_region_from_aws_authorization(string):
backend_lock = RLock() backend_lock = RLock()
class AccountSpecificBackend(dict):
"""
Dictionary storing the data for a service in a specific account.
Data access pattern:
account_specific_backend[region: str] = backend: BaseBackend
"""
def __init__(
self, service_name, account_id, backend, use_boto3_regions, additional_regions
):
self.service_name = service_name
self.account_id = account_id
self.backend = backend
self.regions = []
if use_boto3_regions:
sess = Session()
self.regions.extend(sess.get_available_regions(service_name))
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
self.regions.extend(additional_regions or [])
def reset(self):
for region_specific_backend in self.values():
region_specific_backend.reset()
def __contains__(self, region):
return region in self.regions or region in self.keys()
def __getitem__(self, region_name):
if region_name in self.keys():
return super().__getitem__(region_name)
# Create the backend for a specific region
with backend_lock:
if region_name in self.regions and region_name not in self.keys():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
if region_name not in self.regions and allow_unknown_region():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
return super().__getitem__(region_name)
class BackendDict(dict): class BackendDict(dict):
""" """
Data Structure to store everything related to a specific service. Data Structure to store everything related to a specific service.
@ -484,51 +532,3 @@ class BackendDict(dict):
use_boto3_regions=self._use_boto3_regions, use_boto3_regions=self._use_boto3_regions,
additional_regions=self._additional_regions, additional_regions=self._additional_regions,
) )
class AccountSpecificBackend(dict):
"""
Dictionary storing the data for a service in a specific account.
Data access pattern:
account_specific_backend[region: str] = backend: BaseBackend
"""
def __init__(
self, service_name, account_id, backend, use_boto3_regions, additional_regions
):
self.service_name = service_name
self.account_id = account_id
self.backend = backend
self.regions = []
if use_boto3_regions:
sess = Session()
self.regions.extend(sess.get_available_regions(service_name))
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
self.regions.extend(additional_regions or [])
def reset(self):
for region_specific_backend in self.values():
region_specific_backend.reset()
def __contains__(self, region):
return region in self.regions or region in self.keys()
def __getitem__(self, region_name):
if region_name in self.keys():
return super().__getitem__(region_name)
# Create the backend for a specific region
with backend_lock:
if region_name in self.regions and region_name not in self.keys():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
if region_name not in self.regions and allow_unknown_region():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
return super().__getitem__(region_name)

View File

@ -5,23 +5,15 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse): class DataPipelineResponse(BaseResponse):
@property
def parameters(self):
# TODO this should really be moved to core/responses.py
if self.body:
return json.loads(self.body)
else:
return self.querystring
@property @property
def datapipeline_backend(self): def datapipeline_backend(self):
return datapipeline_backends[self.region] return datapipeline_backends[self.region]
def create_pipeline(self): def create_pipeline(self):
name = self.parameters.get("name") name = self._get_param("name")
unique_id = self.parameters.get("uniqueId") unique_id = self._get_param("uniqueId")
description = self.parameters.get("description", "") description = self._get_param("description", "")
tags = self.parameters.get("tags", []) tags = self._get_param("tags", [])
pipeline = self.datapipeline_backend.create_pipeline( pipeline = self.datapipeline_backend.create_pipeline(
name, unique_id, description=description, tags=tags name, unique_id, description=description, tags=tags
) )
@ -31,7 +23,7 @@ class DataPipelineResponse(BaseResponse):
pipelines = list(self.datapipeline_backend.list_pipelines()) pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50 max_pipelines = 50
marker = self.parameters.get("marker") marker = self._get_param("marker")
if marker: if marker:
start = pipeline_ids.index(marker) + 1 start = pipeline_ids.index(marker) + 1
else: else:
@ -53,7 +45,7 @@ class DataPipelineResponse(BaseResponse):
) )
def describe_pipelines(self): def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"] pipeline_ids = self._get_param("pipelineIds")
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps( return json.dumps(
@ -61,19 +53,19 @@ class DataPipelineResponse(BaseResponse):
) )
def delete_pipeline(self): def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self._get_param("pipelineId")
self.datapipeline_backend.delete_pipeline(pipeline_id) self.datapipeline_backend.delete_pipeline(pipeline_id)
return json.dumps({}) return json.dumps({})
def put_pipeline_definition(self): def put_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self._get_param("pipelineId")
pipeline_objects = self.parameters["pipelineObjects"] pipeline_objects = self._get_param("pipelineObjects")
self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects) self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
return json.dumps({"errored": False}) return json.dumps({"errored": False})
def get_pipeline_definition(self): def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self._get_param("pipelineId")
pipeline_definition = self.datapipeline_backend.get_pipeline_definition( pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id pipeline_id
) )
@ -86,8 +78,8 @@ class DataPipelineResponse(BaseResponse):
) )
def describe_objects(self): def describe_objects(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self._get_param("pipelineId")
object_ids = self.parameters["objectIds"] object_ids = self._get_param("objectIds")
pipeline_objects = self.datapipeline_backend.describe_objects( pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id object_ids, pipeline_id
@ -103,6 +95,6 @@ class DataPipelineResponse(BaseResponse):
) )
def activate_pipeline(self): def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self._get_param("pipelineId")
self.datapipeline_backend.activate_pipeline(pipeline_id) self.datapipeline_backend.activate_pipeline(pipeline_id)
return json.dumps({}) return json.dumps({})

View File

@ -397,4 +397,3 @@ dynamodb_backends = BackendDict(
use_boto3_regions=False, use_boto3_regions=False,
additional_regions=["global"], additional_regions=["global"],
) )
dynamodb_backend = dynamodb_backends["global"]

View File

@ -2,7 +2,7 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backend, dynamo_json_dump from .models import dynamodb_backends, dynamo_json_dump
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
@ -36,15 +36,19 @@ class DynamoHandler(BaseResponse):
else: else:
return 404, self.response_headers, "" return 404, self.response_headers, ""
@property
def backend(self):
return dynamodb_backends["global"]
def list_tables(self): def list_tables(self):
body = self.body body = self.body
limit = body.get("Limit") limit = body.get("Limit")
if body.get("ExclusiveStartTableName"): if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName") last = body.get("ExclusiveStartTableName")
start = list(dynamodb_backend.tables.keys()).index(last) + 1 start = list(self.backend.tables.keys()).index(last) + 1
else: else:
start = 0 start = 0
all_tables = list(dynamodb_backend.tables.keys()) all_tables = list(self.backend.tables.keys())
if limit: if limit:
tables = all_tables[start : start + limit] tables = all_tables[start : start + limit]
else: else:
@ -71,7 +75,7 @@ class DynamoHandler(BaseResponse):
read_units = throughput["ReadCapacityUnits"] read_units = throughput["ReadCapacityUnits"]
write_units = throughput["WriteCapacityUnits"] write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.create_table( table = self.backend.create_table(
name, name,
hash_key_attr=hash_key_attr, hash_key_attr=hash_key_attr,
hash_key_type=hash_key_type, hash_key_type=hash_key_type,
@ -84,7 +88,7 @@ class DynamoHandler(BaseResponse):
def delete_table(self): def delete_table(self):
name = self.body["TableName"] name = self.body["TableName"]
table = dynamodb_backend.delete_table(name) table = self.backend.delete_table(name)
if table: if table:
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
else: else:
@ -96,7 +100,7 @@ class DynamoHandler(BaseResponse):
throughput = self.body["ProvisionedThroughput"] throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"] new_read_units = throughput["ReadCapacityUnits"]
new_write_units = throughput["WriteCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.update_table_throughput( table = self.backend.update_table_throughput(
name, new_read_units, new_write_units name, new_read_units, new_write_units
) )
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
@ -104,7 +108,7 @@ class DynamoHandler(BaseResponse):
def describe_table(self): def describe_table(self):
name = self.body["TableName"] name = self.body["TableName"]
try: try:
table = dynamodb_backend.tables[name] table = self.backend.tables[name]
except KeyError: except KeyError:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
@ -113,7 +117,7 @@ class DynamoHandler(BaseResponse):
def put_item(self): def put_item(self):
name = self.body["TableName"] name = self.body["TableName"]
item = self.body["Item"] item = self.body["Item"]
result = dynamodb_backend.put_item(name, item) result = self.backend.put_item(name, item)
if result: if result:
item_dict = result.to_json() item_dict = result.to_json()
item_dict["ConsumedCapacityUnits"] = 1 item_dict["ConsumedCapacityUnits"] = 1
@ -132,12 +136,12 @@ class DynamoHandler(BaseResponse):
if request_type == "PutRequest": if request_type == "PutRequest":
item = request["Item"] item = request["Item"]
dynamodb_backend.put_item(table_name, item) self.backend.put_item(table_name, item)
elif request_type == "DeleteRequest": elif request_type == "DeleteRequest":
key = request["Key"] key = request["Key"]
hash_key = key["HashKeyElement"] hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
item = dynamodb_backend.delete_item(table_name, hash_key, range_key) self.backend.delete_item(table_name, hash_key, range_key)
response = { response = {
"Responses": { "Responses": {
@ -156,7 +160,7 @@ class DynamoHandler(BaseResponse):
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
attrs_to_get = self.body.get("AttributesToGet") attrs_to_get = self.body.get("AttributesToGet")
try: try:
item = dynamodb_backend.get_item(name, hash_key, range_key) item = self.backend.get_item(name, hash_key, range_key)
except ValueError: except ValueError:
er = "com.amazon.coral.validate#ValidationException" er = "com.amazon.coral.validate#ValidationException"
return self.error(er, status=400) return self.error(er, status=400)
@ -181,7 +185,7 @@ class DynamoHandler(BaseResponse):
for key in keys: for key in keys:
hash_key = key["HashKeyElement"] hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
item = dynamodb_backend.get_item(table_name, hash_key, range_key) item = self.backend.get_item(table_name, hash_key, range_key)
if item: if item:
item_describe = item.describe_attrs(attributes_to_get) item_describe = item.describe_attrs(attributes_to_get)
items.append(item_describe) items.append(item_describe)
@ -202,9 +206,7 @@ class DynamoHandler(BaseResponse):
range_comparison = None range_comparison = None
range_values = [] range_values = []
items, _ = dynamodb_backend.query( items, _ = self.backend.query(name, hash_key, range_comparison, range_values)
name, hash_key, range_comparison, range_values
)
if items is None: if items is None:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
@ -236,7 +238,7 @@ class DynamoHandler(BaseResponse):
comparison_values = scan_filter.get("AttributeValueList", []) comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values) filters[attribute_name] = (comparison_operator, comparison_values)
items, scanned_count, _ = dynamodb_backend.scan(name, filters) items, scanned_count, _ = self.backend.scan(name, filters)
if items is None: if items is None:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
@ -263,7 +265,7 @@ class DynamoHandler(BaseResponse):
hash_key = key["HashKeyElement"] hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
return_values = self.body.get("ReturnValues", "") return_values = self.body.get("ReturnValues", "")
item = dynamodb_backend.delete_item(name, hash_key, range_key) item = self.backend.delete_item(name, hash_key, range_key)
if item: if item:
if return_values == "ALL_OLD": if return_values == "ALL_OLD":
item_dict = item.to_json() item_dict = item.to_json()
@ -282,7 +284,7 @@ class DynamoHandler(BaseResponse):
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
updates = self.body["AttributeUpdates"] updates = self.body["AttributeUpdates"]
item = dynamodb_backend.update_item(name, hash_key, range_key, updates) item = self.backend.update_item(name, hash_key, range_key, updates)
if item: if item:
item_dict = item.to_json() item_dict = item.to_json()

View File

@ -214,12 +214,12 @@ class FlowLogsBackend:
self.get_network_interface(resource_id) self.get_network_interface(resource_id)
if log_destination_type == "s3": if log_destination_type == "s3":
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from moto.s3.exceptions import MissingBucket from moto.s3.exceptions import MissingBucket
arn = log_destination.split(":", 5)[5] arn = log_destination.split(":", 5)[5]
try: try:
s3_backend.get_bucket(arn) s3_backends["global"].get_bucket(arn)
except MissingBucket: except MissingBucket:
# Instead of creating FlowLog report # Instead of creating FlowLog report
# the unsuccessful status for the # the unsuccessful status for the

View File

@ -1547,9 +1547,9 @@ Member must satisfy regular expression pattern: {}".format(
except AWSResourceNotFoundException: except AWSResourceNotFoundException:
pass pass
from moto.iam import iam_backend from moto.iam import iam_backends
cert = iam_backend.get_certificate_by_arn(certificate_arn) cert = iam_backends["global"].get_certificate_by_arn(certificate_arn)
if cert is not None: if cert is not None:
return True return True

View File

@ -37,7 +37,7 @@ from moto.firehose.exceptions import (
ResourceNotFoundException, ResourceNotFoundException,
ValidationException, ValidationException,
) )
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
MAX_TAGS_PER_DELIVERY_STREAM = 50 MAX_TAGS_PER_DELIVERY_STREAM = 50
@ -447,7 +447,7 @@ class FirehoseBackend(BaseBackend):
batched_data = b"".join([b64decode(r["Data"]) for r in records]) batched_data = b"".join([b64decode(r["Data"]) for r in records])
try: try:
s3_backend.put_object(bucket_name, object_path, batched_data) s3_backends["global"].put_object(bucket_name, object_path, batched_data)
except Exception as exc: except Exception as exc:
# This could be better ... # This could be better ...
raise RuntimeError( raise RuntimeError(

View File

@ -1,5 +1,4 @@
from .models import iam_backend from .models import iam_backends
from ..core.models import base_decorator from ..core.models import base_decorator
iam_backends = {"global": iam_backend}
mock_iam = base_decorator(iam_backends) mock_iam = base_decorator(iam_backends)

View File

@ -39,8 +39,8 @@ from moto.s3.exceptions import (
BucketSignatureDoesNotMatchError, BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError, S3SignatureDoesNotMatchError,
) )
from moto.sts.models import sts_backend from moto.sts.models import sts_backends
from .models import iam_backend, Policy from .models import iam_backends, Policy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -53,8 +53,12 @@ def create_access_key(access_key_id, headers):
class IAMUserAccessKey(object): class IAMUserAccessKey(object):
@property
def backend(self):
return iam_backends["global"]
def __init__(self, access_key_id, headers): def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users("/", None, None) iam_users = self.backend.list_users("/", None, None)
for iam_user in iam_users: for iam_user in iam_users:
for access_key in iam_user.access_keys: for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id: if access_key.access_key_id == access_key_id:
@ -78,28 +82,30 @@ class IAMUserAccessKey(object):
def collect_policies(self): def collect_policies(self):
user_policies = [] user_policies = []
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name) inline_policy_names = self.backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names: for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy( inline_policy = self.backend.get_user_policy(
self._owner_user_name, inline_policy_name self._owner_user_name, inline_policy_name
) )
user_policies.append(inline_policy) user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies( attached_policies, _ = self.backend.list_attached_user_policies(
self._owner_user_name self._owner_user_name
) )
user_policies += attached_policies user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name) user_groups = self.backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups: for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name) inline_group_policy_names = self.backend.list_group_policies(
user_group.name
)
for inline_group_policy_name in inline_group_policy_names: for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy( inline_user_group_policy = self.backend.get_group_policy(
user_group.name, inline_group_policy_name user_group.name, inline_group_policy_name
) )
user_policies.append(inline_user_group_policy) user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies( attached_group_policies, _ = self.backend.list_attached_group_policies(
user_group.name user_group.name
) )
user_policies += attached_group_policies user_policies += attached_group_policies
@ -108,8 +114,12 @@ class IAMUserAccessKey(object):
class AssumedRoleAccessKey(object): class AssumedRoleAccessKey(object):
@property
def backend(self):
return iam_backends["global"]
def __init__(self, access_key_id, headers): def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles: for assumed_role in sts_backends["global"].assumed_roles:
if assumed_role.access_key_id == access_key_id: if assumed_role.access_key_id == access_key_id:
self._access_key_id = access_key_id self._access_key_id = access_key_id
self._secret_access_key = assumed_role.secret_access_key self._secret_access_key = assumed_role.secret_access_key
@ -139,14 +149,14 @@ class AssumedRoleAccessKey(object):
def collect_policies(self): def collect_policies(self):
role_policies = [] role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name) inline_policy_names = self.backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names: for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy( _, inline_policy = self.backend.get_role_policy(
self._owner_role_name, inline_policy_name self._owner_role_name, inline_policy_name
) )
role_policies.append(inline_policy) role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies( attached_policies, _ = self.backend.list_attached_role_policies(
self._owner_role_name self._owner_role_name
) )
role_policies += attached_policies role_policies += attached_policies

View File

@ -19,6 +19,7 @@ from moto.core import BaseBackend, BaseModel, get_account_id, CloudFormationMode
from moto.core.utils import ( from moto.core.utils import (
iso_8601_datetime_without_milliseconds, iso_8601_datetime_without_milliseconds,
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
BackendDict,
) )
from moto.iam.policy_validation import IAMPolicyDocumentValidator from moto.iam.policy_validation import IAMPolicyDocumentValidator
from moto.utilities.utils import md5_hash from moto.utilities.utils import md5_hash
@ -362,7 +363,7 @@ class ManagedPolicy(Policy, CloudFormationModel):
role_names = properties.get("Roles", []) role_names = properties.get("Roles", [])
tags = properties.get("Tags", {}) tags = properties.get("Tags", {})
policy = iam_backend.create_policy( policy = iam_backends["global"].create_policy(
description=description, description=description,
path=path, path=path,
policy_document=policy_document, policy_document=policy_document,
@ -370,13 +371,17 @@ class ManagedPolicy(Policy, CloudFormationModel):
tags=tags, tags=tags,
) )
for group_name in group_names: for group_name in group_names:
iam_backend.attach_group_policy( iam_backends["global"].attach_group_policy(
group_name=group_name, policy_arn=policy.arn group_name=group_name, policy_arn=policy.arn
) )
for user_name in user_names: for user_name in user_names:
iam_backend.attach_user_policy(user_name=user_name, policy_arn=policy.arn) iam_backends["global"].attach_user_policy(
user_name=user_name, policy_arn=policy.arn
)
for role_name in role_names: for role_name in role_names:
iam_backend.attach_role_policy(role_name=role_name, policy_arn=policy.arn) iam_backends["global"].attach_role_policy(
role_name=role_name, policy_arn=policy.arn
)
return policy return policy
@property @property
@ -466,7 +471,7 @@ class InlinePolicy(CloudFormationModel):
role_names = properties.get("Roles") role_names = properties.get("Roles")
group_names = properties.get("Groups") group_names = properties.get("Groups")
return iam_backend.create_inline_policy( return iam_backends["global"].create_inline_policy(
resource_name, resource_name,
policy_name, policy_name,
policy_document, policy_document,
@ -502,7 +507,7 @@ class InlinePolicy(CloudFormationModel):
role_names = properties.get("Roles") role_names = properties.get("Roles")
group_names = properties.get("Groups") group_names = properties.get("Groups")
return iam_backend.update_inline_policy( return iam_backends["global"].update_inline_policy(
original_resource.name, original_resource.name,
policy_name, policy_name,
policy_document, policy_document,
@ -515,7 +520,7 @@ class InlinePolicy(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
iam_backend.delete_inline_policy(resource_name) iam_backends["global"].delete_inline_policy(resource_name)
@staticmethod @staticmethod
def is_replacement_update(properties): def is_replacement_update(properties):
@ -606,7 +611,7 @@ class Role(CloudFormationModel):
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
role_name = properties.get("RoleName", resource_name) role_name = properties.get("RoleName", resource_name)
role = iam_backend.create_role( role = iam_backends["global"].create_role(
role_name=role_name, role_name=role_name,
assume_role_policy_document=properties["AssumeRolePolicyDocument"], assume_role_policy_document=properties["AssumeRolePolicyDocument"],
path=properties.get("Path", "/"), path=properties.get("Path", "/"),
@ -628,14 +633,14 @@ class Role(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
for profile in iam_backend.instance_profiles.values(): for profile in iam_backends["global"].instance_profiles.values():
profile.delete_role(role_name=resource_name) profile.delete_role(role_name=resource_name)
for role in iam_backend.roles.values(): for role in iam_backends["global"].roles.values():
if role.name == resource_name: if role.name == resource_name:
for arn in role.policies.keys(): for arn in role.policies.keys():
role.delete_policy(arn) role.delete_policy(arn)
iam_backend.delete_role(resource_name) iam_backends["global"].delete_role(resource_name)
@property @property
def arn(self): def arn(self):
@ -649,7 +654,10 @@ class Role(CloudFormationModel):
_managed_policies = [] _managed_policies = []
for key in self.managed_policies.keys(): for key in self.managed_policies.keys():
_managed_policies.append( _managed_policies.append(
{"policyArn": key, "policyName": iam_backend.managed_policies[key].name} {
"policyArn": key,
"policyName": iam_backends["global"].managed_policies[key].name,
}
) )
_role_policy_list = [] _role_policy_list = []
@ -659,7 +667,7 @@ class Role(CloudFormationModel):
) )
_instance_profiles = [] _instance_profiles = []
for key, instance_profile in iam_backend.instance_profiles.items(): for key, instance_profile in iam_backends["global"].instance_profiles.items():
for _ in instance_profile.roles: for _ in instance_profile.roles:
_instance_profiles.append(instance_profile.to_embedded_config_dict()) _instance_profiles.append(instance_profile.to_embedded_config_dict())
break break
@ -808,7 +816,7 @@ class InstanceProfile(CloudFormationModel):
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
role_names = properties["Roles"] role_names = properties["Roles"]
return iam_backend.create_instance_profile( return iam_backends["global"].create_instance_profile(
name=resource_name, name=resource_name,
path=properties.get("Path", "/"), path=properties.get("Path", "/"),
role_names=role_names, role_names=role_names,
@ -818,7 +826,7 @@ class InstanceProfile(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
iam_backend.delete_instance_profile(resource_name) iam_backends["global"].delete_instance_profile(resource_name)
def delete_role(self, role_name): def delete_role(self, role_name):
self.roles = [role for role in self.roles if role.name != role_name] self.roles = [role for role in self.roles if role.name != role_name]
@ -964,7 +972,7 @@ class AccessKey(CloudFormationModel):
user_name = properties.get("UserName") user_name = properties.get("UserName")
status = properties.get("Status", "Active") status = properties.get("Status", "Active")
return iam_backend.create_access_key(user_name, status=status) return iam_backends["global"].create_access_key(user_name, status=status)
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json(
@ -984,7 +992,7 @@ class AccessKey(CloudFormationModel):
else: # No Interruption else: # No Interruption
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
status = properties.get("Status") status = properties.get("Status")
return iam_backend.update_access_key( return iam_backends["global"].update_access_key(
original_resource.user_name, original_resource.access_key_id, status original_resource.user_name, original_resource.access_key_id, status
) )
@ -992,7 +1000,7 @@ class AccessKey(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
iam_backend.delete_access_key_by_name(resource_name) iam_backends["global"].delete_access_key_by_name(resource_name)
@staticmethod @staticmethod
def is_replacement_update(properties): def is_replacement_update(properties):
@ -1303,7 +1311,7 @@ class User(CloudFormationModel):
): ):
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
path = properties.get("Path") path = properties.get("Path")
user, _ = iam_backend.create_user(resource_name, path) user, _ = iam_backends["global"].create_user(resource_name, path)
return user return user
@classmethod @classmethod
@ -1334,7 +1342,7 @@ class User(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
iam_backend.delete_user(resource_name) iam_backends["global"].delete_user(resource_name)
@staticmethod @staticmethod
def is_replacement_update(properties): def is_replacement_update(properties):
@ -2043,7 +2051,7 @@ class IAMBackend(BaseBackend):
instance_profile_id = random_resource_id() instance_profile_id = random_resource_id()
roles = [iam_backend.get_role(role_name) for role_name in role_names] roles = [iam_backends["global"].get_role(role_name) for role_name in role_names]
instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags) instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags)
self.instance_profiles[name] = instance_profile self.instance_profiles[name] = instance_profile
return instance_profile return instance_profile
@ -2838,12 +2846,10 @@ class IAMBackend(BaseBackend):
return inline_policy return inline_policy
def get_inline_policy(self, policy_id): def get_inline_policy(self, policy_id):
inline_policy = None
try: try:
inline_policy = self.inline_policies[policy_id] return self.inline_policies[policy_id]
except KeyError: except KeyError:
raise IAMNotFoundException("Inline policy {0} not found".format(policy_id)) raise IAMNotFoundException("Inline policy {0} not found".format(policy_id))
return inline_policy
def update_inline_policy( def update_inline_policy(
self, self,
@ -2924,4 +2930,6 @@ class IAMBackend(BaseBackend):
return True return True
iam_backend = IAMBackend("global") iam_backends = BackendDict(
IAMBackend, "iam", use_boto3_regions=False, additional_regions=["global"]
)

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,7 @@ from moto.logs.exceptions import (
InvalidParameterException, InvalidParameterException,
LimitExceededException, LimitExceededException,
) )
from moto.s3.models import s3_backend from moto.s3.models import s3_backends
from .utils import PAGINATION_MODEL from .utils import PAGINATION_MODEL
MAX_RESOURCE_POLICIES_PER_REGION = 10 MAX_RESOURCE_POLICIES_PER_REGION = 10
@ -940,7 +940,7 @@ class LogsBackend(BaseBackend):
return query_id return query_id
def create_export_task(self, log_group_name, destination): def create_export_task(self, log_group_name, destination):
s3_backend.get_bucket(destination) s3_backends["global"].get_bucket(destination)
if log_group_name not in self.groups: if log_group_name not in self.groups:
raise ResourceNotFoundException() raise ResourceNotFoundException()
task_id = uuid.uuid4() task_id = uuid.uuid4()

View File

@ -5,7 +5,6 @@ from moto.core.responses import BaseResponse
from .exceptions import exception_handler from .exceptions import exception_handler
from .models import managedblockchain_backends from .models import managedblockchain_backends
from .utils import ( from .utils import (
region_from_managedblckchain_url,
networkid_from_managedblockchain_url, networkid_from_managedblockchain_url,
proposalid_from_managedblockchain_url, proposalid_from_managedblockchain_url,
invitationid_from_managedblockchain_url, invitationid_from_managedblockchain_url,
@ -15,29 +14,21 @@ from .utils import (
class ManagedBlockchainResponse(BaseResponse): class ManagedBlockchainResponse(BaseResponse):
def __init__(self, backend): @property
super().__init__() def backend(self):
self.backend = backend return managedblockchain_backends[self.region]
@classmethod
@exception_handler @exception_handler
def network_response(clazz, request, full_url, headers): def network_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._network_response(request, headers)
managedblockchain_backends[region_name]
)
return response_instance._network_response(request, headers)
def _network_response(self, request, headers): def _network_response(self, request, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
if method == "GET": if method == "GET":
return self._all_networks_response(headers) return self._all_networks_response(headers)
elif method == "POST": elif method == "POST":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._network_response_post(json_body, headers) return self._network_response_post(json_body, headers)
def _all_networks_response(self, headers): def _all_networks_response(self, headers):
@ -70,14 +61,10 @@ class ManagedBlockchainResponse(BaseResponse):
) )
return 200, headers, json.dumps(response) return 200, headers, json.dumps(response)
@classmethod
@exception_handler @exception_handler
def networkid_response(clazz, request, full_url, headers): def networkid_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._networkid_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._networkid_response(request, full_url, headers)
def _networkid_response(self, request, full_url, headers): def _networkid_response(self, request, full_url, headers):
method = request.method method = request.method
@ -92,26 +79,18 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, response return 200, headers, response
@classmethod
@exception_handler @exception_handler
def proposal_response(clazz, request, full_url, headers): def proposal_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._proposal_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._proposal_response(request, full_url, headers)
def _proposal_response(self, request, full_url, headers): def _proposal_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
if method == "GET": if method == "GET":
return self._all_proposals_response(network_id, headers) return self._all_proposals_response(network_id, headers)
elif method == "POST": elif method == "POST":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._proposal_response_post(network_id, json_body, headers) return self._proposal_response_post(network_id, json_body, headers)
def _all_proposals_response(self, network_id, headers): def _all_proposals_response(self, network_id, headers):
@ -134,14 +113,10 @@ class ManagedBlockchainResponse(BaseResponse):
) )
return 200, headers, json.dumps(response) return 200, headers, json.dumps(response)
@classmethod
@exception_handler @exception_handler
def proposalid_response(clazz, request, full_url, headers): def proposalid_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._proposalid_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._proposalid_response(request, full_url, headers)
def _proposalid_response(self, request, full_url, headers): def _proposalid_response(self, request, full_url, headers):
method = request.method method = request.method
@ -156,27 +131,19 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, response return 200, headers, response
@classmethod
@exception_handler @exception_handler
def proposal_votes_response(clazz, request, full_url, headers): def proposal_votes_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._proposal_votes_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._proposal_votes_response(request, full_url, headers)
def _proposal_votes_response(self, request, full_url, headers): def _proposal_votes_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
proposal_id = proposalid_from_managedblockchain_url(full_url) proposal_id = proposalid_from_managedblockchain_url(full_url)
if method == "GET": if method == "GET":
return self._all_proposal_votes_response(network_id, proposal_id, headers) return self._all_proposal_votes_response(network_id, proposal_id, headers)
elif method == "POST": elif method == "POST":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._proposal_votes_response_post( return self._proposal_votes_response_post(
network_id, proposal_id, json_body, headers network_id, proposal_id, json_body, headers
) )
@ -196,14 +163,10 @@ class ManagedBlockchainResponse(BaseResponse):
self.backend.vote_on_proposal(network_id, proposal_id, votermemberid, vote) self.backend.vote_on_proposal(network_id, proposal_id, votermemberid, vote)
return 200, headers, "" return 200, headers, ""
@classmethod
@exception_handler @exception_handler
def invitation_response(clazz, request, full_url, headers): def invitation_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._invitation_response(request, headers)
managedblockchain_backends[region_name]
)
return response_instance._invitation_response(request, headers)
def _invitation_response(self, request, headers): def _invitation_response(self, request, headers):
method = request.method method = request.method
@ -218,14 +181,10 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, response return 200, headers, response
@classmethod
@exception_handler @exception_handler
def invitationid_response(clazz, request, full_url, headers): def invitationid_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._invitationid_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._invitationid_response(request, full_url, headers)
def _invitationid_response(self, request, full_url, headers): def _invitationid_response(self, request, full_url, headers):
method = request.method method = request.method
@ -238,26 +197,18 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, "" return 200, headers, ""
@classmethod
@exception_handler @exception_handler
def member_response(clazz, request, full_url, headers): def member_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._member_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._member_response(request, full_url, headers)
def _member_response(self, request, full_url, headers): def _member_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
if method == "GET": if method == "GET":
return self._all_members_response(network_id, headers) return self._all_members_response(network_id, headers)
elif method == "POST": elif method == "POST":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._member_response_post(network_id, json_body, headers) return self._member_response_post(network_id, json_body, headers)
def _all_members_response(self, network_id, headers): def _all_members_response(self, network_id, headers):
@ -275,27 +226,19 @@ class ManagedBlockchainResponse(BaseResponse):
) )
return 200, headers, json.dumps(response) return 200, headers, json.dumps(response)
@classmethod
@exception_handler @exception_handler
def memberid_response(clazz, request, full_url, headers): def memberid_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._memberid_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._memberid_response(request, full_url, headers)
def _memberid_response(self, request, full_url, headers): def _memberid_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body) member_id = memberid_from_managedblockchain_request(full_url, self.body)
if method == "GET": if method == "GET":
return self._memberid_response_get(network_id, member_id, headers) return self._memberid_response_get(network_id, member_id, headers)
elif method == "PATCH": elif method == "PATCH":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._memberid_response_patch( return self._memberid_response_patch(
network_id, member_id, json_body, headers network_id, member_id, json_body, headers
) )
@ -318,32 +261,24 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json" headers["content-type"] = "application/json"
return 200, headers, "" return 200, headers, ""
@classmethod
@exception_handler @exception_handler
def node_response(clazz, request, full_url, headers): def node_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._node_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._node_response(request, full_url, headers)
def _node_response(self, request, full_url, headers): def _node_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True) querystring = parse_qs(parsed_url.query, keep_blank_values=True)
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body) member_id = memberid_from_managedblockchain_request(full_url, self.body)
if method == "GET": if method == "GET":
status = None status = None
if "status" in querystring: if "status" in querystring:
status = querystring["status"][0] status = querystring["status"][0]
return self._all_nodes_response(network_id, member_id, status, headers) return self._all_nodes_response(network_id, member_id, status, headers)
elif method == "POST": elif method == "POST":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._node_response_post(network_id, member_id, json_body, headers) return self._node_response_post(network_id, member_id, json_body, headers)
def _all_nodes_response(self, network_id, member_id, status, headers): def _all_nodes_response(self, network_id, member_id, status, headers):
@ -368,28 +303,20 @@ class ManagedBlockchainResponse(BaseResponse):
) )
return 200, headers, json.dumps(response) return 200, headers, json.dumps(response)
@classmethod
@exception_handler @exception_handler
def nodeid_response(clazz, request, full_url, headers): def nodeid_response(self, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url) self.setup_class(request, full_url, headers)
response_instance = ManagedBlockchainResponse( return self._nodeid_response(request, full_url, headers)
managedblockchain_backends[region_name]
)
return response_instance._nodeid_response(request, full_url, headers)
def _nodeid_response(self, request, full_url, headers): def _nodeid_response(self, request, full_url, headers):
method = request.method method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url) network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body) member_id = memberid_from_managedblockchain_request(full_url, self.body)
node_id = nodeid_from_managedblockchain_url(full_url) node_id = nodeid_from_managedblockchain_url(full_url)
if method == "GET": if method == "GET":
return self._nodeid_response_get(network_id, member_id, node_id, headers) return self._nodeid_response_get(network_id, member_id, node_id, headers)
elif method == "PATCH": elif method == "PATCH":
json_body = json.loads(body.decode("utf-8")) json_body = json.loads(self.body)
return self._nodeid_response_patch( return self._nodeid_response_patch(
network_id, member_id, node_id, json_body, headers network_id, member_id, node_id, json_body, headers
) )

View File

@ -3,19 +3,19 @@ from .responses import ManagedBlockchainResponse
url_bases = [r"https?://managedblockchain\.(.+)\.amazonaws.com"] url_bases = [r"https?://managedblockchain\.(.+)\.amazonaws.com"]
url_paths = { url_paths = {
"{0}/networks$": ManagedBlockchainResponse.network_response, "{0}/networks$": ManagedBlockchainResponse().network_response,
"{0}/networks/(?P<networkid>[^/.]+)$": ManagedBlockchainResponse.networkid_response, "{0}/networks/(?P<networkid>[^/.]+)$": ManagedBlockchainResponse().networkid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals$": ManagedBlockchainResponse.proposal_response, "{0}/networks/(?P<networkid>[^/.]+)/proposals$": ManagedBlockchainResponse().proposal_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)$": ManagedBlockchainResponse.proposalid_response, "{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)$": ManagedBlockchainResponse().proposalid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)/votes$": ManagedBlockchainResponse.proposal_votes_response, "{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)/votes$": ManagedBlockchainResponse().proposal_votes_response,
"{0}/invitations$": ManagedBlockchainResponse.invitation_response, "{0}/invitations$": ManagedBlockchainResponse().invitation_response,
"{0}/invitations/(?P<invitationid>[^/.]+)$": ManagedBlockchainResponse.invitationid_response, "{0}/invitations/(?P<invitationid>[^/.]+)$": ManagedBlockchainResponse().invitationid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members$": ManagedBlockchainResponse.member_response, "{0}/networks/(?P<networkid>[^/.]+)/members$": ManagedBlockchainResponse().member_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)$": ManagedBlockchainResponse.memberid_response, "{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)$": ManagedBlockchainResponse().memberid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes$": ManagedBlockchainResponse.node_response, "{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes?(?P<querys>[^/.]+)$": ManagedBlockchainResponse.node_response, "{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes?(?P<querys>[^/.]+)$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse.nodeid_response, "{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse().nodeid_response,
# >= botocore 1.19.41 (API change - memberId is now part of query-string or body) # >= botocore 1.19.41 (API change - memberId is now part of query-string or body)
"{0}/networks/(?P<networkid>[^/.]+)/nodes$": ManagedBlockchainResponse.node_response, "{0}/networks/(?P<networkid>[^/.]+)/nodes$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse.nodeid_response, "{0}/networks/(?P<networkid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse().nodeid_response,
} }

View File

@ -6,14 +6,6 @@ import string
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
def region_from_managedblckchain_url(url):
domain = urlparse(url).netloc
region = "us-east-1"
if "." in domain:
region = domain.split(".")[1]
return region
def networkid_from_managedblockchain_url(full_url): def networkid_from_managedblockchain_url(full_url):
id_search = re.search(r"\/n-[A-Z0-9]{26}", full_url, re.IGNORECASE) id_search = re.search(r"\/n-[A-Z0-9]{26}", full_url, re.IGNORECASE)
return_id = None return_id = None

View File

@ -121,7 +121,6 @@ class FakeKey(BaseModel, ManagedState):
lock_mode=None, lock_mode=None,
lock_legal_status=None, lock_legal_status=None,
lock_until=None, lock_until=None,
s3_backend=None,
): ):
ManagedState.__init__( ManagedState.__init__(
self, self,
@ -162,8 +161,6 @@ class FakeKey(BaseModel, ManagedState):
# Default metadata values # Default metadata values
self._metadata["Content-Type"] = "binary/octet-stream" self._metadata["Content-Type"] = "binary/octet-stream"
self.s3_backend = s3_backend
def safe_name(self, encoding_type=None): def safe_name(self, encoding_type=None):
if encoding_type == "url": if encoding_type == "url":
return urllib.parse.quote(self.name, safe="") return urllib.parse.quote(self.name, safe="")
@ -292,7 +289,7 @@ class FakeKey(BaseModel, ManagedState):
res["x-amz-object-lock-retain-until-date"] = self.lock_until res["x-amz-object-lock-retain-until-date"] = self.lock_until
if self.lock_mode: if self.lock_mode:
res["x-amz-object-lock-mode"] = self.lock_mode res["x-amz-object-lock-mode"] = self.lock_mode
tags = s3_backend.tagger.get_tag_dict_for_resource(self.arn) tags = s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn)
if tags: if tags:
res["x-amz-tagging-count"] = len(tags.keys()) res["x-amz-tagging-count"] = len(tags.keys())
@ -1228,13 +1225,13 @@ class FakeBucket(CloudFormationModel):
def create_from_cloudformation_json( def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name, **kwargs cls, resource_name, cloudformation_json, region_name, **kwargs
): ):
bucket = s3_backend.create_bucket(resource_name, region_name) bucket = s3_backends["global"].create_bucket(resource_name, region_name)
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
if "BucketEncryption" in properties: if "BucketEncryption" in properties:
bucket_encryption = cfn_to_api_encryption(properties["BucketEncryption"]) bucket_encryption = cfn_to_api_encryption(properties["BucketEncryption"])
s3_backend.put_bucket_encryption( s3_backends["global"].put_bucket_encryption(
bucket_name=resource_name, encryption=bucket_encryption bucket_name=resource_name, encryption=bucket_encryption
) )
@ -1264,7 +1261,7 @@ class FakeBucket(CloudFormationModel):
bucket_encryption = cfn_to_api_encryption( bucket_encryption = cfn_to_api_encryption(
properties["BucketEncryption"] properties["BucketEncryption"]
) )
s3_backend.put_bucket_encryption( s3_backends["global"].put_bucket_encryption(
bucket_name=original_resource.name, encryption=bucket_encryption bucket_name=original_resource.name, encryption=bucket_encryption
) )
return original_resource return original_resource
@ -1273,7 +1270,7 @@ class FakeBucket(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
s3_backend.delete_bucket(resource_name) s3_backends["global"].delete_bucket(resource_name)
def to_config_dict(self): def to_config_dict(self):
"""Return the AWS Config JSON format of this S3 bucket. """Return the AWS Config JSON format of this S3 bucket.
@ -1298,7 +1295,7 @@ class FakeBucket(CloudFormationModel):
"resourceCreationTime": str(self.creation_date), "resourceCreationTime": str(self.creation_date),
"relatedEvents": [], "relatedEvents": [],
"relationships": [], "relationships": [],
"tags": s3_backend.tagger.get_tag_dict_for_resource(self.arn), "tags": s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn),
"configuration": { "configuration": {
"name": self.name, "name": self.name,
"owner": {"id": OWNER}, "owner": {"id": OWNER},
@ -1449,7 +1446,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
@classmethod @classmethod
def get_cloudwatch_metrics(cls): def get_cloudwatch_metrics(cls):
metrics = [] metrics = []
for name, bucket in s3_backend.buckets.items(): for name, bucket in s3_backends["global"].buckets.items():
metrics.append( metrics.append(
MetricDatum( MetricDatum(
namespace="AWS/S3", namespace="AWS/S3",
@ -1700,7 +1697,6 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
lock_mode=lock_mode, lock_mode=lock_mode,
lock_legal_status=lock_legal_status, lock_legal_status=lock_legal_status,
lock_until=lock_until, lock_until=lock_until,
s3_backend=s3_backend,
) )
keys = [ keys = [
@ -2173,4 +2169,3 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
s3_backends = BackendDict( s3_backends = BackendDict(
S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"] S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"]
) )
s3_backend = s3_backends["global"]

View File

@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse
import xmltodict import xmltodict
from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin from moto.core.responses import BaseResponse
from moto.core.utils import path_url from moto.core.utils import path_url
from moto.core import get_account_id from moto.core import get_account_id
@ -51,7 +51,8 @@ from .exceptions import (
InvalidRange, InvalidRange,
LockNotEnabled, LockNotEnabled,
) )
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey from .models import s3_backends
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
from .utils import bucket_name_from_url, metadata_from_headers, parse_region_from_url from .utils import bucket_name_from_url, metadata_from_headers, parse_region_from_url
from xml.dom import minidom from xml.dom import minidom
@ -151,14 +152,10 @@ def is_delete_keys(request, path):
) )
class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): class S3Response(BaseResponse):
def __init__(self, backend): @property
super().__init__() def backend(self):
self.backend = backend return s3_backends["global"]
self.method = ""
self.path = ""
self.data = {}
self.headers = {}
@property @property
def should_autoescape(self): def should_autoescape(self):
@ -253,15 +250,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
@amzn_request_id @amzn_request_id
def bucket_response( def bucket_response(self, request, full_url, headers):
self, request, full_url, headers self.setup_class(request, full_url, headers, use_raw_body=True)
): # pylint: disable=unused-argument
self.method = request.method
self.path = self._get_path(request)
# Make a copy of request.headers because it's immutable
self.headers = dict(request.headers)
if "Host" not in self.headers:
self.headers["Host"] = urlparse(full_url).netloc
try: try:
response = self._bucket_response(request, full_url) response = self._bucket_response(request, full_url)
except S3ClientError as s3error: except S3ClientError as s3error:
@ -297,30 +287,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.data["BucketName"] = bucket_name self.data["BucketName"] = bucket_name
if hasattr(request, "body"):
# Boto
body = request.body
else:
# Flask server
body = request.data
if body is None:
body = b""
if isinstance(body, bytes):
body = body.decode("utf-8")
body = "{0}".format(body).encode("utf-8")
if method == "HEAD": if method == "HEAD":
return self._bucket_response_head(bucket_name, querystring) return self._bucket_response_head(bucket_name, querystring)
elif method == "GET": elif method == "GET":
return self._bucket_response_get(bucket_name, querystring) return self._bucket_response_get(bucket_name, querystring)
elif method == "PUT": elif method == "PUT":
return self._bucket_response_put( return self._bucket_response_put(
request, body, region_name, bucket_name, querystring request, region_name, bucket_name, querystring
) )
elif method == "DELETE": elif method == "DELETE":
return self._bucket_response_delete(bucket_name, querystring) return self._bucket_response_delete(bucket_name, querystring)
elif method == "POST": elif method == "POST":
return self._bucket_response_post(request, body, bucket_name) return self._bucket_response_post(request, bucket_name)
elif method == "OPTIONS": elif method == "OPTIONS":
return self._response_options(bucket_name) return self._response_options(bucket_name)
else: else:
@ -379,28 +357,26 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
for cors_rule in bucket.cors: for cors_rule in bucket.cors:
if cors_rule.allowed_methods is not None: if cors_rule.allowed_methods is not None:
self.headers["Access-Control-Allow-Methods"] = _to_string( self.response_headers["Access-Control-Allow-Methods"] = _to_string(
cors_rule.allowed_methods cors_rule.allowed_methods
) )
if cors_rule.allowed_origins is not None: if cors_rule.allowed_origins is not None:
self.headers["Access-Control-Allow-Origin"] = _to_string( self.response_headers["Access-Control-Allow-Origin"] = _to_string(
cors_rule.allowed_origins cors_rule.allowed_origins
) )
if cors_rule.allowed_headers is not None: if cors_rule.allowed_headers is not None:
self.headers["Access-Control-Allow-Headers"] = _to_string( self.response_headers["Access-Control-Allow-Headers"] = _to_string(
cors_rule.allowed_headers cors_rule.allowed_headers
) )
if cors_rule.exposed_headers is not None: if cors_rule.exposed_headers is not None:
self.headers["Access-Control-Expose-Headers"] = _to_string( self.response_headers["Access-Control-Expose-Headers"] = _to_string(
cors_rule.exposed_headers cors_rule.exposed_headers
) )
if cors_rule.max_age_seconds is not None: if cors_rule.max_age_seconds is not None:
self.headers["Access-Control-Max-Age"] = _to_string( self.response_headers["Access-Control-Max-Age"] = _to_string(
cors_rule.max_age_seconds cors_rule.max_age_seconds
) )
return self.headers
def _response_options(self, bucket_name): def _response_options(self, bucket_name):
# Return 200 with the headers from the bucket CORS configuration # Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -415,7 +391,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self._set_cors_headers(bucket) self._set_cors_headers(bucket)
return 200, self.headers, "" return 200, self.response_headers, ""
def _bucket_response_get(self, bucket_name, querystring): def _bucket_response_get(self, bucket_name, querystring):
self._set_action("BUCKET", "GET", querystring) self._set_action("BUCKET", "GET", querystring)
@ -728,15 +704,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
pass pass
return False return False
def _parse_pab_config(self, body): def _parse_pab_config(self):
parsed_xml = xmltodict.parse(body) parsed_xml = xmltodict.parse(self.body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml return parsed_xml
def _bucket_response_put( def _bucket_response_put(self, request, region_name, bucket_name, querystring):
self, request, body, region_name, bucket_name, querystring
):
if not request.headers.get("Content-Length"): if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
@ -744,8 +718,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
if "object-lock" in querystring: if "object-lock" in querystring:
body_decoded = body.decode() config = self._lock_config_from_body()
config = self._lock_config_from_xml(body_decoded)
if not self.backend.get_bucket(bucket_name).object_lock_enabled: if not self.backend.get_bucket(bucket_name).object_lock_enabled:
raise BucketMustHaveLockeEnabled raise BucketMustHaveLockeEnabled
@ -760,7 +733,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 200, {}, "" return 200, {}, ""
if "versioning" in querystring: if "versioning" in querystring:
ver = re.search("<Status>([A-Za-z]+)</Status>", body.decode()) body = self.body.decode("utf-8")
ver = re.search(r"<Status>([A-Za-z]+)</Status>", body)
if ver: if ver:
self.backend.set_bucket_versioning(bucket_name, ver.group(1)) self.backend.set_bucket_versioning(bucket_name, ver.group(1))
template = self.response_template(S3_BUCKET_VERSIONING) template = self.response_template(S3_BUCKET_VERSIONING)
@ -768,47 +742,45 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else: else:
return 404, {}, "" return 404, {}, ""
elif "lifecycle" in querystring: elif "lifecycle" in querystring:
rules = xmltodict.parse(body)["LifecycleConfiguration"]["Rule"] rules = xmltodict.parse(self.body)["LifecycleConfiguration"]["Rule"]
if not isinstance(rules, list): if not isinstance(rules, list):
# If there is only one rule, xmldict returns just the item # If there is only one rule, xmldict returns just the item
rules = [rules] rules = [rules]
self.backend.put_bucket_lifecycle(bucket_name, rules) self.backend.put_bucket_lifecycle(bucket_name, rules)
return "" return ""
elif "policy" in querystring: elif "policy" in querystring:
self.backend.put_bucket_policy(bucket_name, body) self.backend.put_bucket_policy(bucket_name, self.body)
return "True" return "True"
elif "acl" in querystring: elif "acl" in querystring:
# Headers are first. If not set, then look at the body (consistent with the documentation): # Headers are first. If not set, then look at the body (consistent with the documentation):
acls = self._acl_from_headers(request.headers) acls = self._acl_from_headers(request.headers)
if not acls: if not acls:
acls = self._acl_from_xml(body) acls = self._acl_from_body()
self.backend.put_bucket_acl(bucket_name, acls) self.backend.put_bucket_acl(bucket_name, acls)
return "" return ""
elif "tagging" in querystring: elif "tagging" in querystring:
tagging = self._bucket_tagging_from_xml(body) tagging = self._bucket_tagging_from_body()
self.backend.put_bucket_tagging(bucket_name, tagging) self.backend.put_bucket_tagging(bucket_name, tagging)
return "" return ""
elif "website" in querystring: elif "website" in querystring:
self.backend.set_bucket_website_configuration(bucket_name, body) self.backend.set_bucket_website_configuration(bucket_name, self.body)
return "" return ""
elif "cors" in querystring: elif "cors" in querystring:
try: try:
self.backend.put_bucket_cors(bucket_name, self._cors_from_xml(body)) self.backend.put_bucket_cors(bucket_name, self._cors_from_body())
return "" return ""
except KeyError: except KeyError:
raise MalformedXML() raise MalformedXML()
elif "logging" in querystring: elif "logging" in querystring:
try: try:
self.backend.put_bucket_logging( self.backend.put_bucket_logging(bucket_name, self._logging_from_body())
bucket_name, self._logging_from_xml(body)
)
return "" return ""
except KeyError: except KeyError:
raise MalformedXML() raise MalformedXML()
elif "notification" in querystring: elif "notification" in querystring:
try: try:
self.backend.put_bucket_notification_configuration( self.backend.put_bucket_notification_configuration(
bucket_name, self._notification_config_from_xml(body) bucket_name, self._notification_config_from_body()
) )
return "" return ""
except KeyError: except KeyError:
@ -817,7 +789,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e raise e
elif "accelerate" in querystring: elif "accelerate" in querystring:
try: try:
accelerate_status = self._accelerate_config_from_xml(body) accelerate_status = self._accelerate_config_from_body()
self.backend.put_bucket_accelerate_configuration( self.backend.put_bucket_accelerate_configuration(
bucket_name, accelerate_status bucket_name, accelerate_status
) )
@ -828,7 +800,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e raise e
elif "publicAccessBlock" in querystring: elif "publicAccessBlock" in querystring:
pab_config = self._parse_pab_config(body) pab_config = self._parse_pab_config()
self.backend.put_bucket_public_access_block( self.backend.put_bucket_public_access_block(
bucket_name, pab_config["PublicAccessBlockConfiguration"] bucket_name, pab_config["PublicAccessBlockConfiguration"]
) )
@ -836,7 +808,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
elif "encryption" in querystring: elif "encryption" in querystring:
try: try:
self.backend.put_bucket_encryption( self.backend.put_bucket_encryption(
bucket_name, self._encryption_config_from_xml(body) bucket_name, self._encryption_config_from_body()
) )
return "" return ""
except KeyError: except KeyError:
@ -848,7 +820,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if not bucket.is_versioned: if not bucket.is_versioned:
template = self.response_template(S3_NO_VERSIONING_ENABLED) template = self.response_template(S3_NO_VERSIONING_ENABLED)
return 400, {}, template.render(bucket_name=bucket_name) return 400, {}, template.render(bucket_name=bucket_name)
replication_config = self._replication_config_from_xml(body) replication_config = self._replication_config_from_xml(self.body)
self.backend.put_bucket_replication(bucket_name, replication_config) self.backend.put_bucket_replication(bucket_name, replication_config)
return "" return ""
else: else:
@ -858,17 +830,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# - LocationConstraint has to be specified if outside us-east-1 # - LocationConstraint has to be specified if outside us-east-1
if ( if (
region_name != DEFAULT_REGION_NAME region_name != DEFAULT_REGION_NAME
and not self._body_contains_location_constraint(body) and not self._body_contains_location_constraint(self.body)
): ):
raise IllegalLocationConstraintException() raise IllegalLocationConstraintException()
if body: if self.body:
if self._create_bucket_configuration_is_empty(body): if self._create_bucket_configuration_is_empty(self.body):
raise MalformedXML() raise MalformedXML()
try: try:
forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ forced_region = xmltodict.parse(self.body)[
"LocationConstraint" "CreateBucketConfiguration"
] ]["LocationConstraint"]
if forced_region == DEFAULT_REGION_NAME: if forced_region == DEFAULT_REGION_NAME:
raise S3ClientError( raise S3ClientError(
@ -950,21 +922,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket) return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name): def _bucket_response_post(self, request, bucket_name):
response_headers = {} response_headers = {}
if not request.headers.get("Content-Length"): if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
path = self._get_path(request) self.path = self._get_path(request)
if self.is_delete_keys(request, path, bucket_name): if self.is_delete_keys(request, self.path, bucket_name):
self.data["Action"] = "DeleteObject" self.data["Action"] = "DeleteObject"
try: try:
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
return self._bucket_response_delete_keys(body, bucket_name) return self._bucket_response_delete_keys(bucket_name)
except BucketAccessDeniedError: except BucketAccessDeniedError:
return self._bucket_response_delete_keys( return self._bucket_response_delete_keys(
body, bucket_name, authenticated=False bucket_name, authenticated=False
) )
self.data["Action"] = "PutObject" self.data["Action"] = "PutObject"
@ -1027,9 +999,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else path_url(request.url) else path_url(request.url)
) )
def _bucket_response_delete_keys(self, body, bucket_name, authenticated=True): def _bucket_response_delete_keys(self, bucket_name, authenticated=True):
template = self.response_template(S3_DELETE_KEYS_RESPONSE) template = self.response_template(S3_DELETE_KEYS_RESPONSE)
body_dict = xmltodict.parse(body, strip_whitespace=False) body_dict = xmltodict.parse(self.body, strip_whitespace=False)
objects = body_dict["Delete"].get("Object", []) objects = body_dict["Delete"].get("Object", [])
if not isinstance(objects, list): if not isinstance(objects, list):
@ -1098,16 +1070,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return bytes(new_body) return bytes(new_body)
@amzn_request_id @amzn_request_id
def key_response( def key_response(self, request, full_url, headers):
self, request, full_url, headers
): # pylint: disable=unused-argument
# Key and Control are lumped in because splitting out the regex is too much of a pain :/ # Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.method = request.method self.setup_class(request, full_url, headers, use_raw_body=True)
self.path = self._get_path(request)
# Make a copy of request.headers because it's immutable
self.headers = dict(request.headers)
if "Host" not in self.headers:
self.headers["Host"] = urlparse(full_url).netloc
response_headers = {} response_headers = {}
try: try:
@ -1300,7 +1265,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 304, response_headers, "Not Modified" return 304, response_headers, "Not Modified"
if "acl" in query: if "acl" in query:
acl = s3_backend.get_object_acl(key) acl = self.backend.get_object_acl(key)
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(acl=acl) return 200, response_headers, template.render(acl=acl)
if "tagging" in query: if "tagging" in query:
@ -1411,7 +1376,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if not lock_enabled: if not lock_enabled:
raise LockNotEnabled raise LockNotEnabled
version_id = query.get("VersionId") version_id = query.get("VersionId")
retention = self._mode_until_from_xml(body) retention = self._mode_until_from_body()
self.backend.put_object_retention( self.backend.put_object_retention(
bucket_name, key_name, version_id=version_id, retention=retention bucket_name, key_name, version_id=version_id, retention=retention
) )
@ -1573,9 +1538,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else: else:
return 404, response_headers, "" return 404, response_headers, ""
def _lock_config_from_xml(self, xml): def _lock_config_from_body(self):
response_dict = {"enabled": False, "mode": None, "days": None, "years": None} response_dict = {"enabled": False, "mode": None, "days": None, "years": None}
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
enabled = ( enabled = (
parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled" parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled"
) )
@ -1596,8 +1561,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return response_dict return response_dict
def _acl_from_xml(self, xml): def _acl_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
if not parsed_xml.get("AccessControlPolicy"): if not parsed_xml.get("AccessControlPolicy"):
raise MalformedACLError() raise MalformedACLError()
@ -1713,8 +1678,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return tags return tags
def _bucket_tagging_from_xml(self, xml): def _bucket_tagging_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
tags = {} tags = {}
# Optional if no tags are being sent: # Optional if no tags are being sent:
@ -1737,16 +1702,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return tags return tags
def _cors_from_xml(self, xml): def _cors_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list): if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list):
return [cors for cors in parsed_xml["CORSConfiguration"]["CORSRule"]] return [cors for cors in parsed_xml["CORSConfiguration"]["CORSRule"]]
return [parsed_xml["CORSConfiguration"]["CORSRule"]] return [parsed_xml["CORSConfiguration"]["CORSRule"]]
def _mode_until_from_xml(self, xml): def _mode_until_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
return ( return (
parsed_xml.get("Retention", None).get("Mode", None), parsed_xml.get("Retention", None).get("Mode", None),
parsed_xml.get("Retention", None).get("RetainUntilDate", None), parsed_xml.get("Retention", None).get("RetainUntilDate", None),
@ -1756,8 +1721,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(xml)
return parsed_xml["LegalHold"]["Status"] return parsed_xml["LegalHold"]["Status"]
def _encryption_config_from_xml(self, xml): def _encryption_config_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
if ( if (
not parsed_xml["ServerSideEncryptionConfiguration"].get("Rule") not parsed_xml["ServerSideEncryptionConfiguration"].get("Rule")
@ -1772,8 +1737,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["ServerSideEncryptionConfiguration"] return parsed_xml["ServerSideEncryptionConfiguration"]
def _logging_from_xml(self, xml): def _logging_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"): if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"):
return {} return {}
@ -1817,8 +1782,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]
def _notification_config_from_xml(self, xml): def _notification_config_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
if not len(parsed_xml["NotificationConfiguration"]): if not len(parsed_xml["NotificationConfiguration"]):
return {} return {}
@ -1892,8 +1857,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["NotificationConfiguration"] return parsed_xml["NotificationConfiguration"]
def _accelerate_config_from_xml(self, xml): def _accelerate_config_from_body(self):
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(self.body)
config = parsed_xml["AccelerateConfiguration"] config = parsed_xml["AccelerateConfiguration"]
return config["Status"] return config["Status"]
@ -2028,7 +1993,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return False return False
S3ResponseInstance = ResponseObject(s3_backend) S3ResponseInstance = S3Response()
S3_ALL_BUCKETS = """<ListAllMyBucketsResult xmlns="http://s3.amazonaws.com/doc/2006-03-01"> S3_ALL_BUCKETS = """<ListAllMyBucketsResult xmlns="http://s3.amazonaws.com/doc/2006-03-01">
<Owner> <Owner>

View File

@ -1,6 +1,5 @@
"""s3control module initialization; sets value for base decorator.""" """s3control module initialization; sets value for base decorator."""
from .models import s3control_backend from .models import s3control_backends
from ..core.models import base_decorator from ..core.models import base_decorator
s3control_backends = {"global": s3control_backend}
mock_s3control = base_decorator(s3control_backends) mock_s3control = base_decorator(s3control_backends)

View File

@ -1,7 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from moto.core import get_account_id, BaseBackend, BaseModel from moto.core import get_account_id, BaseBackend, BaseModel
from moto.core.utils import get_random_hex from moto.core.utils import get_random_hex, BackendDict
from moto.s3.exceptions import ( from moto.s3.exceptions import (
WrongPublicAccessBlockAccountIdError, WrongPublicAccessBlockAccountIdError,
NoSuchPublicAccessBlockConfiguration, NoSuchPublicAccessBlockConfiguration,
@ -43,16 +43,11 @@ class AccessPoint(BaseModel):
class S3ControlBackend(BaseBackend): class S3ControlBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name, account_id):
self.region_name = region_name super().__init__(region_name, account_id)
self.public_access_block = None self.public_access_block = None
self.access_points = defaultdict(dict) self.access_points = defaultdict(dict)
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def get_public_access_block(self, account_id): def get_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto: # The account ID should equal the account id that is set for Moto:
if account_id != get_account_id(): if account_id != get_account_id():
@ -129,4 +124,9 @@ class S3ControlBackend(BaseBackend):
return True return True
s3control_backend = S3ControlBackend() s3control_backends = BackendDict(
S3ControlBackend,
"s3control",
use_boto3_regions=False,
additional_regions=["global"],
)

View File

@ -5,10 +5,14 @@ from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from moto.s3.exceptions import S3ClientError from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from .models import s3control_backend from .models import s3control_backends
class S3ControlResponse(BaseResponse): class S3ControlResponse(BaseResponse):
@property
def backend(self):
return s3control_backends["global"]
@amzn_request_id @amzn_request_id
def public_access_block( def public_access_block(
self, request, full_url, headers self, request, full_url, headers
@ -25,7 +29,7 @@ class S3ControlResponse(BaseResponse):
def get_public_access_block(self, request): def get_public_access_block(self, request):
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
public_block_config = s3control_backend.get_public_access_block( public_block_config = self.backend.get_public_access_block(
account_id=account_id account_id=account_id
) )
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION) template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
@ -35,14 +39,14 @@ class S3ControlResponse(BaseResponse):
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
data = request.body if hasattr(request, "body") else request.data data = request.body if hasattr(request, "body") else request.data
pab_config = self._parse_pab_config(data) pab_config = self._parse_pab_config(data)
s3control_backend.put_public_access_block( self.backend.put_public_access_block(
account_id, pab_config["PublicAccessBlockConfiguration"] account_id, pab_config["PublicAccessBlockConfiguration"]
) )
return 201, {}, json.dumps({}) return 201, {}, json.dumps({})
def delete_public_access_block(self, request): def delete_public_access_block(self, request):
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
s3control_backend.delete_public_access_block(account_id=account_id) self.backend.delete_public_access_block(account_id=account_id)
return 204, {}, json.dumps({}) return 204, {}, json.dumps({})
def _parse_pab_config(self, body): def _parse_pab_config(self, body):
@ -82,7 +86,7 @@ class S3ControlResponse(BaseResponse):
bucket = params["Bucket"] bucket = params["Bucket"]
vpc_configuration = params.get("VpcConfiguration") vpc_configuration = params.get("VpcConfiguration")
public_access_block_configuration = params.get("PublicAccessBlockConfiguration") public_access_block_configuration = params.get("PublicAccessBlockConfiguration")
access_point = s3control_backend.create_access_point( access_point = self.backend.create_access_point(
account_id=account_id, account_id=account_id,
name=name, name=name,
bucket=bucket, bucket=bucket,
@ -95,38 +99,36 @@ class S3ControlResponse(BaseResponse):
def get_access_point(self, full_url): def get_access_point(self, full_url):
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
access_point = s3control_backend.get_access_point( access_point = self.backend.get_access_point(account_id=account_id, name=name)
account_id=account_id, name=name
)
template = self.response_template(GET_ACCESS_POINT_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point) return 200, {}, template.render(access_point=access_point)
def delete_access_point(self, full_url): def delete_access_point(self, full_url):
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
s3control_backend.delete_access_point(account_id=account_id, name=name) self.backend.delete_access_point(account_id=account_id, name=name)
return 204, {}, "" return 204, {}, ""
def create_access_point_policy(self, full_url): def create_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
params = xmltodict.parse(self.body) params = xmltodict.parse(self.body)
policy = params["PutAccessPointPolicyRequest"]["Policy"] policy = params["PutAccessPointPolicyRequest"]["Policy"]
s3control_backend.create_access_point_policy(account_id, name, policy) self.backend.create_access_point_policy(account_id, name, policy)
return 200, {}, "" return 200, {}, ""
def get_access_point_policy(self, full_url): def get_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
policy = s3control_backend.get_access_point_policy(account_id, name) policy = self.backend.get_access_point_policy(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE)
return 200, {}, template.render(policy=policy) return 200, {}, template.render(policy=policy)
def delete_access_point_policy(self, full_url): def delete_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
s3control_backend.delete_access_point_policy(account_id=account_id, name=name) self.backend.delete_access_point_policy(account_id=account_id, name=name)
return 204, {}, "" return 204, {}, ""
def get_access_point_policy_status(self, full_url): def get_access_point_policy_status(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
s3control_backend.get_access_point_policy_status(account_id, name) self.backend.get_access_point_policy_status(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE)
return 200, {}, template.render() return 200, {}, template.render()

View File

@ -6,6 +6,7 @@ from email.mime.base import MIMEBase
from email.utils import parseaddr from email.utils import parseaddr
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.encoders import encode_7or8bit from email.encoders import encode_7or8bit
from typing import Mapping
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import BackendDict from moto.core.utils import BackendDict
@ -537,7 +538,6 @@ class SESBackend(BaseBackend):
return attributes_by_identity return attributes_by_identity
ses_backends = BackendDict( ses_backends: Mapping[str, SESBackend] = BackendDict(
SESBackend, "ses", use_boto3_regions=False, additional_regions=["global"] SESBackend, "ses", use_boto3_regions=False, additional_regions=["global"]
) )
ses_backend = ses_backends["global"]

View File

@ -1,48 +1,52 @@
import base64 import base64
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import ses_backend from .models import ses_backends
from datetime import datetime from datetime import datetime
class EmailResponse(BaseResponse): class EmailResponse(BaseResponse):
@property
def backend(self):
return ses_backends["global"]
def verify_email_identity(self): def verify_email_identity(self):
address = self.querystring.get("EmailAddress")[0] address = self.querystring.get("EmailAddress")[0]
ses_backend.verify_email_identity(address) self.backend.verify_email_identity(address)
template = self.response_template(VERIFY_EMAIL_IDENTITY) template = self.response_template(VERIFY_EMAIL_IDENTITY)
return template.render() return template.render()
def verify_email_address(self): def verify_email_address(self):
address = self.querystring.get("EmailAddress")[0] address = self.querystring.get("EmailAddress")[0]
ses_backend.verify_email_address(address) self.backend.verify_email_address(address)
template = self.response_template(VERIFY_EMAIL_ADDRESS) template = self.response_template(VERIFY_EMAIL_ADDRESS)
return template.render() return template.render()
def list_identities(self): def list_identities(self):
identities = ses_backend.list_identities() identities = self.backend.list_identities()
template = self.response_template(LIST_IDENTITIES_RESPONSE) template = self.response_template(LIST_IDENTITIES_RESPONSE)
return template.render(identities=identities) return template.render(identities=identities)
def list_verified_email_addresses(self): def list_verified_email_addresses(self):
email_addresses = ses_backend.list_verified_email_addresses() email_addresses = self.backend.list_verified_email_addresses()
template = self.response_template(LIST_VERIFIED_EMAIL_RESPONSE) template = self.response_template(LIST_VERIFIED_EMAIL_RESPONSE)
return template.render(email_addresses=email_addresses) return template.render(email_addresses=email_addresses)
def verify_domain_dkim(self): def verify_domain_dkim(self):
domain = self.querystring.get("Domain")[0] domain = self.querystring.get("Domain")[0]
ses_backend.verify_domain(domain) self.backend.verify_domain(domain)
template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE) template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE)
return template.render() return template.render()
def verify_domain_identity(self): def verify_domain_identity(self):
domain = self.querystring.get("Domain")[0] domain = self.querystring.get("Domain")[0]
ses_backend.verify_domain(domain) self.backend.verify_domain(domain)
template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE) template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE)
return template.render() return template.render()
def delete_identity(self): def delete_identity(self):
domain = self.querystring.get("Identity")[0] domain = self.querystring.get("Identity")[0]
ses_backend.delete_identity(domain) self.backend.delete_identity(domain)
template = self.response_template(DELETE_IDENTITY_RESPONSE) template = self.response_template(DELETE_IDENTITY_RESPONSE)
return template.render() return template.render()
@ -63,7 +67,7 @@ class EmailResponse(BaseResponse):
break break
destinations[dest_type].append(address[0]) destinations[dest_type].append(address[0])
message = ses_backend.send_email( message = self.backend.send_email(
source, subject, body, destinations, self.region source, subject, body, destinations, self.region
) )
template = self.response_template(SEND_EMAIL_RESPONSE) template = self.response_template(SEND_EMAIL_RESPONSE)
@ -84,7 +88,7 @@ class EmailResponse(BaseResponse):
break break
destinations[dest_type].append(address[0]) destinations[dest_type].append(address[0])
message = ses_backend.send_templated_email( message = self.backend.send_templated_email(
source, template, template_data, destinations, self.region source, template, template_data, destinations, self.region
) )
template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE) template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE)
@ -107,27 +111,27 @@ class EmailResponse(BaseResponse):
break break
destinations.append(address[0]) destinations.append(address[0])
message = ses_backend.send_raw_email( message = self.backend.send_raw_email(
source, destinations, raw_data, self.region source, destinations, raw_data, self.region
) )
template = self.response_template(SEND_RAW_EMAIL_RESPONSE) template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message) return template.render(message=message)
def get_send_quota(self): def get_send_quota(self):
quota = ses_backend.get_send_quota() quota = self.backend.get_send_quota()
template = self.response_template(GET_SEND_QUOTA_RESPONSE) template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota) return template.render(quota=quota)
def get_identity_notification_attributes(self): def get_identity_notification_attributes(self):
identities = self._get_params()["Identities"] identities = self._get_params()["Identities"]
identities = ses_backend.get_identity_notification_attributes(identities) identities = self.backend.get_identity_notification_attributes(identities)
template = self.response_template(GET_IDENTITY_NOTIFICATION_ATTRIBUTES) template = self.response_template(GET_IDENTITY_NOTIFICATION_ATTRIBUTES)
return template.render(identities=identities) return template.render(identities=identities)
def set_identity_feedback_forwarding_enabled(self): def set_identity_feedback_forwarding_enabled(self):
identity = self._get_param("Identity") identity = self._get_param("Identity")
enabled = self._get_bool_param("ForwardingEnabled") enabled = self._get_bool_param("ForwardingEnabled")
ses_backend.set_identity_feedback_forwarding_enabled(identity, enabled) self.backend.set_identity_feedback_forwarding_enabled(identity, enabled)
template = self.response_template(SET_IDENTITY_FORWARDING_ENABLED_RESPONSE) template = self.response_template(SET_IDENTITY_FORWARDING_ENABLED_RESPONSE)
return template.render() return template.render()
@ -139,18 +143,18 @@ class EmailResponse(BaseResponse):
if sns_topic: if sns_topic:
sns_topic = sns_topic[0] sns_topic = sns_topic[0]
ses_backend.set_identity_notification_topic(identity, not_type, sns_topic) self.backend.set_identity_notification_topic(identity, not_type, sns_topic)
template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE) template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
return template.render() return template.render()
def get_send_statistics(self): def get_send_statistics(self):
statistics = ses_backend.get_send_statistics() statistics = self.backend.get_send_statistics()
template = self.response_template(GET_SEND_STATISTICS) template = self.response_template(GET_SEND_STATISTICS)
return template.render(all_statistics=[statistics]) return template.render(all_statistics=[statistics])
def create_configuration_set(self): def create_configuration_set(self):
configuration_set_name = self.querystring.get("ConfigurationSet.Name")[0] configuration_set_name = self.querystring.get("ConfigurationSet.Name")[0]
ses_backend.create_configuration_set( self.backend.create_configuration_set(
configuration_set_name=configuration_set_name configuration_set_name=configuration_set_name
) )
template = self.response_template(CREATE_CONFIGURATION_SET) template = self.response_template(CREATE_CONFIGURATION_SET)
@ -177,7 +181,7 @@ class EmailResponse(BaseResponse):
"SNSDestination": event_topic_arn, "SNSDestination": event_topic_arn,
} }
ses_backend.create_configuration_set_event_destination( self.backend.create_configuration_set_event_destination(
configuration_set_name=configuration_set_name, configuration_set_name=configuration_set_name,
event_destination=event_destination, event_destination=event_destination,
) )
@ -193,7 +197,7 @@ class EmailResponse(BaseResponse):
template_info["template_name"] = template_data.get("._name", "") template_info["template_name"] = template_data.get("._name", "")
template_info["subject_part"] = template_data.get("._subject_part", "") template_info["subject_part"] = template_data.get("._subject_part", "")
template_info["Timestamp"] = datetime.utcnow() template_info["Timestamp"] = datetime.utcnow()
ses_backend.add_template(template_info=template_info) self.backend.add_template(template_info=template_info)
template = self.response_template(CREATE_TEMPLATE) template = self.response_template(CREATE_TEMPLATE)
return template.render() return template.render()
@ -205,44 +209,44 @@ class EmailResponse(BaseResponse):
template_info["template_name"] = template_data.get("._name", "") template_info["template_name"] = template_data.get("._name", "")
template_info["subject_part"] = template_data.get("._subject_part", "") template_info["subject_part"] = template_data.get("._subject_part", "")
template_info["Timestamp"] = datetime.utcnow() template_info["Timestamp"] = datetime.utcnow()
ses_backend.update_template(template_info=template_info) self.backend.update_template(template_info=template_info)
template = self.response_template(UPDATE_TEMPLATE) template = self.response_template(UPDATE_TEMPLATE)
return template.render() return template.render()
def get_template(self): def get_template(self):
template_name = self._get_param("TemplateName") template_name = self._get_param("TemplateName")
template_data = ses_backend.get_template(template_name) template_data = self.backend.get_template(template_name)
template = self.response_template(GET_TEMPLATE) template = self.response_template(GET_TEMPLATE)
return template.render(template_data=template_data) return template.render(template_data=template_data)
def list_templates(self): def list_templates(self):
email_templates = ses_backend.list_templates() email_templates = self.backend.list_templates()
template = self.response_template(LIST_TEMPLATES) template = self.response_template(LIST_TEMPLATES)
return template.render(templates=email_templates) return template.render(templates=email_templates)
def test_render_template(self): def test_render_template(self):
render_info = self._get_dict_param("Template") render_info = self._get_dict_param("Template")
rendered_template = ses_backend.render_template(render_info) rendered_template = self.backend.render_template(render_info)
template = self.response_template(RENDER_TEMPLATE) template = self.response_template(RENDER_TEMPLATE)
return template.render(template=rendered_template) return template.render(template=rendered_template)
def create_receipt_rule_set(self): def create_receipt_rule_set(self):
rule_set_name = self._get_param("RuleSetName") rule_set_name = self._get_param("RuleSetName")
ses_backend.create_receipt_rule_set(rule_set_name) self.backend.create_receipt_rule_set(rule_set_name)
template = self.response_template(CREATE_RECEIPT_RULE_SET) template = self.response_template(CREATE_RECEIPT_RULE_SET)
return template.render() return template.render()
def create_receipt_rule(self): def create_receipt_rule(self):
rule_set_name = self._get_param("RuleSetName") rule_set_name = self._get_param("RuleSetName")
rule = self._get_dict_param("Rule.") rule = self._get_dict_param("Rule.")
ses_backend.create_receipt_rule(rule_set_name, rule) self.backend.create_receipt_rule(rule_set_name, rule)
template = self.response_template(CREATE_RECEIPT_RULE) template = self.response_template(CREATE_RECEIPT_RULE)
return template.render() return template.render()
def describe_receipt_rule_set(self): def describe_receipt_rule_set(self):
rule_set_name = self._get_param("RuleSetName") rule_set_name = self._get_param("RuleSetName")
rule_set = ses_backend.describe_receipt_rule_set(rule_set_name) rule_set = self.backend.describe_receipt_rule_set(rule_set_name)
for i, rule in enumerate(rule_set): for i, rule in enumerate(rule_set):
formatted_rule = {} formatted_rule = {}
@ -260,7 +264,7 @@ class EmailResponse(BaseResponse):
rule_set_name = self._get_param("RuleSetName") rule_set_name = self._get_param("RuleSetName")
rule_name = self._get_param("RuleName") rule_name = self._get_param("RuleName")
receipt_rule = ses_backend.describe_receipt_rule(rule_set_name, rule_name) receipt_rule = self.backend.describe_receipt_rule(rule_set_name, rule_name)
rule = {} rule = {}
@ -274,7 +278,7 @@ class EmailResponse(BaseResponse):
rule_set_name = self._get_param("RuleSetName") rule_set_name = self._get_param("RuleSetName")
rule = self._get_dict_param("Rule.") rule = self._get_dict_param("Rule.")
ses_backend.update_receipt_rule(rule_set_name, rule) self.backend.update_receipt_rule(rule_set_name, rule)
template = self.response_template(UPDATE_RECEIPT_RULE) template = self.response_template(UPDATE_RECEIPT_RULE)
return template.render() return template.render()
@ -284,7 +288,7 @@ class EmailResponse(BaseResponse):
mail_from_domain = self._get_param("MailFromDomain") mail_from_domain = self._get_param("MailFromDomain")
behavior_on_mx_failure = self._get_param("BehaviorOnMXFailure") behavior_on_mx_failure = self._get_param("BehaviorOnMXFailure")
ses_backend.set_identity_mail_from_domain( self.backend.set_identity_mail_from_domain(
identity, mail_from_domain, behavior_on_mx_failure identity, mail_from_domain, behavior_on_mx_failure
) )
@ -293,7 +297,7 @@ class EmailResponse(BaseResponse):
def get_identity_mail_from_domain_attributes(self): def get_identity_mail_from_domain_attributes(self):
identities = self._get_multi_param("Identities.member.") identities = self._get_multi_param("Identities.member.")
identities = ses_backend.get_identity_mail_from_domain_attributes(identities) identities = self.backend.get_identity_mail_from_domain_attributes(identities)
template = self.response_template(GET_IDENTITY_MAIL_FROM_DOMAIN_ATTRIBUTES) template = self.response_template(GET_IDENTITY_MAIL_FROM_DOMAIN_ATTRIBUTES)
return template.render(identities=identities) return template.render(identities=identities)
@ -301,7 +305,7 @@ class EmailResponse(BaseResponse):
def get_identity_verification_attributes(self): def get_identity_verification_attributes(self):
params = self._get_params() params = self._get_params()
identities = params.get("Identities") identities = params.get("Identities")
verification_attributes = ses_backend.get_identity_verification_attributes( verification_attributes = self.backend.get_identity_verification_attributes(
identities=identities, identities=identities,
) )

View File

@ -11,6 +11,7 @@ from moto.sts.utils import (
random_assumed_role_id, random_assumed_role_id,
DEFAULT_STS_SESSION_DURATION, DEFAULT_STS_SESSION_DURATION,
) )
from typing import Mapping
class Token(BaseModel): class Token(BaseModel):
@ -138,7 +139,6 @@ class STSBackend(BaseBackend):
pass pass
sts_backends = BackendDict( sts_backends: Mapping[str, STSBackend] = BackendDict(
STSBackend, "sts", use_boto3_regions=False, additional_regions=["global"] STSBackend, "sts", use_boto3_regions=False, additional_regions=["global"]
) )
sts_backend = sts_backends["global"]

View File

@ -1,16 +1,20 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core import get_account_id from moto.core import get_account_id
from moto.iam import iam_backend from moto.iam import iam_backends
from .exceptions import STSValidationError from .exceptions import STSValidationError
from .models import sts_backend from .models import sts_backends
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse): class TokenResponse(BaseResponse):
@property
def backend(self):
return sts_backends["global"]
def get_session_token(self): def get_session_token(self):
duration = int(self.querystring.get("DurationSeconds", [43200])[0]) duration = int(self.querystring.get("DurationSeconds", [43200])[0])
token = sts_backend.get_session_token(duration=duration) token = self.backend.get_session_token(duration=duration)
template = self.response_template(GET_SESSION_TOKEN_RESPONSE) template = self.response_template(GET_SESSION_TOKEN_RESPONSE)
return template.render(token=token) return template.render(token=token)
@ -27,7 +31,7 @@ class TokenResponse(BaseResponse):
) )
name = self.querystring.get("Name")[0] name = self.querystring.get("Name")[0]
token = sts_backend.get_federation_token(duration=duration, name=name) token = self.backend.get_federation_token(duration=duration, name=name)
template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE)
return template.render(token=token, account_id=get_account_id()) return template.render(token=token, account_id=get_account_id())
@ -39,7 +43,7 @@ class TokenResponse(BaseResponse):
duration = int(self.querystring.get("DurationSeconds", [3600])[0]) duration = int(self.querystring.get("DurationSeconds", [3600])[0])
external_id = self.querystring.get("ExternalId", [None])[0] external_id = self.querystring.get("ExternalId", [None])[0]
role = sts_backend.assume_role( role = self.backend.assume_role(
role_session_name=role_session_name, role_session_name=role_session_name,
role_arn=role_arn, role_arn=role_arn,
policy=policy, policy=policy,
@ -57,7 +61,7 @@ class TokenResponse(BaseResponse):
duration = int(self.querystring.get("DurationSeconds", [3600])[0]) duration = int(self.querystring.get("DurationSeconds", [3600])[0])
external_id = self.querystring.get("ExternalId", [None])[0] external_id = self.querystring.get("ExternalId", [None])[0]
role = sts_backend.assume_role_with_web_identity( role = self.backend.assume_role_with_web_identity(
role_session_name=role_session_name, role_session_name=role_session_name,
role_arn=role_arn, role_arn=role_arn,
policy=policy, policy=policy,
@ -72,7 +76,7 @@ class TokenResponse(BaseResponse):
principal_arn = self.querystring.get("PrincipalArn")[0] principal_arn = self.querystring.get("PrincipalArn")[0]
saml_assertion = self.querystring.get("SAMLAssertion")[0] saml_assertion = self.querystring.get("SAMLAssertion")[0]
role = sts_backend.assume_role_with_saml( role = self.backend.assume_role_with_saml(
role_arn=role_arn, role_arn=role_arn,
principal_arn=principal_arn, principal_arn=principal_arn,
saml_assertion=saml_assertion, saml_assertion=saml_assertion,
@ -88,12 +92,12 @@ class TokenResponse(BaseResponse):
arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=get_account_id()) arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=get_account_id())
access_key_id = self.get_current_user() access_key_id = self.get_current_user()
assumed_role = sts_backend.get_assumed_role_from_access_key(access_key_id) assumed_role = self.backend.get_assumed_role_from_access_key(access_key_id)
if assumed_role: if assumed_role:
user_id = assumed_role.user_id user_id = assumed_role.user_id
arn = assumed_role.arn arn = assumed_role.arn
user = iam_backend.get_user_from_access_key_id(access_key_id) user = iam_backends["global"].get_user_from_access_key_id(access_key_id)
if user: if user:
user_id = user.id user_id = user.id
arn = user.arn arn = user.arn

View File

@ -45,4 +45,4 @@ def test_domain_dispatched_with_service():
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")
backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"})
keys = set(backend_app.view_functions.keys()) keys = set(backend_app.view_functions.keys())
keys.should.contain("ResponseObject.key_response") keys.should.contain("S3Response.key_response")

View File

@ -647,7 +647,10 @@ def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
if append_key_id: if append_key_id:
target_id += key_id target_id += key_id
client.generate_data_key(KeyId=target_id, NumberOfBytes=32) resp = client.generate_data_key(KeyId=target_id, NumberOfBytes=32)
resp.should.have.key("KeyId").equals(
f"arn:aws:kms:us-east-1:123456789012:key/{key_id}"
)
@mock_kms @mock_kms

View File

@ -75,6 +75,9 @@ def test_s3_server_ignore_subdomain_for_bucketnames():
def test_s3_server_bucket_versioning(): def test_s3_server_bucket_versioning():
test_client = authenticated_client() test_client = authenticated_client()
res = test_client.put("/", "http://foobaz.localhost:5000/")
res.status_code.should.equal(200)
# Just enough XML to enable versioning # Just enough XML to enable versioning
body = "<Status>Enabled</Status>" body = "<Status>Enabled</Status>"
res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body) res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body)