From f8d2ce2e6a68215195c2e06f94456c78892b6c5d Mon Sep 17 00:00:00 2001 From: jweite Date: Fri, 30 Oct 2020 17:05:06 -0400 Subject: [PATCH] Notebook Lifecycle Config create, describe and delete (#3417) * Notebook Lifecycle Config create, describe and delete * PR3417 comment changes: raise on create with duplicate name, derive a ValidationException class and use it instead of RESTException, unit test for delete non-existing. Co-authored-by: Joseph Weitekamp --- moto/sagemaker/exceptions.py | 8 +- moto/sagemaker/models.py | 206 ++++++++++-------- moto/sagemaker/responses.py | 35 +++ .../test_sagemaker_notebooks.py | 65 ++++++ 4 files changed, 220 insertions(+), 94 deletions(-) diff --git a/moto/sagemaker/exceptions.py b/moto/sagemaker/exceptions.py index dc2ce915a..e2d01e82e 100644 --- a/moto/sagemaker/exceptions.py +++ b/moto/sagemaker/exceptions.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import json -from moto.core.exceptions import RESTError - +from moto.core.exceptions import RESTError, JsonRESTError ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %} {% block extra %}{{ model }}{% endblock %} @@ -45,3 +44,8 @@ class AWSError(Exception): json.dumps({"__type": self.type, "message": self.message}), dict(status=self.status), ) + + +class ValidationError(JsonRESTError): + def __init__(self, message, **kwargs): + super(ValidationError, self).__init__("ValidationException", message, **kwargs) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 9c394cc23..8fef306b8 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -8,7 +8,7 @@ from datetime import datetime from moto.core import ACCOUNT_ID, BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.sagemaker import validators -from .exceptions import MissingModel +from .exceptions import MissingModel, ValidationError class BaseObject(BaseModel): @@ -285,11 +285,7 @@ class FakeEndpointConfig(BaseObject): message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format( instance_type, VALID_INSTANCE_TYPES ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) @property def response_object(self): @@ -431,11 +427,7 @@ class FakeSagemakerNotebookInstance: def validate_volume_size_in_gb(self, volume_size_in_gb): if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True): message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf" - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def validate_instance_type(self, instance_type): VALID_INSTANCE_TYPES = [ @@ -482,11 +474,7 @@ class FakeSagemakerNotebookInstance: message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format( instance_type, VALID_INSTANCE_TYPES ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) @property def arn(self): @@ -516,6 +504,46 @@ class FakeSagemakerNotebookInstance: self.status = "Stopped" +class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): + def __init__( + self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start + ): + self.region_name = region_name + self.notebook_instance_lifecycle_config_name = ( + notebook_instance_lifecycle_config_name + ) + self.on_create = on_create + self.on_start = on_start + self.creation_time = self.last_modified_time = datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + self.notebook_instance_lifecycle_config_arn = FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( + self.notebook_instance_lifecycle_config_name, self.region_name + ) + + @staticmethod + def arn_formatter(notebook_instance_lifecycle_config_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":notebook-instance-lifecycle-configuration/" + + notebook_instance_lifecycle_config_name + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"TrainingJobArn": self.training_job_arn} + + class SageMakerModelBackend(BaseBackend): def __init__(self, region_name=None): self._models = {} @@ -523,6 +551,7 @@ class SageMakerModelBackend(BaseBackend): self.endpoint_configs = {} self.endpoints = {} self.training_jobs = {} + self.notebook_instance_lifecycle_configurations = {} self.region_name = region_name def reset(self): @@ -551,9 +580,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find model '{}'.".format( Model.arn_for_model_name(model_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", message=message, template="error_json", - ) + raise ValidationError(message=message) def list_models(self): models = [] @@ -617,22 +644,13 @@ class SageMakerModelBackend(BaseBackend): message = "Cannot create a duplicate Notebook Instance ({})".format( duplicate_arn ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def get_notebook_instance(self, notebook_instance_name): try: return self.notebook_instances[notebook_instance_name] except KeyError: - message = "RecordNotFound" - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message="RecordNotFound") def get_notebook_instance_by_arn(self, arn): instances = [ @@ -641,12 +659,7 @@ class SageMakerModelBackend(BaseBackend): if notebook_instance.arn == arn ] if len(instances) == 0: - message = "RecordNotFound" - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message="RecordNotFound") return instances[0] def start_notebook_instance(self, notebook_instance_name): @@ -663,11 +676,7 @@ class SageMakerModelBackend(BaseBackend): message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format( notebook_instance.status, notebook_instance.arn ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) del self.notebook_instances[notebook_instance_name] def get_notebook_instance_tags(self, arn): @@ -677,6 +686,60 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + def create_notebook_instance_lifecycle_config( + self, notebook_instance_lifecycle_config_name, on_create, on_start + ): + if ( + notebook_instance_lifecycle_config_name + in self.notebook_instance_lifecycle_configurations + ): + message = "Unable to create Notebook Instance Lifecycle Config {}. (Details: Notebook Instance Lifecycle Config already exists.)".format( + FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( + notebook_instance_lifecycle_config_name, self.region_name + ) + ) + raise ValidationError(message=message) + lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig( + region_name=self.region_name, + notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name, + on_create=on_create, + on_start=on_start, + ) + self.notebook_instance_lifecycle_configurations[ + notebook_instance_lifecycle_config_name + ] = lifecycle_config + return lifecycle_config + + def describe_notebook_instance_lifecycle_config( + self, notebook_instance_lifecycle_config_name + ): + try: + return self.notebook_instance_lifecycle_configurations[ + notebook_instance_lifecycle_config_name + ].response_object + except KeyError: + message = "Unable to describe Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format( + FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( + notebook_instance_lifecycle_config_name, self.region_name + ) + ) + raise ValidationError(message=message) + + def delete_notebook_instance_lifecycle_config( + self, notebook_instance_lifecycle_config_name + ): + try: + del self.notebook_instance_lifecycle_configurations[ + notebook_instance_lifecycle_config_name + ] + except KeyError: + message = "Unable to delete Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format( + FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter( + notebook_instance_lifecycle_config_name, self.region_name + ) + ) + raise ValidationError(message=message) + def create_endpoint_config( self, endpoint_config_name, @@ -706,11 +769,7 @@ class SageMakerModelBackend(BaseBackend): production_variant["ModelName"], self.region_name ) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def describe_endpoint_config(self, endpoint_config_name): try: @@ -719,11 +778,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint configuration '{}'.".format( FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def delete_endpoint_config(self, endpoint_config_name): try: @@ -732,11 +787,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint configuration '{}'.".format( FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def create_endpoint( self, endpoint_name, endpoint_config_name, tags, @@ -747,11 +798,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint_config '{}'.".format( FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) endpoint = FakeEndpoint( region_name=self.region_name, @@ -772,11 +819,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint configuration '{}'.".format( FakeEndpoint.arn_formatter(endpoint_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def delete_endpoint(self, endpoint_name): try: @@ -785,11 +828,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint configuration '{}'.".format( FakeEndpoint.arn_formatter(endpoint_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def get_endpoint_by_arn(self, arn): endpoints = [ @@ -799,11 +838,7 @@ class SageMakerModelBackend(BaseBackend): ] if len(endpoints) == 0: message = "RecordNotFound" - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) return endpoints[0] def get_endpoint_tags(self, arn): @@ -865,11 +900,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find training job '{}'.".format( FakeTrainingJob.arn_formatter(training_job_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def delete_training_job(self, training_job_name): try: @@ -878,11 +909,7 @@ class SageMakerModelBackend(BaseBackend): message = "Could not find endpoint configuration '{}'.".format( FakeTrainingJob.arn_formatter(training_job_name, self.region_name) ) - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message=message) def get_training_job_by_arn(self, arn): training_jobs = [ @@ -891,12 +918,7 @@ class SageMakerModelBackend(BaseBackend): if training_job.training_job_arn == arn ] if len(training_jobs) == 0: - message = "RecordNotFound" - raise RESTError( - error_type="ValidationException", - message=message, - template="error_json", - ) + raise ValidationError(message="RecordNotFound") return training_jobs[0] def get_training_job_tags(self, arn): diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 48a3a6432..749ac787f 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -239,3 +239,38 @@ class SageMakerResponse(BaseResponse): training_job_name = self._get_param("TrainingJobName") self.sagemaker_backend.delete_training_job(training_job_name) return 200, {}, json.dumps("{}") + + @amzn_request_id + def create_notebook_instance_lifecycle_config(self): + try: + lifecycle_configuration = self.sagemaker_backend.create_notebook_instance_lifecycle_config( + notebook_instance_lifecycle_config_name=self._get_param( + "NotebookInstanceLifecycleConfigName" + ), + on_create=self._get_param("OnCreate"), + on_start=self._get_param("OnStart"), + ) + response = { + "NotebookInstanceLifecycleConfigArn": lifecycle_configuration.notebook_instance_lifecycle_config_arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_notebook_instance_lifecycle_config(self): + response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config( + notebook_instance_lifecycle_config_name=self._get_param( + "NotebookInstanceLifecycleConfigName" + ) + ) + return json.dumps(response) + + @amzn_request_id + def delete_notebook_instance_lifecycle_config(self): + self.sagemaker_backend.delete_notebook_instance_lifecycle_config( + notebook_instance_lifecycle_config_name=self._get_param( + "NotebookInstanceLifecycleConfigName" + ) + ) + return 200, {}, json.dumps("{}") diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py index 70cdc9423..c04618c77 100644 --- a/tests/test_sagemaker/test_sagemaker_notebooks.py +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -225,3 +225,68 @@ def test_describe_nonexistent_model(): assert_true( e.exception.response["Error"]["Message"].startswith("Could not find model") ) + + +@mock_sagemaker +def test_notebook_instance_lifecycle_config(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + name = "MyLifeCycleConfig" + on_create = [{"Content": "Create Script Line 1"}] + on_start = [{"Content": "Start Script Line 1"}] + resp = sagemaker.create_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start + ) + assert_true( + resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker") + ) + assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name)) + + with assert_raises(ClientError) as e: + resp = sagemaker.create_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, + OnCreate=on_create, + OnStart=on_start, + ) + assert_true( + e.exception.response["Error"]["Message"].endswith( + "Notebook Instance Lifecycle Config already exists.)" + ) + ) + + resp = sagemaker.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, + ) + assert_equal(resp["NotebookInstanceLifecycleConfigName"], name) + assert_true( + resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker") + ) + assert_true(resp["NotebookInstanceLifecycleConfigArn"].endswith(name)) + assert_equal(resp["OnStart"], on_start) + assert_equal(resp["OnCreate"], on_create) + assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime)) + assert_true(isinstance(resp["CreationTime"], datetime.datetime)) + + sagemaker.delete_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, + ) + + with assert_raises(ClientError) as e: + sagemaker.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, + ) + assert_true( + e.exception.response["Error"]["Message"].endswith( + "Notebook Instance Lifecycle Config does not exist.)" + ) + ) + + with assert_raises(ClientError) as e: + sagemaker.delete_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=name, + ) + assert_true( + e.exception.response["Error"]["Message"].endswith( + "Notebook Instance Lifecycle Config does not exist.)" + ) + )