Sagemaker models (#3105)
* First failing test, and enough framework to run it. * Rudimentary passing test. * Sagemaker Notebook Support, take-1: create, describe, start, stop, delete. * Added list_tags. * Merged in model support from https://github.com/porthunt/moto/tree/sagemaker-support. * Re-org'd * Fixed up describe_model exception when no matching model. * Segregated tests by Sagemaker entity. Model arn check by regex.. * Python2 compabitility changes. * Added sagemaker to list of known backends. Corrected urls. * Added sagemaker special case to moto.server.infer_service_region_host due to irregular url format (use of 'api' subdomain) to support server mode. * Changes for PR 3105 comments of July 10, 2020 * PR3105 July 10, 2020, 8:55 AM EDT comment: dropped unnecessary re-addition of arn when formulating model list response. * PR 3105 July 15, 2020 9:10 AM EDT Comment: clean-up SageMakerModelBackend.describe_models logic for finding the model in the dict. * Optimized imports Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
3e2a5e7ee8
commit
1b80b0a810
@ -95,6 +95,7 @@ mock_route53 = lazy_load(".route53", "mock_route53")
|
||||
mock_route53_deprecated = lazy_load(".route53", "mock_route53_deprecated")
|
||||
mock_s3 = lazy_load(".s3", "mock_s3")
|
||||
mock_s3_deprecated = lazy_load(".s3", "mock_s3_deprecated")
|
||||
mock_sagemaker = lazy_load(".sagemaker", "mock_sagemaker")
|
||||
mock_secretsmanager = lazy_load(".secretsmanager", "mock_secretsmanager")
|
||||
mock_ses = lazy_load(".ses", "mock_ses")
|
||||
mock_ses_deprecated = lazy_load(".ses", "mock_ses_deprecated")
|
||||
|
@ -58,6 +58,7 @@ BACKENDS = {
|
||||
"route53": ("route53", "route53_backends"),
|
||||
"s3": ("s3", "s3_backends"),
|
||||
"s3bucket_path": ("s3", "s3_backends"),
|
||||
"sagemaker": ("sagemaker", "sagemaker_backends"),
|
||||
"secretsmanager": ("secretsmanager", "secretsmanager_backends"),
|
||||
"ses": ("ses", "ses_backends"),
|
||||
"sns": ("sns", "sns_backends"),
|
||||
|
5
moto/sagemaker/__init__.py
Normal file
5
moto/sagemaker/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from __future__ import unicode_literals
|
||||
from .models import sagemaker_backends
|
||||
|
||||
sagemaker_backend = sagemaker_backends["us-east-1"]
|
||||
mock_sagemaker = sagemaker_backend.decorator
|
47
moto/sagemaker/exceptions.py
Normal file
47
moto/sagemaker/exceptions.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import unicode_literals
|
||||
import json
|
||||
from moto.core.exceptions import RESTError
|
||||
|
||||
|
||||
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
||||
{% block extra %}<ModelName>{{ model }}</ModelName>{% endblock %}
|
||||
"""
|
||||
|
||||
|
||||
class SagemakerClientError(RESTError):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("template", "single_error")
|
||||
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
||||
super(SagemakerClientError, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelError(RESTError):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("template", "model_error")
|
||||
self.templates["model_error"] = ERROR_WITH_MODEL_NAME
|
||||
super(ModelError, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class MissingModel(ModelError):
|
||||
code = 404
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MissingModel, self).__init__(
|
||||
"NoSuchModel", "Could not find model", *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class AWSError(Exception):
|
||||
TYPE = None
|
||||
STATUS = 400
|
||||
|
||||
def __init__(self, message, type=None, status=None):
|
||||
self.message = message
|
||||
self.type = type if type is not None else self.TYPE
|
||||
self.status = status if status is not None else self.STATUS
|
||||
|
||||
def response(self):
|
||||
return (
|
||||
json.dumps({"__type": self.type, "message": self.message}),
|
||||
dict(status=self.status),
|
||||
)
|
398
moto/sagemaker/models.py
Normal file
398
moto/sagemaker/models.py
Normal file
@ -0,0 +1,398 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.ec2 import ec2_backends
|
||||
from moto.sagemaker import validators
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from .exceptions import MissingModel
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
def camelCase(self, key):
|
||||
words = []
|
||||
for i, word in enumerate(key.split("_")):
|
||||
words.append(word.title())
|
||||
return "".join(words)
|
||||
|
||||
def gen_response_object(self):
|
||||
response_object = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if "_" in key:
|
||||
response_object[self.camelCase(key)] = value
|
||||
else:
|
||||
response_object[key[0].upper() + key[1:]] = value
|
||||
return response_object
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
return self.gen_response_object()
|
||||
|
||||
|
||||
class Model(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
model_name,
|
||||
execution_role_arn,
|
||||
primary_container,
|
||||
vpc_config,
|
||||
containers=[],
|
||||
tags=[],
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.containers = containers
|
||||
self.tags = tags
|
||||
self.enable_network_isolation = False
|
||||
self.vpc_config = vpc_config
|
||||
self.primary_container = primary_container
|
||||
self.execution_role_arn = execution_role_arn or "arn:test"
|
||||
self.model_arn = self.arn_for_model_name(self.model_name, region_name)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"ModelArn": self.model_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_for_model_name(model_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":model/"
|
||||
+ model_name
|
||||
)
|
||||
|
||||
|
||||
class VpcConfig(BaseObject):
|
||||
def __init__(self, security_group_ids, subnets):
|
||||
self.security_group_ids = security_group_ids
|
||||
self.subnets = subnets
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
|
||||
class Container(BaseObject):
|
||||
def __init__(self, **kwargs):
|
||||
self.container_hostname = kwargs.get("container_hostname", "localhost")
|
||||
self.model_data_url = kwargs.get("data_url", "")
|
||||
self.model_package_name = kwargs.get("package_name", "pkg")
|
||||
self.image = kwargs.get("image", "")
|
||||
self.environment = kwargs.get("environment", {})
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
|
||||
class FakeSagemakerNotebookInstance:
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
notebook_instance_name,
|
||||
instance_type,
|
||||
role_arn,
|
||||
subnet_id,
|
||||
security_group_ids,
|
||||
kms_key_id,
|
||||
tags,
|
||||
lifecycle_config_name,
|
||||
direct_internet_access,
|
||||
volume_size_in_gb,
|
||||
accelerator_types,
|
||||
default_code_repository,
|
||||
additional_code_repositories,
|
||||
root_access,
|
||||
):
|
||||
self.validate_volume_size_in_gb(volume_size_in_gb)
|
||||
self.validate_instance_type(instance_type)
|
||||
|
||||
self.region_name = region_name
|
||||
self.notebook_instance_name = notebook_instance_name
|
||||
self.instance_type = instance_type
|
||||
self.role_arn = role_arn
|
||||
self.subnet_id = subnet_id
|
||||
self.security_group_ids = security_group_ids
|
||||
self.kms_key_id = kms_key_id
|
||||
self.tags = tags or []
|
||||
self.lifecycle_config_name = lifecycle_config_name
|
||||
self.direct_internet_access = direct_internet_access
|
||||
self.volume_size_in_gb = volume_size_in_gb
|
||||
self.accelerator_types = accelerator_types
|
||||
self.default_code_repository = default_code_repository
|
||||
self.additional_code_repositories = additional_code_repositories
|
||||
self.root_access = root_access
|
||||
self.status = None
|
||||
self.creation_time = self.last_modified_time = datetime.now()
|
||||
self.start()
|
||||
|
||||
def validate_volume_size_in_gb(self, volume_size_in_gb):
|
||||
if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True):
|
||||
message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def validate_instance_type(self, instance_type):
|
||||
VALID_INSTANCE_TYPES = [
|
||||
"ml.p2.xlarge",
|
||||
"ml.m5.4xlarge",
|
||||
"ml.m4.16xlarge",
|
||||
"ml.t3.xlarge",
|
||||
"ml.p3.16xlarge",
|
||||
"ml.t2.xlarge",
|
||||
"ml.p2.16xlarge",
|
||||
"ml.c4.2xlarge",
|
||||
"ml.c5.2xlarge",
|
||||
"ml.c4.4xlarge",
|
||||
"ml.c5d.2xlarge",
|
||||
"ml.c5.4xlarge",
|
||||
"ml.c5d.4xlarge",
|
||||
"ml.c4.8xlarge",
|
||||
"ml.c5d.xlarge",
|
||||
"ml.c5.9xlarge",
|
||||
"ml.c5.xlarge",
|
||||
"ml.c5d.9xlarge",
|
||||
"ml.c4.xlarge",
|
||||
"ml.t2.2xlarge",
|
||||
"ml.c5d.18xlarge",
|
||||
"ml.t3.2xlarge",
|
||||
"ml.t3.medium",
|
||||
"ml.t2.medium",
|
||||
"ml.c5.18xlarge",
|
||||
"ml.p3.2xlarge",
|
||||
"ml.m5.xlarge",
|
||||
"ml.m4.10xlarge",
|
||||
"ml.t2.large",
|
||||
"ml.m5.12xlarge",
|
||||
"ml.m4.xlarge",
|
||||
"ml.t3.large",
|
||||
"ml.m5.24xlarge",
|
||||
"ml.m4.2xlarge",
|
||||
"ml.p2.8xlarge",
|
||||
"ml.m5.2xlarge",
|
||||
"ml.p3.8xlarge",
|
||||
"ml.m4.4xlarge",
|
||||
]
|
||||
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
|
||||
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
instance_type, VALID_INSTANCE_TYPES
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
@property
|
||||
def arn(self):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ self.region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":notebook-instance/"
|
||||
+ self.notebook_instance_name
|
||||
)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return "{}.notebook.{}.sagemaker.aws".format(
|
||||
self.notebook_instance_name, self.region_name
|
||||
)
|
||||
|
||||
def start(self):
|
||||
self.status = "InService"
|
||||
|
||||
@property
|
||||
def is_deletable(self):
|
||||
return self.status in ["Stopped", "Failed"]
|
||||
|
||||
def stop(self):
|
||||
self.status = "Stopped"
|
||||
|
||||
|
||||
class SageMakerModelBackend(BaseBackend):
|
||||
def __init__(self, region_name=None):
|
||||
self._models = {}
|
||||
self.notebook_instances = {}
|
||||
self.region_name = region_name
|
||||
|
||||
def reset(self):
|
||||
region_name = self.region_name
|
||||
self.__dict__ = {}
|
||||
self.__init__(region_name)
|
||||
|
||||
def create_model(self, **kwargs):
|
||||
model_obj = Model(
|
||||
region_name=self.region_name,
|
||||
model_name=kwargs.get("ModelName"),
|
||||
execution_role_arn=kwargs.get("ExecutionRoleArn"),
|
||||
primary_container=kwargs.get("PrimaryContainer", {}),
|
||||
vpc_config=kwargs.get("VpcConfig", {}),
|
||||
containers=kwargs.get("Containers", []),
|
||||
tags=kwargs.get("Tags", []),
|
||||
)
|
||||
|
||||
self._models[kwargs.get("ModelName")] = model_obj
|
||||
return model_obj.response_create
|
||||
|
||||
def describe_model(self, model_name=None):
|
||||
model = self._models.get(model_name)
|
||||
if model:
|
||||
return model.response_object
|
||||
message = "Could not find model '{}'.".format(
|
||||
Model.arn_for_model_name(model_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException", message=message, template="error_json",
|
||||
)
|
||||
|
||||
def list_models(self):
|
||||
models = []
|
||||
for model in self._models.values():
|
||||
model_response = deepcopy(model.response_object)
|
||||
models.append(model_response)
|
||||
return {"Models": models}
|
||||
|
||||
def delete_model(self, model_name=None):
|
||||
for model in self._models.values():
|
||||
if model.model_name == model_name:
|
||||
self._models.pop(model.model_name)
|
||||
break
|
||||
else:
|
||||
raise MissingModel(model=model_name)
|
||||
|
||||
def create_notebook_instance(
|
||||
self,
|
||||
notebook_instance_name,
|
||||
instance_type,
|
||||
role_arn,
|
||||
subnet_id=None,
|
||||
security_group_ids=None,
|
||||
kms_key_id=None,
|
||||
tags=None,
|
||||
lifecycle_config_name=None,
|
||||
direct_internet_access="Enabled",
|
||||
volume_size_in_gb=5,
|
||||
accelerator_types=None,
|
||||
default_code_repository=None,
|
||||
additional_code_repositories=None,
|
||||
root_access=None,
|
||||
):
|
||||
self._validate_unique_notebook_instance_name(notebook_instance_name)
|
||||
|
||||
notebook_instance = FakeSagemakerNotebookInstance(
|
||||
self.region_name,
|
||||
notebook_instance_name,
|
||||
instance_type,
|
||||
role_arn,
|
||||
subnet_id=subnet_id,
|
||||
security_group_ids=security_group_ids,
|
||||
kms_key_id=kms_key_id,
|
||||
tags=tags,
|
||||
lifecycle_config_name=lifecycle_config_name,
|
||||
direct_internet_access=direct_internet_access
|
||||
if direct_internet_access is not None
|
||||
else "Enabled",
|
||||
volume_size_in_gb=volume_size_in_gb if volume_size_in_gb is not None else 5,
|
||||
accelerator_types=accelerator_types,
|
||||
default_code_repository=default_code_repository,
|
||||
additional_code_repositories=additional_code_repositories,
|
||||
root_access=root_access,
|
||||
)
|
||||
self.notebook_instances[notebook_instance_name] = notebook_instance
|
||||
return notebook_instance
|
||||
|
||||
def _validate_unique_notebook_instance_name(self, notebook_instance_name):
|
||||
if notebook_instance_name in self.notebook_instances:
|
||||
duplicate_arn = self.notebook_instances[notebook_instance_name].arn
|
||||
message = "Cannot create a duplicate Notebook Instance ({})".format(
|
||||
duplicate_arn
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_notebook_instance(self, notebook_instance_name):
|
||||
try:
|
||||
return self.notebook_instances[notebook_instance_name]
|
||||
except KeyError:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_notebook_instance_by_arn(self, arn):
|
||||
instances = [
|
||||
notebook_instance
|
||||
for notebook_instance in self.notebook_instances.values()
|
||||
if notebook_instance.arn == arn
|
||||
]
|
||||
if len(instances) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
return instances[0]
|
||||
|
||||
def start_notebook_instance(self, notebook_instance_name):
|
||||
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
||||
notebook_instance.start()
|
||||
|
||||
def stop_notebook_instance(self, notebook_instance_name):
|
||||
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
||||
notebook_instance.stop()
|
||||
|
||||
def delete_notebook_instance(self, notebook_instance_name):
|
||||
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
||||
if not notebook_instance.is_deletable:
|
||||
message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
|
||||
notebook_instance.status, notebook_instance.arn
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
del self.notebook_instances[notebook_instance_name]
|
||||
|
||||
def get_notebook_instance_tags(self, arn):
|
||||
try:
|
||||
notebook_instance = self.get_notebook_instance_by_arn(arn)
|
||||
return notebook_instance.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
|
||||
sagemaker_backends = {}
|
||||
for region, ec2_backend in ec2_backends.items():
|
||||
sagemaker_backends[region] = SageMakerModelBackend(region)
|
127
moto/sagemaker/responses.py
Normal file
127
moto/sagemaker/responses.py
Normal file
@ -0,0 +1,127 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.utils import amzn_request_id
|
||||
from .exceptions import AWSError
|
||||
from .models import sagemaker_backends
|
||||
|
||||
|
||||
class SageMakerResponse(BaseResponse):
|
||||
@property
|
||||
def sagemaker_backend(self):
|
||||
return sagemaker_backends[self.region]
|
||||
|
||||
@property
|
||||
def request_params(self):
|
||||
try:
|
||||
return json.loads(self.body)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
def describe_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
response = self.sagemaker_backend.describe_model(model_name)
|
||||
return json.dumps(response)
|
||||
|
||||
def create_model(self):
|
||||
response = self.sagemaker_backend.create_model(**self.request_params)
|
||||
return json.dumps(response)
|
||||
|
||||
def delete_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
response = self.sagemaker_backend.delete_model(model_name)
|
||||
return json.dumps(response)
|
||||
|
||||
def list_models(self):
|
||||
response = self.sagemaker_backend.list_models(**self.request_params)
|
||||
return json.dumps(response)
|
||||
|
||||
def _get_param(self, param, if_none=None):
|
||||
return self.request_params.get(param, if_none)
|
||||
|
||||
@amzn_request_id
|
||||
def create_notebook_instance(self):
|
||||
try:
|
||||
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
|
||||
notebook_instance_name=self._get_param("NotebookInstanceName"),
|
||||
instance_type=self._get_param("InstanceType"),
|
||||
subnet_id=self._get_param("SubnetId"),
|
||||
security_group_ids=self._get_param("SecurityGroupIds"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
kms_key_id=self._get_param("KmsKeyId"),
|
||||
tags=self._get_param("Tags"),
|
||||
lifecycle_config_name=self._get_param("LifecycleConfigName"),
|
||||
direct_internet_access=self._get_param("DirectInternetAccess"),
|
||||
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
|
||||
accelerator_types=self._get_param("AcceleratorTypes"),
|
||||
default_code_repository=self._get_param("DefaultCodeRepository"),
|
||||
additional_code_repositories=self._get_param(
|
||||
"AdditionalCodeRepositories"
|
||||
),
|
||||
root_access=self._get_param("RootAccess"),
|
||||
)
|
||||
response = {
|
||||
"NotebookInstanceArn": sagemaker_notebook.arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
try:
|
||||
notebook_instance = self.sagemaker_backend.get_notebook_instance(
|
||||
notebook_instance_name
|
||||
)
|
||||
response = {
|
||||
"NotebookInstanceArn": notebook_instance.arn,
|
||||
"NotebookInstanceName": notebook_instance.notebook_instance_name,
|
||||
"NotebookInstanceStatus": notebook_instance.status,
|
||||
"Url": notebook_instance.url,
|
||||
"InstanceType": notebook_instance.instance_type,
|
||||
"SubnetId": notebook_instance.subnet_id,
|
||||
"SecurityGroups": notebook_instance.security_group_ids,
|
||||
"RoleArn": notebook_instance.role_arn,
|
||||
"KmsKeyId": notebook_instance.kms_key_id,
|
||||
# ToDo: NetworkInterfaceId
|
||||
"LastModifiedTime": str(notebook_instance.last_modified_time),
|
||||
"CreationTime": str(notebook_instance.creation_time),
|
||||
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
|
||||
"DirectInternetAccess": notebook_instance.direct_internet_access,
|
||||
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
|
||||
"AcceleratorTypes": notebook_instance.accelerator_types,
|
||||
"DefaultCodeRepository": notebook_instance.default_code_repository,
|
||||
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
|
||||
"RootAccess": notebook_instance.root_access,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def start_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def stop_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def delete_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
11
moto/sagemaker/urls.py
Normal file
11
moto/sagemaker/urls.py
Normal file
@ -0,0 +1,11 @@
|
||||
from __future__ import unicode_literals
|
||||
from .responses import SageMakerResponse
|
||||
|
||||
url_bases = [
|
||||
"https?://api.sagemaker.(.+).amazonaws.com",
|
||||
"https?://api-fips.sagemaker.(.+).amazonaws.com",
|
||||
]
|
||||
|
||||
url_paths = {
|
||||
"{0}/$": SageMakerResponse.dispatch,
|
||||
}
|
20
moto/sagemaker/validators.py
Normal file
20
moto/sagemaker/validators.py
Normal file
@ -0,0 +1,20 @@
|
||||
def is_integer_between(x, mn=None, mx=None, optional=False):
|
||||
if optional and x is None:
|
||||
return True
|
||||
try:
|
||||
if mn is not None and mx is not None:
|
||||
return int(x) >= mn and int(x) < mx
|
||||
elif mn is not None:
|
||||
return int(x) >= mn
|
||||
elif mx is not None:
|
||||
return int(x) < mx
|
||||
else:
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_one_of(x, choices, optional=False):
|
||||
if optional and x is None:
|
||||
return True
|
||||
return x in choices
|
@ -102,6 +102,10 @@ class DomainDispatcherApplication(object):
|
||||
# If Newer API version, use dynamodb2
|
||||
if dynamo_api_version > "20111205":
|
||||
host = "dynamodb2"
|
||||
elif service == "sagemaker":
|
||||
host = "api.sagemaker.{region}.amazonaws.com".format(
|
||||
service=service, region=region
|
||||
)
|
||||
else:
|
||||
host = "{service}.{region}.amazonaws.com".format(
|
||||
service=service, region=region
|
||||
|
0
tests/test_sagemaker/__init__.py
Normal file
0
tests/test_sagemaker/__init__.py
Normal file
122
tests/test_sagemaker/test_sagemaker_models.py
Normal file
122
tests/test_sagemaker/test_sagemaker_models.py
Normal file
@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import boto3
|
||||
import tests.backport_assert_raises # noqa
|
||||
from botocore.exceptions import ClientError
|
||||
from nose.tools import assert_raises
|
||||
from moto import mock_sagemaker
|
||||
|
||||
import sure # noqa
|
||||
|
||||
from moto.sagemaker.models import VpcConfig
|
||||
|
||||
|
||||
class MySageMakerModel(object):
|
||||
def __init__(self, name, arn, container=None, vpc_config=None):
|
||||
self.name = name
|
||||
self.arn = arn
|
||||
self.container = container if container else {}
|
||||
self.vpc_config = (
|
||||
vpc_config if vpc_config else {"sg-groups": ["sg-123"], "subnets": ["123"]}
|
||||
)
|
||||
|
||||
def save(self):
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
vpc_config = VpcConfig(
|
||||
self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
|
||||
)
|
||||
client.create_model(
|
||||
ModelName=self.name,
|
||||
ExecutionRoleArn=self.arn,
|
||||
VpcConfig=vpc_config.response_object,
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
test_model = MySageMakerModel(
|
||||
name="blah",
|
||||
arn="arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar",
|
||||
vpc_config={"sg-groups": ["sg-123"], "subnets": ["123"]},
|
||||
)
|
||||
test_model.save()
|
||||
model = client.describe_model(ModelName="blah")
|
||||
assert model.get("ModelName").should.equal("blah")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
|
||||
exec_role_arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
name = "blah"
|
||||
model = client.create_model(
|
||||
ModelName=name,
|
||||
ExecutionRoleArn=exec_role_arn,
|
||||
VpcConfig=vpc_config.response_object,
|
||||
)
|
||||
|
||||
model["ModelArn"].should.match(r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name))
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model = MySageMakerModel(name=name, arn=arn)
|
||||
test_model.save()
|
||||
|
||||
assert len(client.list_models()["Models"]).should.equal(1)
|
||||
client.delete_model(ModelName=name)
|
||||
assert len(client.list_models()["Models"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model_not_found():
|
||||
with assert_raises(ClientError) as err:
|
||||
boto3.client("sagemaker", region_name="us-east-1").delete_model(
|
||||
ModelName="blah"
|
||||
)
|
||||
assert err.exception.response["Error"]["Code"].should.equal("404")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model = MySageMakerModel(name=name, arn=arn)
|
||||
test_model.save()
|
||||
models = client.list_models()
|
||||
assert len(models["Models"]).should.equal(1)
|
||||
assert models["Models"][0]["ModelName"].should.equal(name)
|
||||
assert models["Models"][0]["ModelArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name)
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
name_model_1 = "blah"
|
||||
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
|
||||
test_model_1.save()
|
||||
|
||||
name_model_2 = "blah2"
|
||||
arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
|
||||
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
|
||||
test_model_2.save()
|
||||
models = client.list_models()
|
||||
assert len(models["Models"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
models = client.list_models()
|
||||
assert len(models["Models"]).should.equal(0)
|
227
tests/test_sagemaker/test_sagemaker_notebooks.py
Normal file
227
tests/test_sagemaker/test_sagemaker_notebooks.py
Normal file
@ -0,0 +1,227 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, ParamValidationError
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from nose.tools import assert_true, assert_equal, assert_raises
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
FAKE_SUBNET_ID = "subnet-012345678"
|
||||
FAKE_SECURITY_GROUP_IDS = ["sg-0123456789abcdef0", "sg-0123456789abcdef1"]
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
FAKE_KMS_KEY_ID = "62d4509a-9f96-446c-a9ba-6b1c353c8c58"
|
||||
GENERIC_TAGS_PARAM = [
|
||||
{"Key": "newkey1", "Value": "newval1"},
|
||||
{"Key": "newkey2", "Value": "newval2"},
|
||||
]
|
||||
FAKE_LIFECYCLE_CONFIG_NAME = "FakeLifecycleConfigName"
|
||||
FAKE_DEFAULT_CODE_REPO = "https://github.com/user/repo1"
|
||||
FAKE_ADDL_CODE_REPOS = [
|
||||
"https://github.com/user/repo2",
|
||||
"https://github.com/user/repo2",
|
||||
]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_minimal_params():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
|
||||
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
|
||||
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
|
||||
assert_equal(resp["NotebookInstanceName"], NAME_PARAM)
|
||||
assert_equal(resp["NotebookInstanceStatus"], "InService")
|
||||
assert_equal(
|
||||
resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME)
|
||||
)
|
||||
assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM)
|
||||
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_equal(resp["DirectInternetAccess"], "Enabled")
|
||||
assert_equal(resp["VolumeSizeInGB"], 5)
|
||||
|
||||
|
||||
# assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults...
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_params():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
DIRECT_INTERNET_ACCESS_PARAM = "Enabled"
|
||||
VOLUME_SIZE_IN_GB_PARAM = 7
|
||||
ACCELERATOR_TYPES_PARAM = ["ml.eia1.medium", "ml.eia2.medium"]
|
||||
ROOT_ACCESS_PARAM = "Disabled"
|
||||
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"SubnetId": FAKE_SUBNET_ID,
|
||||
"SecurityGroupIds": FAKE_SECURITY_GROUP_IDS,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"KmsKeyId": FAKE_KMS_KEY_ID,
|
||||
"Tags": GENERIC_TAGS_PARAM,
|
||||
"LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME,
|
||||
"DirectInternetAccess": DIRECT_INTERNET_ACCESS_PARAM,
|
||||
"VolumeSizeInGB": VOLUME_SIZE_IN_GB_PARAM,
|
||||
"AcceleratorTypes": ACCELERATOR_TYPES_PARAM,
|
||||
"DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO,
|
||||
"AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS,
|
||||
"RootAccess": ROOT_ACCESS_PARAM,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
|
||||
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
|
||||
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
|
||||
assert_equal(resp["NotebookInstanceName"], NAME_PARAM)
|
||||
assert_equal(resp["NotebookInstanceStatus"], "InService")
|
||||
assert_equal(
|
||||
resp["Url"], "{}.notebook.{}.sagemaker.aws".format(NAME_PARAM, TEST_REGION_NAME)
|
||||
)
|
||||
assert_equal(resp["InstanceType"], INSTANCE_TYPE_PARAM)
|
||||
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_equal(resp["DirectInternetAccess"], "Enabled")
|
||||
assert_equal(resp["VolumeSizeInGB"], VOLUME_SIZE_IN_GB_PARAM)
|
||||
# assert_equal(resp["RootAccess"], True) # ToDo: Not sure if this defaults...
|
||||
assert_equal(resp["SubnetId"], FAKE_SUBNET_ID)
|
||||
assert_equal(resp["SecurityGroups"], FAKE_SECURITY_GROUP_IDS)
|
||||
assert_equal(resp["KmsKeyId"], FAKE_KMS_KEY_ID)
|
||||
assert_equal(
|
||||
resp["NotebookInstanceLifecycleConfigName"], FAKE_LIFECYCLE_CONFIG_NAME
|
||||
)
|
||||
assert_equal(resp["AcceleratorTypes"], ACCELERATOR_TYPES_PARAM)
|
||||
assert_equal(resp["DefaultCodeRepository"], FAKE_DEFAULT_CODE_REPO)
|
||||
assert_equal(resp["AdditionalCodeRepositories"], FAKE_ADDL_CODE_REPOS)
|
||||
|
||||
resp = sagemaker.list_tags(ResourceArn=resp["NotebookInstanceArn"])
|
||||
assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_bad_volume_size():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
vol_size = 2
|
||||
args = {
|
||||
"NotebookInstanceName": "MyNotebookInstance",
|
||||
"InstanceType": "ml.t2.medium",
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"VolumeSizeInGB": vol_size,
|
||||
}
|
||||
with assert_raises(ParamValidationError) as ex:
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert_equal(
|
||||
ex.exception.args[0],
|
||||
"Parameter validation failed:\nInvalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf".format(
|
||||
vol_size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_invalid_instance_type():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
instance_type = "undefined_instance_type"
|
||||
args = {
|
||||
"NotebookInstanceName": "MyNotebookInstance",
|
||||
"InstanceType": instance_type,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
with assert_raises(ClientError) as ex:
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert_equal(ex.exception.response["Error"]["Code"], "ValidationException")
|
||||
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||
instance_type
|
||||
)
|
||||
|
||||
assert_true(expected_message in ex.exception.response["Error"]["Message"])
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_notebook_instance_lifecycle():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert_true(resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker"))
|
||||
assert_true(resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]))
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
notebook_instance_arn = resp["NotebookInstanceArn"]
|
||||
|
||||
with assert_raises(ClientError) as ex:
|
||||
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_equal(ex.exception.response["Error"]["Code"], "ValidationException")
|
||||
expected_message = "Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
|
||||
notebook_instance_arn
|
||||
)
|
||||
assert_true(expected_message in ex.exception.response["Error"]["Message"])
|
||||
|
||||
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_equal(resp["NotebookInstanceStatus"], "Stopped")
|
||||
|
||||
sagemaker.start_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_equal(resp["NotebookInstanceStatus"], "InService")
|
||||
|
||||
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_equal(resp["NotebookInstanceStatus"], "Stopped")
|
||||
|
||||
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
|
||||
with assert_raises(ClientError) as ex:
|
||||
sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert_equal(ex.exception.response["Error"]["Message"], "RecordNotFound")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_nonexistent_model():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
resp = sagemaker.describe_model(ModelName="Nonexistent")
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find model")
|
||||
)
|
Loading…
Reference in New Issue
Block a user