Simplify error handling (#4936)

This commit is contained in:
Bert Blommers 2022-03-15 15:42:46 -01:00 committed by GitHub
parent 5ae0ced349
commit 67ab7f857a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 483 additions and 878 deletions

View File

@ -2,7 +2,7 @@ import json
import base64 import base64
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import acm_backends, AWSError, AWSValidationException from .models import acm_backends, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse): class AWSCertificateManagerResponse(BaseResponse):
@ -37,10 +37,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: self.acm_backend.add_tags_to_certificate(arn, tags)
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
return "" return ""
@ -54,10 +51,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: self.acm_backend.delete_certificate(arn)
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
return "" return ""
@ -71,10 +65,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: cert_bundle = self.acm_backend.get_certificate(arn)
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
return json.dumps(cert_bundle.describe()) return json.dumps(cert_bundle.describe())
@ -88,10 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: cert_bundle = self.acm_backend.get_certificate(arn)
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = { result = {
"Certificate": cert_bundle.cert.decode(), "Certificate": cert_bundle.cert.decode(),
@ -123,29 +111,26 @@ class AWSCertificateManagerResponse(BaseResponse):
try: try:
certificate = base64.standard_b64decode(certificate) certificate = base64.standard_b64decode(certificate)
except Exception: except Exception:
return AWSValidationException( raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid." "The certificate is not PEM-encoded or is not valid."
).response() )
try: try:
private_key = base64.standard_b64decode(private_key) private_key = base64.standard_b64decode(private_key)
except Exception: except Exception:
return AWSValidationException( raise AWSValidationException(
"The private key is not PEM-encoded or is not valid." "The private key is not PEM-encoded or is not valid."
).response() )
if chain is not None: if chain is not None:
try: try:
chain = base64.standard_b64decode(chain) chain = base64.standard_b64decode(chain)
except Exception: except Exception:
return AWSValidationException( raise AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid." "The certificate chain is not PEM-encoded or is not valid."
).response() )
try: arn = self.acm_backend.import_cert(
arn = self.acm_backend.import_cert( certificate, private_key, chain=chain, arn=current_arn, tags=tags
certificate, private_key, chain=chain, arn=current_arn, tags=tags )
)
except AWSError as err:
return err.response()
return json.dumps({"CertificateArn": arn}) return json.dumps({"CertificateArn": arn})
@ -170,10 +155,7 @@ class AWSCertificateManagerResponse(BaseResponse):
msg = "A required parameter for the specified action is not supplied." msg = "A required parameter for the specified action is not supplied."
return {"__type": "MissingParameter", "message": msg}, dict(status=400) return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try: cert_bundle = self.acm_backend.get_certificate(arn)
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {"Tags": []} result = {"Tags": []}
# Tag "objects" can not contain the Value part # Tag "objects" can not contain the Value part
@ -196,10 +178,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: self.acm_backend.remove_tags_from_certificate(arn, tags)
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
return "" return ""
@ -219,15 +198,12 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: arn = self.acm_backend.request_certificate(
arn = self.acm_backend.request_certificate( domain_name,
domain_name, idempotency_token,
idempotency_token, subject_alt_names,
subject_alt_names, tags,
tags, )
)
except AWSError as err:
return err.response()
return json.dumps({"CertificateArn": arn}) return json.dumps({"CertificateArn": arn})
@ -247,16 +223,12 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: cert_bundle = self.acm_backend.get_certificate(arn)
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain: if cert_bundle.common_name != domain:
msg = "Parameter Domain does not match certificate domain" msg = "Parameter Domain does not match certificate domain"
_type = "InvalidDomainValidationOptionsException" _type = "InvalidDomainValidationOptionsException"
return json.dumps({"__type": _type, "message": msg}), dict(status=400) return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err:
return err.response()
return "" return ""
@ -271,20 +243,17 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400), dict(status=400),
) )
try: (
( certificate,
certificate, certificate_chain,
certificate_chain, private_key,
private_key, ) = self.acm_backend.export_certificate(
) = self.acm_backend.export_certificate( certificate_arn=certificate_arn, passphrase=passphrase
certificate_arn=certificate_arn, passphrase=passphrase )
return json.dumps(
dict(
Certificate=certificate,
CertificateChain=certificate_chain,
PrivateKey=private_key,
) )
return json.dumps( )
dict(
Certificate=certificate,
CertificateChain=certificate_chain,
PrivateKey=private_key,
)
)
except AWSError as err:
return err.response()

View File

