Techdebt: MyPy Sagemaker (#6243)
This commit is contained in:
parent
462ff57900
commit
ce3234a6a9
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from moto.core.exceptions import RESTError, JsonRESTError, AWSError
|
from moto.core.exceptions import RESTError, JsonRESTError, AWSError
|
||||||
|
|
||||||
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
||||||
@ -6,14 +7,14 @@ ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
|||||||
|
|
||||||
|
|
||||||
class SagemakerClientError(RESTError):
|
class SagemakerClientError(RESTError):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
kwargs.setdefault("template", "single_error")
|
kwargs.setdefault("template", "single_error")
|
||||||
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ModelError(RESTError):
|
class ModelError(RESTError):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
kwargs.setdefault("template", "model_error")
|
kwargs.setdefault("template", "model_error")
|
||||||
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -22,13 +23,13 @@ class ModelError(RESTError):
|
|||||||
class MissingModel(ModelError):
|
class MissingModel(ModelError):
|
||||||
code = 404
|
code = 404
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, model: str):
|
||||||
super().__init__("NoSuchModel", "Could not find model", *args, **kwargs)
|
super().__init__("NoSuchModel", "Could not find model", model=model)
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(JsonRESTError):
|
class ValidationError(JsonRESTError):
|
||||||
def __init__(self, message, **kwargs):
|
def __init__(self, message: str):
|
||||||
super().__init__("ValidationException", message, **kwargs)
|
super().__init__("ValidationException", message)
|
||||||
|
|
||||||
|
|
||||||
class AWSValidationException(AWSError):
|
class AWSValidationException(AWSError):
|
||||||
@ -36,5 +37,5 @@ class AWSValidationException(AWSError):
|
|||||||
|
|
||||||
|
|
||||||
class ResourceNotFound(JsonRESTError):
|
class ResourceNotFound(JsonRESTError):
|
||||||
def __init__(self, message, **kwargs):
|
def __init__(self, message: str):
|
||||||
super().__init__(__class__.__name__, message, **kwargs)
|
super().__init__(__class__.__name__, message) # type: ignore
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,54 +1,59 @@
|
|||||||
import json
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from moto.sagemaker.exceptions import AWSValidationException
|
from moto.sagemaker.exceptions import AWSValidationException
|
||||||
|
|
||||||
|
from moto.core.common_types import TYPE_RESPONSE
|
||||||
from moto.core.responses import BaseResponse
|
from moto.core.responses import BaseResponse
|
||||||
from moto.utilities.aws_headers import amzn_request_id
|
from moto.utilities.aws_headers import amzn_request_id
|
||||||
from .models import sagemaker_backends
|
from .models import sagemaker_backends, SageMakerModelBackend
|
||||||
|
|
||||||
|
|
||||||
def format_enum_error(value, attribute, allowed):
|
def format_enum_error(value: str, attribute: str, allowed: Any) -> str:
|
||||||
return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}"
|
return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}"
|
||||||
|
|
||||||
|
|
||||||
class SageMakerResponse(BaseResponse):
|
class SageMakerResponse(BaseResponse):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(service_name="sagemaker")
|
super().__init__(service_name="sagemaker")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sagemaker_backend(self):
|
def sagemaker_backend(self) -> SageMakerModelBackend:
|
||||||
return sagemaker_backends[self.current_account][self.region]
|
return sagemaker_backends[self.current_account][self.region]
|
||||||
|
|
||||||
@property
|
def describe_model(self) -> str:
|
||||||
def request_params(self):
|
|
||||||
try:
|
|
||||||
return json.loads(self.body)
|
|
||||||
except ValueError:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def describe_model(self):
|
|
||||||
model_name = self._get_param("ModelName")
|
model_name = self._get_param("ModelName")
|
||||||
model = self.sagemaker_backend.describe_model(model_name)
|
model = self.sagemaker_backend.describe_model(model_name)
|
||||||
return json.dumps(model.response_object)
|
return json.dumps(model.response_object)
|
||||||
|
|
||||||
def create_model(self):
|
def create_model(self) -> str:
|
||||||
model = self.sagemaker_backend.create_model(**self.request_params)
|
model_name = self._get_param("ModelName")
|
||||||
|
execution_role_arn = self._get_param("ExecutionRoleArn")
|
||||||
|
primary_container = self._get_param("PrimaryContainer")
|
||||||
|
vpc_config = self._get_param("VpcConfig")
|
||||||
|
containers = self._get_param("Containers")
|
||||||
|
tags = self._get_param("Tags")
|
||||||
|
model = self.sagemaker_backend.create_model(
|
||||||
|
model_name=model_name,
|
||||||
|
execution_role_arn=execution_role_arn,
|
||||||
|
primary_container=primary_container,
|
||||||
|
vpc_config=vpc_config,
|
||||||
|
containers=containers,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
return json.dumps(model.response_create)
|
return json.dumps(model.response_create)
|
||||||
|
|
||||||
def delete_model(self):
|
def delete_model(self) -> str:
|
||||||
model_name = self._get_param("ModelName")
|
model_name = self._get_param("ModelName")
|
||||||
response = self.sagemaker_backend.delete_model(model_name)
|
self.sagemaker_backend.delete_model(model_name)
|
||||||
return json.dumps(response)
|
return "{}"
|
||||||
|
|
||||||
def list_models(self):
|
def list_models(self) -> str:
|
||||||
models = self.sagemaker_backend.list_models(**self.request_params)
|
models = self.sagemaker_backend.list_models()
|
||||||
return json.dumps({"Models": [model.response_object for model in models]})
|
return json.dumps({"Models": [model.response_object for model in models]})
|
||||||
|
|
||||||
def _get_param(self, param_name, if_none=None):
|
|
||||||
return self.request_params.get(param_name, if_none)
|
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_notebook_instance(self):
|
def create_notebook_instance(self) -> TYPE_RESPONSE:
|
||||||
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
|
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
|
||||||
notebook_instance_name=self._get_param("NotebookInstanceName"),
|
notebook_instance_name=self._get_param("NotebookInstanceName"),
|
||||||
instance_type=self._get_param("InstanceType"),
|
instance_type=self._get_param("InstanceType"),
|
||||||
@ -65,13 +70,10 @@ class SageMakerResponse(BaseResponse):
|
|||||||
additional_code_repositories=self._get_param("AdditionalCodeRepositories"),
|
additional_code_repositories=self._get_param("AdditionalCodeRepositories"),
|
||||||
root_access=self._get_param("RootAccess"),
|
root_access=self._get_param("RootAccess"),
|
||||||
)
|
)
|
||||||
response = {
|
return 200, {}, json.dumps({"NotebookInstanceArn": sagemaker_notebook.arn})
|
||||||
"NotebookInstanceArn": sagemaker_notebook.arn,
|
|
||||||
}
|
|
||||||
return 200, {}, json.dumps(response)
|
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_notebook_instance(self):
|
def describe_notebook_instance(self) -> TYPE_RESPONSE:
|
||||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||||
notebook_instance = self.sagemaker_backend.get_notebook_instance(
|
notebook_instance = self.sagemaker_backend.get_notebook_instance(
|
||||||
notebook_instance_name
|
notebook_instance_name
|
||||||
@ -100,25 +102,25 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def start_notebook_instance(self):
|
def start_notebook_instance(self) -> TYPE_RESPONSE:
|
||||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||||
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
|
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def stop_notebook_instance(self):
|
def stop_notebook_instance(self) -> TYPE_RESPONSE:
|
||||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||||
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
|
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_notebook_instance(self):
|
def delete_notebook_instance(self) -> TYPE_RESPONSE:
|
||||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||||
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
|
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_tags(self):
|
def list_tags(self) -> TYPE_RESPONSE:
|
||||||
arn = self._get_param("ResourceArn")
|
arn = self._get_param("ResourceArn")
|
||||||
max_results = self._get_param("MaxResults")
|
max_results = self._get_param("MaxResults")
|
||||||
next_token = self._get_param("NextToken")
|
next_token = self._get_param("NextToken")
|
||||||
@ -131,22 +133,21 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def add_tags(self):
|
def add_tags(self) -> TYPE_RESPONSE:
|
||||||
arn = self._get_param("ResourceArn")
|
arn = self._get_param("ResourceArn")
|
||||||
tags = self._get_param("Tags")
|
tags = self._get_param("Tags")
|
||||||
tags = self.sagemaker_backend.add_tags(arn, tags)
|
tags = self.sagemaker_backend.add_tags(arn, tags)
|
||||||
response = {"Tags": tags}
|
return 200, {}, json.dumps({"Tags": tags})
|
||||||
return 200, {}, json.dumps(response)
|
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_tags(self):
|
def delete_tags(self) -> TYPE_RESPONSE:
|
||||||
arn = self._get_param("ResourceArn")
|
arn = self._get_param("ResourceArn")
|
||||||
tag_keys = self._get_param("TagKeys")
|
tag_keys = self._get_param("TagKeys")
|
||||||
self.sagemaker_backend.delete_tags(arn, tag_keys)
|
self.sagemaker_backend.delete_tags(arn, tag_keys)
|
||||||
return 200, {}, json.dumps({})
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_endpoint_config(self):
|
def create_endpoint_config(self) -> TYPE_RESPONSE:
|
||||||
endpoint_config = self.sagemaker_backend.create_endpoint_config(
|
endpoint_config = self.sagemaker_backend.create_endpoint_config(
|
||||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||||
production_variants=self._get_param("ProductionVariants"),
|
production_variants=self._get_param("ProductionVariants"),
|
||||||
@ -154,49 +155,47 @@ class SageMakerResponse(BaseResponse):
|
|||||||
tags=self._get_param("Tags"),
|
tags=self._get_param("Tags"),
|
||||||
kms_key_id=self._get_param("KmsKeyId"),
|
kms_key_id=self._get_param("KmsKeyId"),
|
||||||
)
|
)
|
||||||
response = {
|
return (
|
||||||
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
|
200,
|
||||||
}
|
{},
|
||||||
return 200, {}, json.dumps(response)
|
json.dumps({"EndpointConfigArn": endpoint_config.endpoint_config_arn}),
|
||||||
|
)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_endpoint_config(self):
|
def describe_endpoint_config(self) -> str:
|
||||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||||
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
|
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_endpoint_config(self):
|
def delete_endpoint_config(self) -> TYPE_RESPONSE:
|
||||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||||
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
|
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_endpoint(self):
|
def create_endpoint(self) -> TYPE_RESPONSE:
|
||||||
endpoint = self.sagemaker_backend.create_endpoint(
|
endpoint = self.sagemaker_backend.create_endpoint(
|
||||||
endpoint_name=self._get_param("EndpointName"),
|
endpoint_name=self._get_param("EndpointName"),
|
||||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||||
tags=self._get_param("Tags"),
|
tags=self._get_param("Tags"),
|
||||||
)
|
)
|
||||||
response = {
|
return 200, {}, json.dumps({"EndpointArn": endpoint.endpoint_arn})
|
||||||
"EndpointArn": endpoint.endpoint_arn,
|
|
||||||
}
|
|
||||||
return 200, {}, json.dumps(response)
|
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_endpoint(self):
|
def describe_endpoint(self) -> str:
|
||||||
endpoint_name = self._get_param("EndpointName")
|
endpoint_name = self._get_param("EndpointName")
|
||||||
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
|
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_endpoint(self):
|
def delete_endpoint(self) -> TYPE_RESPONSE:
|
||||||
endpoint_name = self._get_param("EndpointName")
|
endpoint_name = self._get_param("EndpointName")
|
||||||
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_processing_job(self):
|
def create_processing_job(self) -> TYPE_RESPONSE:
|
||||||
processing_job = self.sagemaker_backend.create_processing_job(
|
processing_job = self.sagemaker_backend.create_processing_job(
|
||||||
app_specification=self._get_param("AppSpecification"),
|
app_specification=self._get_param("AppSpecification"),
|
||||||
experiment_config=self._get_param("ExperimentConfig"),
|
experiment_config=self._get_param("ExperimentConfig"),
|
||||||
@ -208,19 +207,17 @@ class SageMakerResponse(BaseResponse):
|
|||||||
stopping_condition=self._get_param("StoppingCondition"),
|
stopping_condition=self._get_param("StoppingCondition"),
|
||||||
tags=self._get_param("Tags"),
|
tags=self._get_param("Tags"),
|
||||||
)
|
)
|
||||||
response = {
|
response = {"ProcessingJobArn": processing_job.processing_job_arn}
|
||||||
"ProcessingJobArn": processing_job.processing_job_arn,
|
|
||||||
}
|
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_processing_job(self):
|
def describe_processing_job(self) -> str:
|
||||||
processing_job_name = self._get_param("ProcessingJobName")
|
processing_job_name = self._get_param("ProcessingJobName")
|
||||||
response = self.sagemaker_backend.describe_processing_job(processing_job_name)
|
response = self.sagemaker_backend.describe_processing_job(processing_job_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_training_job(self):
|
def create_training_job(self) -> TYPE_RESPONSE:
|
||||||
training_job = self.sagemaker_backend.create_training_job(
|
training_job = self.sagemaker_backend.create_training_job(
|
||||||
training_job_name=self._get_param("TrainingJobName"),
|
training_job_name=self._get_param("TrainingJobName"),
|
||||||
hyper_parameters=self._get_param("HyperParameters"),
|
hyper_parameters=self._get_param("HyperParameters"),
|
||||||
@ -251,13 +248,13 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_training_job(self):
|
def describe_training_job(self) -> str:
|
||||||
training_job_name = self._get_param("TrainingJobName")
|
training_job_name = self._get_param("TrainingJobName")
|
||||||
response = self.sagemaker_backend.describe_training_job(training_job_name)
|
response = self.sagemaker_backend.describe_training_job(training_job_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_notebook_instance_lifecycle_config(self):
|
def create_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE:
|
||||||
lifecycle_configuration = (
|
lifecycle_configuration = (
|
||||||
self.sagemaker_backend.create_notebook_instance_lifecycle_config(
|
self.sagemaker_backend.create_notebook_instance_lifecycle_config(
|
||||||
notebook_instance_lifecycle_config_name=self._get_param(
|
notebook_instance_lifecycle_config_name=self._get_param(
|
||||||
@ -273,7 +270,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_notebook_instance_lifecycle_config(self):
|
def describe_notebook_instance_lifecycle_config(self) -> str:
|
||||||
response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config(
|
response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config(
|
||||||
notebook_instance_lifecycle_config_name=self._get_param(
|
notebook_instance_lifecycle_config_name=self._get_param(
|
||||||
"NotebookInstanceLifecycleConfigName"
|
"NotebookInstanceLifecycleConfigName"
|
||||||
@ -282,7 +279,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_notebook_instance_lifecycle_config(self):
|
def delete_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE:
|
||||||
self.sagemaker_backend.delete_notebook_instance_lifecycle_config(
|
self.sagemaker_backend.delete_notebook_instance_lifecycle_config(
|
||||||
notebook_instance_lifecycle_config_name=self._get_param(
|
notebook_instance_lifecycle_config_name=self._get_param(
|
||||||
"NotebookInstanceLifecycleConfigName"
|
"NotebookInstanceLifecycleConfigName"
|
||||||
@ -291,7 +288,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def search(self):
|
def search(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.search(
|
response = self.sagemaker_backend.search(
|
||||||
resource=self._get_param("Resource"),
|
resource=self._get_param("Resource"),
|
||||||
search_expression=self._get_param("SearchExpression"),
|
search_expression=self._get_param("SearchExpression"),
|
||||||
@ -299,7 +296,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_experiments(self):
|
def list_experiments(self) -> TYPE_RESPONSE:
|
||||||
MaxResults = self._get_param("MaxResults")
|
MaxResults = self._get_param("MaxResults")
|
||||||
NextToken = self._get_param("NextToken")
|
NextToken = self._get_param("NextToken")
|
||||||
|
|
||||||
@ -327,28 +324,28 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_experiment(self):
|
def delete_experiment(self) -> TYPE_RESPONSE:
|
||||||
self.sagemaker_backend.delete_experiment(
|
self.sagemaker_backend.delete_experiment(
|
||||||
experiment_name=self._get_param("ExperimentName")
|
experiment_name=self._get_param("ExperimentName")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps({})
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_experiment(self):
|
def create_experiment(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.create_experiment(
|
response = self.sagemaker_backend.create_experiment(
|
||||||
experiment_name=self._get_param("ExperimentName")
|
experiment_name=self._get_param("ExperimentName")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_experiment(self):
|
def describe_experiment(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.describe_experiment(
|
response = self.sagemaker_backend.describe_experiment(
|
||||||
experiment_name=self._get_param("ExperimentName")
|
experiment_name=self._get_param("ExperimentName")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_trials(self):
|
def list_trials(self) -> TYPE_RESPONSE:
|
||||||
MaxResults = self._get_param("MaxResults")
|
MaxResults = self._get_param("MaxResults")
|
||||||
NextToken = self._get_param("NextToken")
|
NextToken = self._get_param("NextToken")
|
||||||
|
|
||||||
@ -379,7 +376,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_trial(self):
|
def create_trial(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.create_trial(
|
response = self.sagemaker_backend.create_trial(
|
||||||
trial_name=self._get_param("TrialName"),
|
trial_name=self._get_param("TrialName"),
|
||||||
experiment_name=self._get_param("ExperimentName"),
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
@ -387,7 +384,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_trial_components(self):
|
def list_trial_components(self) -> TYPE_RESPONSE:
|
||||||
MaxResults = self._get_param("MaxResults")
|
MaxResults = self._get_param("MaxResults")
|
||||||
NextToken = self._get_param("NextToken")
|
NextToken = self._get_param("NextToken")
|
||||||
|
|
||||||
@ -417,7 +414,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_trial_component(self):
|
def create_trial_component(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.create_trial_component(
|
response = self.sagemaker_backend.create_trial_component(
|
||||||
trial_component_name=self._get_param("TrialComponentName"),
|
trial_component_name=self._get_param("TrialComponentName"),
|
||||||
trial_name=self._get_param("TrialName"),
|
trial_name=self._get_param("TrialName"),
|
||||||
@ -425,55 +422,56 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_trial(self):
|
def describe_trial(self) -> str:
|
||||||
trial_name = self._get_param("TrialName")
|
trial_name = self._get_param("TrialName")
|
||||||
response = self.sagemaker_backend.describe_trial(trial_name)
|
response = self.sagemaker_backend.describe_trial(trial_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_trial(self):
|
def delete_trial(self) -> TYPE_RESPONSE:
|
||||||
trial_name = self._get_param("TrialName")
|
trial_name = self._get_param("TrialName")
|
||||||
self.sagemaker_backend.delete_trial(trial_name)
|
self.sagemaker_backend.delete_trial(trial_name)
|
||||||
return 200, {}, json.dumps({})
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_trial_component(self):
|
def delete_trial_component(self) -> TYPE_RESPONSE:
|
||||||
trial_component_name = self._get_param("TrialComponentName")
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
self.sagemaker_backend.delete_trial_component(trial_component_name)
|
self.sagemaker_backend.delete_trial_component(trial_component_name)
|
||||||
return 200, {}, json.dumps({})
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_trial_component(self):
|
def describe_trial_component(self) -> str:
|
||||||
trial_component_name = self._get_param("TrialComponentName")
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
response = self.sagemaker_backend.describe_trial_component(trial_component_name)
|
response = self.sagemaker_backend.describe_trial_component(trial_component_name)
|
||||||
return json.dumps(response)
|
return json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def associate_trial_component(self):
|
def associate_trial_component(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.associate_trial_component(self.request_params)
|
trial_name = self._get_param("TrialName")
|
||||||
return 200, {}, json.dumps(response)
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
|
response = self.sagemaker_backend.associate_trial_component(
|
||||||
@amzn_request_id
|
trial_name, trial_component_name
|
||||||
def disassociate_trial_component(self):
|
|
||||||
response = self.sagemaker_backend.disassociate_trial_component(
|
|
||||||
self.request_params
|
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_associations(self, *args, **kwargs): # pylint: disable=unused-argument
|
def disassociate_trial_component(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.list_associations(self.request_params)
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
|
trial_name = self._get_param("TrialName")
|
||||||
|
response = self.sagemaker_backend.disassociate_trial_component(
|
||||||
|
trial_name, trial_component_name
|
||||||
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_pipeline(self):
|
def describe_pipeline(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.describe_pipeline(
|
response = self.sagemaker_backend.describe_pipeline(
|
||||||
self._get_param("PipelineName")
|
self._get_param("PipelineName")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def start_pipeline_execution(self):
|
def start_pipeline_execution(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.start_pipeline_execution(
|
response = self.sagemaker_backend.start_pipeline_execution(
|
||||||
self._get_param("PipelineName"),
|
self._get_param("PipelineName"),
|
||||||
self._get_param("PipelineExecutionDisplayName"),
|
self._get_param("PipelineExecutionDisplayName"),
|
||||||
@ -485,35 +483,35 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_pipeline_execution(self):
|
def describe_pipeline_execution(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.describe_pipeline_execution(
|
response = self.sagemaker_backend.describe_pipeline_execution(
|
||||||
self._get_param("PipelineExecutionArn")
|
self._get_param("PipelineExecutionArn")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_pipeline_definition_for_execution(self):
|
def describe_pipeline_definition_for_execution(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.describe_pipeline_definition_for_execution(
|
response = self.sagemaker_backend.describe_pipeline_definition_for_execution(
|
||||||
self._get_param("PipelineExecutionArn")
|
self._get_param("PipelineExecutionArn")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_pipeline_parameters_for_execution(self):
|
def list_pipeline_parameters_for_execution(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.list_pipeline_parameters_for_execution(
|
response = self.sagemaker_backend.list_pipeline_parameters_for_execution(
|
||||||
self._get_param("PipelineExecutionArn")
|
self._get_param("PipelineExecutionArn")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_pipeline_executions(self):
|
def list_pipeline_executions(self) -> TYPE_RESPONSE:
|
||||||
response = self.sagemaker_backend.list_pipeline_executions(
|
response = self.sagemaker_backend.list_pipeline_executions(
|
||||||
self._get_param("PipelineName")
|
self._get_param("PipelineName")
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_pipeline(self):
|
def create_pipeline(self) -> TYPE_RESPONSE:
|
||||||
pipeline = self.sagemaker_backend.create_pipeline(
|
pipeline = self.sagemaker_backend.create_pipeline(
|
||||||
pipeline_name=self._get_param("PipelineName"),
|
pipeline_name=self._get_param("PipelineName"),
|
||||||
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
||||||
@ -533,7 +531,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def delete_pipeline(self):
|
def delete_pipeline(self) -> TYPE_RESPONSE:
|
||||||
pipeline_arn = self.sagemaker_backend.delete_pipeline(
|
pipeline_arn = self.sagemaker_backend.delete_pipeline(
|
||||||
pipeline_name=self._get_param("PipelineName"),
|
pipeline_name=self._get_param("PipelineName"),
|
||||||
)
|
)
|
||||||
@ -541,7 +539,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def update_pipeline(self):
|
def update_pipeline(self) -> TYPE_RESPONSE:
|
||||||
pipeline_arn = self.sagemaker_backend.update_pipeline(
|
pipeline_arn = self.sagemaker_backend.update_pipeline(
|
||||||
pipeline_name=self._get_param("PipelineName"),
|
pipeline_name=self._get_param("PipelineName"),
|
||||||
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
||||||
@ -558,7 +556,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_pipelines(self):
|
def list_pipelines(self) -> TYPE_RESPONSE:
|
||||||
max_results_range = range(1, 101)
|
max_results_range = range(1, 101)
|
||||||
allowed_sort_by = ("Name", "CreationTime")
|
allowed_sort_by = ("Name", "CreationTime")
|
||||||
allowed_sort_order = ("Ascending", "Descending")
|
allowed_sort_order = ("Ascending", "Descending")
|
||||||
@ -601,7 +599,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_processing_jobs(self):
|
def list_processing_jobs(self) -> TYPE_RESPONSE:
|
||||||
max_results_range = range(1, 101)
|
max_results_range = range(1, 101)
|
||||||
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
||||||
allowed_sort_order = ["Ascending", "Descending"]
|
allowed_sort_order = ["Ascending", "Descending"]
|
||||||
@ -654,7 +652,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_training_jobs(self):
|
def list_training_jobs(self) -> TYPE_RESPONSE:
|
||||||
max_results_range = range(1, 101)
|
max_results_range = range(1, 101)
|
||||||
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
||||||
allowed_sort_order = ["Ascending", "Descending"]
|
allowed_sort_order = ["Ascending", "Descending"]
|
||||||
@ -706,12 +704,11 @@ class SageMakerResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
def update_endpoint_weights_and_capacities(self):
|
def update_endpoint_weights_and_capacities(self) -> TYPE_RESPONSE:
|
||||||
endpoint_name = self._get_param("EndpointName")
|
endpoint_name = self._get_param("EndpointName")
|
||||||
desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities")
|
desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities")
|
||||||
endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities(
|
endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities(
|
||||||
endpoint_name=endpoint_name,
|
endpoint_name=endpoint_name,
|
||||||
desired_weights_and_capacities=desired_weights_and_capacities,
|
desired_weights_and_capacities=desired_weights_and_capacities,
|
||||||
)
|
)
|
||||||
response = {"EndpointArn": endpoint_arn}
|
return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
|
||||||
return 200, {}, json.dumps(response)
|
|
||||||
|
@ -1,35 +1,46 @@
|
|||||||
|
import typing
|
||||||
|
|
||||||
from moto.s3.models import s3_backends
|
from moto.s3.models import s3_backends
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
from .exceptions import ValidationError
|
from .exceptions import ValidationError
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_from_name(pipelines, pipeline_name):
|
if typing.TYPE_CHECKING:
|
||||||
|
from .models import FakePipeline, FakePipelineExecution
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline_from_name(
|
||||||
|
pipelines: Dict[str, "FakePipeline"], pipeline_name: str
|
||||||
|
) -> "FakePipeline":
|
||||||
try:
|
try:
|
||||||
pipeline = pipelines[pipeline_name]
|
return pipelines[pipeline_name]
|
||||||
return pipeline
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
message=f"Could not find pipeline with PipelineName {pipeline_name}."
|
message=f"Could not find pipeline with PipelineName {pipeline_name}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_name_from_execution_arn(pipeline_execution_arn):
|
def get_pipeline_name_from_execution_arn(pipeline_execution_arn: str) -> str:
|
||||||
return pipeline_execution_arn.split("/")[1].split(":")[-1]
|
return pipeline_execution_arn.split("/")[1].split(":")[-1]
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_execution_from_arn(pipelines, pipeline_execution_arn):
|
def get_pipeline_execution_from_arn(
|
||||||
|
pipelines: Dict[str, "FakePipeline"], pipeline_execution_arn: str
|
||||||
|
) -> "FakePipelineExecution":
|
||||||
try:
|
try:
|
||||||
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
|
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
|
||||||
pipeline = get_pipeline_from_name(pipelines, pipeline_name)
|
pipeline = get_pipeline_from_name(pipelines, pipeline_name)
|
||||||
pipeline_execution = pipeline.pipeline_executions[pipeline_execution_arn]
|
return pipeline.pipeline_executions[pipeline_execution_arn]
|
||||||
return pipeline_execution
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}."
|
message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id):
|
def load_pipeline_definition_from_s3(
|
||||||
|
pipeline_definition_s3_location: Dict[str, Any], account_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
s3_backend = s3_backends[account_id]["global"]
|
s3_backend = s3_backends[account_id]["global"]
|
||||||
result = s3_backend.get_object(
|
result = s3_backend.get_object(
|
||||||
bucket_name=pipeline_definition_s3_location["Bucket"],
|
bucket_name=pipeline_definition_s3_location["Bucket"],
|
||||||
@ -38,5 +49,5 @@ def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id
|
|||||||
return json.loads(result.value)
|
return json.loads(result.value)
|
||||||
|
|
||||||
|
|
||||||
def arn_formatter(_type, _id, account_id, region_name):
|
def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
|
||||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
def is_integer_between(x, mn=None, mx=None, optional=False):
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def is_integer_between(
|
||||||
|
x: int, mn: Optional[int] = None, mx: Optional[int] = None, optional: bool = False
|
||||||
|
) -> bool:
|
||||||
if optional and x is None:
|
if optional and x is None:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
@ -14,7 +19,7 @@ def is_integer_between(x, mn=None, mx=None, optional=False):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_one_of(x, choices, optional=False):
|
def is_one_of(x: Any, choices: Any, optional: bool = False) -> bool:
|
||||||
if optional and x is None:
|
if optional and x is None:
|
||||||
return True
|
return True
|
||||||
return x in choices
|
return x in choices
|
||||||
|
@ -239,7 +239,7 @@ disable = W,C,R,E
|
|||||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
||||||
|
|
||||||
[mypy]
|
[mypy]
|
||||||
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/scheduler
|
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/scheduler
|
||||||
show_column_numbers=True
|
show_column_numbers=True
|
||||||
show_error_codes = True
|
show_error_codes = True
|
||||||
disable_error_code=abstract
|
disable_error_code=abstract
|
||||||
|
Loading…
Reference in New Issue
Block a user