moto/moto/sagemaker/responses.py
Zach Churchill f6dda54a6c
Add CloudFormation support for SageMaker Models (#3861)
* Create a formal interface for SM Cloudformation test configurations

* Create SageMaker Models with CloudFormation

* Utilize six for adding metaclass to TestConfig

* Update SM backend to return Model objects instead of response objects
2021-04-16 15:23:05 +01:00

277 lines
12 KiB
Python

from __future__ import unicode_literals
import json
from moto.core.exceptions import AWSError
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .models import sagemaker_backends
class SageMakerResponse(BaseResponse):
@property
def sagemaker_backend(self):
return sagemaker_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def describe_model(self):
model_name = self._get_param("ModelName")
model = self.sagemaker_backend.describe_model(model_name)
return json.dumps(model.response_object)
def create_model(self):
model = self.sagemaker_backend.create_model(**self.request_params)
return json.dumps(model.response_create)
def delete_model(self):
model_name = self._get_param("ModelName")
response = self.sagemaker_backend.delete_model(model_name)
return json.dumps(response)
def list_models(self):
models = self.sagemaker_backend.list_models(**self.request_params)
return json.dumps({"Models": [model.response_object for model in models]})
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
@amzn_request_id
def create_notebook_instance(self):
try:
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
notebook_instance_name=self._get_param("NotebookInstanceName"),
instance_type=self._get_param("InstanceType"),
subnet_id=self._get_param("SubnetId"),
security_group_ids=self._get_param("SecurityGroupIds"),
role_arn=self._get_param("RoleArn"),
kms_key_id=self._get_param("KmsKeyId"),
tags=self._get_param("Tags"),
lifecycle_config_name=self._get_param("LifecycleConfigName"),
direct_internet_access=self._get_param("DirectInternetAccess"),
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
accelerator_types=self._get_param("AcceleratorTypes"),
default_code_repository=self._get_param("DefaultCodeRepository"),
additional_code_repositories=self._get_param(
"AdditionalCodeRepositories"
),
root_access=self._get_param("RootAccess"),
)
response = {
"NotebookInstanceArn": sagemaker_notebook.arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
try:
notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance_name
)
response = {
"NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceStatus": notebook_instance.status,
"Url": notebook_instance.url,
"InstanceType": notebook_instance.instance_type,
"SubnetId": notebook_instance.subnet_id,
"SecurityGroups": notebook_instance.security_group_ids,
"RoleArn": notebook_instance.role_arn,
"KmsKeyId": notebook_instance.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(notebook_instance.last_modified_time),
"CreationTime": str(notebook_instance.creation_time),
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"DirectInternetAccess": notebook_instance.direct_internet_access,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"AcceleratorTypes": notebook_instance.accelerator_types,
"DefaultCodeRepository": notebook_instance.default_code_repository,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"RootAccess": notebook_instance.root_access,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def start_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def stop_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def delete_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def list_tags(self):
arn = self._get_param("ResourceArn")
try:
if ":notebook-instance/" in arn:
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
elif ":endpoint/" in arn:
tags = self.sagemaker_backend.get_endpoint_tags(arn)
elif ":training-job/" in arn:
tags = self.sagemaker_backend.get_training_job_tags(arn)
else:
tags = []
except AWSError:
tags = []
response = {"Tags": tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def create_endpoint_config(self):
try:
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_endpoint(self):
try:
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint(self):
endpoint_name = self._get_param("EndpointName")
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint(self):
endpoint_name = self._get_param("EndpointName")
self.sagemaker_backend.delete_endpoint(endpoint_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_training_job(self):
try:
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param(
"EnableNetworkIsolation", False
),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_training_job(self):
training_job_name = self._get_param("TrainingJobName")
response = self.sagemaker_backend.describe_training_job(training_job_name)
return json.dumps(response)
@amzn_request_id
def delete_training_job(self):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_notebook_instance_lifecycle_config(self):
try:
lifecycle_configuration = self.sagemaker_backend.create_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
),
on_create=self._get_param("OnCreate"),
on_start=self._get_param("OnStart"),
)
response = {
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_notebook_instance_lifecycle_config(self):
response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
)
)
return json.dumps(response)
@amzn_request_id
def delete_notebook_instance_lifecycle_config(self):
self.sagemaker_backend.delete_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
)
)
return 200, {}, json.dumps("{}")