Notebook Lifecycle Config create, describe and delete (#3417)
* Notebook Lifecycle Config create, describe and delete * PR3417 comment changes: raise on create with duplicate name, derive a ValidationException class and use it instead of RESTException, unit test for delete non-existing. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
cbd4efb42d
commit
f8d2ce2e6a
@ -1,7 +1,6 @@
|
||||
from __future__ import unicode_literals
|
||||
import json
|
||||
from moto.core.exceptions import RESTError
|
||||
|
||||
from moto.core.exceptions import RESTError, JsonRESTError
|
||||
|
||||
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
||||
{% block extra %}<ModelName>{{ model }}</ModelName>{% endblock %}
|
||||
@ -45,3 +44,8 @@ class AWSError(Exception):
|
||||
json.dumps({"__type": self.type, "message": self.message}),
|
||||
dict(status=self.status),
|
||||
)
|
||||
|
||||
|
||||
class ValidationError(JsonRESTError):
|
||||
def __init__(self, message, **kwargs):
|
||||
super(ValidationError, self).__init__("ValidationException", message, **kwargs)
|
||||
|
@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.sagemaker import validators
|
||||
from .exceptions import MissingModel
|
||||
from .exceptions import MissingModel, ValidationError
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
@ -285,11 +285,7 @@ class FakeEndpointConfig(BaseObject):
|
||||
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
instance_type, VALID_INSTANCE_TYPES
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
@ -431,11 +427,7 @@ class FakeSagemakerNotebookInstance:
|
||||
def validate_volume_size_in_gb(self, volume_size_in_gb):
|
||||
if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True):
|
||||
message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def validate_instance_type(self, instance_type):
|
||||
VALID_INSTANCE_TYPES = [
|
||||
@ -482,11 +474,7 @@ class FakeSagemakerNotebookInstance:
|
||||
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
instance_type, VALID_INSTANCE_TYPES
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
@property
|
||||
def arn(self):
|
||||
@ -516,6 +504,46 @@ class FakeSagemakerNotebookInstance:
|
||||
self.status = "Stopped"
|
||||
|
||||
|
||||
class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject):
|
||||
def __init__(
|
||||
self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start
|
||||
):
|
||||
self.region_name = region_name
|
||||
self.notebook_instance_lifecycle_config_name = (
|
||||
notebook_instance_lifecycle_config_name
|
||||
)
|
||||
self.on_create = on_create
|
||||
self.on_start = on_start
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self.notebook_instance_lifecycle_config_arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
||||
self.notebook_instance_lifecycle_config_name, self.region_name
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(notebook_instance_lifecycle_config_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":notebook-instance-lifecycle-configuration/"
|
||||
+ notebook_instance_lifecycle_config_name
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"TrainingJobArn": self.training_job_arn}
|
||||
|
||||
|
||||
class SageMakerModelBackend(BaseBackend):
|
||||
def __init__(self, region_name=None):
|
||||
self._models = {}
|
||||
@ -523,6 +551,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.training_jobs = {}
|
||||
self.notebook_instance_lifecycle_configurations = {}
|
||||
self.region_name = region_name
|
||||
|
||||
def reset(self):
|
||||
@ -551,9 +580,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find model '{}'.".format(
|
||||
Model.arn_for_model_name(model_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException", message=message, template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def list_models(self):
|
||||
models = []
|
||||
@ -617,22 +644,13 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Cannot create a duplicate Notebook Instance ({})".format(
|
||||
duplicate_arn
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_notebook_instance(self, notebook_instance_name):
|
||||
try:
|
||||
return self.notebook_instances[notebook_instance_name]
|
||||
except KeyError:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
|
||||
def get_notebook_instance_by_arn(self, arn):
|
||||
instances = [
|
||||
@ -641,12 +659,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
if notebook_instance.arn == arn
|
||||
]
|
||||
if len(instances) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
return instances[0]
|
||||
|
||||
def start_notebook_instance(self, notebook_instance_name):
|
||||
@ -663,11 +676,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
|
||||
notebook_instance.status, notebook_instance.arn
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
del self.notebook_instances[notebook_instance_name]
|
||||
|
||||
def get_notebook_instance_tags(self, arn):
|
||||
@ -677,6 +686,60 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_notebook_instance_lifecycle_config(
|
||||
self, notebook_instance_lifecycle_config_name, on_create, on_start
|
||||
):
|
||||
if (
|
||||
notebook_instance_lifecycle_config_name
|
||||
in self.notebook_instance_lifecycle_configurations
|
||||
):
|
||||
message = "Unable to create Notebook Instance Lifecycle Config {}. (Details: Notebook Instance Lifecycle Config already exists.)".format(
|
||||
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
||||
notebook_instance_lifecycle_config_name, self.region_name
|
||||
)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig(
|
||||
region_name=self.region_name,
|
||||
notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name,
|
||||
on_create=on_create,
|
||||
on_start=on_start,
|
||||
)
|
||||
self.notebook_instance_lifecycle_configurations[
|
||||
notebook_instance_lifecycle_config_name
|
||||
] = lifecycle_config
|
||||
return lifecycle_config
|
||||
|
||||
def describe_notebook_instance_lifecycle_config(
|
||||
self, notebook_instance_lifecycle_config_name
|
||||
):
|
||||
try:
|
||||
return self.notebook_instance_lifecycle_configurations[
|
||||
notebook_instance_lifecycle_config_name
|
||||
].response_object
|
||||
except KeyError:
|
||||
message = "Unable to describe Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
|
||||
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
||||
notebook_instance_lifecycle_config_name, self.region_name
|
||||
)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def delete_notebook_instance_lifecycle_config(
|
||||
self, notebook_instance_lifecycle_config_name
|
||||
):
|
||||
try:
|
||||
del self.notebook_instance_lifecycle_configurations[
|
||||
notebook_instance_lifecycle_config_name
|
||||
]
|
||||
except KeyError:
|
||||
message = "Unable to delete Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
|
||||
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
||||
notebook_instance_lifecycle_config_name, self.region_name
|
||||
)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def create_endpoint_config(
|
||||
self,
|
||||
endpoint_config_name,
|
||||
@ -706,11 +769,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
production_variant["ModelName"], self.region_name
|
||||
)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def describe_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
@ -719,11 +778,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def delete_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
@ -732,11 +787,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def create_endpoint(
|
||||
self, endpoint_name, endpoint_config_name, tags,
|
||||
@ -747,11 +798,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint_config '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
endpoint = FakeEndpoint(
|
||||
region_name=self.region_name,
|
||||
@ -772,11 +819,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def delete_endpoint(self, endpoint_name):
|
||||
try:
|
||||
@ -785,11 +828,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_endpoint_by_arn(self, arn):
|
||||
endpoints = [
|
||||
@ -799,11 +838,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
]
|
||||
if len(endpoints) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
return endpoints[0]
|
||||
|
||||
def get_endpoint_tags(self, arn):
|
||||
@ -865,11 +900,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find training job '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def delete_training_job(self, training_job_name):
|
||||
try:
|
||||
@ -878,11 +909,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_training_job_by_arn(self, arn):
|
||||
training_jobs = [
|
||||
@ -891,12 +918,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
if training_job.training_job_arn == arn
|
||||
]
|
||||
if len(training_jobs) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
return training_jobs[0]
|
||||
|
||||
def get_training_job_tags(self, arn):
|
||||
|
@ -239,3 +239,38 @@ class SageMakerResponse(BaseResponse):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
self.sagemaker_backend.delete_training_job(training_job_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_notebook_instance_lifecycle_config(self):
|
||||
try:
|
||||
lifecycle_configuration = self.sagemaker_backend.create_notebook_instance_lifecycle_config(
|
||||
notebook_instance_lifecycle_config_name=self._get_param(
|
||||
"NotebookInstanceLifecycleConfigName"
|
||||
),
|
||||
on_create=self._get_param("OnCreate"),
|
||||
on_start=self._get_param("OnStart"),
|
||||
)
|
||||
response = {
|
||||
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_notebook_instance_lifecycle_config(self):
|
||||
response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config(
|
||||
notebook_instance_lifecycle_config_name=self._get_param(
|
||||
"NotebookInstanceLifecycleConfigName"
|
||||
)
|
||||
)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_notebook_instance_lifecycle_config(self):
|
||||
self.sagemaker_backend.delete_notebook_instance_lifecycle_config(
|
||||
notebook_instance_lifecycle_config_name=self._get_param(
|
||||
"NotebookInstanceLifecycleConfigName"
|
||||
)
|
||||
)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
@ -225,3 +225,68 @@ def test_describe_nonexistent_model():
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find model")
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_notebook_instance_lifecycle_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
name = "MyLifeCycleConfig"
|
||||
on_create = [{"Content": "Create Script Line 1"}]
|
||||
on_start = [{"Content": "Start Script Line 1"}]
|
||||
resp = sagemaker.create_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start
|
||||
)
|
||||
assert_true(
|
||||
resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
|
||||
)
|
||||
assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name))
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
resp = sagemaker.create_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
OnCreate=on_create,
|
||||
OnStart=on_start,
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].endswith(
|
||||
"Notebook Instance Lifecycle Config already exists.)"
|
||||
)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
)
|
||||
assert_equal(resp["NotebookInstanceLifecycleConfigName"], name)
|
||||
assert_true(
|
||||
resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
|
||||
)
|
||||
assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name))
|
||||
assert_equal(resp["OnStart"], on_start)
|
||||
assert_equal(resp["OnCreate"], on_create)
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
|
||||
sagemaker.delete_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].endswith(
|
||||
"Notebook Instance Lifecycle Config does not exist.)"
|
||||
)
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.delete_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].endswith(
|
||||
"Notebook Instance Lifecycle Config does not exist.)"
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user