From 1b80b0a8109cf1e45199757b8c98f07f9f3d3107 Mon Sep 17 00:00:00 2001 From: jweite Date: Thu, 16 Jul 2020 08:12:25 -0400 Subject: [PATCH] Sagemaker models (#3105) * First failing test, and enough framework to run it. * Rudimentary passing test. * Sagemaker Notebook Support, take-1: create, describe, start, stop, delete. * Added list_tags. * Merged in model support from https://github.com/porthunt/moto/tree/sagemaker-support. * Re-org'd * Fixed up describe_model exception when no matching model. * Segregated tests by Sagemaker entity. Model arn check by regex.. * Python2 compabitility changes. * Added sagemaker to list of known backends. Corrected urls. * Added sagemaker special case to moto.server.infer_service_region_host due to irregular url format (use of 'api' subdomain) to support server mode. * Changes for PR 3105 comments of July 10, 2020 * PR3105 July 10, 2020, 8:55 AM EDT comment: dropped unnecessary re-addition of arn when formulating model list response. * PR 3105 July 15, 2020 9:10 AM EDT Comment: clean-up SageMakerModelBackend.describe_models logic for finding the model in the dict. * Optimized imports Co-authored-by: Joseph Weitekamp --- moto/__init__.py | 1 + moto/backends.py | 1 + moto/sagemaker/__init__.py | 5 + moto/sagemaker/exceptions.py | 47 +++ moto/sagemaker/models.py | 398 ++++++++++++++++++ moto/sagemaker/responses.py | 127 ++++++ moto/sagemaker/urls.py | 11 + moto/sagemaker/validators.py | 20 + moto/server.py | 4 + tests/test_sagemaker/__init__.py | 0 tests/test_sagemaker/test_sagemaker_models.py | 122 ++++++ .../test_sagemaker_notebooks.py | 227 ++++++++++ 12 files changed, 963 insertions(+) create mode 100644 moto/sagemaker/__init__.py create mode 100644 moto/sagemaker/exceptions.py create mode 100644 moto/sagemaker/models.py create mode 100644 moto/sagemaker/responses.py create mode 100644 moto/sagemaker/urls.py create mode 100644 moto/sagemaker/validators.py create mode 100644 tests/test_sagemaker/__init__.py create mode 100644 tests/test_sagemaker/test_sagemaker_models.py create mode 100644 tests/test_sagemaker/test_sagemaker_notebooks.py diff --git a/moto/__init__.py b/moto/__init__.py index b4375bfc6..5143a4933 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -95,6 +95,7 @@ mock_route53 = lazy_load(".route53", "mock_route53") mock_route53_deprecated = lazy_load(".route53", "mock_route53_deprecated") mock_s3 = lazy_load(".s3", "mock_s3") mock_s3_deprecated = lazy_load(".s3", "mock_s3_deprecated") +mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker") mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager") mock_ses = lazy_load(".ses", "mock_ses") mock_ses_deprecated = lazy_load(".ses", "mock_ses_deprecated") diff --git a/moto/backends.py b/moto/backends.py index 6f612bf1f..a73940909 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -58,6 +58,7 @@ BACKENDS = { "route53": ("route53", "route53_backends"), "s3": ("s3", "s3_backends"), "s3bucket_path": ("s3", "s3_backends"), + "sagemaker": ("sagemaker", "sagemaker_backends"), "secretsmanager": ("secretsmanager", "secretsmanager_backends"), "ses": ("ses", "ses_backends"), "sns": ("sns", "sns_backends"), diff --git a/moto/sagemaker/__init__.py b/moto/sagemaker/__init__.py new file mode 100644 index 000000000..85e635380 --- /dev/null +++ b/moto/sagemaker/__init__.py @@ -0,0 +1,5 @@ +from __future__ import unicode_literals +from .models import sagemaker_backends + +sagemaker_backend = sagemaker_backends["us-east-1"] +mock_sagemaker = sagemaker_backend.decorator diff --git a/moto/sagemaker/exceptions.py b/moto/sagemaker/exceptions.py new file mode 100644 index 000000000..dc2ce915a --- /dev/null +++ b/moto/sagemaker/exceptions.py @@ -0,0 +1,47 @@ +from __future__ import unicode_literals +import json +from moto.core.exceptions import RESTError + + +ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %} +{% block extra %}{{ model }}{% endblock %} +""" + + +class SagemakerClientError(RESTError): + def __init__(self, *args, **kwargs): + kwargs.setdefault("template", "single_error") + self.templates["model_error"] = ERROR_WITH_MODEL_NAME + super(SagemakerClientError, self).__init__(*args, **kwargs) + + +class ModelError(RESTError): + def __init__(self, *args, **kwargs): + kwargs.setdefault("template", "model_error") + self.templates["model_error"] = ERROR_WITH_MODEL_NAME + super(ModelError, self).__init__(*args, **kwargs) + + +class MissingModel(ModelError): + code = 404 + + def __init__(self, *args, **kwargs): + super(MissingModel, self).__init__( + "NoSuchModel", "Could not find model", *args, **kwargs + ) + + +class AWSError(Exception): + TYPE = None + STATUS = 400 + + def __init__(self, message, type=None, status=None): + self.message = message + self.type = type if type is not None else self.TYPE + self.status = status if status is not None else self.STATUS + + def response(self): + return ( + json.dumps({"__type": self.type, "message": self.message}), + dict(status=self.status), + ) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py new file mode 100644 index 000000000..3e0dce87b --- /dev/null +++ b/moto/sagemaker/models.py @@ -0,0 +1,398 @@ +from __future__ import unicode_literals + +from copy import deepcopy +from datetime import datetime + +from moto.core import BaseBackend, BaseModel +from moto.core.exceptions import RESTError +from moto.ec2 import ec2_backends +from moto.sagemaker import validators +from moto.sts.models import ACCOUNT_ID +from .exceptions import MissingModel + + +class BaseObject(BaseModel): + def camelCase(self, key): + words = [] + for i, word in enumerate(key.split("_")): + words.append(word.title()) + return "".join(words) + + def gen_response_object(self): + response_object = dict() + for key, value in self.__dict__.items(): + if "_" in key: + response_object[self.camelCase(key)] = value + else: + response_object[key[0].upper() + key[1:]] = value + return response_object + + @property + def response_object(self): + return self.gen_response_object() + + +class Model(BaseObject): + def __init__( + self, + region_name, + model_name, + execution_role_arn, + primary_container, + vpc_config, + containers=[], + tags=[], + ): + self.model_name = model_name + self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.containers = containers + self.tags = tags + self.enable_network_isolation = False + self.vpc_config = vpc_config + self.primary_container = primary_container + self.execution_role_arn = execution_role_arn or "arn:test" + self.model_arn = self.arn_for_model_name(self.model_name, region_name) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"ModelArn": self.model_arn} + + @staticmethod + def arn_for_model_name(model_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":model/" + + model_name + ) + + +class VpcConfig(BaseObject): + def __init__(self, security_group_ids, subnets): + self.security_group_ids = security_group_ids + self.subnets = subnets + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + +class Container(BaseObject): + def __init__(self, **kwargs): + self.container_hostname = kwargs.get("container_hostname", "localhost") + self.model_data_url = kwargs.get("data_url", "") + self.model_package_name = kwargs.get("package_name", "pkg") + self.image = kwargs.get("image", "") + self.environment = kwargs.get("environment", {}) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + +class FakeSagemakerNotebookInstance: + def __init__( + self, + region_name, + notebook_instance_name, + instance_type, + role_arn, + subnet_id, + security_group_ids, + kms_key_id, + tags, + lifecycle_config_name, + direct_internet_access, + volume_size_in_gb, + accelerator_types, + default_code_repository, + additional_code_repositories, + root_access, + ): + self.validate_volume_size_in_gb(volume_size_in_gb) + self.validate_instance_type(instance_type) + + self.region_name = region_name + self.notebook_instance_name = notebook_instance_name + self.instance_type = instance_type + self.role_arn = role_arn + self.subnet_id = subnet_id + self.security_group_ids = security_group_ids + self.kms_key_id = kms_key_id + self.tags = tags or [] + self.lifecycle_config_name = lifecycle_config_name + self.direct_internet_access = direct_internet_access + self.volume_size_in_gb = volume_size_in_gb + self.accelerator_types = accelerator_types + self.default_code_repository = default_code_repository + self.additional_code_repositories = additional_code_repositories + self.root_access = root_access + self.status = None + self.creation_time = self.last_modified_time = datetime.now() + self.start() + + def validate_volume_size_in_gb(self, volume_size_in_gb): + if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True): + message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf" + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def validate_instance_type(self, instance_type): + VALID_INSTANCE_TYPES = [ + "ml.p2.xlarge", + "ml.m5.4xlarge", + "ml.m4.16xlarge", + "ml.t3.xlarge", + "ml.p3.16xlarge", + "ml.t2.xlarge", + "ml.p2.16xlarge", + "ml.c4.2xlarge", + "ml.c5.2xlarge", + "ml.c4.4xlarge", + "ml.c5d.2xlarge", + "ml.c5.4xlarge", + "ml.c5d.4xlarge", + "ml.c4.8xlarge", + "ml.c5d.xlarge", + "ml.c5.9xlarge", + "ml.c5.xlarge", + "ml.c5d.9xlarge", + "ml.c4.xlarge", + "ml.t2.2xlarge", + "ml.c5d.18xlarge", + "ml.t3.2xlarge", + "ml.t3.medium", + "ml.t2.medium", + "ml.c5.18xlarge", + "ml.p3.2xlarge", + "ml.m5.xlarge", + "ml.m4.10xlarge", + "ml.t2.large", + "ml.m5.12xlarge", + "ml.m4.xlarge", + "ml.t3.large", + "ml.m5.24xlarge", + "ml.m4.2xlarge", + "ml.p2.8xlarge", + "ml.m5.2xlarge", + "ml.p3.8xlarge", + "ml.m4.4xlarge", + ] + if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES): + message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format( + instance_type, VALID_INSTANCE_TYPES + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + @property + def arn(self): + return ( + "arn:aws:sagemaker:" + + self.region_name + + ":" + + str(ACCOUNT_ID) + + ":notebook-instance/" + + self.notebook_instance_name + ) + + @property + def url(self): + return "{}.notebook.{}.sagemaker.aws".format( + self.notebook_instance_name, self.region_name + ) + + def start(self): + self.status = "InService" + + @property + def is_deletable(self): + return self.status in ["Stopped", "Failed"] + + def stop(self): + self.status = "Stopped" + + +class SageMakerModelBackend(BaseBackend): + def __init__(self, region_name=None): + self._models = {} + self.notebook_instances = {} + self.region_name = region_name + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def create_model(self, **kwargs): + model_obj = Model( + region_name=self.region_name, + model_name=kwargs.get("ModelName"), + execution_role_arn=kwargs.get("ExecutionRoleArn"), + primary_container=kwargs.get("PrimaryContainer", {}), + vpc_config=kwargs.get("VpcConfig", {}), + containers=kwargs.get("Containers", []), + tags=kwargs.get("Tags", []), + ) + + self._models[kwargs.get("ModelName")] = model_obj + return model_obj.response_create + + def describe_model(self, model_name=None): + model = self._models.get(model_name) + if model: + return model.response_object + message = "Could not find model '{}'.".format( + Model.arn_for_model_name(model_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", message=message, template="error_json", + ) + + def list_models(self): + models = [] + for model in self._models.values(): + model_response = deepcopy(model.response_object) + models.append(model_response) + return {"Models": models} + + def delete_model(self, model_name=None): + for model in self._models.values(): + if model.model_name == model_name: + self._models.pop(model.model_name) + break + else: + raise MissingModel(model=model_name) + + def create_notebook_instance( + self, + notebook_instance_name, + instance_type, + role_arn, + subnet_id=None, + security_group_ids=None, + kms_key_id=None, + tags=None, + lifecycle_config_name=None, + direct_internet_access="Enabled", + volume_size_in_gb=5, + accelerator_types=None, + default_code_repository=None, + additional_code_repositories=None, + root_access=None, + ): + self._validate_unique_notebook_instance_name(notebook_instance_name) + + notebook_instance = FakeSagemakerNotebookInstance( + self.region_name, + notebook_instance_name, + instance_type, + role_arn, + subnet_id=subnet_id, + security_group_ids=security_group_ids, + kms_key_id=kms_key_id, + tags=tags, + lifecycle_config_name=lifecycle_config_name, + direct_internet_access=direct_internet_access + if direct_internet_access is not None + else "Enabled", + volume_size_in_gb=volume_size_in_gb if volume_size_in_gb is not None else 5, + accelerator_types=accelerator_types, + default_code_repository=default_code_repository, + additional_code_repositories=additional_code_repositories, + root_access=root_access, + ) + self.notebook_instances[notebook_instance_name] = notebook_instance + return notebook_instance + + def _validate_unique_notebook_instance_name(self, notebook_instance_name): + if notebook_instance_name in self.notebook_instances: + duplicate_arn = self.notebook_instances[notebook_instance_name].arn + message = "Cannot create a duplicate Notebook Instance ({})".format( + duplicate_arn + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def get_notebook_instance(self, notebook_instance_name): + try: + return self.notebook_instances[notebook_instance_name] + except KeyError: + message = "RecordNotFound" + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def get_notebook_instance_by_arn(self, arn): + instances = [ + notebook_instance + for notebook_instance in self.notebook_instances.values() + if notebook_instance.arn == arn + ] + if len(instances) == 0: + message = "RecordNotFound" + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + return instances[0] + + def start_notebook_instance(self, notebook_instance_name): + notebook_instance = self.get_notebook_instance(notebook_instance_name) + notebook_instance.start() + + def stop_notebook_instance(self, notebook_instance_name): + notebook_instance = self.get_notebook_instance(notebook_instance_name) + notebook_instance.stop() + + def delete_notebook_instance(self, notebook_instance_name): + notebook_instance = self.get_notebook_instance(notebook_instance_name) + if not notebook_instance.is_deletable: + message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format( + notebook_instance.status, notebook_instance.arn + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + del self.notebook_instances[notebook_instance_name] + + def get_notebook_instance_tags(self, arn): + try: + notebook_instance = self.get_notebook_instance_by_arn(arn) + return notebook_instance.tags or [] + except RESTError: + return [] + + +sagemaker_backends = {} +for region, ec2_backend in ec2_backends.items(): + sagemaker_backends[region] = SageMakerModelBackend(region) diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py new file mode 100644 index 000000000..58e28ef01 --- /dev/null +++ b/moto/sagemaker/responses.py @@ -0,0 +1,127 @@ +from __future__ import unicode_literals + +import json + +from moto.core.responses import BaseResponse +from moto.core.utils import amzn_request_id +from .exceptions import AWSError +from .models import sagemaker_backends + + +class SageMakerResponse(BaseResponse): + @property + def sagemaker_backend(self): + return sagemaker_backends[self.region] + + @property + def request_params(self): + try: + return json.loads(self.body) + except ValueError: + return {} + + def describe_model(self): + model_name = self._get_param("ModelName") + response = self.sagemaker_backend.describe_model(model_name) + return json.dumps(response) + + def create_model(self): + response = self.sagemaker_backend.create_model(**self.request_params) + return json.dumps(response) + + def delete_model(self): + model_name = self._get_param("ModelName") + response = self.sagemaker_backend.delete_model(model_name) + return json.dumps(response) + + def list_models(self): + response = self.sagemaker_backend.list_models(**self.request_params) + return json.dumps(response) + + def _get_param(self, param, if_none=None): + return self.request_params.get(param, if_none) + + @amzn_request_id + def create_notebook_instance(self): + try: + sagemaker_notebook = self.sagemaker_backend.create_notebook_instance( + notebook_instance_name=self._get_param("NotebookInstanceName"), + instance_type=self._get_param("InstanceType"), + subnet_id=self._get_param("SubnetId"), + security_group_ids=self._get_param("SecurityGroupIds"), + role_arn=self._get_param("RoleArn"), + kms_key_id=self._get_param("KmsKeyId"), + tags=self._get_param("Tags"), + lifecycle_config_name=self._get_param("LifecycleConfigName"), + direct_internet_access=self._get_param("DirectInternetAccess"), + volume_size_in_gb=self._get_param("VolumeSizeInGB"), + accelerator_types=self._get_param("AcceleratorTypes"), + default_code_repository=self._get_param("DefaultCodeRepository"), + additional_code_repositories=self._get_param( + "AdditionalCodeRepositories" + ), + root_access=self._get_param("RootAccess"), + ) + response = { + "NotebookInstanceArn": sagemaker_notebook.arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_notebook_instance(self): + notebook_instance_name = self._get_param("NotebookInstanceName") + try: + notebook_instance = self.sagemaker_backend.get_notebook_instance( + notebook_instance_name + ) + response = { + "NotebookInstanceArn": notebook_instance.arn, + "NotebookInstanceName": notebook_instance.notebook_instance_name, + "NotebookInstanceStatus": notebook_instance.status, + "Url": notebook_instance.url, + "InstanceType": notebook_instance.instance_type, + "SubnetId": notebook_instance.subnet_id, + "SecurityGroups": notebook_instance.security_group_ids, + "RoleArn": notebook_instance.role_arn, + "KmsKeyId": notebook_instance.kms_key_id, + # ToDo: NetworkInterfaceId + "LastModifiedTime": str(notebook_instance.last_modified_time), + "CreationTime": str(notebook_instance.creation_time), + "NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name, + "DirectInternetAccess": notebook_instance.direct_internet_access, + "VolumeSizeInGB": notebook_instance.volume_size_in_gb, + "AcceleratorTypes": notebook_instance.accelerator_types, + "DefaultCodeRepository": notebook_instance.default_code_repository, + "AdditionalCodeRepositories": notebook_instance.additional_code_repositories, + "RootAccess": notebook_instance.root_access, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def start_notebook_instance(self): + notebook_instance_name = self._get_param("NotebookInstanceName") + self.sagemaker_backend.start_notebook_instance(notebook_instance_name) + return 200, {}, json.dumps("{}") + + @amzn_request_id + def stop_notebook_instance(self): + notebook_instance_name = self._get_param("NotebookInstanceName") + self.sagemaker_backend.stop_notebook_instance(notebook_instance_name) + return 200, {}, json.dumps("{}") + + @amzn_request_id + def delete_notebook_instance(self): + notebook_instance_name = self._get_param("NotebookInstanceName") + self.sagemaker_backend.delete_notebook_instance(notebook_instance_name) + return 200, {}, json.dumps("{}") + + @amzn_request_id + def list_tags(self): + arn = self._get_param("ResourceArn") + tags = self.sagemaker_backend.get_notebook_instance_tags(arn) + response = {"Tags": tags} + return 200, {}, json.dumps(response) diff --git a/moto/sagemaker/urls.py b/moto/sagemaker/urls.py new file mode 100644 index 000000000..224342ce5 --- /dev/null +++ b/moto/sagemaker/urls.py @@ -0,0 +1,11 @@ +from __future__ import unicode_literals +from .responses import SageMakerResponse + +url_bases = [ + "https?://api.sagemaker.(.+).amazonaws.com", + "https?://api-fips.sagemaker.(.+).amazonaws.com", +] + +url_paths = { + "{0}/$": SageMakerResponse.dispatch, +} diff --git a/moto/sagemaker/validators.py b/moto/sagemaker/validators.py new file mode 100644 index 000000000..69cbee2a5 --- /dev/null +++ b/moto/sagemaker/validators.py @@ -0,0 +1,20 @@ +def is_integer_between(x, mn=None, mx=None, optional=False): + if optional and x is None: + return True + try: + if mn is not None and mx is not None: + return int(x) >= mn and int(x) < mx + elif mn is not None: + return int(x) >= mn + elif mx is not None: + return int(x) < mx + else: + return True + except ValueError: + return False + + +def is_one_of(x, choices, optional=False): + if optional and x is None: + return True + return x in choices diff --git a/moto/server.py b/moto/server.py index 46e37d921..bf76095a6 100644 --- a/moto/server.py +++ b/moto/server.py @@ -102,6 +102,10 @@ class DomainDispatcherApplication(object): # If Newer API version, use dynamodb2 if dynamo_api_version > "20111205": host = "dynamodb2" + elif service == "sagemaker": + host = "api.sagemaker.{region}.amazonaws.com".format( + service=service, region=region + ) else: host = "{service}.{region}.amazonaws.com".format( service=service, region=region diff --git a/tests/test_sagemaker/__init__.py b/tests/test_sagemaker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_sagemaker/test_sagemaker_models.py b/tests/test_sagemaker/test_sagemaker_models.py new file mode 100644 index 000000000..4139ca575 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_models.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import boto3 +import tests.backport_assert_raises # noqa +from botocore.exceptions import ClientError +from nose.tools import assert_raises +from moto import mock_sagemaker + +import sure # noqa + +from moto.sagemaker.models import VpcConfig + + +class MySageMakerModel(object): + def __init__(self, name, arn, container=None, vpc_config=None): + self.name = name + self.arn = arn + self.container = container if container else {} + self.vpc_config = ( + vpc_config if vpc_config else {"sg-groups": ["sg-123"], "subnets": ["123"]} + ) + + def save(self): + client = boto3.client("sagemaker", region_name="us-east-1") + vpc_config = VpcConfig( + self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets") + ) + client.create_model( + ModelName=self.name, + ExecutionRoleArn=self.arn, + VpcConfig=vpc_config.response_object, + ) + + +@mock_sagemaker +def test_describe_model(): + client = boto3.client("sagemaker", region_name="us-east-1") + test_model = MySageMakerModel( + name="blah", + arn="arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar", + vpc_config={"sg-groups": ["sg-123"], "subnets": ["123"]}, + ) + test_model.save() + model = client.describe_model(ModelName="blah") + assert model.get("ModelName").should.equal("blah") + + +@mock_sagemaker +def test_create_model(): + client = boto3.client("sagemaker", region_name="us-east-1") + vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"]) + exec_role_arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" + name = "blah" + model = client.create_model( + ModelName=name, + ExecutionRoleArn=exec_role_arn, + VpcConfig=vpc_config.response_object, + ) + + model["ModelArn"].should.match(r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name)) + + +@mock_sagemaker +def test_delete_model(): + client = boto3.client("sagemaker", region_name="us-east-1") + name = "blah" + arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" + test_model = MySageMakerModel(name=name, arn=arn) + test_model.save() + + assert len(client.list_models()["Models"]).should.equal(1) + client.delete_model(ModelName=name) + assert len(client.list_models()["Models"]).should.equal(0) + + +@mock_sagemaker +def test_delete_model_not_found(): + with assert_raises(ClientError) as err: + boto3.client("sagemaker", region_name="us-east-1").delete_model( + ModelName="blah" + ) + assert err.exception.response["Error"]["Code"].should.equal("404") + + +@mock_sagemaker +def test_list_models(): + client = boto3.client("sagemaker", region_name="us-east-1") + name = "blah" + arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" + test_model = MySageMakerModel(name=name, arn=arn) + test_model.save() + models = client.list_models() + assert len(models["Models"]).should.equal(1) + assert models["Models"][0]["ModelName"].should.equal(name) + assert models["Models"][0]["ModelArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name) + ) + + +@mock_sagemaker +def test_list_models_multiple(): + client = boto3.client("sagemaker", region_name="us-east-1") + + name_model_1 = "blah" + arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" + test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1) + test_model_1.save() + + name_model_2 = "blah2" + arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2" + test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2) + test_model_2.save() + models = client.list_models() + assert len(models["Models"]).should.equal(2) + + +@mock_sagemaker +def test_list_models_none(): + client = boto3.client("sagemaker", region_name="us-east-1") + models = client.list_models() + assert len(models["Models"]).should.equal(0) diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py new file mode 100644 index 000000000..70cdc9423 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import datetime +import boto3 +from botocore.exceptions import ClientError, ParamValidationError +import sure # noqa + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID +from nose.tools import assert_true, assert_equal, assert_raises + +TEST_REGION_NAME = "us-east-1" +FAKE_SUBNET_ID = "subnet-012345678" +FAKE_SECURITY_GROUP_IDS = ["sg-0123456789abcdef0", "sg-0123456789abcdef1"] +FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +FAKE_KMS_KEY_ID = "62d4509a-9f96-446c-a9ba-6b1c353c8c58" +GENERIC_TAGS_PARAM = [ + {"Key": "newkey1", "Value": "newval1"}, + {"Key": "newkey2", "Value": "newval2"}, +] +FAKE_LIFECYCLE_CONFIG_NAME = "FakeLifecycleConfigName" +FAKE_DEFAULT_CODE_REPO = "https://github.com/user/repo1" +FAKE_ADDL_CODE_REPOS = [ + "https://github.com/user/repo2", + "https://github.com/user/repo2", +] + + +@mock_sagemaker +def test_create_notebook_instance_minimal_params(): + + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + NAME_PARAM = "MyNotebookInstance" + INSTANCE_TYPE_PARAM = "ml.t2.medium" + + args = { + "NotebookInstanceName": NAME_PARAM, + "InstanceType": INSTANCE_TYPE_PARAM, + "RoleArn": FAKE_ROLE_ARN, + } + resp = sagemaker.create_notebook_instance(**args) + assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")) + assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")) + assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])) + assert_equal(resp["NotebookInstanceName"], NAME_PARAM) + assert_equal(resp["NotebookInstanceStatus"], "InService") + assert_equal( + resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME) + ) + assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM) + assert_equal(resp["RoleArn"], FAKE_ROLE_ARN) + assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime)) + assert_true(isinstance(resp["CreationTime"], datetime.datetime)) + assert_equal(resp["DirectInternetAccess"], "Enabled") + assert_equal(resp["VolumeSizeInGB"], 5) + + +# assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults... + + +@mock_sagemaker +def test_create_notebook_instance_params(): + + sagemaker = boto3.client("sagemaker", region_name="us-east-1") + + NAME_PARAM = "MyNotebookInstance" + INSTANCE_TYPE_PARAM = "ml.t2.medium" + DIRECT_INTERNET_ACCESS_PARAM = "Enabled" + VOLUME_SIZE_IN_GB_PARAM = 7 + ACCELERATOR_TYPES_PARAM = ["ml.eia1.medium", "ml.eia2.medium"] + ROOT_ACCESS_PARAM = "Disabled" + + args = { + "NotebookInstanceName": NAME_PARAM, + "InstanceType": INSTANCE_TYPE_PARAM, + "SubnetId": FAKE_SUBNET_ID, + "SecurityGroupIds": FAKE_SECURITY_GROUP_IDS, + "RoleArn": FAKE_ROLE_ARN, + "KmsKeyId": FAKE_KMS_KEY_ID, + "Tags": GENERIC_TAGS_PARAM, + "LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME, + "DirectInternetAccess": DIRECT_INTERNET_ACCESS_PARAM, + "VolumeSizeInGB": VOLUME_SIZE_IN_GB_PARAM, + "AcceleratorTypes": ACCELERATOR_TYPES_PARAM, + "DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO, + "AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS, + "RootAccess": ROOT_ACCESS_PARAM, + } + resp = sagemaker.create_notebook_instance(**args) + assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")) + assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")) + assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])) + assert_equal(resp["NotebookInstanceName"], NAME_PARAM) + assert_equal(resp["NotebookInstanceStatus"], "InService") + assert_equal( + resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME) + ) + assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM) + assert_equal(resp["RoleArn"], FAKE_ROLE_ARN) + assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime)) + assert_true(isinstance(resp["CreationTime"], datetime.datetime)) + assert_equal(resp["DirectInternetAccess"], "Enabled") + assert_equal(resp["VolumeSizeInGB"], VOLUME_SIZE_IN_GB_PARAM) + # assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults... + assert_equal(resp["SubnetId"], FAKE_SUBNET_ID) + assert_equal(resp["SecurityGroups"], FAKE_SECURITY_GROUP_IDS) + assert_equal(resp["KmsKeyId"], FAKE_KMS_KEY_ID) + assert_equal( + resp["NotebookInstanceLifecycleConfigName"], FAKE_LIFECYCLE_CONFIG_NAME + ) + assert_equal(resp["AcceleratorTypes"], ACCELERATOR_TYPES_PARAM) + assert_equal(resp["DefaultCodeRepository"], FAKE_DEFAULT_CODE_REPO) + assert_equal(resp["AdditionalCodeRepositories"], FAKE_ADDL_CODE_REPOS) + + resp = sagemaker.list_tags(ResourceArn=resp["NotebookInstanceArn"]) + assert_equal(resp["Tags"], GENERIC_TAGS_PARAM) + + +@mock_sagemaker +def test_create_notebook_instance_bad_volume_size(): + + sagemaker = boto3.client("sagemaker", region_name="us-east-1") + + vol_size = 2 + args = { + "NotebookInstanceName": "MyNotebookInstance", + "InstanceType": "ml.t2.medium", + "RoleArn": FAKE_ROLE_ARN, + "VolumeSizeInGB": vol_size, + } + with assert_raises(ParamValidationError) as ex: + resp = sagemaker.create_notebook_instance(**args) + assert_equal( + ex.exception.args[0], + "Parameter validation failed:\nInvalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf".format( + vol_size + ), + ) + + +@mock_sagemaker +def test_create_notebook_instance_invalid_instance_type(): + + sagemaker = boto3.client("sagemaker", region_name="us-east-1") + + instance_type = "undefined_instance_type" + args = { + "NotebookInstanceName": "MyNotebookInstance", + "InstanceType": instance_type, + "RoleArn": FAKE_ROLE_ARN, + } + with assert_raises(ClientError) as ex: + resp = sagemaker.create_notebook_instance(**args) + assert_equal(ex.exception.response["Error"]["Code"], "ValidationException") + expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format( + instance_type + ) + + assert_true(expected_message in ex.exception.response["Error"]["Message"]) + + +@mock_sagemaker +def test_notebook_instance_lifecycle(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + NAME_PARAM = "MyNotebookInstance" + INSTANCE_TYPE_PARAM = "ml.t2.medium" + + args = { + "NotebookInstanceName": NAME_PARAM, + "InstanceType": INSTANCE_TYPE_PARAM, + "RoleArn": FAKE_ROLE_ARN, + } + resp = sagemaker.create_notebook_instance(**args) + assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")) + assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + notebook_instance_arn = resp["NotebookInstanceArn"] + + with assert_raises(ClientError) as ex: + sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_equal(ex.exception.response["Error"]["Code"], "ValidationException") + expected_message = "Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format( + notebook_instance_arn + ) + assert_true(expected_message in ex.exception.response["Error"]["Message"]) + + sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_equal(resp["NotebookInstanceStatus"], "Stopped") + + sagemaker.start_notebook_instance(NotebookInstanceName=NAME_PARAM) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_equal(resp["NotebookInstanceStatus"], "InService") + + sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM) + + resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_equal(resp["NotebookInstanceStatus"], "Stopped") + + sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM) + + with assert_raises(ClientError) as ex: + sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + assert_equal(ex.exception.response["Error"]["Message"], "RecordNotFound") + + +@mock_sagemaker +def test_describe_nonexistent_model(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + with assert_raises(ClientError) as e: + resp = sagemaker.describe_model(ModelName="Nonexistent") + assert_true( + e.exception.response["Error"]["Message"].startswith("Could not find model") + )