Simplify error handling (#4936)

This commit is contained in:
Bert Blommers 2022-03-15 15:42:46 -01:00 committed by GitHub
parent 5ae0ced349
commit 67ab7f857a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 483 additions and 878 deletions

View File

@ -2,7 +2,7 @@ import json
import base64
from moto.core.responses import BaseResponse
from .models import acm_backends, AWSError, AWSValidationException
from .models import acm_backends, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse):
@ -37,10 +37,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
self.acm_backend.add_tags_to_certificate(arn, tags)
return ""
@ -54,10 +51,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
self.acm_backend.delete_certificate(arn)
return ""
@ -71,10 +65,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
cert_bundle = self.acm_backend.get_certificate(arn)
return json.dumps(cert_bundle.describe())
@ -88,10 +79,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
cert_bundle = self.acm_backend.get_certificate(arn)
result = {
"Certificate": cert_bundle.cert.decode(),
@ -123,29 +111,26 @@ class AWSCertificateManagerResponse(BaseResponse):
try:
certificate = base64.standard_b64decode(certificate)
except Exception:
return AWSValidationException(
raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
).response()
)
try:
private_key = base64.standard_b64decode(private_key)
except Exception:
return AWSValidationException(
raise AWSValidationException(
"The private key is not PEM-encoded or is not valid."
).response()
)
if chain is not None:
try:
chain = base64.standard_b64decode(chain)
except Exception:
return AWSValidationException(
raise AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid."
).response()
)
try:
arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn, tags=tags
)
except AWSError as err:
return err.response()
arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn, tags=tags
)
return json.dumps({"CertificateArn": arn})
@ -170,10 +155,7 @@ class AWSCertificateManagerResponse(BaseResponse):
msg = "A required parameter for the specified action is not supplied."
return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
cert_bundle = self.acm_backend.get_certificate(arn)
result = {"Tags": []}
# Tag "objects" can not contain the Value part
@ -196,10 +178,7 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
self.acm_backend.remove_tags_from_certificate(arn, tags)
return ""
@ -219,15 +198,12 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
arn = self.acm_backend.request_certificate(
domain_name,
idempotency_token,
subject_alt_names,
tags,
)
except AWSError as err:
return err.response()
arn = self.acm_backend.request_certificate(
domain_name,
idempotency_token,
subject_alt_names,
tags,
)
return json.dumps({"CertificateArn": arn})
@ -247,16 +223,12 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain:
msg = "Parameter Domain does not match certificate domain"
_type = "InvalidDomainValidationOptionsException"
return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err:
return err.response()
if cert_bundle.common_name != domain:
msg = "Parameter Domain does not match certificate domain"
_type = "InvalidDomainValidationOptionsException"
return json.dumps({"__type": _type, "message": msg}), dict(status=400)
return ""
@ -271,20 +243,17 @@ class AWSCertificateManagerResponse(BaseResponse):
dict(status=400),
)
try:
(
certificate,
certificate_chain,
private_key,
) = self.acm_backend.export_certificate(
certificate_arn=certificate_arn, passphrase=passphrase
(
certificate,
certificate_chain,
private_key,
) = self.acm_backend.export_certificate(
certificate_arn=certificate_arn, passphrase=passphrase
)
return json.dumps(
dict(
Certificate=certificate,
CertificateChain=certificate_chain,
PrivateKey=private_key,
)
return json.dumps(
dict(
Certificate=certificate,
CertificateChain=certificate_chain,
PrivateKey=private_key,
)
)
except AWSError as err:
return err.response()
)

View File

@ -1,28 +1,16 @@
import json
from functools import wraps
from urllib.parse import unquote
from moto.utilities.utils import merge_multiple_dicts
from moto.core.responses import BaseResponse
from .models import apigateway_backends
from .exceptions import ApiGatewayException, InvalidRequestInput
from .exceptions import InvalidRequestInput
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except ApiGatewayException as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400):
headers = self.response_headers or {}
@ -103,7 +91,6 @@ class APIGatewayResponse(BaseResponse):
value = op["value"]
return self.__validate_api_key_source(value)
@error_handler
def restapis_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -133,7 +120,6 @@ class APIGatewayResponse(BaseResponse):
json.dumps({"item": [resource.to_dict() for resource in resources]}),
)
@error_handler
def gateway_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
@ -148,7 +134,6 @@ class APIGatewayResponse(BaseResponse):
if request.method == "GET":
return self.get_gateway_responses()
@error_handler
def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -163,7 +148,6 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
@error_handler
def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -295,7 +279,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(authorizer_response)
@error_handler
def request_validators(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -316,7 +299,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(validator)
@error_handler
def request_validator_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -336,7 +318,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(validator)
@error_handler
def authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -355,7 +336,6 @@ class APIGatewayResponse(BaseResponse):
return 202, {}, "{}"
return 200, {}, json.dumps(authorizer_response)
@error_handler
def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -406,7 +386,6 @@ class APIGatewayResponse(BaseResponse):
stage["tags"].pop(tag, None)
return 200, {}, json.dumps({"item": ""})
@error_handler
def stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -426,7 +405,6 @@ class APIGatewayResponse(BaseResponse):
return 202, {}, "{}"
return 200, {}, json.dumps(stage_response)
@error_handler
def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -474,7 +452,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(integration_response)
@error_handler
def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -509,7 +486,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(integration_response)
@error_handler
def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -526,7 +502,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(deployment)
@error_handler
def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -540,7 +515,6 @@ class APIGatewayResponse(BaseResponse):
deployment = self.backend.delete_deployment(function_id, deployment_id)
return 202, {}, json.dumps(deployment)
@error_handler
def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -553,7 +527,6 @@ class APIGatewayResponse(BaseResponse):
apikeys_response = self.backend.get_api_keys(include_values=include_values)
return 200, {}, json.dumps({"item": apikeys_response})
@error_handler
def apikey_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -585,7 +558,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps({"item": usage_plans_response})
return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -603,7 +575,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_keys(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -619,7 +590,6 @@ class APIGatewayResponse(BaseResponse):
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response})
@error_handler
def usage_plan_key_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -635,7 +605,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(usage_plan_response)
@error_handler
def domain_names(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -672,7 +641,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(domain_name_resp)
@error_handler
def domain_name_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -695,7 +663,6 @@ class APIGatewayResponse(BaseResponse):
return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names)
@error_handler
def models(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -723,7 +690,6 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(model)
@error_handler
def model_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -734,7 +700,6 @@ class APIGatewayResponse(BaseResponse):
model_info = self.backend.get_model(rest_api_id, model_name)
return 200, {}, json.dumps(model_info)
@error_handler
def base_path_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -754,7 +719,6 @@ class APIGatewayResponse(BaseResponse):
)
return 201, {}, json.dumps(base_path_mapping_resp)
@error_handler
def base_path_mapping_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -778,7 +742,6 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(base_path_mapping)
@error_handler
def vpc_link(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")

View File

@ -1,25 +1,13 @@
"""Handles incoming apigatewayv2 requests, invokes methods, returns responses."""
import json
from functools import wraps
from moto.core.responses import BaseResponse
from urllib.parse import unquote
from .exceptions import APIGatewayV2Error, UnknownProtocol
from .exceptions import UnknownProtocol
from .models import apigatewayv2_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except APIGatewayV2Error as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class ApiGatewayV2Response(BaseResponse):
"""Handler for ApiGatewayV2 requests and responses."""
@ -28,7 +16,6 @@ class ApiGatewayV2Response(BaseResponse):
"""Return backend instance specific for this region."""
return apigatewayv2_backends[self.region]
@error_handler
def apis(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -37,7 +24,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "GET":
return self.get_apis()
@error_handler
def api(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -50,7 +36,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE":
return self.delete_api()
@error_handler
def authorizer(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -67,7 +52,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_authorizer()
@error_handler
def cors(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -80,7 +64,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE":
return self.delete_route_request_parameter()
@error_handler
def model(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -97,7 +80,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_model()
@error_handler
def integration(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -108,7 +90,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "PATCH":
return self.update_integration()
@error_handler
def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -117,7 +98,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_integration()
@error_handler
def integration_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -136,7 +116,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_integration_response()
@error_handler
def route(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -147,7 +126,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "PATCH":
return self.update_route()
@error_handler
def routes(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -156,7 +134,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_route()
@error_handler
def route_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -171,7 +148,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "POST":
return self.create_route_response()
@error_handler
def tags(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -182,7 +158,6 @@ class ApiGatewayV2Response(BaseResponse):
if self.method == "DELETE":
return self.untag_resource()
@error_handler
def vpc_link(self, request, full_url, headers):
self.setup_class(request, full_url, headers)

View File

@ -1,24 +1,11 @@
"""Handles incoming appsync requests, invokes methods, returns responses."""
import json
from functools import wraps
from moto.core.responses import BaseResponse
from urllib.parse import unquote
from .exceptions import AppSyncExceptions
from .models import appsync_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except AppSyncExceptions as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class AppSyncResponse(BaseResponse):
"""Handler for AppSync requests and responses."""
@ -34,7 +21,6 @@ class AppSyncResponse(BaseResponse):
if request.method == "GET":
return self.list_graphql_apis()
@error_handler
def graph_ql_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":

View File

@ -3,24 +3,11 @@ import sys
from urllib.parse import unquote
from functools import wraps
from moto.core.utils import amz_crc32, amzn_request_id, path_url
from moto.core.responses import BaseResponse
from .exceptions import LambdaClientError
from .models import lambda_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except LambdaClientError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class LambdaResponse(BaseResponse):
@property
def json_body(self):
@ -39,7 +26,6 @@ class LambdaResponse(BaseResponse):
"""
return lambda_backends[self.region]
@error_handler
def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
@ -95,7 +81,6 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
@error_handler
def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
@ -139,7 +124,6 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle {0} request".format(request.method))
@error_handler
def policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":

View File

@ -2,8 +2,6 @@ from moto.core.responses import BaseResponse
from .models import batch_backends
from urllib.parse import urlsplit, unquote
from .exceptions import AWSError
import json
@ -48,16 +46,13 @@ class BatchResponse(BaseResponse):
state = self._get_param("state")
_type = self._get_param("type")
try:
name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type,
state=state,
compute_resources=compute_resource,
service_role=service_role,
)
except AWSError as err:
return err.response()
name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type,
state=state,
compute_resources=compute_resource,
service_role=service_role,
)
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
@ -76,10 +71,7 @@ class BatchResponse(BaseResponse):
def deletecomputeenvironment(self):
compute_environment = self._get_param("computeEnvironment")
try:
self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err:
return err.response()
self.batch_backend.delete_compute_environment(compute_environment)
return ""
@ -90,15 +82,12 @@ class BatchResponse(BaseResponse):
service_role = self._get_param("serviceRole")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name,
compute_resources=compute_resource,
service_role=service_role,
state=state,
)
except AWSError as err:
return err.response()
name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name,
compute_resources=compute_resource,
service_role=service_role,
state=state,
)
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
@ -112,16 +101,13 @@ class BatchResponse(BaseResponse):
state = self._get_param("state")
tags = self._get_param("tags")
try:
name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order,
tags=tags,
)
except AWSError as err:
return err.response()
name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order,
tags=tags,
)
result = {"jobQueueArn": arn, "jobQueueName": name}
@ -143,15 +129,12 @@ class BatchResponse(BaseResponse):
priority = self._get_param("priority")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order,
)
except AWSError as err:
return err.response()
name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order,
)
result = {"jobQueueArn": arn, "jobQueueName": name}
@ -176,20 +159,17 @@ class BatchResponse(BaseResponse):
timeout = self._get_param("timeout")
platform_capabilities = self._get_param("platformCapabilities")
propagate_tags = self._get_param("propagateTags")
try:
name, arn, revision = self.batch_backend.register_job_definition(
def_name=def_name,
parameters=parameters,
_type=_type,
tags=tags,
retry_strategy=retry_strategy,
container_properties=container_properties,
timeout=timeout,
platform_capabilities=platform_capabilities,
propagate_tags=propagate_tags,
)
except AWSError as err:
return err.response()
name, arn, revision = self.batch_backend.register_job_definition(
def_name=def_name,
parameters=parameters,
_type=_type,
tags=tags,
retry_strategy=retry_strategy,
container_properties=container_properties,
timeout=timeout,
platform_capabilities=platform_capabilities,
propagate_tags=propagate_tags,
)
result = {
"jobDefinitionArn": arn,
@ -229,17 +209,14 @@ class BatchResponse(BaseResponse):
job_queue = self._get_param("jobQueue")
timeout = self._get_param("timeout")
try:
name, job_id = self.batch_backend.submit_job(
job_name,
job_def,
job_queue,
depends_on=depends_on,
container_overrides=container_overrides,
timeout=timeout,
)
except AWSError as err:
return err.response()
name, job_id = self.batch_backend.submit_job(
job_name,
job_def,
job_queue,
depends_on=depends_on,
container_overrides=container_overrides,
timeout=timeout,
)
result = {"jobId": job_id, "jobName": name}
@ -249,20 +226,14 @@ class BatchResponse(BaseResponse):
def describejobs(self):
jobs = self._get_param("jobs")
try:
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err:
return err.response()
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
# ListJobs
def listjobs(self):
job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus")
try:
jobs = self.batch_backend.list_jobs(job_queue, job_status)
except AWSError as err:
return err.response()
jobs = self.batch_backend.list_jobs(job_queue, job_status)
result = {"jobSummaryList": [job.describe_short() for job in jobs]}
return json.dumps(result)
@ -272,10 +243,7 @@ class BatchResponse(BaseResponse):
job_id = self._get_param("jobId")
reason = self._get_param("reason")
try:
self.batch_backend.terminate_job(job_id, reason)
except AWSError as err:
return err.response()
self.batch_backend.terminate_job(job_id, reason)
return ""

View File

@ -1,30 +1,16 @@
import xmltodict
from functools import wraps
from moto.core.responses import BaseResponse
from .models import cloudfront_backend
from .exceptions import CloudFrontException
XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/"
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except CloudFrontException as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class CloudFrontResponse(BaseResponse):
def _get_xml_body(self):
return xmltodict.parse(self.body, dict_constructor=dict)
@error_handler
def distributions(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "POST":
@ -49,7 +35,6 @@ class CloudFrontResponse(BaseResponse):
response = template.render(distributions=distributions)
return 200, {}, response
@error_handler
def individual_distribution(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
distribution_id = full_url.split("/")[-1]

View File

@ -132,20 +132,13 @@ class AuthFailureError(RESTError):
)
class AWSError(Exception):
class AWSError(JsonRESTError):
TYPE = None
STATUS = 400
def __init__(self, message, exception_type=None, status=None):
self.message = message
self.type = exception_type or self.TYPE
self.status = status or self.STATUS
def response(self):
return (
json.dumps({"__type": self.type, "message": self.message}),
dict(status=self.status),
)
super().__init__(exception_type or self.TYPE, message)
self.code = status or self.STATUS
class InvalidNextTokenException(JsonRESTError):
@ -160,8 +153,7 @@ class InvalidNextTokenException(JsonRESTError):
class InvalidToken(AWSError):
TYPE = "InvalidToken"
STATUS = 400
code = 400
def __init__(self, message="Invalid token"):
super().__init__("Invalid Token: {}".format(message))
super().__init__("Invalid Token: {}".format(message), "InvalidToken")

View File

@ -17,6 +17,7 @@ from botocore.awsrequest import AWSResponse
from types import FunctionType
from moto import settings
from moto.core.exceptions import HTTPException
import responses
import unittest
from unittest.mock import patch
@ -283,9 +284,14 @@ class BotocoreStubber:
for header, value in request.headers.items():
if isinstance(value, bytes):
request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(
request, request.url, request.headers
)
try:
status, headers, body = response_callback(
request, request.url, request.headers
)
except HTTPException as e:
status = e.code
headers = e.get_headers()
body = e.get_body()
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)

View File

@ -3,7 +3,6 @@ from urllib.parse import urlparse
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import DataBrewClientError
from .models import databrew_backends
@ -58,8 +57,5 @@ class DataBrewResponse(BaseResponse):
recipe_name = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
try:
recipe = self.databrew_backend.get_recipe(recipe_name)
return json.dumps(recipe.as_dict())
except DataBrewClientError as e:
return e.code, e.get_headers(), e.get_body()
recipe = self.databrew_backend.get_recipe(recipe_name)
return json.dumps(recipe.as_dict())

View File

@ -5,9 +5,10 @@ from moto.core.exceptions import AWSError
class EKSError(AWSError):
def __init__(self, **kwargs):
super(AWSError, self).__init__()
super(AWSError, self).__init__(error_type=self.TYPE, message="")
self.description = json.dumps(kwargs)
self.headers = {"status": self.STATUS, "x-amzn-ErrorType": self.TYPE}
self.code = self.STATUS
def response(self):
return self.STATUS, self.headers, self.description

View File

@ -2,12 +2,6 @@ import json
from moto.core.responses import BaseResponse
from .exceptions import (
InvalidParameterException,
InvalidRequestException,
ResourceInUseException,
ResourceNotFoundException,
)
from .models import eks_backends
DEFAULT_MAX_RESULTS = 100
@ -32,27 +26,19 @@ class EKSResponse(BaseResponse):
tags = self._get_param("tags")
encryption_config = self._get_param("encryptionConfig")
try:
cluster = self.eks_backend.create_cluster(
name=name,
version=version,
role_arn=role_arn,
resources_vpc_config=resources_vpc_config,
kubernetes_network_config=kubernetes_network_config,
logging=logging,
client_request_token=client_request_token,
tags=tags,
encryption_config=encryption_config,
)
cluster = self.eks_backend.create_cluster(
name=name,
version=version,
role_arn=role_arn,
resources_vpc_config=resources_vpc_config,
kubernetes_network_config=kubernetes_network_config,
logging=logging,
client_request_token=client_request_token,
tags=tags,
encryption_config=encryption_config,
)
return 200, {}, json.dumps({"cluster": dict(cluster)})
except (
ResourceInUseException,
ResourceNotFoundException,
InvalidParameterException,
) as e:
# Backend will capture this and re-raise it as a ClientError.
return e.response()
return 200, {}, json.dumps({"cluster": dict(cluster)})
def create_fargate_profile(self):
fargate_profile_name = self._get_param("fargateProfileName")
@ -63,25 +49,17 @@ class EKSResponse(BaseResponse):
client_request_token = self._get_param("clientRequestToken")
tags = self._get_param("tags")
try:
fargate_profile = self.eks_backend.create_fargate_profile(
fargate_profile_name=fargate_profile_name,
cluster_name=cluster_name,
pod_execution_role_arn=pod_execution_role_arn,
subnets=subnets,
selectors=selectors,
client_request_token=client_request_token,
tags=tags,
)
fargate_profile = self.eks_backend.create_fargate_profile(
fargate_profile_name=fargate_profile_name,
cluster_name=cluster_name,
pod_execution_role_arn=pod_execution_role_arn,
subnets=subnets,
selectors=selectors,
client_request_token=client_request_token,
tags=tags,
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except (
ResourceNotFoundException,
ResourceInUseException,
InvalidParameterException,
InvalidRequestException,
) as e:
return e.response()
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
def create_nodegroup(self):
cluster_name = self._get_param("name")
@ -101,69 +79,52 @@ class EKSResponse(BaseResponse):
version = self._get_param("version")
release_version = self._get_param("releaseVersion")
try:
nodegroup = self.eks_backend.create_nodegroup(
cluster_name=cluster_name,
nodegroup_name=nodegroup_name,
scaling_config=scaling_config,
disk_size=disk_size,
subnets=subnets,
instance_types=instance_types,
ami_type=ami_type,
remote_access=remote_access,
node_role=node_role,
labels=labels,
tags=tags,
client_request_token=client_request_token,
launch_template=launch_template,
capacity_type=capacity_type,
version=version,
release_version=release_version,
)
nodegroup = self.eks_backend.create_nodegroup(
cluster_name=cluster_name,
nodegroup_name=nodegroup_name,
scaling_config=scaling_config,
disk_size=disk_size,
subnets=subnets,
instance_types=instance_types,
ami_type=ami_type,
remote_access=remote_access,
node_role=node_role,
labels=labels,
tags=tags,
client_request_token=client_request_token,
launch_template=launch_template,
capacity_type=capacity_type,
version=version,
release_version=release_version,
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (
ResourceInUseException,
ResourceNotFoundException,
InvalidRequestException,
InvalidParameterException,
) as e:
return e.response()
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
def describe_cluster(self):
name = self._get_param("name")
try:
cluster = self.eks_backend.describe_cluster(name=name)
cluster = self.eks_backend.describe_cluster(name=name)
return 200, {}, json.dumps({"cluster": dict(cluster)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
return 200, {}, json.dumps({"cluster": dict(cluster)})
def describe_fargate_profile(self):
cluster_name = self._get_param("name")
fargate_profile_name = self._get_param("fargateProfileName")
try:
fargate_profile = self.eks_backend.describe_fargate_profile(
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
fargate_profile = self.eks_backend.describe_fargate_profile(
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
def describe_nodegroup(self):
cluster_name = self._get_param("name")
nodegroup_name = self._get_param("nodegroupName")
try:
nodegroup = self.eks_backend.describe_nodegroup(
cluster_name=cluster_name, nodegroup_name=nodegroup_name
)
nodegroup = self.eks_backend.describe_nodegroup(
cluster_name=cluster_name, nodegroup_name=nodegroup_name
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
def list_clusters(self):
max_results = self._get_int_param("maxResults", DEFAULT_MAX_RESULTS)
@ -206,35 +167,26 @@ class EKSResponse(BaseResponse):
def delete_cluster(self):
name = self._get_param("name")
try:
cluster = self.eks_backend.delete_cluster(name=name)
cluster = self.eks_backend.delete_cluster(name=name)
return 200, {}, json.dumps({"cluster": dict(cluster)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
return 200, {}, json.dumps({"cluster": dict(cluster)})
def delete_fargate_profile(self):
cluster_name = self._get_param("name")
fargate_profile_name = self._get_param("fargateProfileName")
try:
fargate_profile = self.eks_backend.delete_fargate_profile(
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
)
fargate_profile = self.eks_backend.delete_fargate_profile(
cluster_name=cluster_name, fargate_profile_name=fargate_profile_name
)
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
except ResourceNotFoundException as e:
return e.response()
return 200, {}, json.dumps({"fargateProfile": dict(fargate_profile)})
def delete_nodegroup(self):
cluster_name = self._get_param("name")
nodegroup_name = self._get_param("nodegroupName")
try:
nodegroup = self.eks_backend.delete_nodegroup(
cluster_name=cluster_name, nodegroup_name=nodegroup_name
)
nodegroup = self.eks_backend.delete_nodegroup(
cluster_name=cluster_name, nodegroup_name=nodegroup_name
)
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})
except (ResourceInUseException, ResourceNotFoundException) as e:
return e.response()
return 200, {}, json.dumps({"nodegroup": dict(nodegroup)})

View File

@ -1,23 +1,11 @@
import json
import re
from functools import wraps
from moto.core.responses import BaseResponse
from .exceptions import ElasticSearchError, InvalidDomainName
from .exceptions import InvalidDomainName
from .models import es_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except ElasticSearchError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class ElasticsearchServiceResponse(BaseResponse):
"""Handler for ElasticsearchService requests and responses."""
@ -27,7 +15,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return es_backends[self.region]
@classmethod
@error_handler
def list_domains(cls, request, full_url, headers):
response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers)
@ -35,7 +22,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return response.list_domain_names()
@classmethod
@error_handler
def domains(cls, request, full_url, headers):
response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers)
@ -43,7 +29,6 @@ class ElasticsearchServiceResponse(BaseResponse):
return response.create_elasticsearch_domain()
@classmethod
@error_handler
def domain(cls, request, full_url, headers):
response = ElasticsearchServiceResponse()
response.setup_class(request, full_url, headers)

View File

@ -2,7 +2,6 @@ import json
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import forecast_backends
@ -18,57 +17,45 @@ class ForecastResponse(BaseResponse):
dataset_arns = self._get_param("DatasetArns")
tags = self._get_param("Tags")
try:
dataset_group = self.forecast_backend.create_dataset_group(
dataset_group_name=dataset_group_name,
domain=domain,
dataset_arns=dataset_arns,
tags=tags,
)
response = {"DatasetGroupArn": dataset_group.arn}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
dataset_group = self.forecast_backend.create_dataset_group(
dataset_group_name=dataset_group_name,
domain=domain,
dataset_arns=dataset_arns,
tags=tags,
)
response = {"DatasetGroupArn": dataset_group.arn}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn")
try:
dataset_group = self.forecast_backend.describe_dataset_group(
dataset_group_arn=dataset_group_arn
)
response = {
"CreationTime": dataset_group.creation_date,
"DatasetArns": dataset_group.dataset_arns,
"DatasetGroupArn": dataset_group.arn,
"DatasetGroupName": dataset_group.dataset_group_name,
"Domain": dataset_group.domain,
"LastModificationTime": dataset_group.modified_date,
"Status": "ACTIVE",
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
dataset_group = self.forecast_backend.describe_dataset_group(
dataset_group_arn=dataset_group_arn
)
response = {
"CreationTime": dataset_group.creation_date,
"DatasetArns": dataset_group.dataset_arns,
"DatasetGroupArn": dataset_group.arn,
"DatasetGroupName": dataset_group.dataset_group_name,
"Domain": dataset_group.domain,
"LastModificationTime": dataset_group.modified_date,
"Status": "ACTIVE",
}
return 200, {}, json.dumps(response)
@amzn_request_id
def delete_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn")
try:
self.forecast_backend.delete_dataset_group(dataset_group_arn)
return 200, {}, None
except AWSError as err:
return err.response()
self.forecast_backend.delete_dataset_group(dataset_group_arn)
return 200, {}, None
@amzn_request_id
def update_dataset_group(self):
dataset_group_arn = self._get_param("DatasetGroupArn")
dataset_arns = self._get_param("DatasetArns")
try:
self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns)
return 200, {}, None
except AWSError as err:
return err.response()
self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns)
return 200, {}, None
@amzn_request_id
def list_dataset_groups(self):

View File

@ -11,7 +11,7 @@ class UnknownBroker(MQError):
super().__init__("NotFoundException", "Can't find requested broker")
self.broker_id = broker_id
def get_body(self):
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = {
"errorAttribute": "broker-id",
"message": f"Can't find requested broker [{self.broker_id}]. Make sure your broker exists.",
@ -24,7 +24,7 @@ class UnknownConfiguration(MQError):
super().__init__("NotFoundException", "Can't find requested configuration")
self.config_id = config_id
def get_body(self):
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = {
"errorAttribute": "configuration_id",
"message": f"Can't find requested configuration [{self.config_id}]. Make sure your configuration exists.",
@ -37,7 +37,7 @@ class UnknownUser(MQError):
super().__init__("NotFoundException", "Can't find requested user")
self.username = username
def get_body(self):
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = {
"errorAttribute": "username",
"message": f"Can't find requested user [{self.username}]. Make sure your user exists.",
@ -50,7 +50,7 @@ class UnsupportedEngineType(MQError):
super().__init__("BadRequestException", "")
self.engine_type = engine_type
def get_body(self):
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = {
"errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] does not support configuration.",
@ -63,7 +63,7 @@ class UnknownEngineType(MQError):
super().__init__("BadRequestException", "")
self.engine_type = engine_type
def get_body(self):
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
body = {
"errorAttribute": "engineType",
"message": f"Broker engine type [{self.engine_type}] is invalid. Valid values are: [ACTIVEMQ]",

View File

@ -1,24 +1,11 @@
"""Handles incoming mq requests, invokes methods, returns responses."""
import json
from functools import wraps
from urllib.parse import unquote
from moto.core.responses import BaseResponse
from .exceptions import MQError
from .models import mq_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except MQError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class MQResponse(BaseResponse):
"""Handler for MQ requests and responses."""
@ -27,7 +14,6 @@ class MQResponse(BaseResponse):
"""Return backend instance specific for this region."""
return mq_backends[self.region]
@error_handler
def broker(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
@ -44,7 +30,6 @@ class MQResponse(BaseResponse):
if request.method == "GET":
return self.list_brokers()
@error_handler
def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
@ -52,7 +37,6 @@ class MQResponse(BaseResponse):
if request.method == "PUT":
return self.update_configuration()
@error_handler
def configurations(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "POST":
@ -72,7 +56,6 @@ class MQResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_tags()
@error_handler
def user(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "POST":

View File

@ -1,24 +1,11 @@
"""Handles incoming pinpoint requests, invokes methods, returns responses."""
import json
from functools import wraps
from moto.core.responses import BaseResponse
from urllib.parse import unquote
from .exceptions import PinpointExceptions
from .models import pinpoint_backends
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except PinpointExceptions as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class PinpointResponse(BaseResponse):
"""Handler for Pinpoint requests and responses."""
@ -27,7 +14,6 @@ class PinpointResponse(BaseResponse):
"""Return backend instance specific for this region."""
return pinpoint_backends[self.region]
@error_handler
def app(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "DELETE":
@ -49,7 +35,6 @@ class PinpointResponse(BaseResponse):
if request.method == "PUT":
return self.update_application_settings()
@error_handler
def eventstream(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "DELETE":

View File

@ -1,28 +1,16 @@
"""Handles Route53 API requests, invokes method and returns response."""
from functools import wraps
from urllib.parse import parse_qs, urlparse
from jinja2 import Template
import xmltodict
from moto.core.responses import BaseResponse
from moto.route53.exceptions import Route53ClientError, InvalidChangeBatch
from moto.route53.exceptions import InvalidChangeBatch
from moto.route53.models import route53_backend
XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except Route53ClientError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class Route53(BaseResponse):
"""Handler for Route53 requests and responses."""
@ -36,7 +24,6 @@ class Route53(BaseResponse):
return False
@error_handler
def list_or_create_hostzone_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -115,7 +102,6 @@ class Route53(BaseResponse):
template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE)
return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS)
@error_handler
def get_or_delete_hostzone_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url)
@ -129,7 +115,6 @@ class Route53(BaseResponse):
route53_backend.delete_hosted_zone(zoneid)
return 200, headers, DELETE_HOSTED_ZONE_RESPONSE
@error_handler
def rrset_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -298,7 +283,6 @@ class Route53(BaseResponse):
template = Template(GET_CHANGE_RESPONSE)
return 200, headers, template.render(change_id=change_id, xmlns=XMLNS)
@error_handler
def list_or_create_query_logging_config_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -343,7 +327,6 @@ class Route53(BaseResponse):
),
)
@error_handler
def get_or_delete_query_logging_config_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url)
@ -394,7 +377,6 @@ class Route53(BaseResponse):
template.render(delegation_set=delegation_set),
)
@error_handler
def reusable_delegation_set(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url)

View File

@ -1,26 +1,13 @@
import json
import xmltodict
from functools import wraps
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from .exceptions import S3ControlError
from .models import s3control_backend
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except S3ControlError as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class S3ControlResponse(BaseResponse):
@amzn_request_id
def public_access_block(
@ -64,7 +51,6 @@ class S3ControlResponse(BaseResponse):
return parsed_xml
@error_handler
def access_point(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
@ -74,7 +60,6 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_access_point(full_url)
@error_handler
def access_point_policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
@ -84,7 +69,6 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_access_point_policy(full_url)
@error_handler
def access_point_policy_status(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":

View File

@ -46,63 +46,55 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_notebook_instance(self):
try:
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
notebook_instance_name=self._get_param("NotebookInstanceName"),
instance_type=self._get_param("InstanceType"),
subnet_id=self._get_param("SubnetId"),
security_group_ids=self._get_param("SecurityGroupIds"),
role_arn=self._get_param("RoleArn"),
kms_key_id=self._get_param("KmsKeyId"),
tags=self._get_param("Tags"),
lifecycle_config_name=self._get_param("LifecycleConfigName"),
direct_internet_access=self._get_param("DirectInternetAccess"),
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
accelerator_types=self._get_param("AcceleratorTypes"),
default_code_repository=self._get_param("DefaultCodeRepository"),
additional_code_repositories=self._get_param(
"AdditionalCodeRepositories"
),
root_access=self._get_param("RootAccess"),
)
response = {
"NotebookInstanceArn": sagemaker_notebook.arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
notebook_instance_name=self._get_param("NotebookInstanceName"),
instance_type=self._get_param("InstanceType"),
subnet_id=self._get_param("SubnetId"),
security_group_ids=self._get_param("SecurityGroupIds"),
role_arn=self._get_param("RoleArn"),
kms_key_id=self._get_param("KmsKeyId"),
tags=self._get_param("Tags"),
lifecycle_config_name=self._get_param("LifecycleConfigName"),
direct_internet_access=self._get_param("DirectInternetAccess"),
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
accelerator_types=self._get_param("AcceleratorTypes"),
default_code_repository=self._get_param("DefaultCodeRepository"),
additional_code_repositories=self._get_param("AdditionalCodeRepositories"),
root_access=self._get_param("RootAccess"),
)
response = {
"NotebookInstanceArn": sagemaker_notebook.arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
try:
notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance_name
)
response = {
"NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceStatus": notebook_instance.status,
"Url": notebook_instance.url,
"InstanceType": notebook_instance.instance_type,
"SubnetId": notebook_instance.subnet_id,
"SecurityGroups": notebook_instance.security_group_ids,
"RoleArn": notebook_instance.role_arn,
"KmsKeyId": notebook_instance.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(notebook_instance.last_modified_time),
"CreationTime": str(notebook_instance.creation_time),
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"DirectInternetAccess": notebook_instance.direct_internet_access,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"AcceleratorTypes": notebook_instance.accelerator_types,
"DefaultCodeRepository": notebook_instance.default_code_repository,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"RootAccess": notebook_instance.root_access,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance_name
)
response = {
"NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceStatus": notebook_instance.status,
"Url": notebook_instance.url,
"InstanceType": notebook_instance.instance_type,
"SubnetId": notebook_instance.subnet_id,
"SecurityGroups": notebook_instance.security_group_ids,
"RoleArn": notebook_instance.role_arn,
"KmsKeyId": notebook_instance.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(notebook_instance.last_modified_time),
"CreationTime": str(notebook_instance.creation_time),
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"DirectInternetAccess": notebook_instance.direct_internet_access,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"AcceleratorTypes": notebook_instance.accelerator_types,
"DefaultCodeRepository": notebook_instance.default_code_repository,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"RootAccess": notebook_instance.root_access,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def start_notebook_instance(self):
@ -171,20 +163,17 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_endpoint_config(self):
try:
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_endpoint_config(self):
@ -200,18 +189,15 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_endpoint(self):
try:
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_endpoint(self):
@ -227,23 +213,20 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_processing_job(self):
try:
processing_job = self.sagemaker_backend.create_processing_job(
app_specification=self._get_param("AppSpecification"),
experiment_config=self._get_param("ExperimentConfig"),
network_config=self._get_param("NetworkConfig"),
processing_inputs=self._get_param("ProcessingInputs"),
processing_job_name=self._get_param("ProcessingJobName"),
processing_output_config=self._get_param("ProcessingOutputConfig"),
role_arn=self._get_param("RoleArn"),
stopping_condition=self._get_param("StoppingCondition"),
)
response = {
"ProcessingJobArn": processing_job.processing_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
processing_job = self.sagemaker_backend.create_processing_job(
app_specification=self._get_param("AppSpecification"),
experiment_config=self._get_param("ExperimentConfig"),
network_config=self._get_param("NetworkConfig"),
processing_inputs=self._get_param("ProcessingInputs"),
processing_job_name=self._get_param("ProcessingJobName"),
processing_output_config=self._get_param("ProcessingOutputConfig"),
role_arn=self._get_param("RoleArn"),
stopping_condition=self._get_param("StoppingCondition"),
)
response = {
"ProcessingJobArn": processing_job.processing_job_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_processing_job(self):
@ -253,39 +236,34 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_training_job(self):
try:
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param(
"EnableNetworkIsolation", False
),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param("EnableNetworkIsolation", False),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_training_job(self):
@ -301,22 +279,19 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_notebook_instance_lifecycle_config(self):
try:
lifecycle_configuration = (
self.sagemaker_backend.create_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
),
on_create=self._get_param("OnCreate"),
on_start=self._get_param("OnStart"),
)
lifecycle_configuration = (
self.sagemaker_backend.create_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
),
on_create=self._get_param("OnCreate"),
on_start=self._get_param("OnStart"),
)
response = {
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
)
response = {
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_notebook_instance_lifecycle_config(self):
@ -426,14 +401,11 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_trial(self):
try:
response = self.sagemaker_backend.create_trial(
trial_name=self._get_param("TrialName"),
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
response = self.sagemaker_backend.create_trial(
trial_name=self._get_param("TrialName"),
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_trial_components(self):
@ -467,14 +439,11 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def create_trial_component(self):
try:
response = self.sagemaker_backend.create_trial_component(
trial_component_name=self._get_param("TrialComponentName"),
trial_name=self._get_param("TrialName"),
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
response = self.sagemaker_backend.create_trial_component(
trial_component_name=self._get_param("TrialComponentName"),
trial_name=self._get_param("TrialName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_trial(self):
@ -530,52 +499,47 @@ class SageMakerResponse(BaseResponse):
"Failed",
]
try:
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
max_results, max_results_range[-1]
)
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
max_results, max_results_range[-1]
)
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(
status_equals, "statusEquals", allowed_status_equals
)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_processing_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_processing_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_training_jobs(self):
@ -590,49 +554,44 @@ class SageMakerResponse(BaseResponse):
"Failed",
]
try:
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
max_results, max_results_range[-1]
)
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
max_results, max_results_range[-1]
)
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(
status_equals, "statusEquals", allowed_status_equals
)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_training_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_training_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)

View File

@ -546,6 +546,13 @@ class StepFunctionBackend(BaseBackend):
)
return execution.get_execution_history(state_machine.roleArn)
def list_tags_for_resource(self, arn):
try:
state_machine = self.describe_state_machine(arn)
return state_machine.tags or []
except StateMachineDoesNotExist:
return []
def tag_resource(self, resource_arn, tags):
try:
state_machine = self.describe_state_machine(resource_arn)

View File

@ -2,7 +2,6 @@ import json
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import stepfunction_backends
@ -17,17 +16,14 @@ class StepFunctionResponse(BaseResponse):
definition = self._get_param("definition")
roleArn = self._get_param("roleArn")
tags = self._get_param("tags")
try:
state_machine = self.stepfunction_backend.create_state_machine(
name=name, definition=definition, roleArn=roleArn, tags=tags
)
response = {
"creationDate": state_machine.creation_date,
"stateMachineArn": state_machine.arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
state_machine = self.stepfunction_backend.create_state_machine(
name=name, definition=definition, roleArn=roleArn, tags=tags
)
response = {
"creationDate": state_machine.creation_date,
"stateMachineArn": state_machine.arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_state_machines(self):
@ -56,55 +52,42 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id
def _describe_state_machine(self, state_machine_arn):
try:
state_machine = self.stepfunction_backend.describe_state_machine(
state_machine_arn
)
response = {
"creationDate": state_machine.creation_date,
"stateMachineArn": state_machine.arn,
"definition": state_machine.definition,
"name": state_machine.name,
"roleArn": state_machine.roleArn,
"status": "ACTIVE",
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
state_machine = self.stepfunction_backend.describe_state_machine(
state_machine_arn
)
response = {
"creationDate": state_machine.creation_date,
"stateMachineArn": state_machine.arn,
"definition": state_machine.definition,
"name": state_machine.name,
"roleArn": state_machine.roleArn,
"status": "ACTIVE",
}
return 200, {}, json.dumps(response)
@amzn_request_id
def delete_state_machine(self):
arn = self._get_param("stateMachineArn")
try:
self.stepfunction_backend.delete_state_machine(arn)
return 200, {}, json.dumps("{}")
except AWSError as err:
return err.response()
self.stepfunction_backend.delete_state_machine(arn)
return 200, {}, json.dumps("{}")
@amzn_request_id
def update_state_machine(self):
arn = self._get_param("stateMachineArn")
definition = self._get_param("definition")
role_arn = self._get_param("roleArn")
try:
state_machine = self.stepfunction_backend.update_state_machine(
arn=arn, definition=definition, role_arn=role_arn
)
response = {
"updateDate": state_machine.update_date,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
state_machine = self.stepfunction_backend.update_state_machine(
arn=arn, definition=definition, role_arn=role_arn
)
response = {
"updateDate": state_machine.update_date,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_tags_for_resource(self):
arn = self._get_param("resourceArn")
try:
state_machine = self.stepfunction_backend.describe_state_machine(arn)
tags = state_machine.tags or []
except AWSError:
tags = []
tags = self.stepfunction_backend.list_tags_for_resource(arn)
response = {"tags": tags}
return 200, {}, json.dumps(response)
@ -112,20 +95,14 @@ class StepFunctionResponse(BaseResponse):
def tag_resource(self):
arn = self._get_param("resourceArn")
tags = self._get_param("tags", [])
try:
self.stepfunction_backend.tag_resource(arn, tags)
except AWSError as err:
return err.response()
self.stepfunction_backend.tag_resource(arn, tags)
return 200, {}, json.dumps({})
@amzn_request_id
def untag_resource(self):
arn = self._get_param("resourceArn")
tag_keys = self._get_param("tagKeys", [])
try:
self.stepfunction_backend.untag_resource(arn, tag_keys)
except AWSError as err:
return err.response()
self.stepfunction_backend.untag_resource(arn, tag_keys)
return 200, {}, json.dumps({})
@amzn_request_id
@ -133,12 +110,9 @@ class StepFunctionResponse(BaseResponse):
arn = self._get_param("stateMachineArn")
name = self._get_param("name")
execution_input = self._get_param("input", if_none="{}")
try:
execution = self.stepfunction_backend.start_execution(
arn, name, execution_input
)
except AWSError as err:
return err.response()
execution = self.stepfunction_backend.start_execution(
arn, name, execution_input
)
response = {
"executionArn": execution.execution_arn,
"startDate": execution.start_date,
@ -151,16 +125,13 @@ class StepFunctionResponse(BaseResponse):
next_token = self._get_param("nextToken")
arn = self._get_param("stateMachineArn")
status_filter = self._get_param("statusFilter")
try:
state_machine = self.stepfunction_backend.describe_state_machine(arn)
results, next_token = self.stepfunction_backend.list_executions(
arn,
status_filter=status_filter,
max_results=max_results,
next_token=next_token,
)
except AWSError as err:
return err.response()
state_machine = self.stepfunction_backend.describe_state_machine(arn)
results, next_token = self.stepfunction_backend.list_executions(
arn,
status_filter=status_filter,
max_results=max_results,
next_token=next_token,
)
executions = [
{
"executionArn": execution.execution_arn,
@ -179,48 +150,36 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id
def describe_execution(self):
arn = self._get_param("executionArn")
try:
execution = self.stepfunction_backend.describe_execution(arn)
response = {
"executionArn": arn,
"input": execution.execution_input,
"name": execution.name,
"startDate": execution.start_date,
"stateMachineArn": execution.state_machine_arn,
"status": execution.status,
"stopDate": execution.stop_date,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
execution = self.stepfunction_backend.describe_execution(arn)
response = {
"executionArn": arn,
"input": execution.execution_input,
"name": execution.name,
"startDate": execution.start_date,
"stateMachineArn": execution.state_machine_arn,
"status": execution.status,
"stopDate": execution.stop_date,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_state_machine_for_execution(self):
arn = self._get_param("executionArn")
try:
execution = self.stepfunction_backend.describe_execution(arn)
return self._describe_state_machine(execution.state_machine_arn)
except AWSError as err:
return err.response()
execution = self.stepfunction_backend.describe_execution(arn)
return self._describe_state_machine(execution.state_machine_arn)
@amzn_request_id
def stop_execution(self):
arn = self._get_param("executionArn")
try:
execution = self.stepfunction_backend.stop_execution(arn)
response = {"stopDate": execution.stop_date}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
execution = self.stepfunction_backend.stop_execution(arn)
response = {"stopDate": execution.stop_date}
return 200, {}, json.dumps(response)
@amzn_request_id
def get_execution_history(self):
execution_arn = self._get_param("executionArn")
try:
execution_history = self.stepfunction_backend.get_execution_history(
execution_arn
)
response = {"events": execution_history}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
execution_history = self.stepfunction_backend.get_execution_history(
execution_arn
)
response = {"events": execution_history}
return 200, {}, json.dumps(response)

View File

@ -35,10 +35,7 @@ class XRayResponse(BaseResponse):
# PutTelemetryRecords
def telemetry_records(self):
try:
self.xray_backend.add_telemetry_records(self.request_params)
except AWSError as err:
return err.response()
self.xray_backend.add_telemetry_records(self.request_params)
return ""
@ -109,7 +106,7 @@ class XRayResponse(BaseResponse):
start_time, end_time, filter_expression
)
except AWSError as err:
return err.response()
raise err
except Exception as err:
return (
json.dumps({"__type": "InternalFailure", "message": str(err)}),
@ -132,7 +129,7 @@ class XRayResponse(BaseResponse):
try:
result = self.xray_backend.get_trace_ids(trace_ids)
except AWSError as err:
return err.response()
raise err
except Exception as err:
return (
json.dumps({"__type": "InternalFailure", "message": str(err)}),

View File

@ -161,7 +161,7 @@ class TestWithSetupMethod:
@mock_kinesis
class TestKinesisUsingSetupMethod:
def setup_method(self, *args):
def setup_method(self, *args): # pylint: disable=unused-argument
self.stream_name = "test_stream"
self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1")
self.boto3_kinesis_client.create_stream(