Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)
* Basic upport for Endpoints, EndpointConfigs and TrainingJobs * Dropped extraneous pass statement. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
a123a22eeb
commit
ba99c61477
@ -1,5 +1,6 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
@ -32,6 +33,288 @@ class BaseObject(BaseModel):
|
||||
return self.gen_response_object()
|
||||
|
||||
|
||||
class FakeTrainingJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
training_job_name,
|
||||
hyper_parameters,
|
||||
algorithm_specification,
|
||||
role_arn,
|
||||
input_data_config,
|
||||
output_data_config,
|
||||
resource_config,
|
||||
vpc_config,
|
||||
stopping_condition,
|
||||
tags,
|
||||
enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training,
|
||||
checkpoint_config,
|
||||
debug_hook_config,
|
||||
debug_rule_configurations,
|
||||
tensor_board_output_config,
|
||||
experiment_config,
|
||||
):
|
||||
self.training_job_name = training_job_name
|
||||
self.hyper_parameters = hyper_parameters
|
||||
self.algorithm_specification = algorithm_specification
|
||||
self.role_arn = role_arn
|
||||
self.input_data_config = input_data_config
|
||||
self.output_data_config = output_data_config
|
||||
self.resource_config = resource_config
|
||||
self.vpc_config = vpc_config
|
||||
self.stopping_condition = stopping_condition
|
||||
self.tags = tags
|
||||
self.enable_network_isolation = enable_network_isolation
|
||||
self.enable_inter_container_traffic_encryption = (
|
||||
enable_inter_container_traffic_encryption
|
||||
)
|
||||
self.enable_managed_spot_training = enable_managed_spot_training
|
||||
self.checkpoint_config = checkpoint_config
|
||||
self.debug_hook_config = debug_hook_config
|
||||
self.debug_rule_configurations = debug_rule_configurations
|
||||
self.tensor_board_output_config = tensor_board_output_config
|
||||
self.experiment_config = experiment_config
|
||||
self.training_job_arn = FakeTrainingJob.arn_formatter(
|
||||
training_job_name, region_name
|
||||
)
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self.model_artifacts = {
|
||||
"S3ModelArtifacts": os.path.join(
|
||||
self.output_data_config["S3OutputPath"],
|
||||
self.training_job_name,
|
||||
"output",
|
||||
"model.tar.gz",
|
||||
)
|
||||
}
|
||||
self.training_job_status = "Completed"
|
||||
self.secondary_status = "Completed"
|
||||
self.algorithm_specification["MetricDefinitions"] = [
|
||||
{
|
||||
"Name": "test:dcg",
|
||||
"Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)",
|
||||
}
|
||||
]
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.training_start_time = now_string
|
||||
self.training_end_time = now_string
|
||||
self.secondary_status_transitions = [
|
||||
{
|
||||
"Status": "Starting",
|
||||
"StartTime": self.creation_time,
|
||||
"EndTime": self.creation_time,
|
||||
"StatusMessage": "Preparing the instances for training",
|
||||
}
|
||||
]
|
||||
self.final_metric_data_list = [
|
||||
{
|
||||
"MetricName": "train:progress",
|
||||
"Value": 100.0,
|
||||
"Timestamp": self.creation_time,
|
||||
}
|
||||
]
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"TrainingJobArn": self.training_job_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(endpoint_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":training-job/"
|
||||
+ endpoint_name
|
||||
)
|
||||
|
||||
|
||||
class FakeEndpoint(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
endpoint_name,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
):
|
||||
self.endpoint_name = endpoint_name
|
||||
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
|
||||
self.endpoint_config_name = endpoint_config_name
|
||||
self.production_variants = production_variants
|
||||
self.data_capture_config = data_capture_config
|
||||
self.tags = tags or []
|
||||
self.endpoint_status = "InService"
|
||||
self.failure_reason = None
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"EndpointArn": self.endpoint_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(endpoint_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":endpoint/"
|
||||
+ endpoint_name
|
||||
)
|
||||
|
||||
|
||||
class FakeEndpointConfig(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
kms_key_id,
|
||||
):
|
||||
self.validate_production_variants(production_variants)
|
||||
|
||||
self.endpoint_config_name = endpoint_config_name
|
||||
self.endpoint_config_arn = FakeEndpointConfig.arn_formatter(
|
||||
endpoint_config_name, region_name
|
||||
)
|
||||
self.production_variants = production_variants or []
|
||||
self.data_capture_config = data_capture_config or {}
|
||||
self.tags = tags or []
|
||||
self.kms_key_id = kms_key_id
|
||||
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def validate_production_variants(self, production_variants):
|
||||
for production_variant in production_variants:
|
||||
self.validate_instance_type(production_variant["InstanceType"])
|
||||
|
||||
def validate_instance_type(self, instance_type):
|
||||
VALID_INSTANCE_TYPES = [
|
||||
"ml.r5d.12xlarge",
|
||||
"ml.r5.12xlarge",
|
||||
"ml.p2.xlarge",
|
||||
"ml.m5.4xlarge",
|
||||
"ml.m4.16xlarge",
|
||||
"ml.r5d.24xlarge",
|
||||
"ml.r5.24xlarge",
|
||||
"ml.p3.16xlarge",
|
||||
"ml.m5d.xlarge",
|
||||
"ml.m5.large",
|
||||
"ml.t2.xlarge",
|
||||
"ml.p2.16xlarge",
|
||||
"ml.m5d.12xlarge",
|
||||
"ml.inf1.2xlarge",
|
||||
"ml.m5d.24xlarge",
|
||||
"ml.c4.2xlarge",
|
||||
"ml.c5.2xlarge",
|
||||
"ml.c4.4xlarge",
|
||||
"ml.inf1.6xlarge",
|
||||
"ml.c5d.2xlarge",
|
||||
"ml.c5.4xlarge",
|
||||
"ml.g4dn.xlarge",
|
||||
"ml.g4dn.12xlarge",
|
||||
"ml.c5d.4xlarge",
|
||||
"ml.g4dn.2xlarge",
|
||||
"ml.c4.8xlarge",
|
||||
"ml.c4.large",
|
||||
"ml.c5d.xlarge",
|
||||
"ml.c5.large",
|
||||
"ml.g4dn.4xlarge",
|
||||
"ml.c5.9xlarge",
|
||||
"ml.g4dn.16xlarge",
|
||||
"ml.c5d.large",
|
||||
"ml.c5.xlarge",
|
||||
"ml.c5d.9xlarge",
|
||||
"ml.c4.xlarge",
|
||||
"ml.inf1.xlarge",
|
||||
"ml.g4dn.8xlarge",
|
||||
"ml.inf1.24xlarge",
|
||||
"ml.m5d.2xlarge",
|
||||
"ml.t2.2xlarge",
|
||||
"ml.c5d.18xlarge",
|
||||
"ml.m5d.4xlarge",
|
||||
"ml.t2.medium",
|
||||
"ml.c5.18xlarge",
|
||||
"ml.r5d.2xlarge",
|
||||
"ml.r5.2xlarge",
|
||||
"ml.p3.2xlarge",
|
||||
"ml.m5d.large",
|
||||
"ml.m5.xlarge",
|
||||
"ml.m4.10xlarge",
|
||||
"ml.t2.large",
|
||||
"ml.r5d.4xlarge",
|
||||
"ml.r5.4xlarge",
|
||||
"ml.m5.12xlarge",
|
||||
"ml.m4.xlarge",
|
||||
"ml.m5.24xlarge",
|
||||
"ml.m4.2xlarge",
|
||||
"ml.p2.8xlarge",
|
||||
"ml.m5.2xlarge",
|
||||
"ml.r5d.xlarge",
|
||||
"ml.r5d.large",
|
||||
"ml.r5.xlarge",
|
||||
"ml.r5.large",
|
||||
"ml.p3.8xlarge",
|
||||
"ml.m4.4xlarge",
|
||||
]
|
||||
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
|
||||
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
instance_type, VALID_INSTANCE_TYPES
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"EndpointConfigArn": self.endpoint_config_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(model_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":endpoint-config/"
|
||||
+ model_name
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
@ -238,6 +521,9 @@ class SageMakerModelBackend(BaseBackend):
|
||||
def __init__(self, region_name=None):
|
||||
self._models = {}
|
||||
self.notebook_instances = {}
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.training_jobs = {}
|
||||
self.region_name = region_name
|
||||
|
||||
def reset(self):
|
||||
@ -305,10 +591,10 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self._validate_unique_notebook_instance_name(notebook_instance_name)
|
||||
|
||||
notebook_instance = FakeSagemakerNotebookInstance(
|
||||
self.region_name,
|
||||
notebook_instance_name,
|
||||
instance_type,
|
||||
role_arn,
|
||||
region_name=self.region_name,
|
||||
notebook_instance_name=notebook_instance_name,
|
||||
instance_type=instance_type,
|
||||
role_arn=role_arn,
|
||||
subnet_id=subnet_id,
|
||||
security_group_ids=security_group_ids,
|
||||
kms_key_id=kms_key_id,
|
||||
@ -392,6 +678,235 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_endpoint_config(
|
||||
self,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
kms_key_id,
|
||||
):
|
||||
endpoint_config = FakeEndpointConfig(
|
||||
region_name=self.region_name,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=production_variants,
|
||||
data_capture_config=data_capture_config,
|
||||
tags=tags,
|
||||
kms_key_id=kms_key_id,
|
||||
)
|
||||
self.validate_production_variants(production_variants)
|
||||
|
||||
self.endpoint_configs[endpoint_config_name] = endpoint_config
|
||||
return endpoint_config
|
||||
|
||||
def validate_production_variants(self, production_variants):
|
||||
for production_variant in production_variants:
|
||||
if production_variant["ModelName"] not in self._models:
|
||||
message = "Could not find model '{}'.".format(
|
||||
Model.arn_for_model_name(
|
||||
production_variant["ModelName"], self.region_name
|
||||
)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def describe_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
return self.endpoint_configs[endpoint_config_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
del self.endpoint_configs[endpoint_config_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def create_endpoint(
|
||||
self, endpoint_name, endpoint_config_name, tags,
|
||||
):
|
||||
try:
|
||||
endpoint_config = self.describe_endpoint_config(endpoint_config_name)
|
||||
except KeyError:
|
||||
message = "Could not find endpoint_config '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
endpoint = FakeEndpoint(
|
||||
region_name=self.region_name,
|
||||
endpoint_name=endpoint_name,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=endpoint_config["ProductionVariants"],
|
||||
data_capture_config=endpoint_config["DataCaptureConfig"],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.endpoints[endpoint_name] = endpoint
|
||||
return endpoint
|
||||
|
||||
def describe_endpoint(self, endpoint_name):
|
||||
try:
|
||||
return self.endpoints[endpoint_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_endpoint(self, endpoint_name):
|
||||
try:
|
||||
del self.endpoints[endpoint_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_endpoint_by_arn(self, arn):
|
||||
endpoints = [
|
||||
endpoint
|
||||
for endpoint in self.endpoints.values()
|
||||
if endpoint.endpoint_arn == arn
|
||||
]
|
||||
if len(endpoints) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
return endpoints[0]
|
||||
|
||||
def get_endpoint_tags(self, arn):
|
||||
try:
|
||||
endpoint = self.get_endpoint_by_arn(arn)
|
||||
return endpoint.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_training_job(
|
||||
self,
|
||||
training_job_name,
|
||||
hyper_parameters,
|
||||
algorithm_specification,
|
||||
role_arn,
|
||||
input_data_config,
|
||||
output_data_config,
|
||||
resource_config,
|
||||
vpc_config,
|
||||
stopping_condition,
|
||||
tags,
|
||||
enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training,
|
||||
checkpoint_config,
|
||||
debug_hook_config,
|
||||
debug_rule_configurations,
|
||||
tensor_board_output_config,
|
||||
experiment_config,
|
||||
):
|
||||
training_job = FakeTrainingJob(
|
||||
region_name=self.region_name,
|
||||
training_job_name=training_job_name,
|
||||
hyper_parameters=hyper_parameters,
|
||||
algorithm_specification=algorithm_specification,
|
||||
role_arn=role_arn,
|
||||
input_data_config=input_data_config,
|
||||
output_data_config=output_data_config,
|
||||
resource_config=resource_config,
|
||||
vpc_config=vpc_config,
|
||||
stopping_condition=stopping_condition,
|
||||
tags=tags,
|
||||
enable_network_isolation=enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training=enable_managed_spot_training,
|
||||
checkpoint_config=checkpoint_config,
|
||||
debug_hook_config=debug_hook_config,
|
||||
debug_rule_configurations=debug_rule_configurations,
|
||||
tensor_board_output_config=tensor_board_output_config,
|
||||
experiment_config=experiment_config,
|
||||
)
|
||||
self.training_jobs[training_job_name] = training_job
|
||||
return training_job
|
||||
|
||||
def describe_training_job(self, training_job_name):
|
||||
try:
|
||||
return self.training_jobs[training_job_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find training job '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_training_job(self, training_job_name):
|
||||
try:
|
||||
del self.training_jobs[training_job_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_training_job_by_arn(self, arn):
|
||||
training_jobs = [
|
||||
training_job
|
||||
for training_job in self.training_jobs.values()
|
||||
if training_job.training_job_arn == arn
|
||||
]
|
||||
if len(training_jobs) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
return training_jobs[0]
|
||||
|
||||
def get_training_job_tags(self, arn):
|
||||
try:
|
||||
training_job = self.get_training_job_by_arn(arn)
|
||||
return training_job.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
|
||||
sagemaker_backends = {}
|
||||
for region, ec2_backend in ec2_backends.items():
|
||||
|
@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
|
||||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
try:
|
||||
if ":notebook-instance/" in arn:
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
elif ":endpoint/" in arn:
|
||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||
elif ":training-job/" in arn:
|
||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||
else:
|
||||
tags = []
|
||||
except AWSError:
|
||||
tags = []
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint_config(self):
|
||||
try:
|
||||
endpoint_config = self.sagemaker_backend.create_endpoint_config(
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
production_variants=self._get_param("ProductionVariants"),
|
||||
data_capture_config=self._get_param("DataCaptureConfig"),
|
||||
tags=self._get_param("Tags"),
|
||||
kms_key_id=self._get_param("KmsKeyId"),
|
||||
)
|
||||
response = {
|
||||
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint(self):
|
||||
try:
|
||||
endpoint = self.sagemaker_backend.create_endpoint(
|
||||
endpoint_name=self._get_param("EndpointName"),
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
tags=self._get_param("Tags"),
|
||||
)
|
||||
response = {
|
||||
"EndpointArn": endpoint.endpoint_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_training_job(self):
|
||||
try:
|
||||
training_job = self.sagemaker_backend.create_training_job(
|
||||
training_job_name=self._get_param("TrainingJobName"),
|
||||
hyper_parameters=self._get_param("HyperParameters"),
|
||||
algorithm_specification=self._get_param("AlgorithmSpecification"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
input_data_config=self._get_param("InputDataConfig"),
|
||||
output_data_config=self._get_param("OutputDataConfig"),
|
||||
resource_config=self._get_param("ResourceConfig"),
|
||||
vpc_config=self._get_param("VpcConfig"),
|
||||
stopping_condition=self._get_param("StoppingCondition"),
|
||||
tags=self._get_param("Tags"),
|
||||
enable_network_isolation=self._get_param(
|
||||
"EnableNetworkIsolation", False
|
||||
),
|
||||
enable_inter_container_traffic_encryption=self._get_param(
|
||||
"EnableInterContainerTrafficEncryption", False
|
||||
),
|
||||
enable_managed_spot_training=self._get_param(
|
||||
"EnableManagedSpotTraining", False
|
||||
),
|
||||
checkpoint_config=self._get_param("CheckpointConfig"),
|
||||
debug_hook_config=self._get_param("DebugHookConfig"),
|
||||
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
|
||||
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
|
||||
experiment_config=self._get_param("ExperimentConfig"),
|
||||
)
|
||||
response = {
|
||||
"TrainingJobArn": training_job.training_job_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
response = self.sagemaker_backend.describe_training_job(training_job_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
self.sagemaker_backend.delete_training_job(training_job_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
@ -3,7 +3,6 @@ from .responses import SageMakerResponse
|
||||
|
||||
url_bases = [
|
||||
"https?://api.sagemaker.(.+).amazonaws.com",
|
||||
"https?://api-fips.sagemaker.(.+).amazonaws.com",
|
||||
]
|
||||
|
||||
url_paths = {
|
||||
|
246
tests/test_sagemaker/test_sagemaker_endpoint.py
Normal file
246
tests/test_sagemaker/test_sagemaker_endpoint.py
Normal file
@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, ParamValidationError
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from nose.tools import assert_true, assert_equal, assert_raises
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
GENERIC_TAGS_PARAM = [
|
||||
{"Key": "newkey1", "Value": "newval1"},
|
||||
{"Key": "newkey2", "Value": "newval2"},
|
||||
]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find model")
|
||||
)
|
||||
|
||||
_create_model(sagemaker, model_name)
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["ProductionVariants"].should.equal(production_variants)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_invalid_instance_type():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
instance_type = "InvalidInstanceType"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": instance_type,
|
||||
},
|
||||
]
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert_equal(e.exception.response["Error"]["Code"], "ValidationException")
|
||||
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||
instance_type
|
||||
)
|
||||
assert_true(expected_message in e.exception.response["Error"]["Message"])
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
|
||||
resp = sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
Tags=GENERIC_TAGS_PARAM,
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
||||
resp["EndpointName"].should.equal(endpoint_name)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
|
||||
|
||||
resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
_create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
|
||||
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
)
|
||||
|
||||
|
||||
def _create_model(boto_client, model_name):
|
||||
resp = boto_client.create_model(
|
||||
ModelName=model_name,
|
||||
PrimaryContainer={
|
||||
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
|
||||
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
|
||||
},
|
||||
ExecutionRoleArn=FAKE_ROLE_ARN,
|
||||
)
|
||||
assert_equal(resp["ResponseMetadata"]["HTTPStatusCode"], 200)
|
||||
|
||||
|
||||
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
resp = boto_client.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
|
||||
resp = boto_client.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
127
tests/test_sagemaker/test_sagemaker_training.py
Normal file
127
tests/test_sagemaker/test_sagemaker_training.py
Normal file
@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import boto3
|
||||
import datetime
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from nose.tools import assert_true, assert_equal, assert_raises, assert_regexp_matches
|
||||
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_training_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
training_job_name = "MyTrainingJob"
|
||||
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
bucket = "my-bucket"
|
||||
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
|
||||
|
||||
params = {
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"TrainingJobName": training_job_name,
|
||||
"AlgorithmSpecification": {
|
||||
"TrainingImage": container,
|
||||
"TrainingInputMode": "File",
|
||||
},
|
||||
"ResourceConfig": {
|
||||
"InstanceCount": 1,
|
||||
"InstanceType": "ml.c4.2xlarge",
|
||||
"VolumeSizeInGB": 10,
|
||||
},
|
||||
"InputDataConfig": [
|
||||
{
|
||||
"ChannelName": "train",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/train/".format(bucket, prefix),
|
||||
"S3DataDistributionType": "ShardedByS3Key",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
{
|
||||
"ChannelName": "validation",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/validation/".format(bucket, prefix),
|
||||
"S3DataDistributionType": "FullyReplicated",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
],
|
||||
"OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)},
|
||||
"HyperParameters": {
|
||||
"feature_dim": "30",
|
||||
"mini_batch_size": "100",
|
||||
"predictor_type": "regressor",
|
||||
"epochs": "10",
|
||||
"num_models": "32",
|
||||
"loss": "absolute_loss",
|
||||
},
|
||||
"StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60},
|
||||
}
|
||||
|
||||
resp = sagemaker.create_training_job(**params)
|
||||
resp["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
|
||||
resp["TrainingJobName"].should.equal(training_job_name)
|
||||
resp["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
assert_true(
|
||||
resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
|
||||
params["OutputDataConfig"]["S3OutputPath"]
|
||||
)
|
||||
)
|
||||
assert_true(training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]))
|
||||
assert_true(
|
||||
resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
|
||||
)
|
||||
assert_equal(resp["TrainingJobStatus"], "Completed")
|
||||
assert_equal(resp["SecondaryStatus"], "Completed")
|
||||
assert_equal(resp["HyperParameters"], params["HyperParameters"])
|
||||
assert_equal(
|
||||
resp["AlgorithmSpecification"]["TrainingImage"],
|
||||
params["AlgorithmSpecification"]["TrainingImage"],
|
||||
)
|
||||
assert_equal(
|
||||
resp["AlgorithmSpecification"]["TrainingInputMode"],
|
||||
params["AlgorithmSpecification"]["TrainingInputMode"],
|
||||
)
|
||||
assert_true("MetricDefinitions" in resp["AlgorithmSpecification"])
|
||||
assert_true("Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
|
||||
assert_true("Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
|
||||
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
|
||||
assert_equal(resp["InputDataConfig"], params["InputDataConfig"])
|
||||
assert_equal(resp["OutputDataConfig"], params["OutputDataConfig"])
|
||||
assert_equal(resp["ResourceConfig"], params["ResourceConfig"])
|
||||
assert_equal(resp["StoppingCondition"], params["StoppingCondition"])
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["TrainingStartTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["TrainingEndTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
assert_true("SecondaryStatusTransitions" in resp)
|
||||
assert_true("Status" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("StartTime" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("EndTime" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("StatusMessage" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("FinalMetricDataList" in resp)
|
||||
assert_true("MetricName" in resp["FinalMetricDataList"][0])
|
||||
assert_true("Value" in resp["FinalMetricDataList"][0])
|
||||
assert_true("Timestamp" in resp["FinalMetricDataList"][0])
|
||||
|
||||
pass
|
Loading…
Reference in New Issue
Block a user