Techdebt - Align models-responses integration for all services (#5207)

This commit is contained in:
Bert Blommers 2022-06-09 17:40:22 +00:00 committed by GitHub
parent 47fe052c6f
commit a2c2c06243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 609 additions and 747 deletions

View File

@ -4,6 +4,7 @@ from collections import defaultdict
import copy
import datetime
from gzip import GzipFile
from typing import Mapping
from sys import platform
import docker
@ -25,10 +26,10 @@ import requests.exceptions
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend
from moto.iam.models import iam_backends
from moto.iam.exceptions import IAMNotFoundException
from moto.core.utils import unix_time_millis, BackendDict
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
@ -182,7 +183,7 @@ def _validate_s3_bucket_and_key(data):
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_object(data["S3Bucket"], data["S3Key"])
key = s3_backends["global"].get_object(data["S3Bucket"], data["S3Key"])
except MissingBucket:
if do_validate_s3():
raise InvalidParameterValueException(
@ -585,7 +586,7 @@ class LambdaFunction(CloudFormationModel, DockerModel):
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_object(
key = s3_backends["global"].get_object(
updated_spec["S3Bucket"], updated_spec["S3Key"]
)
except MissingBucket:
@ -1121,7 +1122,7 @@ class LambdaStorage(object):
if account != get_account_id():
raise CrossAccountNotAllowed()
try:
iam_backend.get_role_by_arn(fn.role)
iam_backends["global"].get_role_by_arn(fn.role)
except IAMNotFoundException:
raise InvalidParameterValueException(
"The role defined for the function cannot be assumed by Lambda."
@ -1666,4 +1667,4 @@ def do_validate_s3():
return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"]
lambda_backends = BackendDict(LambdaBackend, "lambda")
lambda_backends: Mapping[str, LambdaBackend] = BackendDict(LambdaBackend, "lambda")

View File

@ -47,7 +47,7 @@ from moto.ssm import models # noqa # pylint: disable=all
# End ugly list of imports
from moto.core import get_account_id, CloudFormationModel
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from moto.s3.utils import bucket_and_name_from_url
from moto.ssm import ssm_backends
from .utils import random_suffix
@ -528,7 +528,7 @@ class ResourceMap(collections_abc.Mapping):
if name == "AWS::Include":
location = params["Location"]
bucket_name, name = bucket_and_name_from_url(location)
key = s3_backend.get_object(bucket_name, name)
key = s3_backends["global"].get_object(bucket_name, name)
self._parsed_resources.update(json.loads(key.value))
def parse_ssm_parameter(self, value, value_type):

View File

@ -6,7 +6,7 @@ from yaml.scanner import ScannerError # pylint:disable=c-extension-no-member
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from moto.s3.exceptions import S3ClientError
from moto.core import get_account_id
from .models import cloudformation_backends
@ -68,7 +68,7 @@ class CloudFormationResponse(BaseResponse):
bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/")
key = s3_backend.get_object(bucket_name, key_name)
key = s3_backends["global"].get_object(bucket_name, key_name)
return key.value.decode("utf-8")
def _get_params_from_list(self, parameters_list):

View File

@ -255,4 +255,3 @@ cloudfront_backends = BackendDict(
use_boto3_regions=False,
additional_regions=["global"],
)
cloudfront_backend = cloudfront_backends["global"]

View File

@ -1,7 +1,7 @@
import xmltodict
from moto.core.responses import BaseResponse
from .models import cloudfront_backend
from .models import cloudfront_backends
XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/"
@ -11,6 +11,10 @@ class CloudFrontResponse(BaseResponse):
def _get_xml_body(self):
return xmltodict.parse(self.body, dict_constructor=dict)
@property
def backend(self):
return cloudfront_backends["global"]
def distributions(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "POST":
@ -21,7 +25,7 @@ class CloudFrontResponse(BaseResponse):
def create_distribution(self):
params = self._get_xml_body()
distribution_config = params.get("DistributionConfig")
distribution, location, e_tag = cloudfront_backend.create_distribution(
distribution, location, e_tag = self.backend.create_distribution(
distribution_config=distribution_config
)
template = self.response_template(CREATE_DISTRIBUTION_TEMPLATE)
@ -30,7 +34,7 @@ class CloudFrontResponse(BaseResponse):
return 200, headers, response
def list_distributions(self):
distributions = cloudfront_backend.list_distributions()
distributions = self.backend.list_distributions()
template = self.response_template(LIST_TEMPLATE)
response = template.render(distributions=distributions)
return 200, {}, response
@ -40,10 +44,10 @@ class CloudFrontResponse(BaseResponse):
distribution_id = full_url.split("/")[-1]
if request.method == "DELETE":
if_match = self._get_param("If-Match")
cloudfront_backend.delete_distribution(distribution_id, if_match)
self.backend.delete_distribution(distribution_id, if_match)
return 204, {}, ""
if request.method == "GET":
dist, etag = cloudfront_backend.get_distribution(distribution_id)
dist, etag = self.backend.get_distribution(distribution_id)
template = self.response_template(GET_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS)
return 200, {"ETag": etag}, response
@ -55,7 +59,7 @@ class CloudFrontResponse(BaseResponse):
dist_id = full_url.split("/")[-2]
if_match = headers["If-Match"]
dist, location, e_tag = cloudfront_backend.update_distribution(
dist, location, e_tag = self.backend.update_distribution(
DistributionConfig=distribution_config,
Id=dist_id,
IfMatch=if_match,

View File

@ -130,10 +130,10 @@ class Trail(BaseModel):
raise TrailNameInvalidChars()
def check_bucket_exists(self):
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
try:
s3_backend.get_bucket(self.bucket_name)
s3_backends["global"].get_bucket(self.bucket_name)
except Exception:
raise S3BucketDoesNotExistException(
f"S3 bucket {self.bucket_name} does not exist!"

View File

@ -544,7 +544,6 @@ class CloudWatchBackend(BaseBackend):
unit=None,
):
period_delta = timedelta(seconds=period)
# TODO: Also filter by unit and dimensions
filtered_data = [
md
for md in self.get_all_metrics()

View File

@ -4,6 +4,10 @@ from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse):
@property
def backend(self):
return cognitoidentity_backends[self.region]
def create_identity_pool(self):
identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param(
@ -16,7 +20,7 @@ class CognitoIdentityResponse(BaseResponse):
saml_provider_arns = self._get_param("SamlProviderARNs")
pool_tags = self._get_param("IdentityPoolTags")
return cognitoidentity_backends[self.region].create_identity_pool(
return self.backend.create_identity_pool(
identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers,
@ -38,7 +42,7 @@ class CognitoIdentityResponse(BaseResponse):
saml_providers = self._get_param("SamlProviderARNs")
pool_tags = self._get_param("IdentityPoolTags")
return cognitoidentity_backends[self.region].update_identity_pool(
return self.backend.update_identity_pool(
identity_pool_id=pool_id,
identity_pool_name=pool_name,
allow_unauthenticated=allow_unauthenticated,
@ -51,19 +55,13 @@ class CognitoIdentityResponse(BaseResponse):
)
def get_id(self):
return cognitoidentity_backends[self.region].get_id(
identity_pool_id=self._get_param("IdentityPoolId")
)
return self.backend.get_id(identity_pool_id=self._get_param("IdentityPoolId"))
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(
self._get_param("IdentityPoolId")
)
return self.backend.describe_identity_pool(self._get_param("IdentityPoolId"))
def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(
self._get_param("IdentityId")
)
return self.backend.get_credentials_for_identity(self._get_param("IdentityId"))
def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[
@ -73,11 +71,11 @@ class CognitoIdentityResponse(BaseResponse):
)
def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token(
return self.backend.get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)
def list_identities(self):
return cognitoidentity_backends[self.region].list_identities(
return self.backend.list_identities(
self._get_param("IdentityPoolId") or get_random_identity_id(self.region)
)

View File

@ -20,12 +20,14 @@ class CognitoIdpResponse(BaseResponse):
def parameters(self):
return json.loads(self.body)
@property
def backend(self):
return cognitoidp_backends[self.region]
# User pool
def create_user_pool(self):
name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool(
name, self.parameters
)
user_pool = self.backend.create_user_pool(name, self.parameters)
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def set_user_pool_mfa_config(self):
@ -50,22 +52,20 @@ class CognitoIdpResponse(BaseResponse):
"[SmsConfiguration] is a required member of [SoftwareTokenMfaConfiguration]."
)
response = cognitoidp_backends[self.region].set_user_pool_mfa_config(
response = self.backend.set_user_pool_mfa_config(
user_pool_id, sms_config, token_config, mfa_config
)
return json.dumps(response)
def get_user_pool_mfa_config(self):
user_pool_id = self._get_param("UserPoolId")
response = cognitoidp_backends[self.region].get_user_pool_mfa_config(
user_pool_id
)
response = self.backend.get_user_pool_mfa_config(user_pool_id)
return json.dumps(response)
def list_user_pools(self):
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken")
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
user_pools, next_token = self.backend.list_user_pools(
max_results=max_results, next_token=next_token
)
response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
@ -75,16 +75,16 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id)
user_pool = self.backend.describe_user_pool(user_pool_id)
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def update_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].update_user_pool(user_pool_id, self.parameters)
self.backend.update_user_pool(user_pool_id, self.parameters)
def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].delete_user_pool(user_pool_id)
self.backend.delete_user_pool(user_pool_id)
return ""
# User pool domain
@ -92,7 +92,7 @@ class CognitoIdpResponse(BaseResponse):
domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain(
user_pool_domain = self.backend.create_user_pool_domain(
user_pool_id, domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
@ -102,9 +102,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_domain(self):
domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(
domain
)
user_pool_domain = self.backend.describe_user_pool_domain(domain)
domain_description = {}
if user_pool_domain:
domain_description = user_pool_domain.to_json()
@ -113,13 +111,13 @@ class CognitoIdpResponse(BaseResponse):
def delete_user_pool_domain(self):
domain = self._get_param("Domain")
cognitoidp_backends[self.region].delete_user_pool_domain(domain)
self.backend.delete_user_pool_domain(domain)
return ""
def update_user_pool_domain(self):
domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain(
user_pool_domain = self.backend.update_user_pool_domain(
domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
@ -131,7 +129,7 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
generate_secret = self.parameters.pop("GenerateSecret", False)
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
user_pool_client = self.backend.create_user_pool_client(
user_pool_id, generate_secret, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -157,7 +155,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(
user_pool_client = self.backend.describe_user_pool_client(
user_pool_id, client_id
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -165,7 +163,7 @@ class CognitoIdpResponse(BaseResponse):
def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(
user_pool_client = self.backend.update_user_pool_client(
user_pool_id, client_id, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
@ -173,16 +171,14 @@ class CognitoIdpResponse(BaseResponse):
def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client(
user_pool_id, client_id
)
self.backend.delete_user_pool_client(user_pool_id, client_id)
return ""
# Identity provider
def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider(
identity_provider = self.backend.create_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
@ -210,9 +206,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(
user_pool_id, name
)
identity_provider = self.backend.describe_identity_provider(user_pool_id, name)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
@ -220,7 +214,7 @@ class CognitoIdpResponse(BaseResponse):
def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider(
identity_provider = self.backend.update_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
@ -230,7 +224,7 @@ class CognitoIdpResponse(BaseResponse):
def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
cognitoidp_backends[self.region].delete_identity_provider(user_pool_id, name)
self.backend.delete_identity_provider(user_pool_id, name)
return ""
# Group
@ -241,7 +235,7 @@ class CognitoIdpResponse(BaseResponse):
role_arn = self._get_param("RoleArn")
precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group(
group = self.backend.create_group(
user_pool_id, group_name, description, role_arn, precedence
)
@ -250,18 +244,18 @@ class CognitoIdpResponse(BaseResponse):
def get_group(self):
group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name)
group = self.backend.get_group(user_pool_id, group_name)
return json.dumps({"Group": group.to_json()})
def list_groups(self):
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id)
groups = self.backend.list_groups(user_pool_id)
return json.dumps({"Groups": [group.to_json() for group in groups]})
def delete_group(self):
group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].delete_group(user_pool_id, group_name)
self.backend.delete_group(user_pool_id, group_name)
return ""
def update_group(self):
@ -271,7 +265,7 @@ class CognitoIdpResponse(BaseResponse):
role_arn = self._get_param("RoleArn")
precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].update_group(
group = self.backend.update_group(
user_pool_id, group_name, description, role_arn, precedence
)
@ -282,26 +276,20 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group(
user_pool_id, group_name, username
)
self.backend.admin_add_user_to_group(user_pool_id, group_name, username)
return ""
def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group(
user_pool_id, group_name
)
users = self.backend.list_users_in_group(user_pool_id, group_name)
return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self):
username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(
user_pool_id, username
)
groups = self.backend.admin_list_groups_for_user(user_pool_id, username)
return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self):
@ -309,18 +297,14 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group(
user_pool_id, group_name, username
)
self.backend.admin_remove_user_from_group(user_pool_id, group_name, username)
return ""
def admin_reset_user_password(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
cognitoidp_backends[self.region].admin_reset_user_password(
user_pool_id, username
)
self.backend.admin_reset_user_password(user_pool_id, username)
return ""
# User
@ -329,7 +313,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
message_action = self._get_param("MessageAction")
temporary_password = self._get_param("TemporaryPassword")
user = cognitoidp_backends[self.region].admin_create_user(
user = self.backend.admin_create_user(
user_pool_id,
username,
message_action,
@ -342,14 +326,12 @@ class CognitoIdpResponse(BaseResponse):
def admin_confirm_sign_up(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
return cognitoidp_backends[self.region].admin_confirm_sign_up(
user_pool_id, username
)
return self.backend.admin_confirm_sign_up(user_pool_id, username)
def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username)
user = self.backend.admin_get_user(user_pool_id, username)
return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
def get_user(self):
@ -363,7 +345,7 @@ class CognitoIdpResponse(BaseResponse):
token = self._get_param("PaginationToken")
filt = self._get_param("Filter")
attributes_to_get = self._get_param("AttributesToGet")
users, token = cognitoidp_backends[self.region].list_users(
users, token = self.backend.list_users(
user_pool_id, limit=limit, pagination_token=token
)
if filt:
@ -420,19 +402,19 @@ class CognitoIdpResponse(BaseResponse):
def admin_disable_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
cognitoidp_backends[self.region].admin_disable_user(user_pool_id, username)
self.backend.admin_disable_user(user_pool_id, username)
return ""
def admin_enable_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
cognitoidp_backends[self.region].admin_enable_user(user_pool_id, username)
self.backend.admin_enable_user(user_pool_id, username)
return ""
def admin_delete_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
cognitoidp_backends[self.region].admin_delete_user(user_pool_id, username)
self.backend.admin_delete_user(user_pool_id, username)
return ""
def admin_initiate_auth(self):
@ -441,7 +423,7 @@ class CognitoIdpResponse(BaseResponse):
auth_flow = self._get_param("AuthFlow")
auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth(
auth_result = self.backend.admin_initiate_auth(
user_pool_id, client_id, auth_flow, auth_parameters
)
@ -501,31 +483,25 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes(
user_pool_id, username, attributes
)
self.backend.admin_update_user_attributes(user_pool_id, username, attributes)
return ""
def admin_delete_user_attributes(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
attributes = self._get_param("UserAttributeNames")
cognitoidp_backends[self.region].admin_delete_user_attributes(
user_pool_id, username, attributes
)
self.backend.admin_delete_user_attributes(user_pool_id, username, attributes)
return ""
def admin_user_global_sign_out(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
cognitoidp_backends[self.region].admin_user_global_sign_out(
user_pool_id, username
)
self.backend.admin_user_global_sign_out(user_pool_id, username)
return ""
def global_sign_out(self):
access_token = self._get_param("AccessToken")
cognitoidp_backends[self.region].global_sign_out(access_token)
self.backend.global_sign_out(access_token)
return ""
# Resource Server
@ -534,7 +510,7 @@ class CognitoIdpResponse(BaseResponse):
identifier = self._get_param("Identifier")
name = self._get_param("Name")
scopes = self._get_param("Scopes")
resource_server = cognitoidp_backends[self.region].create_resource_server(
resource_server = self.backend.create_resource_server(
user_pool_id, identifier, name, scopes
)
return json.dumps({"ResourceServer": resource_server.to_json()})
@ -575,19 +551,19 @@ class CognitoIdpResponse(BaseResponse):
def associate_software_token(self):
access_token = self._get_param("AccessToken")
result = cognitoidp_backends[self.region].associate_software_token(access_token)
result = self.backend.associate_software_token(access_token)
return json.dumps(result)
def verify_software_token(self):
access_token = self._get_param("AccessToken")
result = cognitoidp_backends[self.region].verify_software_token(access_token)
result = self.backend.verify_software_token(access_token)
return json.dumps(result)
def set_user_mfa_preference(self):
access_token = self._get_param("AccessToken")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings")
cognitoidp_backends[self.region].set_user_mfa_preference(
self.backend.set_user_mfa_preference(
access_token, software_token_mfa_settings, sms_mfa_settings
)
return ""
@ -597,7 +573,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
software_token_mfa_settings = self._get_param("SoftwareTokenMfaSettings")
sms_mfa_settings = self._get_param("SMSMfaSettings")
cognitoidp_backends[self.region].admin_set_user_mfa_preference(
self.backend.admin_set_user_mfa_preference(
user_pool_id, username, software_token_mfa_settings, sms_mfa_settings
)
return ""
@ -607,7 +583,7 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
password = self._get_param("Password")
permanent = self._get_param("Permanent")
cognitoidp_backends[self.region].admin_set_user_password(
self.backend.admin_set_user_password(
user_pool_id, username, password, permanent
)
return ""
@ -615,17 +591,13 @@ class CognitoIdpResponse(BaseResponse):
def add_custom_attributes(self):
user_pool_id = self._get_param("UserPoolId")
custom_attributes = self._get_param("CustomAttributes")
cognitoidp_backends[self.region].add_custom_attributes(
user_pool_id, custom_attributes
)
self.backend.add_custom_attributes(user_pool_id, custom_attributes)
return ""
def update_user_attributes(self):
access_token = self._get_param("AccessToken")
attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].update_user_attributes(
access_token, attributes
)
self.backend.update_user_attributes(access_token, attributes)
return json.dumps({})

View File

@ -1468,14 +1468,11 @@ class ConfigBackend(BaseBackend):
backend_query_region = (
backend_region # Always provide the backend this request arrived from.
)
print(RESOURCE_MAP[resource_type].backends)
if RESOURCE_MAP[resource_type].backends.get("global"):
print("yes, its global")
backend_region = "global"
# If the backend region isn't implemented then we won't find the item:
if not RESOURCE_MAP[resource_type].backends.get(backend_region):
print(f"cant find {backend_region} for {resource_type}")
raise ResourceNotDiscoveredException(resource_type, resource_id)
# Get the item:
@ -1483,7 +1480,6 @@ class ConfigBackend(BaseBackend):
resource_id, backend_region=backend_query_region
)
if not item:
print("item not found")
raise ResourceNotDiscoveredException(resource_type, resource_id)
item["accountId"] = get_account_id()

View File

@ -204,7 +204,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def dispatch(cls, *args, **kwargs):
return cls()._dispatch(*args, **kwargs)
def setup_class(self, request, full_url, headers):
def setup_class(self, request, full_url, headers, use_raw_body=False):
"""
use_raw_body: Use incoming bytes if True, encode to string otherwise
"""
querystring = OrderedDict()
if hasattr(request, "body"):
# Boto
@ -222,7 +225,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring[key] = [value]
raw_body = self.body
if isinstance(self.body, bytes):
if isinstance(self.body, bytes) and not use_raw_body:
self.body = self.body.decode("utf-8")
if not querystring:
@ -244,7 +247,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items():
querystring[key] = [value]
elif self.body:
elif self.body and not use_raw_body:
try:
querystring.update(
OrderedDict(
@ -254,7 +257,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
)
)
except (UnicodeEncodeError, UnicodeDecodeError):
except (UnicodeEncodeError, UnicodeDecodeError, AttributeError):
pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring:
querystring.update(headers)

View File

@ -412,6 +412,54 @@ def extract_region_from_aws_authorization(string):
backend_lock = RLock()
class AccountSpecificBackend(dict):
"""
Dictionary storing the data for a service in a specific account.
Data access pattern:
account_specific_backend[region: str] = backend: BaseBackend
"""
def __init__(
self, service_name, account_id, backend, use_boto3_regions, additional_regions
):
self.service_name = service_name
self.account_id = account_id
self.backend = backend
self.regions = []
if use_boto3_regions:
sess = Session()
self.regions.extend(sess.get_available_regions(service_name))
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
self.regions.extend(additional_regions or [])
def reset(self):
for region_specific_backend in self.values():
region_specific_backend.reset()
def __contains__(self, region):
return region in self.regions or region in self.keys()
def __getitem__(self, region_name):
if region_name in self.keys():
return super().__getitem__(region_name)
# Create the backend for a specific region
with backend_lock:
if region_name in self.regions and region_name not in self.keys():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
if region_name not in self.regions and allow_unknown_region():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
return super().__getitem__(region_name)
class BackendDict(dict):
"""
Data Structure to store everything related to a specific service.
@ -484,51 +532,3 @@ class BackendDict(dict):
use_boto3_regions=self._use_boto3_regions,
additional_regions=self._additional_regions,
)
class AccountSpecificBackend(dict):
"""
Dictionary storing the data for a service in a specific account.
Data access pattern:
account_specific_backend[region: str] = backend: BaseBackend
"""
def __init__(
self, service_name, account_id, backend, use_boto3_regions, additional_regions
):
self.service_name = service_name
self.account_id = account_id
self.backend = backend
self.regions = []
if use_boto3_regions:
sess = Session()
self.regions.extend(sess.get_available_regions(service_name))
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-us-gov")
)
self.regions.extend(
sess.get_available_regions(service_name, partition_name="aws-cn")
)
self.regions.extend(additional_regions or [])
def reset(self):
for region_specific_backend in self.values():
region_specific_backend.reset()
def __contains__(self, region):
return region in self.regions or region in self.keys()
def __getitem__(self, region_name):
if region_name in self.keys():
return super().__getitem__(region_name)
# Create the backend for a specific region
with backend_lock:
if region_name in self.regions and region_name not in self.keys():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
if region_name not in self.regions and allow_unknown_region():
super().__setitem__(
region_name, self.backend(region_name, account_id=self.account_id)
)
return super().__getitem__(region_name)

View File

@ -5,23 +5,15 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse):
@property
def parameters(self):
# TODO this should really be moved to core/responses.py
if self.body:
return json.loads(self.body)
else:
return self.querystring
@property
def datapipeline_backend(self):
return datapipeline_backends[self.region]
def create_pipeline(self):
name = self.parameters.get("name")
unique_id = self.parameters.get("uniqueId")
description = self.parameters.get("description", "")
tags = self.parameters.get("tags", [])
name = self._get_param("name")
unique_id = self._get_param("uniqueId")
description = self._get_param("description", "")
tags = self._get_param("tags", [])
pipeline = self.datapipeline_backend.create_pipeline(
name, unique_id, description=description, tags=tags
)
@ -31,7 +23,7 @@ class DataPipelineResponse(BaseResponse):
pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50
marker = self.parameters.get("marker")
marker = self._get_param("marker")
if marker:
start = pipeline_ids.index(marker) + 1
else:
@ -53,7 +45,7 @@ class DataPipelineResponse(BaseResponse):
)
def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"]
pipeline_ids = self._get_param("pipelineIds")
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps(
@ -61,19 +53,19 @@ class DataPipelineResponse(BaseResponse):
)
def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_id = self._get_param("pipelineId")
self.datapipeline_backend.delete_pipeline(pipeline_id)
return json.dumps({})
def put_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"]
pipeline_id = self._get_param("pipelineId")
pipeline_objects = self._get_param("pipelineObjects")
self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
return json.dumps({"errored": False})
def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_id = self._get_param("pipelineId")
pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id
)
@ -86,8 +78,8 @@ class DataPipelineResponse(BaseResponse):
)
def describe_objects(self):
pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"]
pipeline_id = self._get_param("pipelineId")
object_ids = self._get_param("objectIds")
pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id
@ -103,6 +95,6 @@ class DataPipelineResponse(BaseResponse):
)
def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_id = self._get_param("pipelineId")
self.datapipeline_backend.activate_pipeline(pipeline_id)
return json.dumps({})

View File

@ -397,4 +397,3 @@ dynamodb_backends = BackendDict(
use_boto3_regions=False,
additional_regions=["global"],
)
dynamodb_backend = dynamodb_backends["global"]

View File

@ -2,7 +2,7 @@ import json
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backend, dynamo_json_dump
from .models import dynamodb_backends, dynamo_json_dump
class DynamoHandler(BaseResponse):
@ -36,15 +36,19 @@ class DynamoHandler(BaseResponse):
else:
return 404, self.response_headers, ""
@property
def backend(self):
return dynamodb_backends["global"]
def list_tables(self):
body = self.body
limit = body.get("Limit")
if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName")
start = list(dynamodb_backend.tables.keys()).index(last) + 1
start = list(self.backend.tables.keys()).index(last) + 1
else:
start = 0
all_tables = list(dynamodb_backend.tables.keys())
all_tables = list(self.backend.tables.keys())
if limit:
tables = all_tables[start : start + limit]
else:
@ -71,7 +75,7 @@ class DynamoHandler(BaseResponse):
read_units = throughput["ReadCapacityUnits"]
write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.create_table(
table = self.backend.create_table(
name,
hash_key_attr=hash_key_attr,
hash_key_type=hash_key_type,
@ -84,7 +88,7 @@ class DynamoHandler(BaseResponse):
def delete_table(self):
name = self.body["TableName"]
table = dynamodb_backend.delete_table(name)
table = self.backend.delete_table(name)
if table:
return dynamo_json_dump(table.describe)
else:
@ -96,7 +100,7 @@ class DynamoHandler(BaseResponse):
throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"]
new_write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.update_table_throughput(
table = self.backend.update_table_throughput(
name, new_read_units, new_write_units
)
return dynamo_json_dump(table.describe)
@ -104,7 +108,7 @@ class DynamoHandler(BaseResponse):
def describe_table(self):
name = self.body["TableName"]
try:
table = dynamodb_backend.tables[name]
table = self.backend.tables[name]
except KeyError:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
@ -113,7 +117,7 @@ class DynamoHandler(BaseResponse):
def put_item(self):
name = self.body["TableName"]
item = self.body["Item"]
result = dynamodb_backend.put_item(name, item)
result = self.backend.put_item(name, item)
if result:
item_dict = result.to_json()
item_dict["ConsumedCapacityUnits"] = 1
@ -132,12 +136,12 @@ class DynamoHandler(BaseResponse):
if request_type == "PutRequest":
item = request["Item"]
dynamodb_backend.put_item(table_name, item)
self.backend.put_item(table_name, item)
elif request_type == "DeleteRequest":
key = request["Key"]
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
item = dynamodb_backend.delete_item(table_name, hash_key, range_key)
self.backend.delete_item(table_name, hash_key, range_key)
response = {
"Responses": {
@ -156,7 +160,7 @@ class DynamoHandler(BaseResponse):
range_key = key.get("RangeKeyElement")
attrs_to_get = self.body.get("AttributesToGet")
try:
item = dynamodb_backend.get_item(name, hash_key, range_key)
item = self.backend.get_item(name, hash_key, range_key)
except ValueError:
er = "com.amazon.coral.validate#ValidationException"
return self.error(er, status=400)
@ -181,7 +185,7 @@ class DynamoHandler(BaseResponse):
for key in keys:
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
item = dynamodb_backend.get_item(table_name, hash_key, range_key)
item = self.backend.get_item(table_name, hash_key, range_key)
if item:
item_describe = item.describe_attrs(attributes_to_get)
items.append(item_describe)
@ -202,9 +206,7 @@ class DynamoHandler(BaseResponse):
range_comparison = None
range_values = []
items, _ = dynamodb_backend.query(
name, hash_key, range_comparison, range_values
)
items, _ = self.backend.query(name, hash_key, range_comparison, range_values)
if items is None:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
@ -236,7 +238,7 @@ class DynamoHandler(BaseResponse):
comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values)
items, scanned_count, _ = dynamodb_backend.scan(name, filters)
items, scanned_count, _ = self.backend.scan(name, filters)
if items is None:
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
@ -263,7 +265,7 @@ class DynamoHandler(BaseResponse):
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
return_values = self.body.get("ReturnValues", "")
item = dynamodb_backend.delete_item(name, hash_key, range_key)
item = self.backend.delete_item(name, hash_key, range_key)
if item:
if return_values == "ALL_OLD":
item_dict = item.to_json()
@ -282,7 +284,7 @@ class DynamoHandler(BaseResponse):
range_key = key.get("RangeKeyElement")
updates = self.body["AttributeUpdates"]
item = dynamodb_backend.update_item(name, hash_key, range_key, updates)
item = self.backend.update_item(name, hash_key, range_key, updates)
if item:
item_dict = item.to_json()

View File

@ -214,12 +214,12 @@ class FlowLogsBackend:
self.get_network_interface(resource_id)
if log_destination_type == "s3":
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from moto.s3.exceptions import MissingBucket
arn = log_destination.split(":", 5)[5]
try:
s3_backend.get_bucket(arn)
s3_backends["global"].get_bucket(arn)
except MissingBucket:
# Instead of creating FlowLog report
# the unsuccessful status for the

View File

@ -1547,9 +1547,9 @@ Member must satisfy regular expression pattern: {}".format(
except AWSResourceNotFoundException:
pass
from moto.iam import iam_backend
from moto.iam import iam_backends
cert = iam_backend.get_certificate_by_arn(certificate_arn)
cert = iam_backends["global"].get_certificate_by_arn(certificate_arn)
if cert is not None:
return True

View File

@ -37,7 +37,7 @@ from moto.firehose.exceptions import (
ResourceNotFoundException,
ValidationException,
)
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from moto.utilities.tagging_service import TaggingService
MAX_TAGS_PER_DELIVERY_STREAM = 50
@ -447,7 +447,7 @@ class FirehoseBackend(BaseBackend):
batched_data = b"".join([b64decode(r["Data"]) for r in records])
try:
s3_backend.put_object(bucket_name, object_path, batched_data)
s3_backends["global"].put_object(bucket_name, object_path, batched_data)
except Exception as exc:
# This could be better ...
raise RuntimeError(

View File

@ -1,5 +1,4 @@
from .models import iam_backend
from .models import iam_backends
from ..core.models import base_decorator
iam_backends = {"global": iam_backend}
mock_iam = base_decorator(iam_backends)

View File

@ -39,8 +39,8 @@ from moto.s3.exceptions import (
BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError,
)
from moto.sts.models import sts_backend
from .models import iam_backend, Policy
from moto.sts.models import sts_backends
from .models import iam_backends, Policy
log = logging.getLogger(__name__)
@ -53,8 +53,12 @@ def create_access_key(access_key_id, headers):
class IAMUserAccessKey(object):
@property
def backend(self):
return iam_backends["global"]
def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users("/", None, None)
iam_users = self.backend.list_users("/", None, None)
for iam_user in iam_users:
for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id:
@ -78,28 +82,30 @@ class IAMUserAccessKey(object):
def collect_policies(self):
user_policies = []
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
inline_policy_names = self.backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(
inline_policy = self.backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(
attached_policies, _ = self.backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
user_groups = self.backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
inline_group_policy_names = self.backend.list_group_policies(
user_group.name
)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(
inline_user_group_policy = self.backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(
attached_group_policies, _ = self.backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies
@ -108,8 +114,12 @@ class IAMUserAccessKey(object):
class AssumedRoleAccessKey(object):
@property
def backend(self):
return iam_backends["global"]
def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles:
for assumed_role in sts_backends["global"].assumed_roles:
if assumed_role.access_key_id == access_key_id:
self._access_key_id = access_key_id
self._secret_access_key = assumed_role.secret_access_key
@ -139,14 +149,14 @@ class AssumedRoleAccessKey(object):
def collect_policies(self):
role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
inline_policy_names = self.backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(
_, inline_policy = self.backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(
attached_policies, _ = self.backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies

View File

@ -19,6 +19,7 @@ from moto.core import BaseBackend, BaseModel, get_account_id, CloudFormationMode
from moto.core.utils import (
iso_8601_datetime_without_milliseconds,
iso_8601_datetime_with_milliseconds,
BackendDict,
)
from moto.iam.policy_validation import IAMPolicyDocumentValidator
from moto.utilities.utils import md5_hash
@ -362,7 +363,7 @@ class ManagedPolicy(Policy, CloudFormationModel):
role_names = properties.get("Roles", [])
tags = properties.get("Tags", {})
policy = iam_backend.create_policy(
policy = iam_backends["global"].create_policy(
description=description,
path=path,
policy_document=policy_document,
@ -370,13 +371,17 @@ class ManagedPolicy(Policy, CloudFormationModel):
tags=tags,
)
for group_name in group_names:
iam_backend.attach_group_policy(
iam_backends["global"].attach_group_policy(
group_name=group_name, policy_arn=policy.arn
)
for user_name in user_names:
iam_backend.attach_user_policy(user_name=user_name, policy_arn=policy.arn)
iam_backends["global"].attach_user_policy(
user_name=user_name, policy_arn=policy.arn
)
for role_name in role_names:
iam_backend.attach_role_policy(role_name=role_name, policy_arn=policy.arn)
iam_backends["global"].attach_role_policy(
role_name=role_name, policy_arn=policy.arn
)
return policy
@property
@ -466,7 +471,7 @@ class InlinePolicy(CloudFormationModel):
role_names = properties.get("Roles")
group_names = properties.get("Groups")
return iam_backend.create_inline_policy(
return iam_backends["global"].create_inline_policy(
resource_name,
policy_name,
policy_document,
@ -502,7 +507,7 @@ class InlinePolicy(CloudFormationModel):
role_names = properties.get("Roles")
group_names = properties.get("Groups")
return iam_backend.update_inline_policy(
return iam_backends["global"].update_inline_policy(
original_resource.name,
policy_name,
policy_document,
@ -515,7 +520,7 @@ class InlinePolicy(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
iam_backend.delete_inline_policy(resource_name)
iam_backends["global"].delete_inline_policy(resource_name)
@staticmethod
def is_replacement_update(properties):
@ -606,7 +611,7 @@ class Role(CloudFormationModel):
properties = cloudformation_json["Properties"]
role_name = properties.get("RoleName", resource_name)
role = iam_backend.create_role(
role = iam_backends["global"].create_role(
role_name=role_name,
assume_role_policy_document=properties["AssumeRolePolicyDocument"],
path=properties.get("Path", "/"),
@ -628,14 +633,14 @@ class Role(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
for profile in iam_backend.instance_profiles.values():
for profile in iam_backends["global"].instance_profiles.values():
profile.delete_role(role_name=resource_name)
for role in iam_backend.roles.values():
for role in iam_backends["global"].roles.values():
if role.name == resource_name:
for arn in role.policies.keys():
role.delete_policy(arn)
iam_backend.delete_role(resource_name)
iam_backends["global"].delete_role(resource_name)
@property
def arn(self):
@ -649,7 +654,10 @@ class Role(CloudFormationModel):
_managed_policies = []
for key in self.managed_policies.keys():
_managed_policies.append(
{"policyArn": key, "policyName": iam_backend.managed_policies[key].name}
{
"policyArn": key,
"policyName": iam_backends["global"].managed_policies[key].name,
}
)
_role_policy_list = []
@ -659,7 +667,7 @@ class Role(CloudFormationModel):
)
_instance_profiles = []
for key, instance_profile in iam_backend.instance_profiles.items():
for key, instance_profile in iam_backends["global"].instance_profiles.items():
for _ in instance_profile.roles:
_instance_profiles.append(instance_profile.to_embedded_config_dict())
break
@ -808,7 +816,7 @@ class InstanceProfile(CloudFormationModel):
properties = cloudformation_json["Properties"]
role_names = properties["Roles"]
return iam_backend.create_instance_profile(
return iam_backends["global"].create_instance_profile(
name=resource_name,
path=properties.get("Path", "/"),
role_names=role_names,
@ -818,7 +826,7 @@ class InstanceProfile(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
iam_backend.delete_instance_profile(resource_name)
iam_backends["global"].delete_instance_profile(resource_name)
def delete_role(self, role_name):
self.roles = [role for role in self.roles if role.name != role_name]
@ -964,7 +972,7 @@ class AccessKey(CloudFormationModel):
user_name = properties.get("UserName")
status = properties.get("Status", "Active")
return iam_backend.create_access_key(user_name, status=status)
return iam_backends["global"].create_access_key(user_name, status=status)
@classmethod
def update_from_cloudformation_json(
@ -984,7 +992,7 @@ class AccessKey(CloudFormationModel):
else: # No Interruption
properties = cloudformation_json.get("Properties", {})
status = properties.get("Status")
return iam_backend.update_access_key(
return iam_backends["global"].update_access_key(
original_resource.user_name, original_resource.access_key_id, status
)
@ -992,7 +1000,7 @@ class AccessKey(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
iam_backend.delete_access_key_by_name(resource_name)
iam_backends["global"].delete_access_key_by_name(resource_name)
@staticmethod
def is_replacement_update(properties):
@ -1303,7 +1311,7 @@ class User(CloudFormationModel):
):
properties = cloudformation_json.get("Properties", {})
path = properties.get("Path")
user, _ = iam_backend.create_user(resource_name, path)
user, _ = iam_backends["global"].create_user(resource_name, path)
return user
@classmethod
@ -1334,7 +1342,7 @@ class User(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
iam_backend.delete_user(resource_name)
iam_backends["global"].delete_user(resource_name)
@staticmethod
def is_replacement_update(properties):
@ -2043,7 +2051,7 @@ class IAMBackend(BaseBackend):
instance_profile_id = random_resource_id()
roles = [iam_backend.get_role(role_name) for role_name in role_names]
roles = [iam_backends["global"].get_role(role_name) for role_name in role_names]
instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags)
self.instance_profiles[name] = instance_profile
return instance_profile
@ -2838,12 +2846,10 @@ class IAMBackend(BaseBackend):
return inline_policy
def get_inline_policy(self, policy_id):
inline_policy = None
try:
inline_policy = self.inline_policies[policy_id]
return self.inline_policies[policy_id]
except KeyError:
raise IAMNotFoundException("Inline policy {0} not found".format(policy_id))
return inline_policy
def update_inline_policy(
self,
@ -2924,4 +2930,6 @@ class IAMBackend(BaseBackend):
return True
iam_backend = IAMBackend("global")
iam_backends = BackendDict(
IAMBackend, "iam", use_boto3_regions=False, additional_regions=["global"]
)

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,7 @@ from moto.logs.exceptions import (
InvalidParameterException,
LimitExceededException,
)
from moto.s3.models import s3_backend
from moto.s3.models import s3_backends
from .utils import PAGINATION_MODEL
MAX_RESOURCE_POLICIES_PER_REGION = 10
@ -940,7 +940,7 @@ class LogsBackend(BaseBackend):
return query_id
def create_export_task(self, log_group_name, destination):
s3_backend.get_bucket(destination)
s3_backends["global"].get_bucket(destination)
if log_group_name not in self.groups:
raise ResourceNotFoundException()
task_id = uuid.uuid4()

View File

@ -5,7 +5,6 @@ from moto.core.responses import BaseResponse
from .exceptions import exception_handler
from .models import managedblockchain_backends
from .utils import (
region_from_managedblckchain_url,
networkid_from_managedblockchain_url,
proposalid_from_managedblockchain_url,
invitationid_from_managedblockchain_url,
@ -15,29 +14,21 @@ from .utils import (
class ManagedBlockchainResponse(BaseResponse):
def __init__(self, backend):
super().__init__()
self.backend = backend
@property
def backend(self):
return managedblockchain_backends[self.region]
@classmethod
@exception_handler
def network_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._network_response(request, headers)
def network_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._network_response(request, headers)
def _network_response(self, request, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
if method == "GET":
return self._all_networks_response(headers)
elif method == "POST":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._network_response_post(json_body, headers)
def _all_networks_response(self, headers):
@ -70,14 +61,10 @@ class ManagedBlockchainResponse(BaseResponse):
)
return 200, headers, json.dumps(response)
@classmethod
@exception_handler
def networkid_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._networkid_response(request, full_url, headers)
def networkid_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._networkid_response(request, full_url, headers)
def _networkid_response(self, request, full_url, headers):
method = request.method
@ -92,26 +79,18 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, response
@classmethod
@exception_handler
def proposal_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._proposal_response(request, full_url, headers)
def proposal_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._proposal_response(request, full_url, headers)
def _proposal_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url)
if method == "GET":
return self._all_proposals_response(network_id, headers)
elif method == "POST":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._proposal_response_post(network_id, json_body, headers)
def _all_proposals_response(self, network_id, headers):
@ -134,14 +113,10 @@ class ManagedBlockchainResponse(BaseResponse):
)
return 200, headers, json.dumps(response)
@classmethod
@exception_handler
def proposalid_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._proposalid_response(request, full_url, headers)
def proposalid_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._proposalid_response(request, full_url, headers)
def _proposalid_response(self, request, full_url, headers):
method = request.method
@ -156,27 +131,19 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, response
@classmethod
@exception_handler
def proposal_votes_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._proposal_votes_response(request, full_url, headers)
def proposal_votes_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._proposal_votes_response(request, full_url, headers)
def _proposal_votes_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url)
proposal_id = proposalid_from_managedblockchain_url(full_url)
if method == "GET":
return self._all_proposal_votes_response(network_id, proposal_id, headers)
elif method == "POST":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._proposal_votes_response_post(
network_id, proposal_id, json_body, headers
)
@ -196,14 +163,10 @@ class ManagedBlockchainResponse(BaseResponse):
self.backend.vote_on_proposal(network_id, proposal_id, votermemberid, vote)
return 200, headers, ""
@classmethod
@exception_handler
def invitation_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._invitation_response(request, headers)
def invitation_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._invitation_response(request, headers)
def _invitation_response(self, request, headers):
method = request.method
@ -218,14 +181,10 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, response
@classmethod
@exception_handler
def invitationid_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._invitationid_response(request, full_url, headers)
def invitationid_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._invitationid_response(request, full_url, headers)
def _invitationid_response(self, request, full_url, headers):
method = request.method
@ -238,26 +197,18 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, ""
@classmethod
@exception_handler
def member_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._member_response(request, full_url, headers)
def member_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._member_response(request, full_url, headers)
def _member_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url)
if method == "GET":
return self._all_members_response(network_id, headers)
elif method == "POST":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._member_response_post(network_id, json_body, headers)
def _all_members_response(self, network_id, headers):
@ -275,27 +226,19 @@ class ManagedBlockchainResponse(BaseResponse):
)
return 200, headers, json.dumps(response)
@classmethod
@exception_handler
def memberid_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._memberid_response(request, full_url, headers)
def memberid_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._memberid_response(request, full_url, headers)
def _memberid_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body)
member_id = memberid_from_managedblockchain_request(full_url, self.body)
if method == "GET":
return self._memberid_response_get(network_id, member_id, headers)
elif method == "PATCH":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._memberid_response_patch(
network_id, member_id, json_body, headers
)
@ -318,32 +261,24 @@ class ManagedBlockchainResponse(BaseResponse):
headers["content-type"] = "application/json"
return 200, headers, ""
@classmethod
@exception_handler
def node_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._node_response(request, full_url, headers)
def node_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._node_response(request, full_url, headers)
def _node_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body)
member_id = memberid_from_managedblockchain_request(full_url, self.body)
if method == "GET":
status = None
if "status" in querystring:
status = querystring["status"][0]
return self._all_nodes_response(network_id, member_id, status, headers)
elif method == "POST":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._node_response_post(network_id, member_id, json_body, headers)
def _all_nodes_response(self, network_id, member_id, status, headers):
@ -368,28 +303,20 @@ class ManagedBlockchainResponse(BaseResponse):
)
return 200, headers, json.dumps(response)
@classmethod
@exception_handler
def nodeid_response(clazz, request, full_url, headers):
region_name = region_from_managedblckchain_url(full_url)
response_instance = ManagedBlockchainResponse(
managedblockchain_backends[region_name]
)
return response_instance._nodeid_response(request, full_url, headers)
def nodeid_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
return self._nodeid_response(request, full_url, headers)
def _nodeid_response(self, request, full_url, headers):
method = request.method
if hasattr(request, "body"):
body = request.body
else:
body = request.data
network_id = networkid_from_managedblockchain_url(full_url)
member_id = memberid_from_managedblockchain_request(full_url, body)
member_id = memberid_from_managedblockchain_request(full_url, self.body)
node_id = nodeid_from_managedblockchain_url(full_url)
if method == "GET":
return self._nodeid_response_get(network_id, member_id, node_id, headers)
elif method == "PATCH":
json_body = json.loads(body.decode("utf-8"))
json_body = json.loads(self.body)
return self._nodeid_response_patch(
network_id, member_id, node_id, json_body, headers
)

View File

@ -3,19 +3,19 @@ from .responses import ManagedBlockchainResponse
url_bases = [r"https?://managedblockchain\.(.+)\.amazonaws.com"]
url_paths = {
"{0}/networks$": ManagedBlockchainResponse.network_response,
"{0}/networks/(?P<networkid>[^/.]+)$": ManagedBlockchainResponse.networkid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals$": ManagedBlockchainResponse.proposal_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)$": ManagedBlockchainResponse.proposalid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)/votes$": ManagedBlockchainResponse.proposal_votes_response,
"{0}/invitations$": ManagedBlockchainResponse.invitation_response,
"{0}/invitations/(?P<invitationid>[^/.]+)$": ManagedBlockchainResponse.invitationid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members$": ManagedBlockchainResponse.member_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)$": ManagedBlockchainResponse.memberid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes$": ManagedBlockchainResponse.node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes?(?P<querys>[^/.]+)$": ManagedBlockchainResponse.node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse.nodeid_response,
"{0}/networks$": ManagedBlockchainResponse().network_response,
"{0}/networks/(?P<networkid>[^/.]+)$": ManagedBlockchainResponse().networkid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals$": ManagedBlockchainResponse().proposal_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)$": ManagedBlockchainResponse().proposalid_response,
"{0}/networks/(?P<networkid>[^/.]+)/proposals/(?P<proposalid>[^/.]+)/votes$": ManagedBlockchainResponse().proposal_votes_response,
"{0}/invitations$": ManagedBlockchainResponse().invitation_response,
"{0}/invitations/(?P<invitationid>[^/.]+)$": ManagedBlockchainResponse().invitationid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members$": ManagedBlockchainResponse().member_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)$": ManagedBlockchainResponse().memberid_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes?(?P<querys>[^/.]+)$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/members/(?P<memberid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse().nodeid_response,
# >= botocore 1.19.41 (API change - memberId is now part of query-string or body)
"{0}/networks/(?P<networkid>[^/.]+)/nodes$": ManagedBlockchainResponse.node_response,
"{0}/networks/(?P<networkid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse.nodeid_response,
"{0}/networks/(?P<networkid>[^/.]+)/nodes$": ManagedBlockchainResponse().node_response,
"{0}/networks/(?P<networkid>[^/.]+)/nodes/(?P<nodeid>[^/.]+)$": ManagedBlockchainResponse().nodeid_response,
}

View File

@ -6,14 +6,6 @@ import string
from urllib.parse import parse_qs, urlparse
def region_from_managedblckchain_url(url):
domain = urlparse(url).netloc
region = "us-east-1"
if "." in domain:
region = domain.split(".")[1]
return region
def networkid_from_managedblockchain_url(full_url):
id_search = re.search(r"\/n-[A-Z0-9]{26}", full_url, re.IGNORECASE)
return_id = None

View File

@ -121,7 +121,6 @@ class FakeKey(BaseModel, ManagedState):
lock_mode=None,
lock_legal_status=None,
lock_until=None,
s3_backend=None,
):
ManagedState.__init__(
self,
@ -162,8 +161,6 @@ class FakeKey(BaseModel, ManagedState):
# Default metadata values
self._metadata["Content-Type"] = "binary/octet-stream"
self.s3_backend = s3_backend
def safe_name(self, encoding_type=None):
if encoding_type == "url":
return urllib.parse.quote(self.name, safe="")
@ -292,7 +289,7 @@ class FakeKey(BaseModel, ManagedState):
res["x-amz-object-lock-retain-until-date"] = self.lock_until
if self.lock_mode:
res["x-amz-object-lock-mode"] = self.lock_mode
tags = s3_backend.tagger.get_tag_dict_for_resource(self.arn)
tags = s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn)
if tags:
res["x-amz-tagging-count"] = len(tags.keys())
@ -1228,13 +1225,13 @@ class FakeBucket(CloudFormationModel):
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name, **kwargs
):
bucket = s3_backend.create_bucket(resource_name, region_name)
bucket = s3_backends["global"].create_bucket(resource_name, region_name)
properties = cloudformation_json.get("Properties", {})
if "BucketEncryption" in properties:
bucket_encryption = cfn_to_api_encryption(properties["BucketEncryption"])
s3_backend.put_bucket_encryption(
s3_backends["global"].put_bucket_encryption(
bucket_name=resource_name, encryption=bucket_encryption
)
@ -1264,7 +1261,7 @@ class FakeBucket(CloudFormationModel):
bucket_encryption = cfn_to_api_encryption(
properties["BucketEncryption"]
)
s3_backend.put_bucket_encryption(
s3_backends["global"].put_bucket_encryption(
bucket_name=original_resource.name, encryption=bucket_encryption
)
return original_resource
@ -1273,7 +1270,7 @@ class FakeBucket(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
s3_backend.delete_bucket(resource_name)
s3_backends["global"].delete_bucket(resource_name)
def to_config_dict(self):
"""Return the AWS Config JSON format of this S3 bucket.
@ -1298,7 +1295,7 @@ class FakeBucket(CloudFormationModel):
"resourceCreationTime": str(self.creation_date),
"relatedEvents": [],
"relationships": [],
"tags": s3_backend.tagger.get_tag_dict_for_resource(self.arn),
"tags": s3_backends["global"].tagger.get_tag_dict_for_resource(self.arn),
"configuration": {
"name": self.name,
"owner": {"id": OWNER},
@ -1449,7 +1446,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
@classmethod
def get_cloudwatch_metrics(cls):
metrics = []
for name, bucket in s3_backend.buckets.items():
for name, bucket in s3_backends["global"].buckets.items():
metrics.append(
MetricDatum(
namespace="AWS/S3",
@ -1700,7 +1697,6 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
lock_mode=lock_mode,
lock_legal_status=lock_legal_status,
lock_until=lock_until,
s3_backend=s3_backend,
)
keys = [
@ -2173,4 +2169,3 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
s3_backends = BackendDict(
S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"]
)
s3_backend = s3_backends["global"]

View File

@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse
import xmltodict
from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin
from moto.core.responses import BaseResponse
from moto.core.utils import path_url
from moto.core import get_account_id
@ -51,7 +51,8 @@ from .exceptions import (
InvalidRange,
LockNotEnabled,
)
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
from .models import s3_backends
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
from .utils import bucket_name_from_url, metadata_from_headers, parse_region_from_url
from xml.dom import minidom
@ -151,14 +152,10 @@ def is_delete_keys(request, path):
)
class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def __init__(self, backend):
super().__init__()
self.backend = backend
self.method = ""
self.path = ""
self.data = {}
self.headers = {}
class S3Response(BaseResponse):
@property
def backend(self):
return s3_backends["global"]
@property
def should_autoescape(self):
@ -253,15 +250,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self.bucket_response(request, full_url, headers)
@amzn_request_id
def bucket_response(
self, request, full_url, headers
): # pylint: disable=unused-argument
self.method = request.method
self.path = self._get_path(request)
# Make a copy of request.headers because it's immutable
self.headers = dict(request.headers)
if "Host" not in self.headers:
self.headers["Host"] = urlparse(full_url).netloc
def bucket_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers, use_raw_body=True)
try:
response = self._bucket_response(request, full_url)
except S3ClientError as s3error:
@ -297,30 +287,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.data["BucketName"] = bucket_name
if hasattr(request, "body"):
# Boto
body = request.body
else:
# Flask server
body = request.data
if body is None:
body = b""
if isinstance(body, bytes):
body = body.decode("utf-8")
body = "{0}".format(body).encode("utf-8")
if method == "HEAD":
return self._bucket_response_head(bucket_name, querystring)
elif method == "GET":
return self._bucket_response_get(bucket_name, querystring)
elif method == "PUT":
return self._bucket_response_put(
request, body, region_name, bucket_name, querystring
request, region_name, bucket_name, querystring
)
elif method == "DELETE":
return self._bucket_response_delete(bucket_name, querystring)
elif method == "POST":
return self._bucket_response_post(request, body, bucket_name)
return self._bucket_response_post(request, bucket_name)
elif method == "OPTIONS":
return self._response_options(bucket_name)
else:
@ -379,28 +357,26 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
for cors_rule in bucket.cors:
if cors_rule.allowed_methods is not None:
self.headers["Access-Control-Allow-Methods"] = _to_string(
self.response_headers["Access-Control-Allow-Methods"] = _to_string(
cors_rule.allowed_methods
)
if cors_rule.allowed_origins is not None:
self.headers["Access-Control-Allow-Origin"] = _to_string(
self.response_headers["Access-Control-Allow-Origin"] = _to_string(
cors_rule.allowed_origins
)
if cors_rule.allowed_headers is not None:
self.headers["Access-Control-Allow-Headers"] = _to_string(
self.response_headers["Access-Control-Allow-Headers"] = _to_string(
cors_rule.allowed_headers
)
if cors_rule.exposed_headers is not None:
self.headers["Access-Control-Expose-Headers"] = _to_string(
self.response_headers["Access-Control-Expose-Headers"] = _to_string(
cors_rule.exposed_headers
)
if cors_rule.max_age_seconds is not None:
self.headers["Access-Control-Max-Age"] = _to_string(
self.response_headers["Access-Control-Max-Age"] = _to_string(
cors_rule.max_age_seconds
)
return self.headers
def _response_options(self, bucket_name):
# Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action()
@ -415,7 +391,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self._set_cors_headers(bucket)
return 200, self.headers, ""
return 200, self.response_headers, ""
def _bucket_response_get(self, bucket_name, querystring):
self._set_action("BUCKET", "GET", querystring)
@ -728,15 +704,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
pass
return False
def _parse_pab_config(self, body):
parsed_xml = xmltodict.parse(body)
def _parse_pab_config(self):
parsed_xml = xmltodict.parse(self.body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml
def _bucket_response_put(
self, request, body, region_name, bucket_name, querystring
):
def _bucket_response_put(self, request, region_name, bucket_name, querystring):
if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required"
@ -744,8 +718,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self._authenticate_and_authorize_s3_action()
if "object-lock" in querystring:
body_decoded = body.decode()
config = self._lock_config_from_xml(body_decoded)
config = self._lock_config_from_body()
if not self.backend.get_bucket(bucket_name).object_lock_enabled:
raise BucketMustHaveLockeEnabled
@ -760,7 +733,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 200, {}, ""
if "versioning" in querystring:
ver = re.search("<Status>([A-Za-z]+)</Status>", body.decode())
body = self.body.decode("utf-8")
ver = re.search(r"<Status>([A-Za-z]+)</Status>", body)
if ver:
self.backend.set_bucket_versioning(bucket_name, ver.group(1))
template = self.response_template(S3_BUCKET_VERSIONING)
@ -768,47 +742,45 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
return 404, {}, ""
elif "lifecycle" in querystring:
rules = xmltodict.parse(body)["LifecycleConfiguration"]["Rule"]
rules = xmltodict.parse(self.body)["LifecycleConfiguration"]["Rule"]
if not isinstance(rules, list):
# If there is only one rule, xmldict returns just the item
rules = [rules]
self.backend.put_bucket_lifecycle(bucket_name, rules)
return ""
elif "policy" in querystring:
self.backend.put_bucket_policy(bucket_name, body)
self.backend.put_bucket_policy(bucket_name, self.body)
return "True"
elif "acl" in querystring:
# Headers are first. If not set, then look at the body (consistent with the documentation):
acls = self._acl_from_headers(request.headers)
if not acls:
acls = self._acl_from_xml(body)
acls = self._acl_from_body()
self.backend.put_bucket_acl(bucket_name, acls)
return ""
elif "tagging" in querystring:
tagging = self._bucket_tagging_from_xml(body)
tagging = self._bucket_tagging_from_body()
self.backend.put_bucket_tagging(bucket_name, tagging)
return ""
elif "website" in querystring:
self.backend.set_bucket_website_configuration(bucket_name, body)
self.backend.set_bucket_website_configuration(bucket_name, self.body)
return ""
elif "cors" in querystring:
try:
self.backend.put_bucket_cors(bucket_name, self._cors_from_xml(body))
self.backend.put_bucket_cors(bucket_name, self._cors_from_body())
return ""
except KeyError:
raise MalformedXML()
elif "logging" in querystring:
try:
self.backend.put_bucket_logging(
bucket_name, self._logging_from_xml(body)
)
self.backend.put_bucket_logging(bucket_name, self._logging_from_body())
return ""
except KeyError:
raise MalformedXML()
elif "notification" in querystring:
try:
self.backend.put_bucket_notification_configuration(
bucket_name, self._notification_config_from_xml(body)
bucket_name, self._notification_config_from_body()
)
return ""
except KeyError:
@ -817,7 +789,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e
elif "accelerate" in querystring:
try:
accelerate_status = self._accelerate_config_from_xml(body)
accelerate_status = self._accelerate_config_from_body()
self.backend.put_bucket_accelerate_configuration(
bucket_name, accelerate_status
)
@ -828,7 +800,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e
elif "publicAccessBlock" in querystring:
pab_config = self._parse_pab_config(body)
pab_config = self._parse_pab_config()
self.backend.put_bucket_public_access_block(
bucket_name, pab_config["PublicAccessBlockConfiguration"]
)
@ -836,7 +808,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
elif "encryption" in querystring:
try:
self.backend.put_bucket_encryption(
bucket_name, self._encryption_config_from_xml(body)
bucket_name, self._encryption_config_from_body()
)
return ""
except KeyError:
@ -848,7 +820,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if not bucket.is_versioned:
template = self.response_template(S3_NO_VERSIONING_ENABLED)
return 400, {}, template.render(bucket_name=bucket_name)
replication_config = self._replication_config_from_xml(body)
replication_config = self._replication_config_from_xml(self.body)
self.backend.put_bucket_replication(bucket_name, replication_config)
return ""
else:
@ -858,17 +830,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# - LocationConstraint has to be specified if outside us-east-1
if (
region_name != DEFAULT_REGION_NAME
and not self._body_contains_location_constraint(body)
and not self._body_contains_location_constraint(self.body)
):
raise IllegalLocationConstraintException()
if body:
if self._create_bucket_configuration_is_empty(body):
if self.body:
if self._create_bucket_configuration_is_empty(self.body):
raise MalformedXML()
try:
forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][
"LocationConstraint"
]
forced_region = xmltodict.parse(self.body)[
"CreateBucketConfiguration"
]["LocationConstraint"]
if forced_region == DEFAULT_REGION_NAME:
raise S3ClientError(
@ -950,21 +922,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name):
def _bucket_response_post(self, request, bucket_name):
response_headers = {}
if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required"
path = self._get_path(request)
self.path = self._get_path(request)
if self.is_delete_keys(request, path, bucket_name):
if self.is_delete_keys(request, self.path, bucket_name):
self.data["Action"] = "DeleteObject"
try:
self._authenticate_and_authorize_s3_action()
return self._bucket_response_delete_keys(body, bucket_name)
return self._bucket_response_delete_keys(bucket_name)
except BucketAccessDeniedError:
return self._bucket_response_delete_keys(
body, bucket_name, authenticated=False
bucket_name, authenticated=False
)
self.data["Action"] = "PutObject"
@ -1027,9 +999,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else path_url(request.url)
)
def _bucket_response_delete_keys(self, body, bucket_name, authenticated=True):
def _bucket_response_delete_keys(self, bucket_name, authenticated=True):
template = self.response_template(S3_DELETE_KEYS_RESPONSE)
body_dict = xmltodict.parse(body, strip_whitespace=False)
body_dict = xmltodict.parse(self.body, strip_whitespace=False)
objects = body_dict["Delete"].get("Object", [])
if not isinstance(objects, list):
@ -1098,16 +1070,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return bytes(new_body)
@amzn_request_id
def key_response(
self, request, full_url, headers
): # pylint: disable=unused-argument
def key_response(self, request, full_url, headers):
# Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.method = request.method
self.path = self._get_path(request)
# Make a copy of request.headers because it's immutable
self.headers = dict(request.headers)
if "Host" not in self.headers:
self.headers["Host"] = urlparse(full_url).netloc
self.setup_class(request, full_url, headers, use_raw_body=True)
response_headers = {}
try:
@ -1300,7 +1265,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 304, response_headers, "Not Modified"
if "acl" in query:
acl = s3_backend.get_object_acl(key)
acl = self.backend.get_object_acl(key)
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(acl=acl)
if "tagging" in query:
@ -1411,7 +1376,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if not lock_enabled:
raise LockNotEnabled
version_id = query.get("VersionId")
retention = self._mode_until_from_xml(body)
retention = self._mode_until_from_body()
self.backend.put_object_retention(
bucket_name, key_name, version_id=version_id, retention=retention
)
@ -1573,9 +1538,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
return 404, response_headers, ""
def _lock_config_from_xml(self, xml):
def _lock_config_from_body(self):
response_dict = {"enabled": False, "mode": None, "days": None, "years": None}
parsed_xml = xmltodict.parse(xml)
parsed_xml = xmltodict.parse(self.body)
enabled = (
parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled"
)
@ -1596,8 +1561,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return response_dict
def _acl_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _acl_from_body(self):
parsed_xml = xmltodict.parse(self.body)
if not parsed_xml.get("AccessControlPolicy"):
raise MalformedACLError()
@ -1713,8 +1678,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return tags
def _bucket_tagging_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _bucket_tagging_from_body(self):
parsed_xml = xmltodict.parse(self.body)
tags = {}
# Optional if no tags are being sent:
@ -1737,16 +1702,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return tags
def _cors_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _cors_from_body(self):
parsed_xml = xmltodict.parse(self.body)
if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list):
return [cors for cors in parsed_xml["CORSConfiguration"]["CORSRule"]]
return [parsed_xml["CORSConfiguration"]["CORSRule"]]
def _mode_until_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _mode_until_from_body(self):
parsed_xml = xmltodict.parse(self.body)
return (
parsed_xml.get("Retention", None).get("Mode", None),
parsed_xml.get("Retention", None).get("RetainUntilDate", None),
@ -1756,8 +1721,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
parsed_xml = xmltodict.parse(xml)
return parsed_xml["LegalHold"]["Status"]
def _encryption_config_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _encryption_config_from_body(self):
parsed_xml = xmltodict.parse(self.body)
if (
not parsed_xml["ServerSideEncryptionConfiguration"].get("Rule")
@ -1772,8 +1737,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["ServerSideEncryptionConfiguration"]
def _logging_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _logging_from_body(self):
parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"):
return {}
@ -1817,8 +1782,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]
def _notification_config_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _notification_config_from_body(self):
parsed_xml = xmltodict.parse(self.body)
if not len(parsed_xml["NotificationConfiguration"]):
return {}
@ -1892,8 +1857,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return parsed_xml["NotificationConfiguration"]
def _accelerate_config_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
def _accelerate_config_from_body(self):
parsed_xml = xmltodict.parse(self.body)
config = parsed_xml["AccelerateConfiguration"]
return config["Status"]
@ -2028,7 +1993,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return False
S3ResponseInstance = ResponseObject(s3_backend)
S3ResponseInstance = S3Response()
S3_ALL_BUCKETS = """<ListAllMyBucketsResult xmlns="http://s3.amazonaws.com/doc/2006-03-01">
<Owner>

View File

@ -1,6 +1,5 @@
"""s3control module initialization; sets value for base decorator."""
from .models import s3control_backend
from .models import s3control_backends
from ..core.models import base_decorator
s3control_backends = {"global": s3control_backend}
mock_s3control = base_decorator(s3control_backends)

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from datetime import datetime
from moto.core import get_account_id, BaseBackend, BaseModel
from moto.core.utils import get_random_hex
from moto.core.utils import get_random_hex, BackendDict
from moto.s3.exceptions import (
WrongPublicAccessBlockAccountIdError,
NoSuchPublicAccessBlockConfiguration,
@ -43,16 +43,11 @@ class AccessPoint(BaseModel):
class S3ControlBackend(BaseBackend):
def __init__(self, region_name=None):
self.region_name = region_name
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.public_access_block = None
self.access_points = defaultdict(dict)
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def get_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto:
if account_id != get_account_id():
@ -129,4 +124,9 @@ class S3ControlBackend(BaseBackend):
return True
s3control_backend = S3ControlBackend()
s3control_backends = BackendDict(
S3ControlBackend,
"s3control",
use_boto3_regions=False,
additional_regions=["global"],
)

View File

@ -5,10 +5,14 @@ from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from .models import s3control_backend
from .models import s3control_backends
class S3ControlResponse(BaseResponse):
@property
def backend(self):
return s3control_backends["global"]
@amzn_request_id
def public_access_block(
self, request, full_url, headers
@ -25,7 +29,7 @@ class S3ControlResponse(BaseResponse):
def get_public_access_block(self, request):
account_id = request.headers.get("x-amz-account-id")
public_block_config = s3control_backend.get_public_access_block(
public_block_config = self.backend.get_public_access_block(
account_id=account_id
)
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
@ -35,14 +39,14 @@ class S3ControlResponse(BaseResponse):
account_id = request.headers.get("x-amz-account-id")
data = request.body if hasattr(request, "body") else request.data
pab_config = self._parse_pab_config(data)
s3control_backend.put_public_access_block(
self.backend.put_public_access_block(
account_id, pab_config["PublicAccessBlockConfiguration"]
)
return 201, {}, json.dumps({})
def delete_public_access_block(self, request):
account_id = request.headers.get("x-amz-account-id")
s3control_backend.delete_public_access_block(account_id=account_id)
self.backend.delete_public_access_block(account_id=account_id)
return 204, {}, json.dumps({})
def _parse_pab_config(self, body):
@ -82,7 +86,7 @@ class S3ControlResponse(BaseResponse):
bucket = params["Bucket"]
vpc_configuration = params.get("VpcConfiguration")
public_access_block_configuration = params.get("PublicAccessBlockConfiguration")
access_point = s3control_backend.create_access_point(
access_point = self.backend.create_access_point(
account_id=account_id,
name=name,
bucket=bucket,
@ -95,38 +99,36 @@ class S3ControlResponse(BaseResponse):
def get_access_point(self, full_url):
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
access_point = s3control_backend.get_access_point(
account_id=account_id, name=name
)
access_point = self.backend.get_access_point(account_id=account_id, name=name)
template = self.response_template(GET_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point)
def delete_access_point(self, full_url):
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
s3control_backend.delete_access_point(account_id=account_id, name=name)
self.backend.delete_access_point(account_id=account_id, name=name)
return 204, {}, ""
def create_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url)
params = xmltodict.parse(self.body)
policy = params["PutAccessPointPolicyRequest"]["Policy"]
s3control_backend.create_access_point_policy(account_id, name, policy)
self.backend.create_access_point_policy(account_id, name, policy)
return 200, {}, ""
def get_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url)
policy = s3control_backend.get_access_point_policy(account_id, name)
policy = self.backend.get_access_point_policy(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE)
return 200, {}, template.render(policy=policy)
def delete_access_point_policy(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url)
s3control_backend.delete_access_point_policy(account_id=account_id, name=name)
self.backend.delete_access_point_policy(account_id=account_id, name=name)
return 204, {}, ""
def get_access_point_policy_status(self, full_url):
account_id, name = self._get_accountid_and_name_from_policy(full_url)
s3control_backend.get_access_point_policy_status(account_id, name)
self.backend.get_access_point_policy_status(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE)
return 200, {}, template.render()

View File

@ -6,6 +6,7 @@ from email.mime.base import MIMEBase
from email.utils import parseaddr
from email.mime.multipart import MIMEMultipart
from email.encoders import encode_7or8bit
from typing import Mapping
from moto.core import BaseBackend, BaseModel
from moto.core.utils import BackendDict
@ -537,7 +538,6 @@ class SESBackend(BaseBackend):
return attributes_by_identity
ses_backends = BackendDict(
ses_backends: Mapping[str, SESBackend] = BackendDict(
SESBackend, "ses", use_boto3_regions=False, additional_regions=["global"]
)
ses_backend = ses_backends["global"]

View File

@ -1,48 +1,52 @@
import base64
from moto.core.responses import BaseResponse
from .models import ses_backend
from .models import ses_backends
from datetime import datetime
class EmailResponse(BaseResponse):
@property
def backend(self):
return ses_backends["global"]
def verify_email_identity(self):
address = self.querystring.get("EmailAddress")[0]
ses_backend.verify_email_identity(address)
self.backend.verify_email_identity(address)
template = self.response_template(VERIFY_EMAIL_IDENTITY)
return template.render()
def verify_email_address(self):
address = self.querystring.get("EmailAddress")[0]
ses_backend.verify_email_address(address)
self.backend.verify_email_address(address)
template = self.response_template(VERIFY_EMAIL_ADDRESS)
return template.render()
def list_identities(self):
identities = ses_backend.list_identities()
identities = self.backend.list_identities()
template = self.response_template(LIST_IDENTITIES_RESPONSE)
return template.render(identities=identities)
def list_verified_email_addresses(self):
email_addresses = ses_backend.list_verified_email_addresses()
email_addresses = self.backend.list_verified_email_addresses()
template = self.response_template(LIST_VERIFIED_EMAIL_RESPONSE)
return template.render(email_addresses=email_addresses)
def verify_domain_dkim(self):
domain = self.querystring.get("Domain")[0]
ses_backend.verify_domain(domain)
self.backend.verify_domain(domain)
template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE)
return template.render()
def verify_domain_identity(self):
domain = self.querystring.get("Domain")[0]
ses_backend.verify_domain(domain)
self.backend.verify_domain(domain)
template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE)
return template.render()
def delete_identity(self):
domain = self.querystring.get("Identity")[0]
ses_backend.delete_identity(domain)
self.backend.delete_identity(domain)
template = self.response_template(DELETE_IDENTITY_RESPONSE)
return template.render()
@ -63,7 +67,7 @@ class EmailResponse(BaseResponse):
break
destinations[dest_type].append(address[0])
message = ses_backend.send_email(
message = self.backend.send_email(
source, subject, body, destinations, self.region
)
template = self.response_template(SEND_EMAIL_RESPONSE)
@ -84,7 +88,7 @@ class EmailResponse(BaseResponse):
break
destinations[dest_type].append(address[0])
message = ses_backend.send_templated_email(
message = self.backend.send_templated_email(
source, template, template_data, destinations, self.region
)
template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE)
@ -107,27 +111,27 @@ class EmailResponse(BaseResponse):
break
destinations.append(address[0])
message = ses_backend.send_raw_email(
message = self.backend.send_raw_email(
source, destinations, raw_data, self.region
)
template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message)
def get_send_quota(self):
quota = ses_backend.get_send_quota()
quota = self.backend.get_send_quota()
template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota)
def get_identity_notification_attributes(self):
identities = self._get_params()["Identities"]
identities = ses_backend.get_identity_notification_attributes(identities)
identities = self.backend.get_identity_notification_attributes(identities)
template = self.response_template(GET_IDENTITY_NOTIFICATION_ATTRIBUTES)
return template.render(identities=identities)
def set_identity_feedback_forwarding_enabled(self):
identity = self._get_param("Identity")
enabled = self._get_bool_param("ForwardingEnabled")
ses_backend.set_identity_feedback_forwarding_enabled(identity, enabled)
self.backend.set_identity_feedback_forwarding_enabled(identity, enabled)
template = self.response_template(SET_IDENTITY_FORWARDING_ENABLED_RESPONSE)
return template.render()
@ -139,18 +143,18 @@ class EmailResponse(BaseResponse):
if sns_topic:
sns_topic = sns_topic[0]
ses_backend.set_identity_notification_topic(identity, not_type, sns_topic)
self.backend.set_identity_notification_topic(identity, not_type, sns_topic)
template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
return template.render()
def get_send_statistics(self):
statistics = ses_backend.get_send_statistics()
statistics = self.backend.get_send_statistics()
template = self.response_template(GET_SEND_STATISTICS)
return template.render(all_statistics=[statistics])
def create_configuration_set(self):
configuration_set_name = self.querystring.get("ConfigurationSet.Name")[0]
ses_backend.create_configuration_set(
self.backend.create_configuration_set(
configuration_set_name=configuration_set_name
)
template = self.response_template(CREATE_CONFIGURATION_SET)
@ -177,7 +181,7 @@ class EmailResponse(BaseResponse):
"SNSDestination": event_topic_arn,
}
ses_backend.create_configuration_set_event_destination(
self.backend.create_configuration_set_event_destination(
configuration_set_name=configuration_set_name,
event_destination=event_destination,
)
@ -193,7 +197,7 @@ class EmailResponse(BaseResponse):
template_info["template_name"] = template_data.get("._name", "")
template_info["subject_part"] = template_data.get("._subject_part", "")
template_info["Timestamp"] = datetime.utcnow()
ses_backend.add_template(template_info=template_info)
self.backend.add_template(template_info=template_info)
template = self.response_template(CREATE_TEMPLATE)
return template.render()
@ -205,44 +209,44 @@ class EmailResponse(BaseResponse):
template_info["template_name"] = template_data.get("._name", "")
template_info["subject_part"] = template_data.get("._subject_part", "")
template_info["Timestamp"] = datetime.utcnow()
ses_backend.update_template(template_info=template_info)
self.backend.update_template(template_info=template_info)
template = self.response_template(UPDATE_TEMPLATE)
return template.render()
def get_template(self):
template_name = self._get_param("TemplateName")
template_data = ses_backend.get_template(template_name)
template_data = self.backend.get_template(template_name)
template = self.response_template(GET_TEMPLATE)
return template.render(template_data=template_data)
def list_templates(self):
email_templates = ses_backend.list_templates()
email_templates = self.backend.list_templates()
template = self.response_template(LIST_TEMPLATES)
return template.render(templates=email_templates)
def test_render_template(self):
render_info = self._get_dict_param("Template")
rendered_template = ses_backend.render_template(render_info)
rendered_template = self.backend.render_template(render_info)
template = self.response_template(RENDER_TEMPLATE)
return template.render(template=rendered_template)
def create_receipt_rule_set(self):
rule_set_name = self._get_param("RuleSetName")
ses_backend.create_receipt_rule_set(rule_set_name)
self.backend.create_receipt_rule_set(rule_set_name)
template = self.response_template(CREATE_RECEIPT_RULE_SET)
return template.render()
def create_receipt_rule(self):
rule_set_name = self._get_param("RuleSetName")
rule = self._get_dict_param("Rule.")
ses_backend.create_receipt_rule(rule_set_name, rule)
self.backend.create_receipt_rule(rule_set_name, rule)
template = self.response_template(CREATE_RECEIPT_RULE)
return template.render()
def describe_receipt_rule_set(self):
rule_set_name = self._get_param("RuleSetName")
rule_set = ses_backend.describe_receipt_rule_set(rule_set_name)
rule_set = self.backend.describe_receipt_rule_set(rule_set_name)
for i, rule in enumerate(rule_set):
formatted_rule = {}
@ -260,7 +264,7 @@ class EmailResponse(BaseResponse):
rule_set_name = self._get_param("RuleSetName")
rule_name = self._get_param("RuleName")
receipt_rule = ses_backend.describe_receipt_rule(rule_set_name, rule_name)
receipt_rule = self.backend.describe_receipt_rule(rule_set_name, rule_name)
rule = {}
@ -274,7 +278,7 @@ class EmailResponse(BaseResponse):
rule_set_name = self._get_param("RuleSetName")
rule = self._get_dict_param("Rule.")
ses_backend.update_receipt_rule(rule_set_name, rule)
self.backend.update_receipt_rule(rule_set_name, rule)
template = self.response_template(UPDATE_RECEIPT_RULE)
return template.render()
@ -284,7 +288,7 @@ class EmailResponse(BaseResponse):
mail_from_domain = self._get_param("MailFromDomain")
behavior_on_mx_failure = self._get_param("BehaviorOnMXFailure")
ses_backend.set_identity_mail_from_domain(
self.backend.set_identity_mail_from_domain(
identity, mail_from_domain, behavior_on_mx_failure
)
@ -293,7 +297,7 @@ class EmailResponse(BaseResponse):
def get_identity_mail_from_domain_attributes(self):
identities = self._get_multi_param("Identities.member.")
identities = ses_backend.get_identity_mail_from_domain_attributes(identities)
identities = self.backend.get_identity_mail_from_domain_attributes(identities)
template = self.response_template(GET_IDENTITY_MAIL_FROM_DOMAIN_ATTRIBUTES)
return template.render(identities=identities)
@ -301,7 +305,7 @@ class EmailResponse(BaseResponse):
def get_identity_verification_attributes(self):
params = self._get_params()
identities = params.get("Identities")
verification_attributes = ses_backend.get_identity_verification_attributes(
verification_attributes = self.backend.get_identity_verification_attributes(
identities=identities,
)

View File

@ -11,6 +11,7 @@ from moto.sts.utils import (
random_assumed_role_id,
DEFAULT_STS_SESSION_DURATION,
)
from typing import Mapping
class Token(BaseModel):
@ -138,7 +139,6 @@ class STSBackend(BaseBackend):
pass
sts_backends = BackendDict(
sts_backends: Mapping[str, STSBackend] = BackendDict(
STSBackend, "sts", use_boto3_regions=False, additional_regions=["global"]
)
sts_backend = sts_backends["global"]

View File

@ -1,16 +1,20 @@
from moto.core.responses import BaseResponse
from moto.core import get_account_id
from moto.iam import iam_backend
from moto.iam import iam_backends
from .exceptions import STSValidationError
from .models import sts_backend
from .models import sts_backends
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse):
@property
def backend(self):
return sts_backends["global"]
def get_session_token(self):
duration = int(self.querystring.get("DurationSeconds", [43200])[0])
token = sts_backend.get_session_token(duration=duration)
token = self.backend.get_session_token(duration=duration)
template = self.response_template(GET_SESSION_TOKEN_RESPONSE)
return template.render(token=token)
@ -27,7 +31,7 @@ class TokenResponse(BaseResponse):
)
name = self.querystring.get("Name")[0]
token = sts_backend.get_federation_token(duration=duration, name=name)
token = self.backend.get_federation_token(duration=duration, name=name)
template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE)
return template.render(token=token, account_id=get_account_id())
@ -39,7 +43,7 @@ class TokenResponse(BaseResponse):
duration = int(self.querystring.get("DurationSeconds", [3600])[0])
external_id = self.querystring.get("ExternalId", [None])[0]
role = sts_backend.assume_role(
role = self.backend.assume_role(
role_session_name=role_session_name,
role_arn=role_arn,
policy=policy,
@ -57,7 +61,7 @@ class TokenResponse(BaseResponse):
duration = int(self.querystring.get("DurationSeconds", [3600])[0])
external_id = self.querystring.get("ExternalId", [None])[0]
role = sts_backend.assume_role_with_web_identity(
role = self.backend.assume_role_with_web_identity(
role_session_name=role_session_name,
role_arn=role_arn,
policy=policy,
@ -72,7 +76,7 @@ class TokenResponse(BaseResponse):
principal_arn = self.querystring.get("PrincipalArn")[0]
saml_assertion = self.querystring.get("SAMLAssertion")[0]
role = sts_backend.assume_role_with_saml(
role = self.backend.assume_role_with_saml(
role_arn=role_arn,
principal_arn=principal_arn,
saml_assertion=saml_assertion,
@ -88,12 +92,12 @@ class TokenResponse(BaseResponse):
arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=get_account_id())
access_key_id = self.get_current_user()
assumed_role = sts_backend.get_assumed_role_from_access_key(access_key_id)
assumed_role = self.backend.get_assumed_role_from_access_key(access_key_id)
if assumed_role:
user_id = assumed_role.user_id
arn = assumed_role.arn
user = iam_backend.get_user_from_access_key_id(access_key_id)
user = iam_backends["global"].get_user_from_access_key_id(access_key_id)
if user:
user_id = user.id
arn = user.arn

View File

@ -45,4 +45,4 @@ def test_domain_dispatched_with_service():
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")
backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"})
keys = set(backend_app.view_functions.keys())
keys.should.contain("ResponseObject.key_response")
keys.should.contain("S3Response.key_response")

View File

@ -647,7 +647,10 @@ def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=target_id, NumberOfBytes=32)
resp = client.generate_data_key(KeyId=target_id, NumberOfBytes=32)
resp.should.have.key("KeyId").equals(
f"arn:aws:kms:us-east-1:123456789012:key/{key_id}"
)
@mock_kms

View File

@ -75,6 +75,9 @@ def test_s3_server_ignore_subdomain_for_bucketnames():
def test_s3_server_bucket_versioning():
test_client = authenticated_client()
res = test_client.put("/", "http://foobaz.localhost:5000/")
res.status_code.should.equal(200)
# Just enough XML to enable versioning
body = "<Status>Enabled</Status>"
res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body)