Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)

* Basic upport for Endpoints, EndpointConfigs and TrainingJobs

* Dropped extraneous pass statement.

Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
jweite 2020-07-19 10:06:48 -04:00 committed by GitHub
parent a123a22eeb
commit ba99c61477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1007 additions and 6 deletions

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
@ -32,6 +33,288 @@ class BaseObject(BaseModel):
return self.gen_response_object() return self.gen_response_object()
class FakeTrainingJob(BaseObject):
def __init__(
self,
region_name,
training_job_name,
hyper_parameters,
algorithm_specification,
role_arn,
input_data_config,
output_data_config,
resource_config,
vpc_config,
stopping_condition,
tags,
enable_network_isolation,
enable_inter_container_traffic_encryption,
enable_managed_spot_training,
checkpoint_config,
debug_hook_config,
debug_rule_configurations,
tensor_board_output_config,
experiment_config,
):
self.training_job_name = training_job_name
self.hyper_parameters = hyper_parameters
self.algorithm_specification = algorithm_specification
self.role_arn = role_arn
self.input_data_config = input_data_config
self.output_data_config = output_data_config
self.resource_config = resource_config
self.vpc_config = vpc_config
self.stopping_condition = stopping_condition
self.tags = tags
self.enable_network_isolation = enable_network_isolation
self.enable_inter_container_traffic_encryption = (
enable_inter_container_traffic_encryption
)
self.enable_managed_spot_training = enable_managed_spot_training
self.checkpoint_config = checkpoint_config
self.debug_hook_config = debug_hook_config
self.debug_rule_configurations = debug_rule_configurations
self.tensor_board_output_config = tensor_board_output_config
self.experiment_config = experiment_config
self.training_job_arn = FakeTrainingJob.arn_formatter(
training_job_name, region_name
)
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.model_artifacts = {
"S3ModelArtifacts": os.path.join(
self.output_data_config["S3OutputPath"],
self.training_job_name,
"output",
"model.tar.gz",
)
}
self.training_job_status = "Completed"
self.secondary_status = "Completed"
self.algorithm_specification["MetricDefinitions"] = [
{
"Name": "test:dcg",
"Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)",
}
]
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.training_start_time = now_string
self.training_end_time = now_string
self.secondary_status_transitions = [
{
"Status": "Starting",
"StartTime": self.creation_time,
"EndTime": self.creation_time,
"StatusMessage": "Preparing the instances for training",
}
]
self.final_metric_data_list = [
{
"MetricName": "train:progress",
"Value": 100.0,
"Timestamp": self.creation_time,
}
]
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"TrainingJobArn": self.training_job_arn}
@staticmethod
def arn_formatter(endpoint_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":training-job/"
+ endpoint_name
)
class FakeEndpoint(BaseObject):
def __init__(
self,
region_name,
endpoint_name,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
):
self.endpoint_name = endpoint_name
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
self.endpoint_config_name = endpoint_config_name
self.production_variants = production_variants
self.data_capture_config = data_capture_config
self.tags = tags or []
self.endpoint_status = "InService"
self.failure_reason = None
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"EndpointArn": self.endpoint_arn}
@staticmethod
def arn_formatter(endpoint_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":endpoint/"
+ endpoint_name
)
class FakeEndpointConfig(BaseObject):
def __init__(
self,
region_name,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
kms_key_id,
):
self.validate_production_variants(production_variants)
self.endpoint_config_name = endpoint_config_name
self.endpoint_config_arn = FakeEndpointConfig.arn_formatter(
endpoint_config_name, region_name
)
self.production_variants = production_variants or []
self.data_capture_config = data_capture_config or {}
self.tags = tags or []
self.kms_key_id = kms_key_id
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def validate_production_variants(self, production_variants):
for production_variant in production_variants:
self.validate_instance_type(production_variant["InstanceType"])
def validate_instance_type(self, instance_type):
VALID_INSTANCE_TYPES = [
"ml.r5d.12xlarge",
"ml.r5.12xlarge",
"ml.p2.xlarge",
"ml.m5.4xlarge",
"ml.m4.16xlarge",
"ml.r5d.24xlarge",
"ml.r5.24xlarge",
"ml.p3.16xlarge",
"ml.m5d.xlarge",
"ml.m5.large",
"ml.t2.xlarge",
"ml.p2.16xlarge",
"ml.m5d.12xlarge",
"ml.inf1.2xlarge",
"ml.m5d.24xlarge",
"ml.c4.2xlarge",
"ml.c5.2xlarge",
"ml.c4.4xlarge",
"ml.inf1.6xlarge",
"ml.c5d.2xlarge",
"ml.c5.4xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.12xlarge",
"ml.c5d.4xlarge",
"ml.g4dn.2xlarge",
"ml.c4.8xlarge",
"ml.c4.large",
"ml.c5d.xlarge",
"ml.c5.large",
"ml.g4dn.4xlarge",
"ml.c5.9xlarge",
"ml.g4dn.16xlarge",
"ml.c5d.large",
"ml.c5.xlarge",
"ml.c5d.9xlarge",
"ml.c4.xlarge",
"ml.inf1.xlarge",
"ml.g4dn.8xlarge",
"ml.inf1.24xlarge",
"ml.m5d.2xlarge",
"ml.t2.2xlarge",
"ml.c5d.18xlarge",
"ml.m5d.4xlarge",
"ml.t2.medium",
"ml.c5.18xlarge",
"ml.r5d.2xlarge",
"ml.r5.2xlarge",
"ml.p3.2xlarge",
"ml.m5d.large",
"ml.m5.xlarge",
"ml.m4.10xlarge",
"ml.t2.large",
"ml.r5d.4xlarge",
"ml.r5.4xlarge",
"ml.m5.12xlarge",
"ml.m4.xlarge",
"ml.m5.24xlarge",
"ml.m4.2xlarge",
"ml.p2.8xlarge",
"ml.m5.2xlarge",
"ml.r5d.xlarge",
"ml.r5d.large",
"ml.r5.xlarge",
"ml.r5.large",
"ml.p3.8xlarge",
"ml.m4.4xlarge",
]
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
instance_type, VALID_INSTANCE_TYPES
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"EndpointConfigArn": self.endpoint_config_arn}
@staticmethod
def arn_formatter(model_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":endpoint-config/"
+ model_name
)
class Model(BaseObject): class Model(BaseObject):
def __init__( def __init__(
self, self,
@ -238,6 +521,9 @@ class SageMakerModelBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name=None):
self._models = {} self._models = {}
self.notebook_instances = {} self.notebook_instances = {}
self.endpoint_configs = {}
self.endpoints = {}
self.training_jobs = {}
self.region_name = region_name self.region_name = region_name
def reset(self): def reset(self):
@ -305,10 +591,10 @@ class SageMakerModelBackend(BaseBackend):
self._validate_unique_notebook_instance_name(notebook_instance_name) self._validate_unique_notebook_instance_name(notebook_instance_name)
notebook_instance = FakeSagemakerNotebookInstance( notebook_instance = FakeSagemakerNotebookInstance(
self.region_name, region_name=self.region_name,
notebook_instance_name, notebook_instance_name=notebook_instance_name,
instance_type, instance_type=instance_type,
role_arn, role_arn=role_arn,
subnet_id=subnet_id, subnet_id=subnet_id,
security_group_ids=security_group_ids, security_group_ids=security_group_ids,
kms_key_id=kms_key_id, kms_key_id=kms_key_id,
@ -392,6 +678,235 @@ class SageMakerModelBackend(BaseBackend):
except RESTError: except RESTError:
return [] return []
def create_endpoint_config(
self,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
kms_key_id,
):
endpoint_config = FakeEndpointConfig(
region_name=self.region_name,
endpoint_config_name=endpoint_config_name,
production_variants=production_variants,
data_capture_config=data_capture_config,
tags=tags,
kms_key_id=kms_key_id,
)
self.validate_production_variants(production_variants)
self.endpoint_configs[endpoint_config_name] = endpoint_config
return endpoint_config
def validate_production_variants(self, production_variants):
for production_variant in production_variants:
if production_variant["ModelName"] not in self._models:
message = "Could not find model '{}'.".format(
Model.arn_for_model_name(
production_variant["ModelName"], self.region_name
)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def describe_endpoint_config(self, endpoint_config_name):
try:
return self.endpoint_configs[endpoint_config_name].response_object
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_endpoint_config(self, endpoint_config_name):
try:
del self.endpoint_configs[endpoint_config_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def create_endpoint(
self, endpoint_name, endpoint_config_name, tags,
):
try:
endpoint_config = self.describe_endpoint_config(endpoint_config_name)
except KeyError:
message = "Could not find endpoint_config '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
endpoint = FakeEndpoint(
region_name=self.region_name,
endpoint_name=endpoint_name,
endpoint_config_name=endpoint_config_name,
production_variants=endpoint_config["ProductionVariants"],
data_capture_config=endpoint_config["DataCaptureConfig"],
tags=tags,
)
self.endpoints[endpoint_name] = endpoint
return endpoint
def describe_endpoint(self, endpoint_name):
try:
return self.endpoints[endpoint_name].response_object
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_endpoint(self, endpoint_name):
try:
del self.endpoints[endpoint_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_endpoint_by_arn(self, arn):
endpoints = [
endpoint
for endpoint in self.endpoints.values()
if endpoint.endpoint_arn == arn
]
if len(endpoints) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
return endpoints[0]
def get_endpoint_tags(self, arn):
try:
endpoint = self.get_endpoint_by_arn(arn)
return endpoint.tags or []
except RESTError:
return []
def create_training_job(
self,
training_job_name,
hyper_parameters,
algorithm_specification,
role_arn,
input_data_config,
output_data_config,
resource_config,
vpc_config,
stopping_condition,
tags,
enable_network_isolation,
enable_inter_container_traffic_encryption,
enable_managed_spot_training,
checkpoint_config,
debug_hook_config,
debug_rule_configurations,
tensor_board_output_config,
experiment_config,
):
training_job = FakeTrainingJob(
region_name=self.region_name,
training_job_name=training_job_name,
hyper_parameters=hyper_parameters,
algorithm_specification=algorithm_specification,
role_arn=role_arn,
input_data_config=input_data_config,
output_data_config=output_data_config,
resource_config=resource_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
tags=tags,
enable_network_isolation=enable_network_isolation,
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption,
enable_managed_spot_training=enable_managed_spot_training,
checkpoint_config=checkpoint_config,
debug_hook_config=debug_hook_config,
debug_rule_configurations=debug_rule_configurations,
tensor_board_output_config=tensor_board_output_config,
experiment_config=experiment_config,
)
self.training_jobs[training_job_name] = training_job
return training_job
def describe_training_job(self, training_job_name):
try:
return self.training_jobs[training_job_name].response_object
except KeyError:
message = "Could not find training job '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_training_job(self, training_job_name):
try:
del self.training_jobs[training_job_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_training_job_by_arn(self, arn):
training_jobs = [
training_job
for training_job in self.training_jobs.values()
if training_job.training_job_arn == arn
]
if len(training_jobs) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
return training_jobs[0]
def get_training_job_tags(self, arn):
try:
training_job = self.get_training_job_by_arn(arn)
return training_job.tags or []
except RESTError:
return []
sagemaker_backends = {} sagemaker_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():

View File

@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def list_tags(self): def list_tags(self):
arn = self._get_param("ResourceArn") arn = self._get_param("ResourceArn")
tags = self.sagemaker_backend.get_notebook_instance_tags(arn) try:
if ":notebook-instance/" in arn:
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
elif ":endpoint/" in arn:
tags = self.sagemaker_backend.get_endpoint_tags(arn)
elif ":training-job/" in arn:
tags = self.sagemaker_backend.get_training_job_tags(arn)
else:
tags = []
except AWSError:
tags = []
response = {"Tags": tags} response = {"Tags": tags}
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id
def create_endpoint_config(self):
try:
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_endpoint(self):
try:
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint(self):
endpoint_name = self._get_param("EndpointName")
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint(self):
endpoint_name = self._get_param("EndpointName")
self.sagemaker_backend.delete_endpoint(endpoint_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_training_job(self):
try:
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param(
"EnableNetworkIsolation", False
),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_training_job(self):
training_job_name = self._get_param("TrainingJobName")
response = self.sagemaker_backend.describe_training_job(training_job_name)
return json.dumps(response)
@amzn_request_id
def delete_training_job(self):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")

View File

@ -3,7 +3,6 @@ from .responses import SageMakerResponse
url_bases = [ url_bases = [
"https?://api.sagemaker.(.+).amazonaws.com", "https?://api.sagemaker.(.+).amazonaws.com",
"https?://api-fips.sagemaker.(.+).amazonaws.com",
] ]
url_paths = { url_paths = {

View File

@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import datetime
import boto3
from botocore.exceptions import ClientError, ParamValidationError
import sure # noqa
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
from nose.tools import assert_true, assert_equal, assert_raises
TEST_REGION_NAME = "us-east-1"
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
GENERIC_TAGS_PARAM = [
{"Key": "newkey1", "Value": "newval1"},
{"Key": "newkey2", "Value": "newval2"},
]
@mock_sagemaker
def test_create_endpoint_config():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
model_name = "MyModel"
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
endpoint_config_name = "MyEndpointConfig"
with assert_raises(ClientError) as e:
sagemaker.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=production_variants,
)
assert_true(
e.exception.response["Error"]["Message"].startswith("Could not find model")
)
_create_model(sagemaker, model_name)
resp = sagemaker.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
resp["EndpointConfigName"].should.equal(endpoint_config_name)
resp["ProductionVariants"].should.equal(production_variants)
@mock_sagemaker
def test_delete_endpoint_config():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
model_name = "MyModel"
_create_model(sagemaker, model_name)
endpoint_config_name = "MyEndpointConfig"
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
resp = sagemaker.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
with assert_raises(ClientError) as e:
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
assert_true(
e.exception.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
)
with assert_raises(ClientError) as e:
sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
assert_true(
e.exception.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
)
pass
@mock_sagemaker
def test_create_endpoint_invalid_instance_type():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
model_name = "MyModel"
_create_model(sagemaker, model_name)
instance_type = "InvalidInstanceType"
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": instance_type,
},
]
endpoint_config_name = "MyEndpointConfig"
with assert_raises(ClientError) as e:
sagemaker.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=production_variants,
)
assert_equal(e.exception.response["Error"]["Code"], "ValidationException")
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
instance_type
)
assert_true(expected_message in e.exception.response["Error"]["Message"])
@mock_sagemaker
def test_create_endpoint():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
endpoint_name = "MyEndpoint"
with assert_raises(ClientError) as e:
sagemaker.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
)
assert_true(
e.exception.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
)
model_name = "MyModel"
_create_model(sagemaker, model_name)
endpoint_config_name = "MyEndpointConfig"
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
resp = sagemaker.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
Tags=GENERIC_TAGS_PARAM,
)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
)
resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
)
resp["EndpointName"].should.equal(endpoint_name)
resp["EndpointConfigName"].should.equal(endpoint_config_name)
resp["EndpointStatus"].should.equal("InService")
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
@mock_sagemaker
def test_delete_endpoint():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
model_name = "MyModel"
_create_model(sagemaker, model_name)
endpoint_config_name = "MyEndpointConfig"
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
endpoint_name = "MyEndpoint"
_create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
sagemaker.delete_endpoint(EndpointName=endpoint_name)
with assert_raises(ClientError) as e:
sagemaker.describe_endpoint(EndpointName=endpoint_name)
assert_true(
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
)
with assert_raises(ClientError) as e:
sagemaker.delete_endpoint(EndpointName=endpoint_name)
assert_true(
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
)
def _create_model(boto_client, model_name):
resp = boto_client.create_model(
ModelName=model_name,
PrimaryContainer={
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
},
ExecutionRoleArn=FAKE_ROLE_ARN,
)
assert_equal(resp["ResponseMetadata"]["HTTPStatusCode"], 200)
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
resp = boto_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
resp = boto_client.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
)

