Sagemaker models (#3105)

* First failing test, and enough framework to run it.

* Rudimentary passing test.

* Sagemaker Notebook Support, take-1: create, describe, start, stop, delete.

* Added list_tags.

* Merged in model support from https://github.com/porthunt/moto/tree/sagemaker-support.

* Re-org'd

* Fixed up describe_model exception when no matching model.

* Segregated tests by Sagemaker entity.  Model arn check by regex..

* Python2 compabitility changes.

* Added sagemaker to list of known backends.  Corrected urls.

* Added sagemaker special case to moto.server.infer_service_region_host due to irregular url format (use of 'api' subdomain) to support server mode.

* Changes for PR 3105 comments of July 10, 2020

* PR3105 July 10, 2020, 8:55 AM EDT comment: dropped unnecessary re-addition of arn when formulating model list response.

* PR 3105 July 15, 2020 9:10 AM EDT Comment: clean-up SageMakerModelBackend.describe_models logic for finding the model in the dict.

* Optimized imports

Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
jweite 2020-07-16 08:12:25 -04:00 committed by GitHub
parent 3e2a5e7ee8
commit 1b80b0a810
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 963 additions and 0 deletions

View File

@ -95,6 +95,7 @@ mock_route53 = lazy_load(".route53", "mock_route53")
mock_route53_deprecated = lazy_load(".route53", "mock_route53_deprecated")
mock_s3 = lazy_load(".s3", "mock_s3")
mock_s3_deprecated = lazy_load(".s3", "mock_s3_deprecated")
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
mock_ses = lazy_load(".ses", "mock_ses")
mock_ses_deprecated = lazy_load(".ses", "mock_ses_deprecated")

View File

@ -58,6 +58,7 @@ BACKENDS = {
"route53": ("route53", "route53_backends"),
"s3": ("s3", "s3_backends"),
"s3bucket_path": ("s3", "s3_backends"),
"sagemaker": ("sagemaker", "sagemaker_backends"),
"secretsmanager": ("secretsmanager", "secretsmanager_backends"),
"ses": ("ses", "ses_backends"),
"sns": ("sns", "sns_backends"),

View File

@ -0,0 +1,5 @@
from __future__ import unicode_literals
from .models import sagemaker_backends
sagemaker_backend = sagemaker_backends["us-east-1"]
mock_sagemaker = sagemaker_backend.decorator

View File

@ -0,0 +1,47 @@
from __future__ import unicode_literals
import json
from moto.core.exceptions import RESTError
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
{% block extra %}<ModelName>{{ model }}</ModelName>{% endblock %}
"""
class SagemakerClientError(RESTError):
def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "single_error")
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
super(SagemakerClientError, self).__init__(*args, **kwargs)
class ModelError(RESTError):
def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "model_error")
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
super(ModelError, self).__init__(*args, **kwargs)
class MissingModel(ModelError):
code = 404
def __init__(self, *args, **kwargs):
super(MissingModel, self).__init__(
"NoSuchModel", "Could not find model", *args, **kwargs
)
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message, type=None, status=None):
self.message = message
self.type = type if type is not None else self.TYPE
self.status = status if status is not None else self.STATUS
def response(self):
return (
json.dumps({"__type": self.type, "message": self.message}),
dict(status=self.status),
)

398
moto/sagemaker/models.py Normal file
View File

@ -0,0 +1,398 @@
from __future__ import unicode_literals
from copy import deepcopy
from datetime import datetime
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.ec2 import ec2_backends
from moto.sagemaker import validators
from moto.sts.models import ACCOUNT_ID
from .exceptions import MissingModel
class BaseObject(BaseModel):
def camelCase(self, key):
words = []
for i, word in enumerate(key.split("_")):
words.append(word.title())
return "".join(words)
def gen_response_object(self):
response_object = dict()
for key, value in self.__dict__.items():
if "_" in key:
response_object[self.camelCase(key)] = value
else:
response_object[key[0].upper() + key[1:]] = value
return response_object
@property
def response_object(self):
return self.gen_response_object()
class Model(BaseObject):
def __init__(
self,
region_name,
model_name,
execution_role_arn,
primary_container,
vpc_config,
containers=[],
tags=[],
):
self.model_name = model_name
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.containers = containers
self.tags = tags
self.enable_network_isolation = False
self.vpc_config = vpc_config
self.primary_container = primary_container
self.execution_role_arn = execution_role_arn or "arn:test"
self.model_arn = self.arn_for_model_name(self.model_name, region_name)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"ModelArn": self.model_arn}
@staticmethod
def arn_for_model_name(model_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":model/"
+ model_name
)
class VpcConfig(BaseObject):
def __init__(self, security_group_ids, subnets):
self.security_group_ids = security_group_ids
self.subnets = subnets
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
class Container(BaseObject):
def __init__(self, **kwargs):
self.container_hostname = kwargs.get("container_hostname", "localhost")
self.model_data_url = kwargs.get("data_url", "")
self.model_package_name = kwargs.get("package_name", "pkg")
self.image = kwargs.get("image", "")
self.environment = kwargs.get("environment", {})
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
class FakeSagemakerNotebookInstance:
def __init__(
self,
region_name,
notebook_instance_name,
instance_type,
role_arn,
subnet_id,
security_group_ids,
kms_key_id,
tags,
lifecycle_config_name,
direct_internet_access,
volume_size_in_gb,
accelerator_types,
default_code_repository,
additional_code_repositories,
root_access,
):
self.validate_volume_size_in_gb(volume_size_in_gb)
self.validate_instance_type(instance_type)
self.region_name = region_name
self.notebook_instance_name = notebook_instance_name
self.instance_type = instance_type
self.role_arn = role_arn
self.subnet_id = subnet_id
self.security_group_ids = security_group_ids
self.kms_key_id = kms_key_id
self.tags = tags or []
self.lifecycle_config_name = lifecycle_config_name
self.direct_internet_access = direct_internet_access
self.volume_size_in_gb = volume_size_in_gb
self.accelerator_types = accelerator_types
self.default_code_repository = default_code_repository
self.additional_code_repositories = additional_code_repositories
self.root_access = root_access
self.status = None
self.creation_time = self.last_modified_time = datetime.now()
self.start()
def validate_volume_size_in_gb(self, volume_size_in_gb):
if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True):
message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def validate_instance_type(self, instance_type):
VALID_INSTANCE_TYPES = [
"ml.p2.xlarge",
"ml.m5.4xlarge",
"ml.m4.16xlarge",
"ml.t3.xlarge",
"ml.p3.16xlarge",
"ml.t2.xlarge",
"ml.p2.16xlarge",
"ml.c4.2xlarge",
"ml.c5.2xlarge",
"ml.c4.4xlarge",
"ml.c5d.2xlarge",
"ml.c5.4xlarge",
"ml.c5d.4xlarge",
"ml.c4.8xlarge",
"ml.c5d.xlarge",
"ml.c5.9xlarge",
"ml.c5.xlarge",
"ml.c5d.9xlarge",
"ml.c4.xlarge",
"ml.t2.2xlarge",
"ml.c5d.18xlarge",
"ml.t3.2xlarge",
"ml.t3.medium",
"ml.t2.medium",
"ml.c5.18xlarge",
"ml.p3.2xlarge",
"ml.m5.xlarge",
"ml.m4.10xlarge",
"ml.t2.large",
"ml.m5.12xlarge",
"ml.m4.xlarge",
"ml.t3.large",
"ml.m5.24xlarge",
"ml.m4.2xlarge",
"ml.p2.8xlarge",
"ml.m5.2xlarge",
"ml.p3.8xlarge",
"ml.m4.4xlarge",
]
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
instance_type, VALID_INSTANCE_TYPES
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
@property
def arn(self):
return (
"arn:aws:sagemaker:"
+ self.region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":notebook-instance/"
+ self.notebook_instance_name
)
@property
def url(self):
return "{}.notebook.{}.sagemaker.aws".format(
self.notebook_instance_name, self.region_name
)
def start(self):
self.status = "InService"
@property
def is_deletable(self):
return self.status in ["Stopped", "Failed"]
def stop(self):
self.status = "Stopped"
class SageMakerModelBackend(BaseBackend):
def __init__(self, region_name=None):
self._models = {}
self.notebook_instances = {}
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_model(self, **kwargs):
model_obj = Model(
region_name=self.region_name,
model_name=kwargs.get("ModelName"),
execution_role_arn=kwargs.get("ExecutionRoleArn"),
primary_container=kwargs.get("PrimaryContainer", {}),
vpc_config=kwargs.get("VpcConfig", {}),
containers=kwargs.get("Containers", []),
tags=kwargs.get("Tags", []),
)
self._models[kwargs.get("ModelName")] = model_obj
return model_obj.response_create
def describe_model(self, model_name=None):
model = self._models.get(model_name)
if model:
return model.response_object
message = "Could not find model '{}'.".format(
Model.arn_for_model_name(model_name, self.region_name)
)
raise RESTError(
error_type="ValidationException", message=message, template="error_json",
)
def list_models(self):
models = []
for model in self._models.values():
model_response = deepcopy(model.response_object)
models.append(model_response)
return {"Models": models}
def delete_model(self, model_name=None):
for model in self._models.values():
if model.model_name == model_name:
self._models.pop(model.model_name)
break
else:
raise MissingModel(model=model_name)
def create_notebook_instance(
self,
notebook_instance_name,
instance_type,
role_arn,
subnet_id=None,
security_group_ids=None,
kms_key_id=None,
tags=None,
lifecycle_config_name=None,
direct_internet_access="Enabled",
volume_size_in_gb=5,
accelerator_types=None,
default_code_repository=None,
additional_code_repositories=None,
root_access=None,
):
self._validate_unique_notebook_instance_name(notebook_instance_name)
notebook_instance = FakeSagemakerNotebookInstance(
self.region_name,
notebook_instance_name,
instance_type,
role_arn,
subnet_id=subnet_id,
security_group_ids=security_group_ids,
kms_key_id=kms_key_id,
tags=tags,
lifecycle_config_name=lifecycle_config_name,
direct_internet_access=direct_internet_access
if direct_internet_access is not None
else "Enabled",
volume_size_in_gb=volume_size_in_gb if volume_size_in_gb is not None else 5,
accelerator_types=accelerator_types,
default_code_repository=default_code_repository,
additional_code_repositories=additional_code_repositories,
root_access=root_access,
)
self.notebook_instances[notebook_instance_name] = notebook_instance
return notebook_instance
def _validate_unique_notebook_instance_name(self, notebook_instance_name):
if notebook_instance_name in self.notebook_instances:
duplicate_arn = self.notebook_instances[notebook_instance_name].arn
message = "Cannot create a duplicate Notebook Instance ({})".format(
duplicate_arn
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_notebook_instance(self, notebook_instance_name):
try:
return self.notebook_instances[notebook_instance_name]
except KeyError:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_notebook_instance_by_arn(self, arn):
instances = [
notebook_instance
for notebook_instance in self.notebook_instances.values()
if notebook_instance.arn == arn
]
if len(instances) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
return instances[0]
def start_notebook_instance(self, notebook_instance_name):
notebook_instance = self.get_notebook_instance(notebook_instance_name)
notebook_instance.start()
def stop_notebook_instance(self, notebook_instance_name):
notebook_instance = self.get_notebook_instance(notebook_instance_name)
notebook_instance.stop()
def delete_notebook_instance(self, notebook_instance_name):
notebook_instance = self.get_notebook_instance(notebook_instance_name)
if not notebook_instance.is_deletable:
message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
notebook_instance.status, notebook_instance.arn
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
del self.notebook_instances[notebook_instance_name]
def get_notebook_instance_tags(self, arn):
try:
notebook_instance = self.get_notebook_instance_by_arn(arn)
return notebook_instance.tags or []
except RESTError:
return []
sagemaker_backends = {}
for region, ec2_backend in ec2_backends.items():
sagemaker_backends[region] = SageMakerModelBackend(region)

127
moto/sagemaker/responses.py Normal file
View File

@ -0,0 +1,127 @@
from __future__ import unicode_literals
import json
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import sagemaker_backends
class SageMakerResponse(BaseResponse):
@property
def sagemaker_backend(self):
return sagemaker_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def describe_model(self):
model_name = self._get_param("ModelName")
response = self.sagemaker_backend.describe_model(model_name)
return json.dumps(response)
def create_model(self):
response = self.sagemaker_backend.create_model(**self.request_params)
return json.dumps(response)
def delete_model(self):
model_name = self._get_param("ModelName")
response = self.sagemaker_backend.delete_model(model_name)
return json.dumps(response)
def list_models(self):
response = self.sagemaker_backend.list_models(**self.request_params)
return json.dumps(response)
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
@amzn_request_id
def create_notebook_instance(self):
try:
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
notebook_instance_name=self._get_param("NotebookInstanceName"),
instance_type=self._get_param("InstanceType"),
subnet_id=self._get_param("SubnetId"),
security_group_ids=self._get_param("SecurityGroupIds"),
role_arn=self._get_param("RoleArn"),
kms_key_id=self._get_param("KmsKeyId"),
tags=self._get_param("Tags"),
lifecycle_config_name=self._get_param("LifecycleConfigName"),
direct_internet_access=self._get_param("DirectInternetAccess"),
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
accelerator_types=self._get_param("AcceleratorTypes"),
default_code_repository=self._get_param("DefaultCodeRepository"),
additional_code_repositories=self._get_param(
"AdditionalCodeRepositories"
),
root_access=self._get_param("RootAccess"),
)
response = {
"NotebookInstanceArn": sagemaker_notebook.arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
try:
notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance_name
)
response = {
"NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceStatus": notebook_instance.status,
"Url": notebook_instance.url,
"InstanceType": notebook_instance.instance_type,
"SubnetId": notebook_instance.subnet_id,
"SecurityGroups": notebook_instance.security_group_ids,
"RoleArn": notebook_instance.role_arn,
"KmsKeyId": notebook_instance.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(notebook_instance.last_modified_time),
"CreationTime": str(notebook_instance.creation_time),
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"DirectInternetAccess": notebook_instance.direct_internet_access,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"AcceleratorTypes": notebook_instance.accelerator_types,
"DefaultCodeRepository": notebook_instance.default_code_repository,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"RootAccess": notebook_instance.root_access,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def start_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def stop_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def delete_notebook_instance(self):
notebook_instance_name = self._get_param("NotebookInstanceName")
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def list_tags(self):
arn = self._get_param("ResourceArn")
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
response = {"Tags": tags}
return 200, {}, json.dumps(response)

11
moto/sagemaker/urls.py Normal file
View File

@ -0,0 +1,11 @@
from __future__ import unicode_literals
from .responses import SageMakerResponse
url_bases = [
"https?://api.sagemaker.(.+).amazonaws.com",
"https?://api-fips.sagemaker.(.+).amazonaws.com",
]
url_paths = {
"{0}/$": SageMakerResponse.dispatch,
}

View File

@ -0,0 +1,20 @@
def is_integer_between(x, mn=None, mx=None, optional=False):
if optional and x is None:
return True
try:
if mn is not None and mx is not None:
return int(x) >= mn and int(x) < mx
elif mn is not None:
return int(x) >= mn
elif mx is not None:
return int(x) < mx
else:
return True
except ValueError:
return False
def is_one_of(x, choices, optional=False):
if optional and x is None:
return True
return x in choices

View File

@ -102,6 +102,10 @@ class DomainDispatcherApplication(object):
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
elif service == "sagemaker":
host = "api.sagemaker.{region}.amazonaws.com".format(
service=service, region=region
)
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region

View File

View File

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import boto3
import tests.backport_assert_raises # noqa
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_sagemaker
import sure # noqa
from moto.sagemaker.models import VpcConfig
class MySageMakerModel(object):
def __init__(self, name, arn, container=None, vpc_config=None):
self.name = name
self.arn = arn
self.container = container if container else {}
self.vpc_config = (
vpc_config if vpc_config else {"sg-groups": ["sg-123"], "subnets": ["123"]}
)
def save(self):
client = boto3.client("sagemaker", region_name="us-east-1")
vpc_config = VpcConfig(
self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
)
client.create_model(
ModelName=self.name,
ExecutionRoleArn=self.arn,
VpcConfig=vpc_config.response_object,
)
@mock_sagemaker
def test_describe_model():
client = boto3.client("sagemaker", region_name="us-east-1")
test_model = MySageMakerModel(
name="blah",
arn="arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar",
vpc_config={"sg-groups": ["sg-123"], "subnets": ["123"]},
)
test_model.save()
model = client.describe_model(ModelName="blah")
assert model.get("ModelName").should.equal("blah")
@mock_sagemaker
def test_create_model():
client = boto3.client("sagemaker", region_name="us-east-1")
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
exec_role_arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
name = "blah"
model = client.create_model(
ModelName=name,
ExecutionRoleArn=exec_role_arn,
VpcConfig=vpc_config.response_object,
)
model["ModelArn"].should.match(r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name))
@mock_sagemaker
def test_delete_model():
client = boto3.client("sagemaker", region_name="us-east-1")
name = "blah"
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
test_model = MySageMakerModel(name=name, arn=arn)
test_model.save()
assert len(client.list_models()["Models"]).should.equal(1)
client.delete_model(ModelName=name)
assert len(client.list_models()["Models"]).should.equal(0)
@mock_sagemaker
def test_delete_model_not_found():
with assert_raises(ClientError) as err:
boto3.client("sagemaker", region_name="us-east-1").delete_model(
ModelName="blah"
)
assert err.exception.response["Error"]["Code"].should.equal("404")
@mock_sagemaker
def test_list_models():
client = boto3.client("sagemaker", region_name="us-east-1")
name = "blah"
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
test_model = MySageMakerModel(name=name, arn=arn)
test_model.save()
models = client.list_models()
assert len(models["Models"]).should.equal(1)
assert models["Models"][0]["ModelName"].should.equal(name)
assert models["Models"][0]["ModelArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name)
)
@mock_sagemaker
def test_list_models_multiple():
client = boto3.client("sagemaker", region_name="us-east-1")
name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
test_model_1.save()
name_model_2 = "blah2"
arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
test_model_2.save()
models = client.list_models()
assert len(models["Models"]).should.equal(2)
@mock_sagemaker
def test_list_models_none():
client = boto3.client("sagemaker", region_name="us-east-1")
models = client.list_models()
assert len(models["Models"]).should.equal(0)

View File

@ -0,0 +1,227 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import datetime
import boto3
from botocore.exceptions import ClientError, ParamValidationError
import sure # noqa
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
from nose.tools import assert_true, assert_equal, assert_raises
TEST_REGION_NAME = "us-east-1"
FAKE_SUBNET_ID = "subnet-012345678"
FAKE_SECURITY_GROUP_IDS = ["sg-0123456789abcdef0", "sg-0123456789abcdef1"]
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
FAKE_KMS_KEY_ID = "62d4509a-9f96-446c-a9ba-6b1c353c8c58"
GENERIC_TAGS_PARAM = [
{"Key": "newkey1", "Value": "newval1"},
{"Key": "newkey2", "Value": "newval2"},
]
FAKE_LIFECYCLE_CONFIG_NAME = "FakeLifecycleConfigName"
FAKE_DEFAULT_CODE_REPO = "https://github.com/user/repo1"
FAKE_ADDL_CODE_REPOS = [
"https://github.com/user/repo2",
"https://github.com/user/repo2",
]
@mock_sagemaker
def test_create_notebook_instance_minimal_params():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
NAME_PARAM = "MyNotebookInstance"
INSTANCE_TYPE_PARAM = "ml.t2.medium"
args = {
"NotebookInstanceName": NAME_PARAM,
"InstanceType": INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker.create_notebook_instance(**args)
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
assert_equal(resp["NotebookInstanceName"], NAME_PARAM)
assert_equal(resp["NotebookInstanceStatus"], "InService")
assert_equal(
resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME)
)
assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM)
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
assert_equal(resp["DirectInternetAccess"], "Enabled")
assert_equal(resp["VolumeSizeInGB"], 5)
# assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults...
@mock_sagemaker
def test_create_notebook_instance_params():
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
NAME_PARAM = "MyNotebookInstance"
INSTANCE_TYPE_PARAM = "ml.t2.medium"
DIRECT_INTERNET_ACCESS_PARAM = "Enabled"
VOLUME_SIZE_IN_GB_PARAM = 7
ACCELERATOR_TYPES_PARAM = ["ml.eia1.medium", "ml.eia2.medium"]
ROOT_ACCESS_PARAM = "Disabled"
args = {
"NotebookInstanceName": NAME_PARAM,
"InstanceType": INSTANCE_TYPE_PARAM,
"SubnetId": FAKE_SUBNET_ID,
"SecurityGroupIds": FAKE_SECURITY_GROUP_IDS,
"RoleArn": FAKE_ROLE_ARN,
"KmsKeyId": FAKE_KMS_KEY_ID,
"Tags": GENERIC_TAGS_PARAM,
"LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME,
"DirectInternetAccess": DIRECT_INTERNET_ACCESS_PARAM,
"VolumeSizeInGB": VOLUME_SIZE_IN_GB_PARAM,
"AcceleratorTypes": ACCELERATOR_TYPES_PARAM,
"DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO,
"AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS,
"RootAccess": ROOT_ACCESS_PARAM,
}
resp = sagemaker.create_notebook_instance(**args)
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
assert_equal(resp["NotebookInstanceName"], NAME_PARAM)
assert_equal(resp["NotebookInstanceStatus"], "InService")
assert_equal(
resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME)
)
assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM)
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
assert_equal(resp["DirectInternetAccess"], "Enabled")
assert_equal(resp["VolumeSizeInGB"], VOLUME_SIZE_IN_GB_PARAM)
# assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults...
assert_equal(resp["SubnetId"], FAKE_SUBNET_ID)
assert_equal(resp["SecurityGroups"], FAKE_SECURITY_GROUP_IDS)
assert_equal(resp["KmsKeyId"], FAKE_KMS_KEY_ID)
assert_equal(
resp["NotebookInstanceLifecycleConfigName"], FAKE_LIFECYCLE_CONFIG_NAME
)
assert_equal(resp["AcceleratorTypes"], ACCELERATOR_TYPES_PARAM)
assert_equal(resp["DefaultCodeRepository"], FAKE_DEFAULT_CODE_REPO)
assert_equal(resp["AdditionalCodeRepositories"], FAKE_ADDL_CODE_REPOS)
resp = sagemaker.list_tags(ResourceArn=resp["NotebookInstanceArn"])
assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
@mock_sagemaker
def test_create_notebook_instance_bad_volume_size():
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
vol_size = 2
args = {
"NotebookInstanceName": "MyNotebookInstance",
"InstanceType": "ml.t2.medium",
"RoleArn": FAKE_ROLE_ARN,
"VolumeSizeInGB": vol_size,
}
with assert_raises(ParamValidationError) as ex:
resp = sagemaker.create_notebook_instance(**args)
assert_equal(
ex.exception.args[0],
"Parameter validation failed:\nInvalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf".format(
vol_size
),
)
@mock_sagemaker
def test_create_notebook_instance_invalid_instance_type():
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
instance_type = "undefined_instance_type"
args = {
"NotebookInstanceName": "MyNotebookInstance",
"InstanceType": instance_type,
"RoleArn": FAKE_ROLE_ARN,
}
with assert_raises(ClientError) as ex:
resp = sagemaker.create_notebook_instance(**args)
assert_equal(ex.exception.response["Error"]["Code"], "ValidationException")
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
instance_type
)
assert_true(expected_message in ex.exception.response["Error"]["Message"])
@mock_sagemaker
def test_notebook_instance_lifecycle():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
NAME_PARAM = "MyNotebookInstance"
INSTANCE_TYPE_PARAM = "ml.t2.medium"
args = {
"NotebookInstanceName": NAME_PARAM,
"InstanceType": INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker.create_notebook_instance(**args)
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
notebook_instance_arn = resp["NotebookInstanceArn"]
with assert_raises(ClientError) as ex:
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_equal(ex.exception.response["Error"]["Code"], "ValidationException")
expected_message = "Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
notebook_instance_arn
)
assert_true(expected_message in ex.exception.response["Error"]["Message"])
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_equal(resp["NotebookInstanceStatus"], "Stopped")
sagemaker.start_notebook_instance(NotebookInstanceName=NAME_PARAM)
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_equal(resp["NotebookInstanceStatus"], "InService")
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_equal(resp["NotebookInstanceStatus"], "Stopped")
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
with assert_raises(ClientError) as ex:
sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
assert_equal(ex.exception.response["Error"]["Message"], "RecordNotFound")
@mock_sagemaker
def test_describe_nonexistent_model():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
with assert_raises(ClientError) as e:
resp = sagemaker.describe_model(ModelName="Nonexistent")
assert_true(
e.exception.response["Error"]["Message"].startswith("Could not find model")
)