Notebook Lifecycle Config create, describe and delete (#3417)

* Notebook Lifecycle Config create, describe and delete

* PR3417 comment changes: raise on create with duplicate name, derive a ValidationException class and use it instead of RESTException, unit test for delete non-existing.

Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
jweite 2020-10-30 17:05:06 -04:00 committed by GitHub
parent cbd4efb42d
commit f8d2ce2e6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 220 additions and 94 deletions

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import json
from moto.core.exceptions import RESTError
from moto.core.exceptions import RESTError, JsonRESTError
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
{% block extra %}<ModelName>{{ model }}</ModelName>{% endblock %}
@ -45,3 +44,8 @@ class AWSError(Exception):
json.dumps({"__type": self.type, "message": self.message}),
dict(status=self.status),
)
class ValidationError(JsonRESTError):
def __init__(self, message, **kwargs):
super(ValidationError, self).__init__("ValidationException", message, **kwargs)

View File

@ -8,7 +8,7 @@ from datetime import datetime
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.sagemaker import validators
from .exceptions import MissingModel
from .exceptions import MissingModel, ValidationError
class BaseObject(BaseModel):
@ -285,11 +285,7 @@ class FakeEndpointConfig(BaseObject):
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
instance_type, VALID_INSTANCE_TYPES
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
@property
def response_object(self):
@ -431,11 +427,7 @@ class FakeSagemakerNotebookInstance:
def validate_volume_size_in_gb(self, volume_size_in_gb):
if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True):
message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def validate_instance_type(self, instance_type):
VALID_INSTANCE_TYPES = [
@ -482,11 +474,7 @@ class FakeSagemakerNotebookInstance:
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
instance_type, VALID_INSTANCE_TYPES
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
@property
def arn(self):
@ -516,6 +504,46 @@ class FakeSagemakerNotebookInstance:
self.status = "Stopped"
class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject):
def __init__(
self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start
):
self.region_name = region_name
self.notebook_instance_lifecycle_config_name = (
notebook_instance_lifecycle_config_name
)
self.on_create = on_create
self.on_start = on_start
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.notebook_instance_lifecycle_config_arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
self.notebook_instance_lifecycle_config_name, self.region_name
)
@staticmethod
def arn_formatter(notebook_instance_lifecycle_config_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":notebook-instance-lifecycle-configuration/"
+ notebook_instance_lifecycle_config_name
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"TrainingJobArn": self.training_job_arn}
class SageMakerModelBackend(BaseBackend):
def __init__(self, region_name=None):
self._models = {}
@ -523,6 +551,7 @@ class SageMakerModelBackend(BaseBackend):
self.endpoint_configs = {}
self.endpoints = {}
self.training_jobs = {}
self.notebook_instance_lifecycle_configurations = {}
self.region_name = region_name
def reset(self):
@ -551,9 +580,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find model '{}'.".format(
Model.arn_for_model_name(model_name, self.region_name)
)
raise RESTError(
error_type="ValidationException", message=message, template="error_json",
)
raise ValidationError(message=message)
def list_models(self):
models = []
@ -617,22 +644,13 @@ class SageMakerModelBackend(BaseBackend):
message = "Cannot create a duplicate Notebook Instance ({})".format(
duplicate_arn
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def get_notebook_instance(self, notebook_instance_name):
try:
return self.notebook_instances[notebook_instance_name]
except KeyError:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message="RecordNotFound")
def get_notebook_instance_by_arn(self, arn):
instances = [
@ -641,12 +659,7 @@ class SageMakerModelBackend(BaseBackend):
if notebook_instance.arn == arn
]
if len(instances) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message="RecordNotFound")
return instances[0]
def start_notebook_instance(self, notebook_instance_name):
@ -663,11 +676,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
notebook_instance.status, notebook_instance.arn
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
del self.notebook_instances[notebook_instance_name]
def get_notebook_instance_tags(self, arn):
@ -677,6 +686,60 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
def create_notebook_instance_lifecycle_config(
self, notebook_instance_lifecycle_config_name, on_create, on_start
):
if (
notebook_instance_lifecycle_config_name
in self.notebook_instance_lifecycle_configurations
):
message = "Unable to create Notebook Instance Lifecycle Config {}. (Details: Notebook Instance Lifecycle Config already exists.)".format(
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
notebook_instance_lifecycle_config_name, self.region_name
)
)
raise ValidationError(message=message)
lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig(
region_name=self.region_name,
notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name,
on_create=on_create,
on_start=on_start,
)
self.notebook_instance_lifecycle_configurations[
notebook_instance_lifecycle_config_name
] = lifecycle_config
return lifecycle_config
def describe_notebook_instance_lifecycle_config(
self, notebook_instance_lifecycle_config_name
):
try:
return self.notebook_instance_lifecycle_configurations[
notebook_instance_lifecycle_config_name
].response_object
except KeyError:
message = "Unable to describe Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
notebook_instance_lifecycle_config_name, self.region_name
)
)
raise ValidationError(message=message)
def delete_notebook_instance_lifecycle_config(
self, notebook_instance_lifecycle_config_name
):
try:
del self.notebook_instance_lifecycle_configurations[
notebook_instance_lifecycle_config_name
]
except KeyError:
message = "Unable to delete Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
notebook_instance_lifecycle_config_name, self.region_name
)
)
raise ValidationError(message=message)
def create_endpoint_config(
self,
endpoint_config_name,
@ -706,11 +769,7 @@ class SageMakerModelBackend(BaseBackend):
production_variant["ModelName"], self.region_name
)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def describe_endpoint_config(self, endpoint_config_name):
try:
@ -719,11 +778,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def delete_endpoint_config(self, endpoint_config_name):
try:
@ -732,11 +787,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def create_endpoint(
self, endpoint_name, endpoint_config_name, tags,
@ -747,11 +798,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint_config '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
endpoint = FakeEndpoint(
region_name=self.region_name,
@ -772,11 +819,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def delete_endpoint(self, endpoint_name):
try:
@ -785,11 +828,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def get_endpoint_by_arn(self, arn):
endpoints = [
@ -799,11 +838,7 @@ class SageMakerModelBackend(BaseBackend):
]
if len(endpoints) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
return endpoints[0]
def get_endpoint_tags(self, arn):
@ -865,11 +900,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find training job '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def delete_training_job(self, training_job_name):
try:
@ -878,11 +909,7 @@ class SageMakerModelBackend(BaseBackend):
message = "Could not find endpoint configuration '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message=message)
def get_training_job_by_arn(self, arn):
training_jobs = [
@ -891,12 +918,7 @@ class SageMakerModelBackend(BaseBackend):
if training_job.training_job_arn == arn
]
if len(training_jobs) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
raise ValidationError(message="RecordNotFound")
return training_jobs[0]
def get_training_job_tags(self, arn):

View File

@ -239,3 +239,38 @@ class SageMakerResponse(BaseResponse):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_notebook_instance_lifecycle_config(self):
try:
lifecycle_configuration = self.sagemaker_backend.create_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
),
on_create=self._get_param("OnCreate"),
on_start=self._get_param("OnStart"),
)
response = {
"NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_notebook_instance_lifecycle_config(self):
response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
)
)
return json.dumps(response)
@amzn_request_id
def delete_notebook_instance_lifecycle_config(self):
self.sagemaker_backend.delete_notebook_instance_lifecycle_config(
notebook_instance_lifecycle_config_name=self._get_param(
"NotebookInstanceLifecycleConfigName"
)
)
return 200, {}, json.dumps("{}")

View File

@ -225,3 +225,68 @@ def test_describe_nonexistent_model():
assert_true(
e.exception.response["Error"]["Message"].startswith("Could not find model")
)
@mock_sagemaker
def test_notebook_instance_lifecycle_config():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}]
on_start = [{"Content": "Start Script Line 1"}]
resp = sagemaker.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start
)
assert_true(
resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
)
assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name))
with assert_raises(ClientError) as e:
resp = sagemaker.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
OnCreate=on_create,
OnStart=on_start,
)
assert_true(
e.exception.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config already exists.)"
)
)
resp = sagemaker.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
)
assert_equal(resp["NotebookInstanceLifecycleConfigName"], name)
assert_true(
resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
)
assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name))
assert_equal(resp["OnStart"], on_start)
assert_equal(resp["OnCreate"], on_create)
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
sagemaker.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
)
with assert_raises(ClientError) as e:
sagemaker.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
)
assert_true(
e.exception.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config does not exist.)"
)
)
with assert_raises(ClientError) as e:
sagemaker.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
)
assert_true(
e.exception.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config does not exist.)"
)
)