@ -1,28 +1,16 @@
import json import json
from functools import wraps
from urllib.parse import unquote from urllib.parse import unquote
from moto.utilities.utils import merge_multiple_dicts from moto.utilities.utils import merge_multiple_dicts
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import apigateway_backends from .models import apigateway_backends
from .exceptions import ApiGatewayException, InvalidRequestInput from .exceptions import InvalidRequestInput
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"] AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"] ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except ApiGatewayException as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400): def error(self, type_, message, status=400):
headers = self.response_headers or {} headers = self.response_headers or {}
@ -103,7 +91,6 @@ class APIGatewayResponse(BaseResponse):
value = op["value"] value = op["value"]
return self.__validate_api_key_source(value) return self.__validate_api_key_source(value)
@error_handler
def restapis_individual(self, request, full_url, headers): def restapis_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -133,7 +120,6 @@ class APIGatewayResponse(BaseResponse):
json.dumps({"item": [resource.to_dict() for resource in resources]}), json.dumps({"item": [resource.to_dict() for resource in resources]}),
) )
@error_handler
def gateway_response(self, request, full_url, headers): def gateway_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
@ -148,7 +134,6 @@ class APIGatewayResponse(BaseResponse):
if request.method == "GET": if request.method == "GET":
return self.get_gateway_responses() return self.get_gateway_responses()
@error_handler
def resource_individual(self, request, full_url, headers): def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -163,7 +148,6 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.delete_resource(function_id, resource_id) resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict()) return 200, {}, json.dumps(resource.to_dict())
@error_handler
def resource_methods(self, request, full_url, headers): def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -295,7 +279,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(authorizer_response) return 200, {}, json.dumps(authorizer_response)
@error_handler
def request_validators(self, request, full_url, headers): def request_validators(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -316,7 +299,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(validator) return 200, {}, json.dumps(validator)
@error_handler
def request_validator_individual(self, request, full_url, headers): def request_validator_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -336,7 +318,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(validator) return 200, {}, json.dumps(validator)
@error_handler
def authorizers(self, request, full_url, headers): def authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -355,7 +336,6 @@ class APIGatewayResponse(BaseResponse):
return 202, {}, "{}" return 202, {}, "{}"
return 200, {}, json.dumps(authorizer_response) return 200, {}, json.dumps(authorizer_response)
@error_handler
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -406,7 +386,6 @@ class APIGatewayResponse(BaseResponse):
stage["tags"].pop(tag, None) stage["tags"].pop(tag, None)
return 200, {}, json.dumps({"item": ""}) return 200, {}, json.dumps({"item": ""})
@error_handler
def stages(self, request, full_url, headers): def stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -426,7 +405,6 @@ class APIGatewayResponse(BaseResponse):
return 202, {}, "{}" return 202, {}, "{}"
return 200, {}, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
@error_handler
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -474,7 +452,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
@error_handler
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -509,7 +486,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
@error_handler
def deployments(self, request, full_url, headers): def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -526,7 +502,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
@error_handler
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -540,7 +515,6 @@ class APIGatewayResponse(BaseResponse):
deployment = self.backend.delete_deployment(function_id, deployment_id) deployment = self.backend.delete_deployment(function_id, deployment_id)
return 202, {}, json.dumps(deployment) return 202, {}, json.dumps(deployment)
@error_handler
def apikeys(self, request, full_url, headers): def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -553,7 +527,6 @@ class APIGatewayResponse(BaseResponse):
apikeys_response = self.backend.get_api_keys(include_values=include_values) apikeys_response = self.backend.get_api_keys(include_values=include_values)
return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps({"item": apikeys_response})
@error_handler
def apikey_individual(self, request, full_url, headers): def apikey_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -585,7 +558,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_individual(self, request, full_url, headers): def usage_plan_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -603,7 +575,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_keys(self, request, full_url, headers): def usage_plan_keys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -619,7 +590,6 @@ class APIGatewayResponse(BaseResponse):
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@error_handler
def usage_plan_key_individual(self, request, full_url, headers): def usage_plan_key_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -635,7 +605,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def domain_names(self, request, full_url, headers): def domain_names(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -672,7 +641,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(domain_name_resp) return 200, {}, json.dumps(domain_name_resp)
@error_handler
def domain_name_induvidual(self, request, full_url, headers): def domain_name_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -695,7 +663,6 @@ class APIGatewayResponse(BaseResponse):
return 404, {}, json.dumps({"error": msg}) return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names) return 200, {}, json.dumps(domain_names)
@error_handler
def models(self, request, full_url, headers): def models(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0] rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -723,7 +690,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(model) return 200, {}, json.dumps(model)
@error_handler
def model_induvidual(self, request, full_url, headers): def model_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -734,7 +700,6 @@ class APIGatewayResponse(BaseResponse):
model_info = self.backend.get_model(rest_api_id, model_name) model_info = self.backend.get_model(rest_api_id, model_name)
return 200, {}, json.dumps(model_info) return 200, {}, json.dumps(model_info)
@error_handler
def base_path_mappings(self, request, full_url, headers): def base_path_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -754,7 +719,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 201, {}, json.dumps(base_path_mapping_resp) return 201, {}, json.dumps(base_path_mapping_resp)
@error_handler
def base_path_mapping_individual(self, request, full_url, headers): def base_path_mapping_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -778,7 +742,6 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(base_path_mapping) return 200, {}, json.dumps(base_path_mapping)
@error_handler
def vpc_link(self, request, full_url, headers): def vpc_link(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")

View File

@ -1,25 +1,13 @@
"""Handles incoming apigatewayv2 requests, invokes methods, returns responses.""" """Handles incoming apigatewayv2 requests, invokes methods, returns responses."""
import json import json
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from urllib.parse import unquote from urllib.parse import unquote
from .exceptions import APIGatewayV2Error, UnknownProtocol from .exceptions import UnknownProtocol
from .models import apigatewayv2_backends from .models import apigatewayv2_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except APIGatewayV2Error as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class ApiGatewayV2Response(BaseResponse): class ApiGatewayV2Response(BaseResponse):
"""Handler for ApiGatewayV2 requests and responses.""" """Handler for ApiGatewayV2 requests and responses."""
@ -28,7 +16,6 @@ class ApiGatewayV2Response(BaseResponse):
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return apigatewayv2_backends[self.region] return apigatewayv2_backends[self.region]
@error_handler
def apis(self, request, full_url, headers): def apis(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -37,7 +24,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "GET": if self.method == "GET":
return self.get_apis() return self.get_apis()
@error_handler
def api(self, request, full_url, headers): def api(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -50,7 +36,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE": if self.method == "DELETE":
return self.delete_api() return self.delete_api()
@error_handler
def authorizer(self, request, full_url, headers): def authorizer(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -67,7 +52,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_authorizer() return self.create_authorizer()
@error_handler
def cors(self, request, full_url, headers): def cors(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -80,7 +64,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE": if self.method == "DELETE":
return self.delete_route_request_parameter() return self.delete_route_request_parameter()
@error_handler
def model(self, request, full_url, headers): def model(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -97,7 +80,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_model() return self.create_model()
@error_handler
def integration(self, request, full_url, headers): def integration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -108,7 +90,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "PATCH": if self.method == "PATCH":
return self.update_integration() return self.update_integration()
@error_handler
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -117,7 +98,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_integration() return self.create_integration()
@error_handler
def integration_response(self, request, full_url, headers): def integration_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -136,7 +116,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_integration_response() return self.create_integration_response()
@error_handler
def route(self, request, full_url, headers): def route(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -147,7 +126,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "PATCH": if self.method == "PATCH":
return self.update_route() return self.update_route()
@error_handler
def routes(self, request, full_url, headers): def routes(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -156,7 +134,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_route() return self.create_route()
@error_handler
def route_response(self, request, full_url, headers): def route_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -171,7 +148,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST": if self.method == "POST":
return self.create_route_response() return self.create_route_response()
@error_handler
def tags(self, request, full_url, headers): def tags(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -182,7 +158,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE": if self.method == "DELETE":
return self.untag_resource() return self.untag_resource()
@error_handler
def vpc_link(self, request, full_url, headers): def vpc_link(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)

View File

@ -1,24 +1,11 @@
"""Handles incoming appsync requests, invokes methods, returns responses.""" """Handles incoming appsync requests, invokes methods, returns responses."""
import json import json
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from urllib.parse import unquote from urllib.parse import unquote
from .exceptions import AppSyncExceptions
from .models import appsync_backends from .models import appsync_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except AppSyncExceptions as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class AppSyncResponse(BaseResponse): class AppSyncResponse(BaseResponse):
"""Handler for AppSync requests and responses.""" """Handler for AppSync requests and responses."""
@ -34,7 +21,6 @@ class AppSyncResponse(BaseResponse):
if request.method == "GET": if request.method == "GET":
return self.list_graphql_apis() return self.list_graphql_apis()
@error_handler
def graph_ql_individual(self, request, full_url, headers): def graph_ql_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":

View File

@ -3,24 +3,11 @@ import sys
from urllib.parse import unquote from urllib.parse import unquote
from functools import wraps
from moto.core.utils import amz_crc32, amzn_request_id, path_url from moto.core.utils import amz_crc32, amzn_request_id, path_url
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import LambdaClientError
from .models import lambda_backends from .models import lambda_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except LambdaClientError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class LambdaResponse(BaseResponse): class LambdaResponse(BaseResponse):
@property @property
def json_body(self): def json_body(self):
@ -39,7 +26,6 @@ class LambdaResponse(BaseResponse):
""" """
return lambda_backends[self.region] return lambda_backends[self.region]
@error_handler
def root(self, request, full_url, headers): def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
@ -95,7 +81,6 @@ class LambdaResponse(BaseResponse):
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@error_handler
def versions(self, request, full_url, headers): def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
@ -139,7 +124,6 @@ class LambdaResponse(BaseResponse):
else: else:
raise ValueError("Cannot handle {0} request".format(request.method)) raise ValueError("Cannot handle {0} request".format(request.method))
@error_handler
def policy(self, request, full_url, headers): def policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":

View File

@ -2,8 +2,6 @@ from moto.core.responses import BaseResponse
from .models import batch_backends from .models import batch_backends
from urllib.parse import urlsplit, unquote from urllib.parse import urlsplit, unquote
from .exceptions import AWSError
import json import json
@ -48,16 +46,13 @@ class BatchResponse(BaseResponse):
state = self._get_param("state") state = self._get_param("state")
_type = self._get_param("type") _type = self._get_param("type")
try: name, arn = self.batch_backend.create_compute_environment(
name, arn = self.batch_backend.create_compute_environment( compute_environment_name=compute_env_name,
compute_environment_name=compute_env_name, _type=_type,
_type=_type, state=state,
state=state, compute_resources=compute_resource,
compute_resources=compute_resource, service_role=service_role,
service_role=service_role, )
)
except AWSError as err:
return err.response()
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
@ -76,10 +71,7 @@ class BatchResponse(BaseResponse):
def deletecomputeenvironment(self): def deletecomputeenvironment(self):
compute_environment = self._get_param("computeEnvironment") compute_environment = self._get_param("computeEnvironment")
try: self.batch_backend.delete_compute_environment(compute_environment)
self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err:
return err.response()
return "" return ""
@ -90,15 +82,12 @@ class BatchResponse(BaseResponse):
service_role = self._get_param("serviceRole") service_role = self._get_param("serviceRole")
state = self._get_param("state") state = self._get_param("state")
try: name, arn = self.batch_backend.update_compute_environment(
name, arn = self.batch_backend.update_compute_environment( compute_environment_name=compute_env_name,
compute_environment_name=compute_env_name, compute_resources=compute_resource,
compute_resources=compute_resource, service_role=service_role,
service_role=service_role, state=state,
state=state, )
)
except AWSError as err:
return err.response()
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
@ -112,16 +101,13 @@ class BatchResponse(BaseResponse):
state = self._get_param("state") state = self._get_param("state")
tags = self._get_param("tags") tags = self._get_param("tags")
try: name, arn = self.batch_backend.create_job_queue(
name, arn = self.batch_backend.create_job_queue( queue_name=queue_name,
queue_name=queue_name, priority=priority,
priority=priority, state=state,
state=state, compute_env_order=compute_env_order,
compute_env_order=compute_env_order, tags=tags,
tags=tags, )
)
except AWSError as err:
return err.response()
result = {"jobQueueArn": arn, "jobQueueName": name} result = {"jobQueueArn": arn, "jobQueueName": name}
@ -143,15 +129,12 @@ class BatchResponse(BaseResponse):
priority = self._get_param("priority") priority = self._get_param("priority")
state = self._get_param("state") state = self._get_param("state")
try: name, arn = self.batch_backend.update_job_queue(
name, arn = self.batch_backend.update_job_queue( queue_name=queue_name,
queue_name=queue_name, priority=priority,
priority=priority, state=state,
state=state, compute_env_order=compute_env_order,
compute_env_order=compute_env_order, )
)
except AWSError as err:
return err.response()
result = {"jobQueueArn": arn, "jobQueueName": name} result = {"jobQueueArn": arn, "jobQueueName": name}
@ -176,20 +159,17 @@ class BatchResponse(BaseResponse):
timeout = self._get_param("timeout") timeout = self._get_param("timeout")
platform_capabilities = self._get_param("platformCapabilities") platform_capabilities = self._get_param("platformCapabilities")
propagate_tags = self._get_param("propagateTags") propagate_tags = self._get_param("propagateTags")
try: name, arn, revision = self.batch_backend.register_job_definition(
name, arn, revision = self.batch_backend.register_job_definition( def_name=def_name,
def_name=def_name, parameters=parameters,
parameters=parameters, _type=_type,
_type=_type, tags=tags,
tags=tags, retry_strategy=retry_strategy,
retry_strategy=retry_strategy, container_properties=container_properties,
container_properties=container_properties, timeout=timeout,
timeout=timeout, platform_capabilities=platform_capabilities,
platform_capabilities=platform_capabilities, propagate_tags=propagate_tags,
propagate_tags=propagate_tags, )
)
except AWSError as err:
return err.response()
result = { result = {
"jobDefinitionArn": arn, "jobDefinitionArn": arn,
@ -229,17 +209,14 @@ class BatchResponse(BaseResponse):
job_queue = self._get_param("jobQueue") job_queue = self._get_param("jobQueue")
timeout = self._get_param("timeout") timeout = self._get_param("timeout")
try: name, job_id = self.batch_backend.submit_job(
name, job_id = self.batch_backend.submit_job( job_name,
job_name, job_def,
job_def, job_queue,
job_queue, depends_on=depends_on,
depends_on=depends_on, container_overrides=container_overrides,
container_overrides=container_overrides, timeout=timeout,
timeout=timeout, )
)
except AWSError as err:
return err.response()
result = {"jobId": job_id, "jobName": name} result = {"jobId": job_id, "jobName": name}
@ -249,20 +226,14 @@ class BatchResponse(BaseResponse):
def describejobs(self): def describejobs(self):
jobs = self._get_param("jobs") jobs = self._get_param("jobs")
try: return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err:
return err.response()
# ListJobs # ListJobs
def listjobs(self): def listjobs(self):
job_queue = self._get_param("jobQueue") job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus") job_status = self._get_param("jobStatus")
try: jobs = self.batch_backend.list_jobs(job_queue, job_status)
jobs = self.batch_backend.list_jobs(job_queue, job_status)
except AWSError as err:
return err.response()
result = {"jobSummaryList": [job.describe_short() for job in jobs]} result = {"jobSummaryList": [job.describe_short() for job in jobs]}
return json.dumps(result) return json.dumps(result)
@ -272,10 +243,7 @@ class BatchResponse(BaseResponse):
job_id = self._get_param("jobId") job_id = self._get_param("jobId")
reason = self._get_param("reason") reason = self._get_param("reason")
try: self.batch_backend.terminate_job(job_id, reason)
self.batch_backend.terminate_job(job_id, reason)
except AWSError as err:
return err.response()
return "" return ""

View File

@ -1,30 +1,16 @@
import xmltodict import xmltodict
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cloudfront_backend from .models import cloudfront_backend
from .exceptions import CloudFrontException
XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/"
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except CloudFrontException as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class CloudFrontResponse(BaseResponse): class CloudFrontResponse(BaseResponse):
def _get_xml_body(self): def _get_xml_body(self):
return xmltodict.parse(self.body, dict_constructor=dict) return xmltodict.parse(self.body, dict_constructor=dict)
@error_handler
def distributions(self, request, full_url, headers): def distributions(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
@ -49,7 +35,6 @@ class CloudFrontResponse(BaseResponse):
response = template.render(distributions=distributions) response = template.render(distributions=distributions)
return 200, {}, response return 200, {}, response
@error_handler
def individual_distribution(self, request, full_url, headers): def individual_distribution(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
distribution_id = full_url.split("/")[-1] distribution_id = full_url.split("/")[-1]

View File

@ -132,20 +132,13 @@ class AuthFailureError(RESTError):
) )
class AWSError(Exception): class AWSError(JsonRESTError):
TYPE = None TYPE = None
STATUS = 400 STATUS = 400
def __init__(self, message, exception_type=None, status=None): def __init__(self, message, exception_type=None, status=None):
self.message = message super().__init__(exception_type or self.TYPE, message)
self.type = exception_type or self.TYPE self.code = status or self.STATUS
self.status = status or self.STATUS
def response(self):
return (
json.dumps({"__type": self.type, "message": self.message}),
dict(status=self.status),
)
class InvalidNextTokenException(JsonRESTError): class InvalidNextTokenException(JsonRESTError):
@ -160,8 +153,7 @@ class InvalidNextTokenException(JsonRESTError):
class InvalidToken(AWSError): class InvalidToken(AWSError):
TYPE = "InvalidToken" code = 400
STATUS = 400
def __init__(self, message="Invalid token"): def __init__(self, message="Invalid token"):
super().__init__("Invalid Token: {}".format(message)) super().__init__("Invalid Token: {}".format(message), "InvalidToken")

View File

@ -17,6 +17,7 @@ from botocore.awsrequest import AWSResponse
from types import FunctionType from types import FunctionType
from moto import settings from moto import settings
from moto.core.exceptions import HTTPException
import responses import responses
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
@ -283,9 +284,14 @@ class BotocoreStubber:
for header, value in request.headers.items(): for header, value in request.headers.items():
if isinstance(value, bytes): if isinstance(value, bytes):
request.headers[header] = value.decode("utf-8") request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback( try:
request, request.url, request.headers status, headers, body = response_callback(
) request, request.url, request.headers
)
except HTTPException as e:
status = e.code
headers = e.get_headers()
body = e.get_body()
body = MockRawResponse(body) body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body) response = AWSResponse(request.url, status, headers, body)

View File

@ -3,7 +3,6 @@ from urllib.parse import urlparse
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from .exceptions import DataBrewClientError
from .models import databrew_backends from .models import databrew_backends
@ -58,8 +57,5 @@ class DataBrewResponse(BaseResponse):
recipe_name = parsed_url.path.rstrip("/").rsplit("/", 1)[1] recipe_name = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
try: recipe = self.databrew_backend.get_recipe(recipe_name)
recipe = self.databrew_backend.get_recipe(recipe_name) return json.dumps(recipe.as_dict())
return json.dumps(recipe.as_dict())
except DataBrewClientError as e:
return e.code, e.get_headers(), e.get_body()

View File

@ -5,9 +5,10 @@ from moto.core.exceptions import AWSError
class EKSError(AWSError): class EKSError(AWSError):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(AWSError, self).__init__() super(AWSError, self).__init__(error_type=self.TYPE, message="")
self.description = json.dumps(kwargs) self.description = json.dumps(kwargs)
self.headers = {"status": self.STATUS, "x-amzn-ErrorType": self.TYPE} self.headers = {"status": self.STATUS, "x-amzn-ErrorType": self.TYPE}
self.code = self.STATUS
def response(self): def response(self):
return self.STATUS, self.headers, self.description return self.STATUS, self.headers, self.description

View File

@ -2,12 +2,6 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import (
InvalidParameterException,
InvalidRequestException,
ResourceInUseException,
ResourceNotFoundException,
)
from .models import eks_backends from .models import eks_backends
DEFAULT_MAX_RESULTS = 100 DEFAULT_MAX_RESULTS = 100
@ -32,27 +26,19 @@ class EKSResponse(BaseResponse):
tags = self._get_param("tags") tags = self._get_param("tags")
encryption_config = self._get_param("encryptionConfig") encryption_config = self._get_param("encryptionConfig")
try: cluster = self.eks_backend.create_cluster(
cluster = self.eks_backend.create_cluster( name=name,
name=name, version=version,
version=version, role_arn=role_arn,
role_arn=role_arn, resources_vpc_config=resources_vpc_config,
resources_vpc_config=resources_vpc_config, kubernetes_network_config=kubernetes_network_config,
kubernetes_network_config=kubernetes_network_config, logging=logging,
logging=logging, client_request_token=client_request_token,
client_request_token=client_request_token, tags=tags,
tags=tags, encryption_config=encryption_config,
encryption_config=encryption_config, )
)
return 200, {}, json.dumps({"cluster": dict(cluster)}) return 200, {}, json.dumps({"cluster": dict(cluster)})
except (
ResourceInUseException,
ResourceNotFoundException,
InvalidParameterException,
) as e:
# Backend will capture this and re-raise it as a ClientError.
return e.response()
def create_fargate_profile(self): def create_fargate_profile(self):
fargate_profile_name = self._get_param("fargateProfileName") fargate_profile_name = self._get_param("fargateProfileName")
@ -63,25 +49,17 @@ class EKSResponse(BaseResponse):
client_request_token = self._get_param("clientRequestToken") client_request_token = self._get_param("clientRequestToken")
tags = self._get_param("tags") tags = self._get_param("tags")
try: fargate_profile = self.eks_backend.create_fargate_profile(
fargate_profile = self.eks_backend.create_fargate_profile( fargate_profile_name=fargate_profile_name,
fargate_profile_name=fargate_profile_name, cluster_name=cluster_name,
cluster_name=cluster_name, pod_execution_role_arn=pod_execution_role_arn,
pod_execution_role_arn=pod_execution_role_arn, subnets=subnets,
subnets=subnets, selectors=selectors,
selectors=selectors, client_request_token=client_request_token,
client_request_token=client_request_token, tags=tags,
tags=tags, )
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)}) return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except (
ResourceNotFoundException,
ResourceInUseException,
InvalidParameterException,
InvalidRequestException,
) as e:
return e.response()
def create_nodegroup(self): def create_nodegroup(self):
cluster_name = self._get_param("name") cluster_name = self._get_param("name")
@ -101,69 +79,52 @@ class EKSResponse(BaseResponse):
version = self._get_param("version") version = self._get_param("version")
release_version = self._get_param("releaseVersion") release_version = self._get_param("releaseVersion")
try: nodegroup = self.eks_backend.create_nodegroup(
nodegroup = self.eks_backend.create_nodegroup( cluster_name=cluster_name,
cluster_name=cluster_name, nodegroup_name=nodegroup_name,
nodegroup_name=nodegroup_name, scaling_config=scaling_config,
scaling_config=scaling_config, disk_size=disk_size,
disk_size=disk_size, subnets=subnets,
subnets=subnets, instance_types=instance_types,
instance_types=instance_types, ami_type=ami_type,
ami_type=ami_type, remote_access=remote_access,
remote_access=remote_access, node_role=node_role,
node_role=node_role, labels=labels,
labels=labels, tags=tags,
tags=tags, client_request_token=client_request_token,
client_request_token=client_request_token, launch_template=launch_template,
launch_template=launch_template, capacity_type=capacity_type,
capacity_type=capacity_type, version=version,
version=version, release_version=release_version,
release_version=release_version, )
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)}) return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (
ResourceInUseException,
ResourceNotFoundException,
InvalidRequestException,
InvalidParameterException,
) as e:
return e.response()
def describe_cluster(self): def describe_cluster(self):
name = self._get_param("name") name = self._get_param("name")
try: cluster = self.eks_backend.describe_cluster(name=name)
cluster = self.eks_backend.describe_cluster(name=name)
return 200, {}, json.dumps({"cluster": dict(cluster)}) return 200, {}, json.dumps({"cluster": dict(cluster)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
def describe_fargate_profile(self): def describe_fargate_profile(self):
cluster_name = self._get_param("name") cluster_name = self._get_param("name")
fargate_profile_name = self._get_param("fargateProfileName") fargate_profile_name = self._get_param("fargateProfileName")
try: fargate_profile = self.eks_backend.describe_fargate_profile(
fargate_profile = self.eks_backend.describe_fargate_profile( cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name )
) return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
def describe_nodegroup(self): def describe_nodegroup(self):
cluster_name = self._get_param("name") cluster_name = self._get_param("name")
nodegroup_name = self._get_param("nodegroupName") nodegroup_name = self._get_param("nodegroupName")
try: nodegroup = self.eks_backend.describe_nodegroup(
nodegroup = self.eks_backend.describe_nodegroup( cluster_name=cluster_name, nodegroup_name=nodegroup_name
cluster_name=cluster_name, nodegroup_name=nodegroup_name )
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)}) return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
def list_clusters(self): def list_clusters(self):
max_results = self._get_int_param("maxResults", DEFAULT_MAX_RESULTS) max_results = self._get_int_param("maxResults", DEFAULT_MAX_RESULTS)
@ -206,35 +167,26 @@ class EKSResponse(BaseResponse):
def delete_cluster(self): def delete_cluster(self):
name = self._get_param("name") name = self._get_param("name")
try: cluster = self.eks_backend.delete_cluster(name=name)
cluster = self.eks_backend.delete_cluster(name=name)
return 200, {}, json.dumps({"cluster": dict(cluster)}) return 200, {}, json.dumps({"cluster": dict(cluster)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
def delete_fargate_profile(self): def delete_fargate_profile(self):
cluster_name = self._get_param("name") cluster_name = self._get_param("name")
fargate_profile_name = self._get_param("fargateProfileName") fargate_profile_name = self._get_param("fargateProfileName")
try: fargate_profile = self.eks_backend.delete_fargate_profile(
fargate_profile = self.eks_backend.delete_fargate_profile( cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name )
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)}) return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except ResourceNotFoundException as e:
return e.response()
def delete_nodegroup(self): def delete_nodegroup(self):
cluster_name = self._get_param("name") cluster_name = self._get_param("name")
nodegroup_name = self._get_param("nodegroupName") nodegroup_name = self._get_param("nodegroupName")
try: nodegroup = self.eks_backend.delete_nodegroup(
nodegroup = self.eks_backend.delete_nodegroup( cluster_name=cluster_name, nodegroup_name=nodegroup_name
cluster_name=cluster_name, nodegroup_name=nodegroup_name )
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)}) return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()

View File

@ -1,23 +1,11 @@
import json import json
import re import re
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import ElasticSearchError, InvalidDomainName from .exceptions import InvalidDomainName
from .models import es_backends from .models import es_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except ElasticSearchError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class ElasticsearchServiceResponse(BaseResponse): class ElasticsearchServiceResponse(BaseResponse):
"""Handler for ElasticsearchService requests and responses.""" """Handler for ElasticsearchService requests and responses."""
@ -27,7 +15,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return es_backends[self.region] return es_backends[self.region]
@classmethod @classmethod
@error_handler
def list_domains(cls, request, full_url, headers): def list_domains(cls, request, full_url, headers):
response = ElasticsearchServiceResponse() response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers) response.setup_class(request, full_url, headers)
@ -35,7 +22,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return response.list_domain_names() return response.list_domain_names()
@classmethod @classmethod
@error_handler
def domains(cls, request, full_url, headers): def domains(cls, request, full_url, headers):
response = ElasticsearchServiceResponse() response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers) response.setup_class(request, full_url, headers)
@ -43,7 +29,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return response.create_elasticsearch_domain() return response.create_elasticsearch_domain()
@classmethod @classmethod
@error_handler
def domain(cls, request, full_url, headers): def domain(cls, request, full_url, headers):
response = ElasticsearchServiceResponse() response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers) response.setup_class(request, full_url, headers)

View File

@ -2,7 +2,6 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import forecast_backends from .models import forecast_backends
@ -18,57 +17,45 @@ class ForecastResponse(BaseResponse):
dataset_arns = self._get_param("DatasetArns") dataset_arns = self._get_param("DatasetArns")
tags = self._get_param("Tags") tags = self._get_param("Tags")
try: dataset_group = self.forecast_backend.create_dataset_group(
dataset_group = self.forecast_backend.create_dataset_group( dataset_group_name=dataset_group_name,
dataset_group_name=dataset_group_name, domain=domain,
domain=domain, dataset_arns=dataset_arns,
dataset_arns=dataset_arns, tags=tags,
tags=tags, )
) response = {"DatasetGroupArn": dataset_group.arn}
response = {"DatasetGroupArn": dataset_group.arn} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_dataset_group(self): def describe_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
try: dataset_group = self.forecast_backend.describe_dataset_group(
dataset_group = self.forecast_backend.describe_dataset_group( dataset_group_arn=dataset_group_arn
dataset_group_arn=dataset_group_arn )
) response = {
response = { "CreationTime": dataset_group.creation_date,
"CreationTime": dataset_group.creation_date, "DatasetArns": dataset_group.dataset_arns,
"DatasetArns": dataset_group.dataset_arns, "DatasetGroupArn": dataset_group.arn,
"DatasetGroupArn": dataset_group.arn, "DatasetGroupName": dataset_group.dataset_group_name,
"DatasetGroupName": dataset_group.dataset_group_name, "Domain": dataset_group.domain,
"Domain": dataset_group.domain, "LastModificationTime": dataset_group.modified_date,
"LastModificationTime": dataset_group.modified_date, "Status": "ACTIVE",
"Status": "ACTIVE", }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def delete_dataset_group(self): def delete_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
try: self.forecast_backend.delete_dataset_group(dataset_group_arn)
self.forecast_backend.delete_dataset_group(dataset_group_arn) return 200, {}, None
return 200, {}, None
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def update_dataset_group(self): def update_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
dataset_arns = self._get_param("DatasetArns") dataset_arns = self._get_param("DatasetArns")
try: self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns)
self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns) return 200, {}, None
return 200, {}, None
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def list_dataset_groups(self): def list_dataset_groups(self):

View File

@ -11,7 +11,7 @@ class UnknownBroker(MQError):
super().__init__("NotFoundException", "Can't find requested broker") super().__init__("NotFoundException", "Can't find requested broker")
self.broker_id = broker_id self.broker_id = broker_id
def get_body(self): def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "broker-id", "errorAttribute": "broker-id",
"message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.", "message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.",
@ -24,7 +24,7 @@ class UnknownConfiguration(MQError):
super().__init__("NotFoundException", "Can't find requested configuration") super().__init__("NotFoundException", "Can't find requested configuration")
self.config_id = config_id self.config_id = config_id
def get_body(self): def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "configuration_id", "errorAttribute": "configuration_id",
"message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.", "message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.",
@ -37,7 +37,7 @@ class UnknownUser(MQError):
super().__init__("NotFoundException", "Can't find requested user") super().__init__("NotFoundException", "Can't find requested user")
self.username = username self.username = username
def get_body(self): def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "username", "errorAttribute": "username",
"message": f"Can't find requested user [{self.username}]. Make sure your user exists.", "message": f"Can't find requested user [{self.username}]. Make sure your user exists.",
@ -50,7 +50,7 @@ class UnsupportedEngineType(MQError):
super().__init__("BadRequestException", "") super().__init__("BadRequestException", "")
self.engine_type = engine_type self.engine_type = engine_type
def get_body(self): def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "engineType", "errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] does not support configuration.", "message": f"Broker engine type [{self.engine_type}] does not support configuration.",
@ -63,7 +63,7 @@ class UnknownEngineType(MQError):
super().__init__("BadRequestException", "") super().__init__("BadRequestException", "")
self.engine_type = engine_type self.engine_type = engine_type
def get_body(self): def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = { body = {
"errorAttribute": "engineType", "errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]", "message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]",

View File

@ -1,24 +1,11 @@
"""Handles incoming mq requests, invokes methods, returns responses.""" """Handles incoming mq requests, invokes methods, returns responses."""
import json import json
from functools import wraps
from urllib.parse import unquote from urllib.parse import unquote
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import MQError
from .models import mq_backends from .models import mq_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except MQError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class MQResponse(BaseResponse): class MQResponse(BaseResponse):
"""Handler for MQ requests and responses.""" """Handler for MQ requests and responses."""
@ -27,7 +14,6 @@ class MQResponse(BaseResponse):
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return mq_backends[self.region] return mq_backends[self.region]
@error_handler
def broker(self, request, full_url, headers): def broker(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
@ -44,7 +30,6 @@ class MQResponse(BaseResponse):
if request.method == "GET": if request.method == "GET":
return self.list_brokers() return self.list_brokers()
@error_handler
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
@ -52,7 +37,6 @@ class MQResponse(BaseResponse):
if request.method == "PUT": if request.method == "PUT":
return self.update_configuration() return self.update_configuration()
@error_handler
def configurations(self, request, full_url, headers): def configurations(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
@ -72,7 +56,6 @@ class MQResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_tags() return self.delete_tags()
@error_handler
def user(self, request, full_url, headers): def user(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":

View File

@ -1,24 +1,11 @@
"""Handles incoming pinpoint requests, invokes methods, returns responses.""" """Handles incoming pinpoint requests, invokes methods, returns responses."""
import json import json
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from urllib.parse import unquote from urllib.parse import unquote
from .exceptions import PinpointExceptions
from .models import pinpoint_backends from .models import pinpoint_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except PinpointExceptions as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class PinpointResponse(BaseResponse): class PinpointResponse(BaseResponse):
"""Handler for Pinpoint requests and responses.""" """Handler for Pinpoint requests and responses."""
@ -27,7 +14,6 @@ class PinpointResponse(BaseResponse):
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return pinpoint_backends[self.region] return pinpoint_backends[self.region]
@error_handler
def app(self, request, full_url, headers): def app(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "DELETE": if request.method == "DELETE":
@ -49,7 +35,6 @@ class PinpointResponse(BaseResponse):
if request.method == "PUT": if request.method == "PUT":
return self.update_application_settings() return self.update_application_settings()
@error_handler
def eventstream(self, request, full_url, headers): def eventstream(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "DELETE": if request.method == "DELETE":

View File

@ -1,28 +1,16 @@
"""Handles Route53 API requests, invokes method and returns response.""" """Handles Route53 API requests, invokes method and returns response."""
from functools import wraps
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from jinja2 import Template from jinja2 import Template
import xmltodict import xmltodict
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.route53.exceptions import Route53ClientError, InvalidChangeBatch from moto.route53.exceptions import InvalidChangeBatch
from moto.route53.models import route53_backend from moto.route53.models import route53_backend
XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/" XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except Route53ClientError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class Route53(BaseResponse): class Route53(BaseResponse):
"""Handler for Route53 requests and responses.""" """Handler for Route53 requests and responses."""
@ -36,7 +24,6 @@ class Route53(BaseResponse):
return False return False
@error_handler
def list_or_create_hostzone_response(self, request, full_url, headers): def list_or_create_hostzone_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -115,7 +102,6 @@ class Route53(BaseResponse):
template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE) template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE)
return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS) return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS)
@error_handler
def get_or_delete_hostzone_response(self, request, full_url, headers): def get_or_delete_hostzone_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -129,7 +115,6 @@ class Route53(BaseResponse):
route53_backend.delete_hosted_zone(zoneid) route53_backend.delete_hosted_zone(zoneid)
return 200, headers, DELETE_HOSTED_ZONE_RESPONSE return 200, headers, DELETE_HOSTED_ZONE_RESPONSE
@error_handler
def rrset_response(self, request, full_url, headers): def rrset_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -298,7 +283,6 @@ class Route53(BaseResponse):
template = Template(GET_CHANGE_RESPONSE) template = Template(GET_CHANGE_RESPONSE)
return 200, headers, template.render(change_id=change_id, xmlns=XMLNS) return 200, headers, template.render(change_id=change_id, xmlns=XMLNS)
@error_handler
def list_or_create_query_logging_config_response(self, request, full_url, headers): def list_or_create_query_logging_config_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -343,7 +327,6 @@ class Route53(BaseResponse):
), ),
) )
@error_handler
def get_or_delete_query_logging_config_response(self, request, full_url, headers): def get_or_delete_query_logging_config_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -394,7 +377,6 @@ class Route53(BaseResponse):
template.render(delegation_set=delegation_set), template.render(delegation_set=delegation_set),
) )
@error_handler
def reusable_delegation_set(self, request, full_url, headers): def reusable_delegation_set(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)