View File

@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import boto3
import datetime
import sure # noqa
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
from nose.tools import assert_true, assert_equal, assert_raises, assert_regexp_matches
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
TEST_REGION_NAME = "us-east-1"
@mock_sagemaker
def test_create_training_job():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
training_job_name = "MyTrainingJob"
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
bucket = "my-bucket"
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
params = {
"RoleArn": FAKE_ROLE_ARN,
"TrainingJobName": training_job_name,
"AlgorithmSpecification": {
"TrainingImage": container,
"TrainingInputMode": "File",
},
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "ml.c4.2xlarge",
"VolumeSizeInGB": 10,
},
"InputDataConfig": [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": "s3://{}/{}/train/".format(bucket, prefix),
"S3DataDistributionType": "ShardedByS3Key",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
{
"ChannelName": "validation",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": "s3://{}/{}/validation/".format(bucket, prefix),
"S3DataDistributionType": "FullyReplicated",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
],
"OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)},
"HyperParameters": {
"feature_dim": "30",
"mini_batch_size": "100",
"predictor_type": "regressor",
"epochs": "10",
"num_models": "32",
"loss": "absolute_loss",
},
"StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60},
}
resp = sagemaker.create_training_job(**params)
resp["TrainingJobArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
)
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
resp["TrainingJobName"].should.equal(training_job_name)
resp["TrainingJobArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
)
assert_true(
resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
params["OutputDataConfig"]["S3OutputPath"]
)
)
assert_true(training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]))
assert_true(
resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
)
assert_equal(resp["TrainingJobStatus"], "Completed")
assert_equal(resp["SecondaryStatus"], "Completed")
assert_equal(resp["HyperParameters"], params["HyperParameters"])
assert_equal(
resp["AlgorithmSpecification"]["TrainingImage"],
params["AlgorithmSpecification"]["TrainingImage"],
)
assert_equal(
resp["AlgorithmSpecification"]["TrainingInputMode"],
params["AlgorithmSpecification"]["TrainingInputMode"],
)
assert_true("MetricDefinitions" in resp["AlgorithmSpecification"])
assert_true("Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
assert_true("Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
assert_equal(resp["InputDataConfig"], params["InputDataConfig"])
assert_equal(resp["OutputDataConfig"], params["OutputDataConfig"])
assert_equal(resp["ResourceConfig"], params["ResourceConfig"])
assert_equal(resp["StoppingCondition"], params["StoppingCondition"])
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
assert_true(isinstance(resp["TrainingStartTime"], datetime.datetime))
assert_true(isinstance(resp["TrainingEndTime"], datetime.datetime))
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
assert_true("SecondaryStatusTransitions" in resp)
assert_true("Status" in resp["SecondaryStatusTransitions"][0])
assert_true("StartTime" in resp["SecondaryStatusTransitions"][0])
assert_true("EndTime" in resp["SecondaryStatusTransitions"][0])
assert_true("StatusMessage" in resp["SecondaryStatusTransitions"][0])
assert_true("FinalMetricDataList" in resp)
assert_true("MetricName" in resp["FinalMetricDataList"][0])
assert_true("Value" in resp["FinalMetricDataList"][0])
assert_true("Timestamp" in resp["FinalMetricDataList"][0])
pass