View File

@ -1,26 +1,13 @@
import json import json
import xmltodict import xmltodict
from functools import wraps
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from moto.s3.exceptions import S3ClientError from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from .exceptions import S3ControlError
from .models import s3control_backend from .models import s3control_backend
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except S3ControlError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class S3ControlResponse(BaseResponse): class S3ControlResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def public_access_block( def public_access_block(
@ -64,7 +51,6 @@ class S3ControlResponse(BaseResponse):
return parsed_xml return parsed_xml
@error_handler
def access_point(self, request, full_url, headers): def access_point(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
@ -74,7 +60,6 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_access_point(full_url) return self.delete_access_point(full_url)
@error_handler
def access_point_policy(self, request, full_url, headers): def access_point_policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
@ -84,7 +69,6 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_access_point_policy(full_url) return self.delete_access_point_policy(full_url)
@error_handler
def access_point_policy_status(self, request, full_url, headers): def access_point_policy_status(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":

View File

@ -46,63 +46,55 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_notebook_instance(self): def create_notebook_instance(self):
try: sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance( notebook_instance_name=self._get_param("NotebookInstanceName"),
notebook_instance_name=self._get_param("NotebookInstanceName"), instance_type=self._get_param("InstanceType"),
instance_type=self._get_param("InstanceType"), subnet_id=self._get_param("SubnetId"),
subnet_id=self._get_param("SubnetId"), security_group_ids=self._get_param("SecurityGroupIds"),
security_group_ids=self._get_param("SecurityGroupIds"), role_arn=self._get_param("RoleArn"),
role_arn=self._get_param("RoleArn"), kms_key_id=self._get_param("KmsKeyId"),
kms_key_id=self._get_param("KmsKeyId"), tags=self._get_param("Tags"),
tags=self._get_param("Tags"), lifecycle_config_name=self._get_param("LifecycleConfigName"),
lifecycle_config_name=self._get_param("LifecycleConfigName"), direct_internet_access=self._get_param("DirectInternetAccess"),
direct_internet_access=self._get_param("DirectInternetAccess"), volume_size_in_gb=self._get_param("VolumeSizeInGB"),
volume_size_in_gb=self._get_param("VolumeSizeInGB"), accelerator_types=self._get_param("AcceleratorTypes"),
accelerator_types=self._get_param("AcceleratorTypes"), default_code_repository=self._get_param("DefaultCodeRepository"),
default_code_repository=self._get_param("DefaultCodeRepository"), additional_code_repositories=self._get_param("AdditionalCodeRepositories"),
additional_code_repositories=self._get_param( root_access=self._get_param("RootAccess"),
"AdditionalCodeRepositories" )
), response = {
root_access=self._get_param("RootAccess"), "NotebookInstanceArn": sagemaker_notebook.arn,
) }
response = { return 200, {}, json.dumps(response)
"NotebookInstanceArn": sagemaker_notebook.arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_notebook_instance(self): def describe_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName") notebook_instance_name = self._get_param("NotebookInstanceName")
try: notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance = self.sagemaker_backend.get_notebook_instance( notebook_instance_name
notebook_instance_name )
) response = {
response = { "NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceArn": notebook_instance.arn, "NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceName": notebook_instance.notebook_instance_name, "NotebookInstanceStatus": notebook_instance.status,
"NotebookInstanceStatus": notebook_instance.status, "Url": notebook_instance.url,
"Url": notebook_instance.url, "InstanceType": notebook_instance.instance_type,
"InstanceType": notebook_instance.instance_type, "SubnetId": notebook_instance.subnet_id,
"SubnetId": notebook_instance.subnet_id, "SecurityGroups": notebook_instance.security_group_ids,
"SecurityGroups": notebook_instance.security_group_ids, "RoleArn": notebook_instance.role_arn,
"RoleArn": notebook_instance.role_arn, "KmsKeyId": notebook_instance.kms_key_id,
"KmsKeyId": notebook_instance.kms_key_id, # ToDo: NetworkInterfaceId
# ToDo: NetworkInterfaceId "LastModifiedTime": str(notebook_instance.last_modified_time),
"LastModifiedTime": str(notebook_instance.last_modified_time), "CreationTime": str(notebook_instance.creation_time),
"CreationTime": str(notebook_instance.creation_time), "NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name, "DirectInternetAccess": notebook_instance.direct_internet_access,
"DirectInternetAccess": notebook_instance.direct_internet_access, "VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb, "AcceleratorTypes": notebook_instance.accelerator_types,
"AcceleratorTypes": notebook_instance.accelerator_types, "DefaultCodeRepository": notebook_instance.default_code_repository,
"DefaultCodeRepository": notebook_instance.default_code_repository, "AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories, "RootAccess": notebook_instance.root_access,
"RootAccess": notebook_instance.root_access, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def start_notebook_instance(self): def start_notebook_instance(self):
@ -171,20 +163,17 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_endpoint_config(self): def create_endpoint_config(self):
try: endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config = self.sagemaker_backend.create_endpoint_config( endpoint_config_name=self._get_param("EndpointConfigName"),
endpoint_config_name=self._get_param("EndpointConfigName"), production_variants=self._get_param("ProductionVariants"),
production_variants=self._get_param("ProductionVariants"), data_capture_config=self._get_param("DataCaptureConfig"),
data_capture_config=self._get_param("DataCaptureConfig"), tags=self._get_param("Tags"),
tags=self._get_param("Tags"), kms_key_id=self._get_param("KmsKeyId"),
kms_key_id=self._get_param("KmsKeyId"), )
) response = {
response = { "EndpointConfigArn": endpoint_config.endpoint_config_arn,
"EndpointConfigArn": endpoint_config.endpoint_config_arn, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_endpoint_config(self): def describe_endpoint_config(self):
@ -200,18 +189,15 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_endpoint(self): def create_endpoint(self):
try: endpoint = self.sagemaker_backend.create_endpoint(
endpoint = self.sagemaker_backend.create_endpoint( endpoint_name=self._get_param("EndpointName"),
endpoint_name=self._get_param("EndpointName"), endpoint_config_name=self._get_param("EndpointConfigName"),
endpoint_config_name=self._get_param("EndpointConfigName"), tags=self._get_param("Tags"),
tags=self._get_param("Tags"), )
) response = {
response = { "EndpointArn": endpoint.endpoint_arn,
"EndpointArn": endpoint.endpoint_arn, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_endpoint(self): def describe_endpoint(self):
@ -227,23 +213,20 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_processing_job(self): def create_processing_job(self):
try: processing_job = self.sagemaker_backend.create_processing_job(
processing_job = self.sagemaker_backend.create_processing_job( app_specification=self._get_param("AppSpecification"),
app_specification=self._get_param("AppSpecification"), experiment_config=self._get_param("ExperimentConfig"),
experiment_config=self._get_param("ExperimentConfig"), network_config=self._get_param("NetworkConfig"),
network_config=self._get_param("NetworkConfig"), processing_inputs=self._get_param("ProcessingInputs"),
processing_inputs=self._get_param("ProcessingInputs"), processing_job_name=self._get_param("ProcessingJobName"),
processing_job_name=self._get_param("ProcessingJobName"), processing_output_config=self._get_param("ProcessingOutputConfig"),
processing_output_config=self._get_param("ProcessingOutputConfig"), role_arn=self._get_param("RoleArn"),
role_arn=self._get_param("RoleArn"), stopping_condition=self._get_param("StoppingCondition"),
stopping_condition=self._get_param("StoppingCondition"), )
) response = {
response = { "ProcessingJobArn": processing_job.processing_job_arn,
"ProcessingJobArn": processing_job.processing_job_arn, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_processing_job(self): def describe_processing_job(self):
@ -253,39 +236,34 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_training_job(self): def create_training_job(self):
try: training_job = self.sagemaker_backend.create_training_job(
training_job = self.sagemaker_backend.create_training_job( training_job_name=self._get_param("TrainingJobName"),
training_job_name=self._get_param("TrainingJobName"), hyper_parameters=self._get_param("HyperParameters"),
hyper_parameters=self._get_param("HyperParameters"), algorithm_specification=self._get_param("AlgorithmSpecification"),
algorithm_specification=self._get_param("AlgorithmSpecification"), role_arn=self._get_param("RoleArn"),
role_arn=self._get_param("RoleArn"), input_data_config=self._get_param("InputDataConfig"),
input_data_config=self._get_param("InputDataConfig"), output_data_config=self._get_param("OutputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"), resource_config=self._get_param("ResourceConfig"),
resource_config=self._get_param("ResourceConfig"), vpc_config=self._get_param("VpcConfig"),
vpc_config=self._get_param("VpcConfig"), stopping_condition=self._get_param("StoppingCondition"),
stopping_condition=self._get_param("StoppingCondition"), tags=self._get_param("Tags"),
tags=self._get_param("Tags"), enable_network_isolation=self._get_param("EnableNetworkIsolation", False),
enable_network_isolation=self._get_param( enable_inter_container_traffic_encryption=self._get_param(
"EnableNetworkIsolation", False "EnableInterContainerTrafficEncryption", False
), ),
enable_inter_container_traffic_encryption=self._get_param( enable_managed_spot_training=self._get_param(
"EnableInterContainerTrafficEncryption", False "EnableManagedSpotTraining", False
), ),
enable_managed_spot_training=self._get_param( checkpoint_config=self._get_param("CheckpointConfig"),
"EnableManagedSpotTraining", False debug_hook_config=self._get_param("DebugHookConfig"),
), debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
checkpoint_config=self._get_param("CheckpointConfig"), tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
debug_hook_config=self._get_param("DebugHookConfig"), experiment_config=self._get_param("ExperimentConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"), )
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"), response = {
experiment_config=self._get_param("ExperimentConfig"), "TrainingJobArn": training_job.training_job_arn,
) }
response = { return 200, {}, json.dumps(response)
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_training_job(self): def describe_training_job(self):
@ -301,22 +279,19 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_notebook_instance_lifecycle_config(self): def create_notebook_instance_lifecycle_config(self):
try: lifecycle_configuration = (
lifecycle_configuration = ( self.sagemaker_backend.create_notebook_instance_lifecycle_config(
self.sagemaker_backend.create_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param(
notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName"
"NotebookInstanceLifecycleConfigName" ),
), on_create=self._get_param("OnCreate"),
on_create=self._get_param("OnCreate"), on_start=self._get_param("OnStart"),
on_start=self._get_param("OnStart"),
)
) )
response = { )
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn, response = {
} "NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
return 200, {}, json.dumps(response) }
except AWSError as err: return 200, {}, json.dumps(response)
return err.response()
@amzn_request_id @amzn_request_id
def describe_notebook_instance_lifecycle_config(self): def describe_notebook_instance_lifecycle_config(self):
@ -426,14 +401,11 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_trial(self): def create_trial(self):
try: response = self.sagemaker_backend.create_trial(
response = self.sagemaker_backend.create_trial( trial_name=self._get_param("TrialName"),
trial_name=self._get_param("TrialName"), experiment_name=self._get_param("ExperimentName"),
experiment_name=self._get_param("ExperimentName"), )
) return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def list_trial_components(self): def list_trial_components(self):
@ -467,14 +439,11 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def create_trial_component(self): def create_trial_component(self):
try: response = self.sagemaker_backend.create_trial_component(
response = self.sagemaker_backend.create_trial_component( trial_component_name=self._get_param("TrialComponentName"),
trial_component_name=self._get_param("TrialComponentName"), trial_name=self._get_param("TrialName"),
trial_name=self._get_param("TrialName"), )
) return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_trial(self): def describe_trial(self):
@ -530,52 +499,47 @@ class SageMakerResponse(BaseResponse):
"Failed", "Failed",
] ]
try: max_results = self._get_int_param("MaxResults")
max_results = self._get_int_param("MaxResults") sort_by = self._get_param("SortBy", "CreationTime")
sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Ascending")
sort_order = self._get_param("SortOrder", "Ascending") status_equals = self._get_param("StatusEquals")
status_equals = self._get_param("StatusEquals") next_token = self._get_param("NextToken")
next_token = self._get_param("NextToken") errors = []
errors = [] if max_results and max_results not in max_results_range:
if max_results and max_results not in max_results_range: errors.append(
errors.append( "Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format( max_results, max_results_range[-1]
max_results, max_results_range[-1]
)
) )
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(
status_equals, "statusEquals", allowed_status_equals
)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_processing_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
) )
return 200, {}, json.dumps(response)
except AWSError as err: if sort_by not in allowed_sort_by:
return err.response() errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_processing_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def list_training_jobs(self): def list_training_jobs(self):
@ -590,49 +554,44 @@ class SageMakerResponse(BaseResponse):
"Failed", "Failed",
] ]
try: max_results = self._get_int_param("MaxResults")
max_results = self._get_int_param("MaxResults") sort_by = self._get_param("SortBy", "CreationTime")
sort_by = self._get_param("SortBy", "CreationTime") sort_order = self._get_param("SortOrder", "Ascending")
sort_order = self._get_param("SortOrder", "Ascending") status_equals = self._get_param("StatusEquals")
status_equals = self._get_param("StatusEquals") next_token = self._get_param("NextToken")
next_token = self._get_param("NextToken") errors = []
errors = [] if max_results and max_results not in max_results_range:
if max_results and max_results not in max_results_range: errors.append(
errors.append( "Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format( max_results, max_results_range[-1]
max_results, max_results_range[-1]
)
) )
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(
status_equals, "statusEquals", allowed_status_equals
)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_training_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
) )
return 200, {}, json.dumps(response)
except AWSError as err: if sort_by not in allowed_sort_by:
return err.response() errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_training_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)

View File

@ -546,6 +546,13 @@ class StepFunctionBackend(BaseBackend):
) )
return execution.get_execution_history(state_machine.roleArn) return execution.get_execution_history(state_machine.roleArn)
def list_tags_for_resource(self, arn):
try:
state_machine = self.describe_state_machine(arn)
return state_machine.tags or []
except StateMachineDoesNotExist:
return []
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn, tags):
try: try:
state_machine = self.describe_state_machine(resource_arn) state_machine = self.describe_state_machine(resource_arn)

View File

@ -2,7 +2,6 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import stepfunction_backends from .models import stepfunction_backends
@ -17,17 +16,14 @@ class StepFunctionResponse(BaseResponse):
definition = self._get_param("definition") definition = self._get_param("definition")
roleArn = self._get_param("roleArn") roleArn = self._get_param("roleArn")
tags = self._get_param("tags") tags = self._get_param("tags")
try: state_machine = self.stepfunction_backend.create_state_machine(
state_machine = self.stepfunction_backend.create_state_machine( name=name, definition=definition, roleArn=roleArn, tags=tags
name=name, definition=definition, roleArn=roleArn, tags=tags )
) response = {
response = { "creationDate": state_machine.creation_date,
"creationDate": state_machine.creation_date, "stateMachineArn": state_machine.arn,
"stateMachineArn": state_machine.arn, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def list_state_machines(self): def list_state_machines(self):
@ -56,55 +52,42 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def _describe_state_machine(self, state_machine_arn): def _describe_state_machine(self, state_machine_arn):
try: state_machine = self.stepfunction_backend.describe_state_machine(
state_machine = self.stepfunction_backend.describe_state_machine( state_machine_arn
state_machine_arn )
) response = {
response = { "creationDate": state_machine.creation_date,
"creationDate": state_machine.creation_date, "stateMachineArn": state_machine.arn,
"stateMachineArn": state_machine.arn, "definition": state_machine.definition,
"definition": state_machine.definition, "name": state_machine.name,
"name": state_machine.name, "roleArn": state_machine.roleArn,
"roleArn": state_machine.roleArn, "status": "ACTIVE",
"status": "ACTIVE", }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def delete_state_machine(self): def delete_state_machine(self):
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
try: self.stepfunction_backend.delete_state_machine(arn)
self.stepfunction_backend.delete_state_machine(arn) return 200, {}, json.dumps("{}")
return 200, {}, json.dumps("{}")
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def update_state_machine(self): def update_state_machine(self):
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
definition = self._get_param("definition") definition = self._get_param("definition")
role_arn = self._get_param("roleArn") role_arn = self._get_param("roleArn")
try: state_machine = self.stepfunction_backend.update_state_machine(
state_machine = self.stepfunction_backend.update_state_machine( arn=arn, definition=definition, role_arn=role_arn
arn=arn, definition=definition, role_arn=role_arn )
) response = {
response = { "updateDate": state_machine.update_date,
"updateDate": state_machine.update_date, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def list_tags_for_resource(self): def list_tags_for_resource(self):
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
try: tags = self.stepfunction_backend.list_tags_for_resource(arn)
state_machine = self.stepfunction_backend.describe_state_machine(arn)
tags = state_machine.tags or []
except AWSError:
tags = []
response = {"tags": tags} response = {"tags": tags}
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@ -112,20 +95,14 @@ class StepFunctionResponse(BaseResponse):
def tag_resource(self): def tag_resource(self):
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
tags = self._get_param("tags", []) tags = self._get_param("tags", [])
try: self.stepfunction_backend.tag_resource(arn, tags)
self.stepfunction_backend.tag_resource(arn, tags)
except AWSError as err:
return err.response()
return 200, {}, json.dumps({}) return 200, {}, json.dumps({})
@amzn_request_id @amzn_request_id
def untag_resource(self): def untag_resource(self):
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
tag_keys = self._get_param("tagKeys", []) tag_keys = self._get_param("tagKeys", [])
try: self.stepfunction_backend.untag_resource(arn, tag_keys)
self.stepfunction_backend.untag_resource(arn, tag_keys)
except AWSError as err:
return err.response()
return 200, {}, json.dumps({}) return 200, {}, json.dumps({})
@amzn_request_id @amzn_request_id
@ -133,12 +110,9 @@ class StepFunctionResponse(BaseResponse):
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
name = self._get_param("name") name = self._get_param("name")
execution_input = self._get_param("input", if_none="{}") execution_input = self._get_param("input", if_none="{}")
try: execution = self.stepfunction_backend.start_execution(
execution = self.stepfunction_backend.start_execution( arn, name, execution_input
arn, name, execution_input )
)
except AWSError as err:
return err.response()
response = { response = {
"executionArn": execution.execution_arn, "executionArn": execution.execution_arn,
"startDate": execution.start_date, "startDate": execution.start_date,
@ -151,16 +125,13 @@ class StepFunctionResponse(BaseResponse):
next_token = self._get_param("nextToken") next_token = self._get_param("nextToken")
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
status_filter = self._get_param("statusFilter") status_filter = self._get_param("statusFilter")
try: state_machine = self.stepfunction_backend.describe_state_machine(arn)
state_machine = self.stepfunction_backend.describe_state_machine(arn) results, next_token = self.stepfunction_backend.list_executions(
results, next_token = self.stepfunction_backend.list_executions( arn,
arn, status_filter=status_filter,
status_filter=status_filter, max_results=max_results,
max_results=max_results, next_token=next_token,
next_token=next_token, )
)
except AWSError as err:
return err.response()
executions = [ executions = [
{ {
"executionArn": execution.execution_arn, "executionArn": execution.execution_arn,
@ -179,48 +150,36 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def describe_execution(self): def describe_execution(self):
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
try: execution = self.stepfunction_backend.describe_execution(arn)
execution = self.stepfunction_backend.describe_execution(arn) response = {
response = { "executionArn": arn,
"executionArn": arn, "input": execution.execution_input,
"input": execution.execution_input, "name": execution.name,
"name": execution.name, "startDate": execution.start_date,
"startDate": execution.start_date, "stateMachineArn": execution.state_machine_arn,
"stateMachineArn": execution.state_machine_arn, "status": execution.status,
"status": execution.status, "stopDate": execution.stop_date,
"stopDate": execution.stop_date, }
} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def describe_state_machine_for_execution(self): def describe_state_machine_for_execution(self):
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
try: execution = self.stepfunction_backend.describe_execution(arn)
execution = self.stepfunction_backend.describe_execution(arn) return self._describe_state_machine(execution.state_machine_arn)
return self._describe_state_machine(execution.state_machine_arn)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def stop_execution(self): def stop_execution(self):
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
try: execution = self.stepfunction_backend.stop_execution(arn)
execution = self.stepfunction_backend.stop_execution(arn) response = {"stopDate": execution.stop_date}
response = {"stopDate": execution.stop_date} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id @amzn_request_id
def get_execution_history(self): def get_execution_history(self):
execution_arn = self._get_param("executionArn") execution_arn = self._get_param("executionArn")
try: execution_history = self.stepfunction_backend.get_execution_history(
execution_history = self.stepfunction_backend.get_execution_history( execution_arn
execution_arn )
) response = {"events": execution_history}
response = {"events": execution_history} return 200, {}, json.dumps(response)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()

View File

@ -35,10 +35,7 @@ class XRayResponse(BaseResponse):
# PutTelemetryRecords # PutTelemetryRecords
def telemetry_records(self): def telemetry_records(self):
try: self.xray_backend.add_telemetry_records(self.request_params)
self.xray_backend.add_telemetry_records(self.request_params)
except AWSError as err:
return err.response()
return "" return ""
@ -109,7 +106,7 @@ class XRayResponse(BaseResponse):
start_time, end_time, filter_expression start_time, end_time, filter_expression
) )
except AWSError as err: except AWSError as err:
return err.response() raise err
except Exception as err: except Exception as err:
return ( return (
json.dumps({"__type": "InternalFailure", "message": str(err)}), json.dumps({"__type": "InternalFailure", "message": str(err)}),
@ -132,7 +129,7 @@ class XRayResponse(BaseResponse):
try: try:
result = self.xray_backend.get_trace_ids(trace_ids) result = self.xray_backend.get_trace_ids(trace_ids)
except AWSError as err: except AWSError as err:
return err.response() raise err
except Exception as err: except Exception as err:
return ( return (
json.dumps({"__type": "InternalFailure", "message": str(err)}), json.dumps({"__type": "InternalFailure", "message": str(err)}),

View File

@ -161,7 +161,7 @@ class TestWithSetupMethod:
@mock_kinesis @mock_kinesis
class TestKinesisUsingSetupMethod: class TestKinesisUsingSetupMethod:
def setup_method(self, *args): def setup_method(self, *args): # pylint: disable=unused-argument
self.stream_name = "test_stream" self.stream_name = "test_stream"
self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1") self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1")
self.boto3_kinesis_client.create_stream( self.boto3_kinesis_client.create_stream(