2102 lines
75 KiB
Python
2102 lines
75 KiB
Python
import json
|
|
import os
|
|
from datetime import datetime
|
|
from moto.core import BaseBackend, BaseModel, CloudFormationModel
|
|
from moto.core.utils import BackendDict
|
|
from moto.sagemaker import validators
|
|
from moto.utilities.paginator import paginate
|
|
from .exceptions import (
|
|
MissingModel,
|
|
ValidationError,
|
|
AWSValidationException,
|
|
ResourceNotFound,
|
|
)
|
|
|
|
|
|
PAGINATION_MODEL = {
|
|
"list_experiments": {
|
|
"input_token": "NextToken",
|
|
"limit_key": "MaxResults",
|
|
"limit_default": 100,
|
|
"unique_attribute": "experiment_arn",
|
|
"fail_on_invalid_token": True,
|
|
},
|
|
"list_trials": {
|
|
"input_token": "NextToken",
|
|
"limit_key": "MaxResults",
|
|
"limit_default": 100,
|
|
"unique_attribute": "trial_arn",
|
|
"fail_on_invalid_token": True,
|
|
},
|
|
"list_trial_components": {
|
|
"input_token": "NextToken",
|
|
"limit_key": "MaxResults",
|
|
"limit_default": 100,
|
|
"unique_attribute": "trial_component_arn",
|
|
"fail_on_invalid_token": True,
|
|
},
|
|
"list_tags": {
|
|
"input_token": "NextToken",
|
|
"limit_key": "MaxResults",
|
|
"limit_default": 50,
|
|
"unique_attribute": "Key",
|
|
"fail_on_invalid_token": True,
|
|
},
|
|
}
|
|
|
|
|
|
def arn_formatter(_type, _id, account_id, region_name):
|
|
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
|
|
|
|
|
class BaseObject(BaseModel):
|
|
def camelCase(self, key):
|
|
words = []
|
|
for word in key.split("_"):
|
|
words.append(word.title())
|
|
return "".join(words)
|
|
|
|
def update(self, details_json):
|
|
details = json.loads(details_json)
|
|
for k in details.keys():
|
|
setattr(self, k, details[k])
|
|
|
|
def gen_response_object(self):
|
|
response_object = dict()
|
|
for key, value in self.__dict__.items():
|
|
if "_" in key:
|
|
response_object[self.camelCase(key)] = value
|
|
else:
|
|
response_object[key[0].upper() + key[1:]] = value
|
|
return response_object
|
|
|
|
@property
|
|
def response_object(self):
|
|
return self.gen_response_object()
|
|
|
|
|
|
class FakeProcessingJob(BaseObject):
|
|
def __init__(
|
|
self,
|
|
app_specification,
|
|
experiment_config,
|
|
network_config,
|
|
processing_inputs,
|
|
processing_job_name,
|
|
processing_output_config,
|
|
account_id,
|
|
region_name,
|
|
role_arn,
|
|
tags,
|
|
stopping_condition,
|
|
):
|
|
self.processing_job_name = processing_job_name
|
|
self.processing_job_arn = arn_formatter(
|
|
"processing-job", processing_job_name, account_id, region_name
|
|
)
|
|
|
|
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
self.creation_time = now_string
|
|
self.last_modified_time = now_string
|
|
self.processing_end_time = now_string
|
|
self.tags = tags or []
|
|
self.role_arn = role_arn
|
|
self.app_specification = app_specification
|
|
self.experiment_config = experiment_config
|
|
self.network_config = network_config
|
|
self.processing_inputs = processing_inputs
|
|
self.processing_job_status = "Completed"
|
|
self.processing_output_config = processing_output_config
|
|
self.stopping_condition = stopping_condition
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"ProcessingJobArn": self.processing_job_arn}
|
|
|
|
|
|
class FakeTrainingJob(BaseObject):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
training_job_name,
|
|
hyper_parameters,
|
|
algorithm_specification,
|
|
role_arn,
|
|
input_data_config,
|
|
output_data_config,
|
|
resource_config,
|
|
vpc_config,
|
|
stopping_condition,
|
|
tags,
|
|
enable_network_isolation,
|
|
enable_inter_container_traffic_encryption,
|
|
enable_managed_spot_training,
|
|
checkpoint_config,
|
|
debug_hook_config,
|
|
debug_rule_configurations,
|
|
tensor_board_output_config,
|
|
experiment_config,
|
|
):
|
|
self.training_job_name = training_job_name
|
|
self.hyper_parameters = hyper_parameters
|
|
self.algorithm_specification = algorithm_specification
|
|
self.role_arn = role_arn
|
|
self.input_data_config = input_data_config
|
|
self.output_data_config = output_data_config
|
|
self.resource_config = resource_config
|
|
self.vpc_config = vpc_config
|
|
self.stopping_condition = stopping_condition
|
|
self.tags = tags or []
|
|
self.enable_network_isolation = enable_network_isolation
|
|
self.enable_inter_container_traffic_encryption = (
|
|
enable_inter_container_traffic_encryption
|
|
)
|
|
self.enable_managed_spot_training = enable_managed_spot_training
|
|
self.checkpoint_config = checkpoint_config
|
|
self.debug_hook_config = debug_hook_config
|
|
self.debug_rule_configurations = debug_rule_configurations
|
|
self.tensor_board_output_config = tensor_board_output_config
|
|
self.experiment_config = experiment_config
|
|
self.training_job_arn = arn_formatter(
|
|
"training-job", training_job_name, account_id, region_name
|
|
)
|
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
self.model_artifacts = {
|
|
"S3ModelArtifacts": os.path.join(
|
|
self.output_data_config["S3OutputPath"],
|
|
self.training_job_name,
|
|
"output",
|
|
"model.tar.gz",
|
|
)
|
|
}
|
|
self.training_job_status = "Completed"
|
|
self.secondary_status = "Completed"
|
|
self.algorithm_specification["MetricDefinitions"] = [
|
|
{
|
|
"Name": "test:dcg",
|
|
"Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)",
|
|
}
|
|
]
|
|
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
self.creation_time = now_string
|
|
self.last_modified_time = now_string
|
|
self.training_start_time = now_string
|
|
self.training_end_time = now_string
|
|
self.secondary_status_transitions = [
|
|
{
|
|
"Status": "Starting",
|
|
"StartTime": self.creation_time,
|
|
"EndTime": self.creation_time,
|
|
"StatusMessage": "Preparing the instances for training",
|
|
}
|
|
]
|
|
self.final_metric_data_list = [
|
|
{
|
|
"MetricName": "train:progress",
|
|
"Value": 100.0,
|
|
"Timestamp": self.creation_time,
|
|
}
|
|
]
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"TrainingJobArn": self.training_job_arn}
|
|
|
|
|
|
class FakeEndpoint(BaseObject, CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
endpoint_name,
|
|
endpoint_config_name,
|
|
production_variants,
|
|
data_capture_config,
|
|
tags,
|
|
):
|
|
self.endpoint_name = endpoint_name
|
|
self.endpoint_arn = FakeEndpoint.arn_formatter(
|
|
endpoint_name, account_id, region_name
|
|
)
|
|
self.endpoint_config_name = endpoint_config_name
|
|
self.production_variants = self._process_production_variants(
|
|
production_variants
|
|
)
|
|
self.data_capture_config = data_capture_config
|
|
self.tags = tags or []
|
|
self.endpoint_status = "InService"
|
|
self.failure_reason = None
|
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
def _process_production_variants(self, production_variants):
|
|
endpoint_variants = []
|
|
for production_variant in production_variants:
|
|
temp_variant = {}
|
|
|
|
# VariantName is the only required param
|
|
temp_variant["VariantName"] = production_variant["VariantName"]
|
|
|
|
if production_variant.get("InitialInstanceCount", None):
|
|
temp_variant["CurrentInstanceCount"] = production_variant[
|
|
"InitialInstanceCount"
|
|
]
|
|
temp_variant["DesiredInstanceCount"] = production_variant[
|
|
"InitialInstanceCount"
|
|
]
|
|
|
|
if production_variant.get("InitialVariantWeight", None):
|
|
temp_variant["CurrentWeight"] = production_variant[
|
|
"InitialVariantWeight"
|
|
]
|
|
temp_variant["DesiredWeight"] = production_variant[
|
|
"InitialVariantWeight"
|
|
]
|
|
|
|
if production_variant.get("ServerlessConfig", None):
|
|
temp_variant["CurrentServerlessConfig"] = production_variant[
|
|
"ServerlessConfig"
|
|
]
|
|
temp_variant["DesiredServerlessConfig"] = production_variant[
|
|
"ServerlessConfig"
|
|
]
|
|
|
|
endpoint_variants.append(temp_variant)
|
|
|
|
return endpoint_variants
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"EndpointArn": self.endpoint_arn}
|
|
|
|
@staticmethod
|
|
def arn_formatter(endpoint_name, account_id, region_name):
|
|
return arn_formatter("endpoint", endpoint_name, account_id, region_name)
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.endpoint_arn
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["EndpointName"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "EndpointName":
|
|
return self.endpoint_name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return None
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html
|
|
return "AWS::SageMaker::Endpoint"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
sagemaker_backend = sagemaker_backends[account_id][region_name]
|
|
|
|
# Get required properties from provided CloudFormation template
|
|
properties = cloudformation_json["Properties"]
|
|
endpoint_config_name = properties["EndpointConfigName"]
|
|
|
|
endpoint = sagemaker_backend.create_endpoint(
|
|
endpoint_name=resource_name,
|
|
endpoint_config_name=endpoint_config_name,
|
|
tags=properties.get("Tags", []),
|
|
)
|
|
return endpoint
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls,
|
|
original_resource,
|
|
new_resource_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
):
|
|
# Changes to the Endpoint will not change resource name
|
|
cls.delete_from_cloudformation_json(
|
|
original_resource.endpoint_arn, cloudformation_json, account_id, region_name
|
|
)
|
|
new_resource = cls.create_from_cloudformation_json(
|
|
original_resource.endpoint_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
)
|
|
return new_resource
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name
|
|
):
|
|
# Get actual name because resource_name actually provides the ARN
|
|
# since the Physical Resource ID is the ARN despite SageMaker
|
|
# using the name for most of its operations.
|
|
endpoint_name = resource_name.split("/")[-1]
|
|
|
|
sagemaker_backends[account_id][region_name].delete_endpoint(endpoint_name)
|
|
|
|
|
|
class FakeEndpointConfig(BaseObject, CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
endpoint_config_name,
|
|
production_variants,
|
|
data_capture_config,
|
|
tags,
|
|
kms_key_id,
|
|
):
|
|
self.validate_production_variants(production_variants)
|
|
|
|
self.endpoint_config_name = endpoint_config_name
|
|
self.endpoint_config_arn = FakeEndpointConfig.arn_formatter(
|
|
endpoint_config_name, account_id, region_name
|
|
)
|
|
self.production_variants = production_variants or []
|
|
self.data_capture_config = data_capture_config or {}
|
|
self.tags = tags or []
|
|
self.kms_key_id = kms_key_id
|
|
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
def validate_production_variants(self, production_variants):
|
|
for production_variant in production_variants:
|
|
if "InstanceType" in production_variant.keys():
|
|
self.validate_instance_type(production_variant["InstanceType"])
|
|
elif "ServerlessConfig" in production_variant.keys():
|
|
self.validate_serverless_config(production_variant["ServerlessConfig"])
|
|
else:
|
|
message = "Invalid Keys for ProductionVariant: received {} but expected it to contain one of {}".format(
|
|
production_variant.keys(), ["InstanceType", "ServerlessConfig"]
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def validate_serverless_config(self, serverless_config):
|
|
VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144]
|
|
if not validators.is_one_of(
|
|
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
|
):
|
|
message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
|
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def validate_instance_type(self, instance_type):
|
|
VALID_INSTANCE_TYPES = [
|
|
"ml.r5d.12xlarge",
|
|
"ml.r5.12xlarge",
|
|
"ml.p2.xlarge",
|
|
"ml.m5.4xlarge",
|
|
"ml.m4.16xlarge",
|
|
"ml.r5d.24xlarge",
|
|
"ml.r5.24xlarge",
|
|
"ml.p3.16xlarge",
|
|
"ml.m5d.xlarge",
|
|
"ml.m5.large",
|
|
"ml.t2.xlarge",
|
|
"ml.p2.16xlarge",
|
|
"ml.m5d.12xlarge",
|
|
"ml.inf1.2xlarge",
|
|
"ml.m5d.24xlarge",
|
|
"ml.c4.2xlarge",
|
|
"ml.c5.2xlarge",
|
|
"ml.c4.4xlarge",
|
|
"ml.inf1.6xlarge",
|
|
"ml.c5d.2xlarge",
|
|
"ml.c5.4xlarge",
|
|
"ml.g4dn.xlarge",
|
|
"ml.g4dn.12xlarge",
|
|
"ml.c5d.4xlarge",
|
|
"ml.g4dn.2xlarge",
|
|
"ml.c4.8xlarge",
|
|
"ml.c4.large",
|
|
"ml.c5d.xlarge",
|
|
"ml.c5.large",
|
|
"ml.g4dn.4xlarge",
|
|
"ml.c5.9xlarge",
|
|
"ml.g4dn.16xlarge",
|
|
"ml.c5d.large",
|
|
"ml.c5.xlarge",
|
|
"ml.c5d.9xlarge",
|
|
"ml.c4.xlarge",
|
|
"ml.inf1.xlarge",
|
|
"ml.g4dn.8xlarge",
|
|
"ml.inf1.24xlarge",
|
|
"ml.m5d.2xlarge",
|
|
"ml.t2.2xlarge",
|
|
"ml.c5d.18xlarge",
|
|
"ml.m5d.4xlarge",
|
|
"ml.t2.medium",
|
|
"ml.c5.18xlarge",
|
|
"ml.r5d.2xlarge",
|
|
"ml.r5.2xlarge",
|
|
"ml.p3.2xlarge",
|
|
"ml.m5d.large",
|
|
"ml.m5.xlarge",
|
|
"ml.m4.10xlarge",
|
|
"ml.t2.large",
|
|
"ml.r5d.4xlarge",
|
|
"ml.r5.4xlarge",
|
|
"ml.m5.12xlarge",
|
|
"ml.m4.xlarge",
|
|
"ml.m5.24xlarge",
|
|
"ml.m4.2xlarge",
|
|
"ml.p2.8xlarge",
|
|
"ml.m5.2xlarge",
|
|
"ml.r5d.xlarge",
|
|
"ml.r5d.large",
|
|
"ml.r5.xlarge",
|
|
"ml.r5.large",
|
|
"ml.p3.8xlarge",
|
|
"ml.m4.4xlarge",
|
|
]
|
|
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
|
|
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
|
instance_type, VALID_INSTANCE_TYPES
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"EndpointConfigArn": self.endpoint_config_arn}
|
|
|
|
@staticmethod
|
|
def arn_formatter(endpoint_config_name, account_id, region_name):
|
|
return arn_formatter(
|
|
"endpoint-config", endpoint_config_name, account_id, region_name
|
|
)
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.endpoint_config_arn
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["EndpointConfigName"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "EndpointConfigName":
|
|
return self.endpoint_config_name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return None
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html
|
|
return "AWS::SageMaker::EndpointConfig"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
sagemaker_backend = sagemaker_backends[account_id][region_name]
|
|
|
|
# Get required properties from provided CloudFormation template
|
|
properties = cloudformation_json["Properties"]
|
|
production_variants = properties["ProductionVariants"]
|
|
|
|
endpoint_config = sagemaker_backend.create_endpoint_config(
|
|
endpoint_config_name=resource_name,
|
|
production_variants=production_variants,
|
|
data_capture_config=properties.get("DataCaptureConfig", {}),
|
|
kms_key_id=properties.get("KmsKeyId"),
|
|
tags=properties.get("Tags", []),
|
|
)
|
|
return endpoint_config
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls,
|
|
original_resource,
|
|
new_resource_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
):
|
|
# Most changes to the endpoint config will change resource name for EndpointConfigs
|
|
cls.delete_from_cloudformation_json(
|
|
original_resource.endpoint_config_arn,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
)
|
|
new_resource = cls.create_from_cloudformation_json(
|
|
new_resource_name, cloudformation_json, account_id, region_name
|
|
)
|
|
return new_resource
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name
|
|
):
|
|
# Get actual name because resource_name actually provides the ARN
|
|
# since the Physical Resource ID is the ARN despite SageMaker
|
|
# using the name for most of its operations.
|
|
endpoint_config_name = resource_name.split("/")[-1]
|
|
|
|
sagemaker_backends[account_id][region_name].delete_endpoint_config(
|
|
endpoint_config_name
|
|
)
|
|
|
|
|
|
class Model(BaseObject, CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
model_name,
|
|
execution_role_arn,
|
|
primary_container,
|
|
vpc_config,
|
|
containers=None,
|
|
tags=None,
|
|
):
|
|
self.model_name = model_name
|
|
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
self.containers = containers or []
|
|
self.tags = tags or []
|
|
self.enable_network_isolation = False
|
|
self.vpc_config = vpc_config
|
|
self.primary_container = primary_container
|
|
self.execution_role_arn = execution_role_arn or "arn:test"
|
|
self.model_arn = arn_formatter(
|
|
"model", self.model_name, account_id, region_name
|
|
)
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"ModelArn": self.model_arn}
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.model_arn
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["ModelName"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html#aws-resource-sagemaker-model-return-values
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "ModelName":
|
|
return self.model_name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return None
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html
|
|
return "AWS::SageMaker::Model"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
sagemaker_backend = sagemaker_backends[account_id][region_name]
|
|
|
|
# Get required properties from provided CloudFormation template
|
|
properties = cloudformation_json["Properties"]
|
|
execution_role_arn = properties["ExecutionRoleArn"]
|
|
primary_container = properties["PrimaryContainer"]
|
|
|
|
model = sagemaker_backend.create_model(
|
|
ModelName=resource_name,
|
|
ExecutionRoleArn=execution_role_arn,
|
|
PrimaryContainer=primary_container,
|
|
VpcConfig=properties.get("VpcConfig", {}),
|
|
Containers=properties.get("Containers", []),
|
|
Tags=properties.get("Tags", []),
|
|
)
|
|
return model
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls,
|
|
original_resource,
|
|
new_resource_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
):
|
|
# Most changes to the model will change resource name for Models
|
|
cls.delete_from_cloudformation_json(
|
|
original_resource.model_arn, cloudformation_json, account_id, region_name
|
|
)
|
|
new_resource = cls.create_from_cloudformation_json(
|
|
new_resource_name, cloudformation_json, account_id, region_name
|
|
)
|
|
return new_resource
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name
|
|
):
|
|
# Get actual name because resource_name actually provides the ARN
|
|
# since the Physical Resource ID is the ARN despite SageMaker
|
|
# using the name for most of its operations.
|
|
model_name = resource_name.split("/")[-1]
|
|
|
|
sagemaker_backends[account_id][region_name].delete_model(model_name)
|
|
|
|
|
|
class VpcConfig(BaseObject):
|
|
def __init__(self, security_group_ids, subnets):
|
|
self.security_group_ids = security_group_ids
|
|
self.subnets = subnets
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
|
|
class Container(BaseObject):
|
|
def __init__(self, **kwargs):
|
|
self.container_hostname = kwargs.get("container_hostname", "localhost")
|
|
self.model_data_url = kwargs.get("data_url", "")
|
|
self.model_package_name = kwargs.get("package_name", "pkg")
|
|
self.image = kwargs.get("image", "")
|
|
self.environment = kwargs.get("environment", {})
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
|
|
class FakeSagemakerNotebookInstance(CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
notebook_instance_name,
|
|
instance_type,
|
|
role_arn,
|
|
subnet_id,
|
|
security_group_ids,
|
|
kms_key_id,
|
|
tags,
|
|
lifecycle_config_name,
|
|
direct_internet_access,
|
|
volume_size_in_gb,
|
|
accelerator_types,
|
|
default_code_repository,
|
|
additional_code_repositories,
|
|
root_access,
|
|
):
|
|
self.validate_volume_size_in_gb(volume_size_in_gb)
|
|
self.validate_instance_type(instance_type)
|
|
|
|
self.region_name = region_name
|
|
self.notebook_instance_name = notebook_instance_name
|
|
self.instance_type = instance_type
|
|
self.role_arn = role_arn
|
|
self.subnet_id = subnet_id
|
|
self.security_group_ids = security_group_ids
|
|
self.kms_key_id = kms_key_id
|
|
self.tags = tags or []
|
|
self.lifecycle_config_name = lifecycle_config_name
|
|
self.direct_internet_access = direct_internet_access
|
|
self.volume_size_in_gb = volume_size_in_gb
|
|
self.accelerator_types = accelerator_types
|
|
self.default_code_repository = default_code_repository
|
|
self.additional_code_repositories = additional_code_repositories
|
|
self.root_access = root_access
|
|
self.status = None
|
|
self.creation_time = self.last_modified_time = datetime.now()
|
|
self.arn = arn_formatter(
|
|
"notebook-instance", notebook_instance_name, account_id, region_name
|
|
)
|
|
self.start()
|
|
|
|
def validate_volume_size_in_gb(self, volume_size_in_gb):
|
|
if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True):
|
|
message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf"
|
|
raise ValidationError(message=message)
|
|
|
|
def validate_instance_type(self, instance_type):
|
|
VALID_INSTANCE_TYPES = [
|
|
"ml.p2.xlarge",
|
|
"ml.m5.4xlarge",
|
|
"ml.m4.16xlarge",
|
|
"ml.t3.xlarge",
|
|
"ml.p3.16xlarge",
|
|
"ml.t2.xlarge",
|
|
"ml.p2.16xlarge",
|
|
"ml.c4.2xlarge",
|
|
"ml.c5.2xlarge",
|
|
"ml.c4.4xlarge",
|
|
"ml.c5d.2xlarge",
|
|
"ml.c5.4xlarge",
|
|
"ml.c5d.4xlarge",
|
|
"ml.c4.8xlarge",
|
|
"ml.c5d.xlarge",
|
|
"ml.c5.9xlarge",
|
|
"ml.c5.xlarge",
|
|
"ml.c5d.9xlarge",
|
|
"ml.c4.xlarge",
|
|
"ml.t2.2xlarge",
|
|
"ml.c5d.18xlarge",
|
|
"ml.t3.2xlarge",
|
|
"ml.t3.medium",
|
|
"ml.t2.medium",
|
|
"ml.c5.18xlarge",
|
|
"ml.p3.2xlarge",
|
|
"ml.m5.xlarge",
|
|
"ml.m4.10xlarge",
|
|
"ml.t2.large",
|
|
"ml.m5.12xlarge",
|
|
"ml.m4.xlarge",
|
|
"ml.t3.large",
|
|
"ml.m5.24xlarge",
|
|
"ml.m4.2xlarge",
|
|
"ml.p2.8xlarge",
|
|
"ml.m5.2xlarge",
|
|
"ml.p3.8xlarge",
|
|
"ml.m4.4xlarge",
|
|
]
|
|
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
|
|
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
|
instance_type, VALID_INSTANCE_TYPES
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
@property
|
|
def url(self):
|
|
return "{}.notebook.{}.sagemaker.aws".format(
|
|
self.notebook_instance_name, self.region_name
|
|
)
|
|
|
|
def start(self):
|
|
self.status = "InService"
|
|
|
|
@property
|
|
def is_deletable(self):
|
|
return self.status in ["Stopped", "Failed"]
|
|
|
|
def stop(self):
|
|
self.status = "Stopped"
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.arn
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["NotebookInstanceName"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html#aws-resource-sagemaker-notebookinstance-return-values
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "NotebookInstanceName":
|
|
return self.notebook_instance_name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return None
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html
|
|
return "AWS::SageMaker::NotebookInstance"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
# Get required properties from provided CloudFormation template
|
|
properties = cloudformation_json["Properties"]
|
|
instance_type = properties["InstanceType"]
|
|
role_arn = properties["RoleArn"]
|
|
|
|
notebook = sagemaker_backends[account_id][region_name].create_notebook_instance(
|
|
notebook_instance_name=resource_name,
|
|
instance_type=instance_type,
|
|
role_arn=role_arn,
|
|
)
|
|
return notebook
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls,
|
|
original_resource,
|
|
new_resource_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
):
|
|
# Operations keep same resource name so delete old and create new to mimic update
|
|
cls.delete_from_cloudformation_json(
|
|
original_resource.arn, cloudformation_json, account_id, region_name
|
|
)
|
|
new_resource = cls.create_from_cloudformation_json(
|
|
original_resource.notebook_instance_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
)
|
|
return new_resource
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name
|
|
):
|
|
# Get actual name because resource_name actually provides the ARN
|
|
# since the Physical Resource ID is the ARN despite SageMaker
|
|
# using the name for most of its operations.
|
|
notebook_instance_name = resource_name.split("/")[-1]
|
|
|
|
backend = sagemaker_backends[account_id][region_name]
|
|
backend.stop_notebook_instance(notebook_instance_name)
|
|
backend.delete_notebook_instance(notebook_instance_name)
|
|
|
|
|
|
class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
notebook_instance_lifecycle_config_name,
|
|
on_create,
|
|
on_start,
|
|
):
|
|
self.region_name = region_name
|
|
self.notebook_instance_lifecycle_config_name = (
|
|
notebook_instance_lifecycle_config_name
|
|
)
|
|
self.on_create = on_create
|
|
self.on_start = on_start
|
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
self.notebook_instance_lifecycle_config_arn = (
|
|
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
|
self.notebook_instance_lifecycle_config_name, account_id, region_name
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def arn_formatter(name, account_id, region_name):
|
|
return arn_formatter(
|
|
"notebook-instance-lifecycle-configuration", name, account_id, region_name
|
|
)
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"TrainingJobArn": self.training_job_arn}
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.notebook_instance_lifecycle_config_arn
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["NotebookInstanceLifecycleConfigName"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html#aws-resource-sagemaker-notebookinstancelifecycleconfig-return-values
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "NotebookInstanceLifecycleConfigName":
|
|
return self.notebook_instance_lifecycle_config_name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return None
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html
|
|
return "AWS::SageMaker::NotebookInstanceLifecycleConfig"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
properties = cloudformation_json["Properties"]
|
|
|
|
config = sagemaker_backends[account_id][
|
|
region_name
|
|
].create_notebook_instance_lifecycle_config(
|
|
notebook_instance_lifecycle_config_name=resource_name,
|
|
on_create=properties.get("OnCreate"),
|
|
on_start=properties.get("OnStart"),
|
|
)
|
|
return config
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls,
|
|
original_resource,
|
|
new_resource_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
):
|
|
# Operations keep same resource name so delete old and create new to mimic update
|
|
cls.delete_from_cloudformation_json(
|
|
original_resource.notebook_instance_lifecycle_config_arn,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
)
|
|
new_resource = cls.create_from_cloudformation_json(
|
|
original_resource.notebook_instance_lifecycle_config_name,
|
|
cloudformation_json,
|
|
account_id,
|
|
region_name,
|
|
)
|
|
return new_resource
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name
|
|
):
|
|
# Get actual name because resource_name actually provides the ARN
|
|
# since the Physical Resource ID is the ARN despite SageMaker
|
|
# using the name for most of its operations.
|
|
config_name = resource_name.split("/")[-1]
|
|
|
|
backend = sagemaker_backends[account_id][region_name]
|
|
backend.delete_notebook_instance_lifecycle_config(config_name)
|
|
|
|
|
|
class SageMakerModelBackend(BaseBackend):
|
|
def __init__(self, region_name, account_id):
|
|
super().__init__(region_name, account_id)
|
|
self._models = {}
|
|
self.notebook_instances = {}
|
|
self.endpoint_configs = {}
|
|
self.endpoints = {}
|
|
self.experiments = {}
|
|
self.processing_jobs = {}
|
|
self.trials = {}
|
|
self.trial_components = {}
|
|
self.training_jobs = {}
|
|
self.notebook_instance_lifecycle_configurations = {}
|
|
|
|
@staticmethod
|
|
def default_vpc_endpoint_service(service_region, zones):
|
|
"""Default VPC endpoint services."""
|
|
api_service = BaseBackend.default_vpc_endpoint_service_factory(
|
|
service_region, zones, "api.sagemaker", special_service_name="sagemaker.api"
|
|
)
|
|
|
|
notebook_service_id = f"vpce-svc-{BaseBackend.vpce_random_number()}"
|
|
studio_service_id = f"vpce-svc-{BaseBackend.vpce_random_number()}"
|
|
|
|
notebook_service = {
|
|
"AcceptanceRequired": False,
|
|
"AvailabilityZones": zones,
|
|
"BaseEndpointDnsNames": [
|
|
f"{notebook_service_id}.{service_region}.vpce.amazonaws.com",
|
|
f"notebook.{service_region}.vpce.sagemaker.aws",
|
|
],
|
|
"ManagesVpcEndpoints": False,
|
|
"Owner": "amazon",
|
|
"PrivateDnsName": f"*.notebook.{service_region}.sagemaker.aws",
|
|
"PrivateDnsNameVerificationState": "verified",
|
|
"PrivateDnsNames": [
|
|
{"PrivateDnsName": f"*.notebook.{service_region}.sagemaker.aws"}
|
|
],
|
|
"ServiceId": notebook_service_id,
|
|
"ServiceName": f"aws.sagemaker.{service_region}.notebook",
|
|
"ServiceType": [{"ServiceType": "Interface"}],
|
|
"Tags": [],
|
|
"VpcEndpointPolicySupported": True,
|
|
}
|
|
studio_service = {
|
|
"AcceptanceRequired": False,
|
|
"AvailabilityZones": zones,
|
|
"BaseEndpointDnsNames": [
|
|
f"{studio_service_id}.{service_region}.vpce.amazonaws.com",
|
|
f"studio.{service_region}.vpce.sagemaker.aws",
|
|
],
|
|
"ManagesVpcEndpoints": False,
|
|
"Owner": "amazon",
|
|
"PrivateDnsName": f"*.studio.{service_region}.sagemaker.aws",
|
|
"PrivateDnsNameVerificationState": "verified",
|
|
"PrivateDnsNames": [
|
|
{"PrivateDnsName": f"*.studio.{service_region}.sagemaker.aws"}
|
|
],
|
|
"ServiceId": studio_service_id,
|
|
"ServiceName": f"aws.sagemaker.{service_region}.studio",
|
|
"ServiceType": [{"ServiceType": "Interface"}],
|
|
"Tags": [],
|
|
"VpcEndpointPolicySupported": True,
|
|
}
|
|
return api_service + [notebook_service, studio_service]
|
|
|
|
def create_model(self, **kwargs):
|
|
model_obj = Model(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
model_name=kwargs.get("ModelName"),
|
|
execution_role_arn=kwargs.get("ExecutionRoleArn"),
|
|
primary_container=kwargs.get("PrimaryContainer", {}),
|
|
vpc_config=kwargs.get("VpcConfig", {}),
|
|
containers=kwargs.get("Containers", []),
|
|
tags=kwargs.get("Tags", []),
|
|
)
|
|
|
|
self._models[kwargs.get("ModelName")] = model_obj
|
|
return model_obj
|
|
|
|
def describe_model(self, model_name=None):
|
|
model = self._models.get(model_name)
|
|
if model:
|
|
return model
|
|
arn = arn_formatter("model", model_name, self.account_id, self.region_name)
|
|
raise ValidationError(message=f"Could not find model '{arn}'.")
|
|
|
|
def list_models(self):
|
|
return self._models.values()
|
|
|
|
def delete_model(self, model_name=None):
|
|
for model in self._models.values():
|
|
if model.model_name == model_name:
|
|
self._models.pop(model.model_name)
|
|
break
|
|
else:
|
|
raise MissingModel(model=model_name)
|
|
|
|
def create_experiment(self, experiment_name):
|
|
experiment = FakeExperiment(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
experiment_name=experiment_name,
|
|
tags=[],
|
|
)
|
|
self.experiments[experiment_name] = experiment
|
|
return experiment.response_create
|
|
|
|
def describe_experiment(self, experiment_name):
|
|
experiment_data = self.experiments[experiment_name]
|
|
return {
|
|
"ExperimentName": experiment_data.experiment_name,
|
|
"ExperimentArn": experiment_data.experiment_arn,
|
|
"CreationTime": experiment_data.creation_time,
|
|
"LastModifiedTime": experiment_data.last_modified_time,
|
|
}
|
|
|
|
def _get_resource_from_arn(self, arn):
|
|
resources = {
|
|
"model": self._models,
|
|
"notebook-instance": self.notebook_instances,
|
|
"endpoint": self.endpoints,
|
|
"endpoint-config": self.endpoint_configs,
|
|
"training-job": self.training_jobs,
|
|
"experiment": self.experiments,
|
|
"experiment-trial": self.trials,
|
|
"experiment-trial-component": self.trial_components,
|
|
"processing-job": self.processing_jobs,
|
|
}
|
|
target_resource, target_name = arn.split(":")[-1].split("/")
|
|
try:
|
|
resource = resources.get(target_resource).get(target_name)
|
|
except KeyError:
|
|
message = f"Could not find {target_resource} with name {target_name}"
|
|
raise ValidationError(message=message)
|
|
return resource
|
|
|
|
def add_tags(self, arn, tags):
|
|
resource = self._get_resource_from_arn(arn)
|
|
resource.tags.extend(tags)
|
|
|
|
@paginate(pagination_model=PAGINATION_MODEL)
|
|
def list_tags(self, arn):
|
|
resource = self._get_resource_from_arn(arn)
|
|
return resource.tags
|
|
|
|
def delete_tags(self, arn, tag_keys):
|
|
resource = self._get_resource_from_arn(arn)
|
|
resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys]
|
|
|
|
@paginate(pagination_model=PAGINATION_MODEL)
|
|
def list_experiments(self):
|
|
return list(self.experiments.values())
|
|
|
|
def search(self, resource=None, search_expression=None):
|
|
next_index = None
|
|
|
|
valid_resources = [
|
|
"Pipeline",
|
|
"ModelPackageGroup",
|
|
"TrainingJob",
|
|
"ExperimentTrialComponent",
|
|
"FeatureGroup",
|
|
"Endpoint",
|
|
"PipelineExecution",
|
|
"Project",
|
|
"ExperimentTrial",
|
|
"Image",
|
|
"ImageVersion",
|
|
"ModelPackage",
|
|
"Experiment",
|
|
]
|
|
|
|
if resource not in valid_resources:
|
|
raise AWSValidationException(
|
|
f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}"
|
|
)
|
|
|
|
def evaluate_search_expression(item):
|
|
filters = None
|
|
if search_expression is not None:
|
|
filters = search_expression.get("Filters")
|
|
|
|
if filters is not None:
|
|
for f in filters:
|
|
if f["Operator"] == "Equals":
|
|
if f["Name"].startswith("Tags."):
|
|
key = f["Name"][5:]
|
|
value = f["Value"]
|
|
|
|
if (
|
|
len(
|
|
[
|
|
e
|
|
for e in item.tags
|
|
if e["Key"] == key and e["Value"] == value
|
|
]
|
|
)
|
|
== 0
|
|
):
|
|
return False
|
|
if f["Name"] == "ExperimentName":
|
|
experiment_name = f["Value"]
|
|
|
|
if hasattr(item, "experiment_name"):
|
|
if getattr(item, "experiment_name") != experiment_name:
|
|
return False
|
|
else:
|
|
raise ValidationError(
|
|
message="Unknown property name: ExperimentName"
|
|
)
|
|
|
|
if f["Name"] == "TrialName":
|
|
raise AWSValidationException(
|
|
f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}"
|
|
)
|
|
|
|
if f["Name"] == "Parents.TrialName":
|
|
trial_name = f["Value"]
|
|
|
|
if getattr(item, "trial_name") != trial_name:
|
|
return False
|
|
|
|
return True
|
|
|
|
result = {
|
|
"Results": [],
|
|
"NextToken": str(next_index) if next_index is not None else None,
|
|
}
|
|
if resource == "Experiment":
|
|
experiments_fetched = list(self.experiments.values())
|
|
|
|
experiment_summaries = [
|
|
{
|
|
"ExperimentName": experiment_data.experiment_name,
|
|
"ExperimentArn": experiment_data.experiment_arn,
|
|
"CreationTime": experiment_data.creation_time,
|
|
"LastModifiedTime": experiment_data.last_modified_time,
|
|
}
|
|
for experiment_data in experiments_fetched
|
|
if evaluate_search_expression(experiment_data)
|
|
]
|
|
|
|
for experiment_summary in experiment_summaries:
|
|
result["Results"].append({"Experiment": experiment_summary})
|
|
|
|
if resource == "ExperimentTrial":
|
|
trials_fetched = list(self.trials.values())
|
|
|
|
trial_summaries = [
|
|
{
|
|
"TrialName": trial_data.trial_name,
|
|
"TrialArn": trial_data.trial_arn,
|
|
"CreationTime": trial_data.creation_time,
|
|
"LastModifiedTime": trial_data.last_modified_time,
|
|
}
|
|
for trial_data in trials_fetched
|
|
if evaluate_search_expression(trial_data)
|
|
]
|
|
|
|
for trial_summary in trial_summaries:
|
|
result["Results"].append({"Trial": trial_summary})
|
|
|
|
if resource == "ExperimentTrialComponent":
|
|
trial_components_fetched = list(self.trial_components.values())
|
|
|
|
trial_component_summaries = [
|
|
{
|
|
"TrialComponentName": trial_component_data.trial_component_name,
|
|
"TrialComponentArn": trial_component_data.trial_component_arn,
|
|
"CreationTime": trial_component_data.creation_time,
|
|
"LastModifiedTime": trial_component_data.last_modified_time,
|
|
}
|
|
for trial_component_data in trial_components_fetched
|
|
if evaluate_search_expression(trial_component_data)
|
|
]
|
|
|
|
for trial_component_summary in trial_component_summaries:
|
|
result["Results"].append({"TrialComponent": trial_component_summary})
|
|
return result
|
|
|
|
def delete_experiment(self, experiment_name):
|
|
try:
|
|
del self.experiments[experiment_name]
|
|
except KeyError:
|
|
message = "Could not find experiment configuration '{}'.".format(
|
|
FakeTrial.arn_formatter(experiment_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def create_trial(self, trial_name, experiment_name):
|
|
trial = FakeTrial(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
trial_name=trial_name,
|
|
experiment_name=experiment_name,
|
|
tags=[],
|
|
trial_components=[],
|
|
)
|
|
self.trials[trial_name] = trial
|
|
return trial.response_create
|
|
|
|
def describe_trial(self, trial_name):
|
|
try:
|
|
return self.trials[trial_name].response_object
|
|
except KeyError:
|
|
message = "Could not find trial '{}'.".format(
|
|
FakeTrial.arn_formatter(trial_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def delete_trial(self, trial_name):
|
|
try:
|
|
del self.trials[trial_name]
|
|
except KeyError:
|
|
message = "Could not find trial configuration '{}'.".format(
|
|
FakeTrial.arn_formatter(trial_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
@paginate(pagination_model=PAGINATION_MODEL)
|
|
def list_trials(self, experiment_name=None, trial_component_name=None):
|
|
trials_fetched = list(self.trials.values())
|
|
|
|
def evaluate_filter_expression(trial_data):
|
|
if experiment_name is not None:
|
|
if trial_data.experiment_name != experiment_name:
|
|
return False
|
|
|
|
if trial_component_name is not None:
|
|
if trial_component_name not in trial_data.trial_components:
|
|
return False
|
|
|
|
return True
|
|
|
|
return [
|
|
trial_data
|
|
for trial_data in trials_fetched
|
|
if evaluate_filter_expression(trial_data)
|
|
]
|
|
|
|
def create_trial_component(self, trial_component_name, trial_name):
|
|
trial_component = FakeTrialComponent(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
trial_component_name=trial_component_name,
|
|
trial_name=trial_name,
|
|
tags=[],
|
|
)
|
|
self.trial_components[trial_component_name] = trial_component
|
|
return trial_component.response_create
|
|
|
|
def delete_trial_component(self, trial_component_name):
|
|
try:
|
|
del self.trial_components[trial_component_name]
|
|
except KeyError:
|
|
message = "Could not find trial-component configuration '{}'.".format(
|
|
FakeTrial.arn_formatter(trial_component_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def describe_trial_component(self, trial_component_name):
|
|
try:
|
|
return self.trial_components[trial_component_name].response_object
|
|
except KeyError:
|
|
message = "Could not find trial component '{}'.".format(
|
|
FakeTrialComponent.arn_formatter(
|
|
trial_component_name, self.account_id, self.region_name
|
|
)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def _update_trial_component_details(self, trial_component_name, details_json):
|
|
self.trial_components[trial_component_name].update(details_json)
|
|
|
|
@paginate(pagination_model=PAGINATION_MODEL)
|
|
def list_trial_components(self, trial_name=None):
|
|
trial_components_fetched = list(self.trial_components.values())
|
|
|
|
return [
|
|
trial_component_data
|
|
for trial_component_data in trial_components_fetched
|
|
if trial_name is None or trial_component_data.trial_name == trial_name
|
|
]
|
|
|
|
def associate_trial_component(self, params):
|
|
trial_name = params["TrialName"]
|
|
trial_component_name = params["TrialComponentName"]
|
|
|
|
if trial_name in self.trials.keys():
|
|
self.trials[trial_name].trial_components.extend([trial_component_name])
|
|
else:
|
|
raise ResourceNotFound(
|
|
message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}' does not exist."
|
|
)
|
|
|
|
if trial_component_name in self.trial_components.keys():
|
|
self.trial_components[trial_component_name].trial_name = trial_name
|
|
|
|
return {
|
|
"TrialComponentArn": self.trial_components[
|
|
trial_component_name
|
|
].trial_component_arn,
|
|
"TrialArn": self.trials[trial_name].trial_arn,
|
|
}
|
|
|
|
def disassociate_trial_component(self, params):
|
|
trial_component_name = params["TrialComponentName"]
|
|
trial_name = params["TrialName"]
|
|
|
|
if trial_component_name in self.trial_components.keys():
|
|
self.trial_components[trial_component_name].trial_name = None
|
|
|
|
if trial_name in self.trials.keys():
|
|
self.trials[trial_name].trial_components = list(
|
|
filter(
|
|
lambda x: x != trial_component_name,
|
|
self.trials[trial_name].trial_components,
|
|
)
|
|
)
|
|
|
|
return {
|
|
"TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial-component/{trial_component_name}",
|
|
"TrialArn": f"arn:aws:sagemaker:{self.region_name}:{self.account_id}:experiment-trial/{trial_name}",
|
|
}
|
|
|
|
def create_notebook_instance(
|
|
self,
|
|
notebook_instance_name,
|
|
instance_type,
|
|
role_arn,
|
|
subnet_id=None,
|
|
security_group_ids=None,
|
|
kms_key_id=None,
|
|
tags=None,
|
|
lifecycle_config_name=None,
|
|
direct_internet_access="Enabled",
|
|
volume_size_in_gb=5,
|
|
accelerator_types=None,
|
|
default_code_repository=None,
|
|
additional_code_repositories=None,
|
|
root_access=None,
|
|
):
|
|
self._validate_unique_notebook_instance_name(notebook_instance_name)
|
|
|
|
notebook_instance = FakeSagemakerNotebookInstance(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
notebook_instance_name=notebook_instance_name,
|
|
instance_type=instance_type,
|
|
role_arn=role_arn,
|
|
subnet_id=subnet_id,
|
|
security_group_ids=security_group_ids,
|
|
kms_key_id=kms_key_id,
|
|
tags=tags,
|
|
lifecycle_config_name=lifecycle_config_name,
|
|
direct_internet_access=direct_internet_access
|
|
if direct_internet_access is not None
|
|
else "Enabled",
|
|
volume_size_in_gb=volume_size_in_gb if volume_size_in_gb is not None else 5,
|
|
accelerator_types=accelerator_types,
|
|
default_code_repository=default_code_repository,
|
|
additional_code_repositories=additional_code_repositories,
|
|
root_access=root_access,
|
|
)
|
|
self.notebook_instances[notebook_instance_name] = notebook_instance
|
|
return notebook_instance
|
|
|
|
def _validate_unique_notebook_instance_name(self, notebook_instance_name):
|
|
if notebook_instance_name in self.notebook_instances:
|
|
duplicate_arn = self.notebook_instances[notebook_instance_name].arn
|
|
message = "Cannot create a duplicate Notebook Instance ({})".format(
|
|
duplicate_arn
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def get_notebook_instance(self, notebook_instance_name):
|
|
try:
|
|
return self.notebook_instances[notebook_instance_name]
|
|
except KeyError:
|
|
raise ValidationError(message="RecordNotFound")
|
|
|
|
def start_notebook_instance(self, notebook_instance_name):
|
|
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
|
notebook_instance.start()
|
|
|
|
def stop_notebook_instance(self, notebook_instance_name):
|
|
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
|
notebook_instance.stop()
|
|
|
|
def delete_notebook_instance(self, notebook_instance_name):
|
|
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
|
if not notebook_instance.is_deletable:
|
|
message = "Status ({}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
|
|
notebook_instance.status, notebook_instance.arn
|
|
)
|
|
raise ValidationError(message=message)
|
|
del self.notebook_instances[notebook_instance_name]
|
|
|
|
def create_notebook_instance_lifecycle_config(
|
|
self, notebook_instance_lifecycle_config_name, on_create, on_start
|
|
):
|
|
if (
|
|
notebook_instance_lifecycle_config_name
|
|
in self.notebook_instance_lifecycle_configurations
|
|
):
|
|
message = "Unable to create Notebook Instance Lifecycle Config {}. (Details: Notebook Instance Lifecycle Config already exists.)".format(
|
|
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
|
notebook_instance_lifecycle_config_name,
|
|
self.account_id,
|
|
self.region_name,
|
|
)
|
|
)
|
|
raise ValidationError(message=message)
|
|
lifecycle_config = FakeSageMakerNotebookInstanceLifecycleConfig(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name,
|
|
on_create=on_create,
|
|
on_start=on_start,
|
|
)
|
|
self.notebook_instance_lifecycle_configurations[
|
|
notebook_instance_lifecycle_config_name
|
|
] = lifecycle_config
|
|
return lifecycle_config
|
|
|
|
def describe_notebook_instance_lifecycle_config(
|
|
self, notebook_instance_lifecycle_config_name
|
|
):
|
|
try:
|
|
return self.notebook_instance_lifecycle_configurations[
|
|
notebook_instance_lifecycle_config_name
|
|
].response_object
|
|
except KeyError:
|
|
message = "Unable to describe Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
|
|
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
|
notebook_instance_lifecycle_config_name,
|
|
self.account_id,
|
|
self.region_name,
|
|
)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def delete_notebook_instance_lifecycle_config(
|
|
self, notebook_instance_lifecycle_config_name
|
|
):
|
|
try:
|
|
del self.notebook_instance_lifecycle_configurations[
|
|
notebook_instance_lifecycle_config_name
|
|
]
|
|
except KeyError:
|
|
message = "Unable to delete Notebook Instance Lifecycle Config '{}'. (Details: Notebook Instance Lifecycle Config does not exist.)".format(
|
|
FakeSageMakerNotebookInstanceLifecycleConfig.arn_formatter(
|
|
notebook_instance_lifecycle_config_name,
|
|
self.account_id,
|
|
self.region_name,
|
|
)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def create_endpoint_config(
|
|
self,
|
|
endpoint_config_name,
|
|
production_variants,
|
|
data_capture_config,
|
|
tags,
|
|
kms_key_id,
|
|
):
|
|
endpoint_config = FakeEndpointConfig(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
endpoint_config_name=endpoint_config_name,
|
|
production_variants=production_variants,
|
|
data_capture_config=data_capture_config,
|
|
tags=tags,
|
|
kms_key_id=kms_key_id,
|
|
)
|
|
self.validate_production_variants(production_variants)
|
|
|
|
self.endpoint_configs[endpoint_config_name] = endpoint_config
|
|
return endpoint_config
|
|
|
|
def validate_production_variants(self, production_variants):
|
|
for production_variant in production_variants:
|
|
if production_variant["ModelName"] not in self._models:
|
|
arn = arn_formatter(
|
|
"model",
|
|
production_variant["ModelName"],
|
|
self.account_id,
|
|
self.region_name,
|
|
)
|
|
raise ValidationError(message=f"Could not find model '{arn}'.")
|
|
|
|
def describe_endpoint_config(self, endpoint_config_name):
|
|
try:
|
|
return self.endpoint_configs[endpoint_config_name].response_object
|
|
except KeyError:
|
|
arn = FakeEndpointConfig.arn_formatter(
|
|
endpoint_config_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(
|
|
message=f"Could not find endpoint configuration '{arn}'."
|
|
)
|
|
|
|
def delete_endpoint_config(self, endpoint_config_name):
|
|
try:
|
|
del self.endpoint_configs[endpoint_config_name]
|
|
except KeyError:
|
|
arn = FakeEndpointConfig.arn_formatter(
|
|
endpoint_config_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(
|
|
message=f"Could not find endpoint configuration '{arn}'."
|
|
)
|
|
|
|
def create_endpoint(self, endpoint_name, endpoint_config_name, tags):
|
|
try:
|
|
endpoint_config = self.describe_endpoint_config(endpoint_config_name)
|
|
except KeyError:
|
|
arn = FakeEndpointConfig.arn_formatter(
|
|
endpoint_config_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(message=f"Could not find endpoint_config '{arn}'.")
|
|
|
|
endpoint = FakeEndpoint(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
endpoint_name=endpoint_name,
|
|
endpoint_config_name=endpoint_config_name,
|
|
production_variants=endpoint_config["ProductionVariants"],
|
|
data_capture_config=endpoint_config["DataCaptureConfig"],
|
|
tags=tags,
|
|
)
|
|
|
|
self.endpoints[endpoint_name] = endpoint
|
|
return endpoint
|
|
|
|
def describe_endpoint(self, endpoint_name):
|
|
try:
|
|
return self.endpoints[endpoint_name].response_object
|
|
except KeyError:
|
|
arn = FakeEndpoint.arn_formatter(
|
|
endpoint_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(message=f"Could not find endpoint '{arn}'.")
|
|
|
|
def delete_endpoint(self, endpoint_name):
|
|
try:
|
|
del self.endpoints[endpoint_name]
|
|
except KeyError:
|
|
arn = FakeEndpoint.arn_formatter(
|
|
endpoint_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(message=f"Could not find endpoint '{arn}'.")
|
|
|
|
def create_processing_job(
|
|
self,
|
|
app_specification,
|
|
experiment_config,
|
|
network_config,
|
|
processing_inputs,
|
|
processing_job_name,
|
|
processing_output_config,
|
|
role_arn,
|
|
tags,
|
|
stopping_condition,
|
|
):
|
|
processing_job = FakeProcessingJob(
|
|
app_specification=app_specification,
|
|
experiment_config=experiment_config,
|
|
network_config=network_config,
|
|
processing_inputs=processing_inputs,
|
|
processing_job_name=processing_job_name,
|
|
processing_output_config=processing_output_config,
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
role_arn=role_arn,
|
|
stopping_condition=stopping_condition,
|
|
tags=tags,
|
|
)
|
|
self.processing_jobs[processing_job_name] = processing_job
|
|
return processing_job
|
|
|
|
def describe_processing_job(self, processing_job_name):
|
|
try:
|
|
return self.processing_jobs[processing_job_name].response_object
|
|
except KeyError:
|
|
arn = FakeProcessingJob.arn_formatter(
|
|
processing_job_name, self.account_id, self.region_name
|
|
)
|
|
raise ValidationError(message=f"Could not find processing job '{arn}'.")
|
|
|
|
def list_processing_jobs(
|
|
self,
|
|
next_token,
|
|
max_results,
|
|
creation_time_after,
|
|
creation_time_before,
|
|
last_modified_time_after,
|
|
last_modified_time_before,
|
|
name_contains,
|
|
status_equals,
|
|
):
|
|
if next_token:
|
|
try:
|
|
starting_index = int(next_token)
|
|
if starting_index > len(self.processing_jobs):
|
|
raise ValueError # invalid next_token
|
|
except ValueError:
|
|
raise AWSValidationException('Invalid pagination token because "{0}".')
|
|
else:
|
|
starting_index = 0
|
|
|
|
if max_results:
|
|
end_index = max_results + starting_index
|
|
processing_jobs_fetched = list(self.processing_jobs.values())[
|
|
starting_index:end_index
|
|
]
|
|
if end_index >= len(self.processing_jobs):
|
|
next_index = None
|
|
else:
|
|
next_index = end_index
|
|
else:
|
|
processing_jobs_fetched = list(self.processing_jobs.values())
|
|
next_index = None
|
|
|
|
if name_contains is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: name_contains in x.processing_job_name,
|
|
processing_jobs_fetched,
|
|
)
|
|
|
|
if creation_time_after is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: x.creation_time > creation_time_after, processing_jobs_fetched
|
|
)
|
|
|
|
if creation_time_before is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: x.creation_time < creation_time_before,
|
|
processing_jobs_fetched,
|
|
)
|
|
|
|
if last_modified_time_after is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: x.last_modified_time > last_modified_time_after,
|
|
processing_jobs_fetched,
|
|
)
|
|
|
|
if last_modified_time_before is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: x.last_modified_time < last_modified_time_before,
|
|
processing_jobs_fetched,
|
|
)
|
|
if status_equals is not None:
|
|
processing_jobs_fetched = filter(
|
|
lambda x: x.training_job_status == status_equals,
|
|
processing_jobs_fetched,
|
|
)
|
|
|
|
processing_job_summaries = [
|
|
{
|
|
"ProcessingJobName": processing_job_data.processing_job_name,
|
|
"ProcessingJobArn": processing_job_data.processing_job_arn,
|
|
"CreationTime": processing_job_data.creation_time,
|
|
"ProcessingEndTime": processing_job_data.processing_end_time,
|
|
"LastModifiedTime": processing_job_data.last_modified_time,
|
|
"ProcessingJobStatus": processing_job_data.processing_job_status,
|
|
}
|
|
for processing_job_data in processing_jobs_fetched
|
|
]
|
|
|
|
return {
|
|
"ProcessingJobSummaries": processing_job_summaries,
|
|
"NextToken": str(next_index) if next_index is not None else None,
|
|
}
|
|
|
|
def create_training_job(
|
|
self,
|
|
training_job_name,
|
|
hyper_parameters,
|
|
algorithm_specification,
|
|
role_arn,
|
|
input_data_config,
|
|
output_data_config,
|
|
resource_config,
|
|
vpc_config,
|
|
stopping_condition,
|
|
tags,
|
|
enable_network_isolation,
|
|
enable_inter_container_traffic_encryption,
|
|
enable_managed_spot_training,
|
|
checkpoint_config,
|
|
debug_hook_config,
|
|
debug_rule_configurations,
|
|
tensor_board_output_config,
|
|
experiment_config,
|
|
):
|
|
training_job = FakeTrainingJob(
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
training_job_name=training_job_name,
|
|
hyper_parameters=hyper_parameters,
|
|
algorithm_specification=algorithm_specification,
|
|
role_arn=role_arn,
|
|
input_data_config=input_data_config,
|
|
output_data_config=output_data_config,
|
|
resource_config=resource_config,
|
|
vpc_config=vpc_config,
|
|
stopping_condition=stopping_condition,
|
|
tags=tags,
|
|
enable_network_isolation=enable_network_isolation,
|
|
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption,
|
|
enable_managed_spot_training=enable_managed_spot_training,
|
|
checkpoint_config=checkpoint_config,
|
|
debug_hook_config=debug_hook_config,
|
|
debug_rule_configurations=debug_rule_configurations,
|
|
tensor_board_output_config=tensor_board_output_config,
|
|
experiment_config=experiment_config,
|
|
)
|
|
self.training_jobs[training_job_name] = training_job
|
|
return training_job
|
|
|
|
def describe_training_job(self, training_job_name):
|
|
try:
|
|
return self.training_jobs[training_job_name].response_object
|
|
except KeyError:
|
|
message = "Could not find training job '{}'.".format(
|
|
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def delete_training_job(self, training_job_name):
|
|
try:
|
|
del self.training_jobs[training_job_name]
|
|
except KeyError:
|
|
message = "Could not find endpoint configuration '{}'.".format(
|
|
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
|
)
|
|
raise ValidationError(message=message)
|
|
|
|
def _update_training_job_details(self, training_job_name, details_json):
|
|
self.training_jobs[training_job_name].update(details_json)
|
|
|
|
def list_training_jobs(
|
|
self,
|
|
next_token,
|
|
max_results,
|
|
creation_time_after,
|
|
creation_time_before,
|
|
last_modified_time_after,
|
|
last_modified_time_before,
|
|
name_contains,
|
|
status_equals,
|
|
):
|
|
if next_token:
|
|
try:
|
|
starting_index = int(next_token)
|
|
if starting_index > len(self.training_jobs):
|
|
raise ValueError # invalid next_token
|
|
except ValueError:
|
|
raise AWSValidationException('Invalid pagination token because "{0}".')
|
|
else:
|
|
starting_index = 0
|
|
|
|
if max_results:
|
|
end_index = max_results + starting_index
|
|
training_jobs_fetched = list(self.training_jobs.values())[
|
|
starting_index:end_index
|
|
]
|
|
if end_index >= len(self.training_jobs):
|
|
next_index = None
|
|
else:
|
|
next_index = end_index
|
|
else:
|
|
training_jobs_fetched = list(self.training_jobs.values())
|
|
next_index = None
|
|
|
|
if name_contains is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: name_contains in x.training_job_name, training_jobs_fetched
|
|
)
|
|
|
|
if creation_time_after is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: x.creation_time > creation_time_after, training_jobs_fetched
|
|
)
|
|
|
|
if creation_time_before is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: x.creation_time < creation_time_before, training_jobs_fetched
|
|
)
|
|
|
|
if last_modified_time_after is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: x.last_modified_time > last_modified_time_after,
|
|
training_jobs_fetched,
|
|
)
|
|
|
|
if last_modified_time_before is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: x.last_modified_time < last_modified_time_before,
|
|
training_jobs_fetched,
|
|
)
|
|
if status_equals is not None:
|
|
training_jobs_fetched = filter(
|
|
lambda x: x.training_job_status == status_equals, training_jobs_fetched
|
|
)
|
|
|
|
training_job_summaries = [
|
|
{
|
|
"TrainingJobName": training_job_data.training_job_name,
|
|
"TrainingJobArn": training_job_data.training_job_arn,
|
|
"CreationTime": training_job_data.creation_time,
|
|
"TrainingEndTime": training_job_data.training_end_time,
|
|
"LastModifiedTime": training_job_data.last_modified_time,
|
|
"TrainingJobStatus": training_job_data.training_job_status,
|
|
}
|
|
for training_job_data in training_jobs_fetched
|
|
]
|
|
|
|
return {
|
|
"TrainingJobSummaries": training_job_summaries,
|
|
"NextToken": str(next_index) if next_index is not None else None,
|
|
}
|
|
|
|
def update_endpoint_weights_and_capacities(
|
|
self, endpoint_name, desired_weights_and_capacities
|
|
):
|
|
# Validate inputs
|
|
endpoint = self.endpoints.get(endpoint_name, None)
|
|
if not endpoint:
|
|
arn = FakeEndpoint.arn_formatter(
|
|
endpoint_name, self.account_id, self.region_name
|
|
)
|
|
raise AWSValidationException(f'Could not find endpoint "{arn}".')
|
|
|
|
names_checked = []
|
|
for variant_config in desired_weights_and_capacities:
|
|
name = variant_config.get("VariantName")
|
|
|
|
if name in names_checked:
|
|
raise AWSValidationException(
|
|
f'The variant name "{name}" was non-unique within the request.'
|
|
)
|
|
|
|
if not any(
|
|
variant["VariantName"] == name
|
|
for variant in endpoint.production_variants
|
|
):
|
|
raise AWSValidationException(
|
|
f'The variant name(s) "{name}" is/are not present within endpoint configuration "{endpoint.endpoint_config_name}".'
|
|
)
|
|
|
|
names_checked.append(name)
|
|
|
|
# Update endpoint variants
|
|
endpoint.endpoint_status = "Updating"
|
|
|
|
for variant_config in desired_weights_and_capacities:
|
|
name = variant_config.get("VariantName")
|
|
desired_weight = variant_config.get("DesiredWeight")
|
|
desired_instance_count = variant_config.get("DesiredInstanceCount")
|
|
|
|
for variant in endpoint.production_variants:
|
|
if variant.get("VariantName") == name:
|
|
variant["DesiredWeight"] = desired_weight
|
|
variant["CurrentWeight"] = desired_weight
|
|
variant["DesiredInstanceCount"] = desired_instance_count
|
|
variant["CurrentInstanceCount"] = desired_instance_count
|
|
break
|
|
|
|
endpoint.endpoint_status = "InService"
|
|
return endpoint.endpoint_arn
|
|
|
|
|
|
class FakeExperiment(BaseObject):
|
|
def __init__(self, account_id, region_name, experiment_name, tags):
|
|
self.experiment_name = experiment_name
|
|
self.experiment_arn = arn_formatter(
|
|
"experiment", experiment_name, account_id, region_name
|
|
)
|
|
self.tags = tags
|
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"ExperimentArn": self.experiment_arn}
|
|
|
|
|
|
class FakeTrial(BaseObject):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
region_name,
|
|
trial_name,
|
|
experiment_name,
|
|
tags,
|
|
trial_components,
|
|
):
|
|
self.trial_name = trial_name
|
|
self.trial_arn = arn_formatter(
|
|
"experiment-trial", trial_name, account_id, region_name
|
|
)
|
|
self.tags = tags
|
|
self.trial_components = trial_components
|
|
self.experiment_name = experiment_name
|
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
|
"%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"TrialArn": self.trial_arn}
|
|
|
|
|
|
class FakeTrialComponent(BaseObject):
|
|
def __init__(self, account_id, region_name, trial_component_name, trial_name, tags):
|
|
self.trial_component_name = trial_component_name
|
|
self.trial_component_arn = FakeTrialComponent.arn_formatter(
|
|
trial_component_name, account_id, region_name
|
|
)
|
|
self.tags = tags
|
|
self.trial_name = trial_name
|
|
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
self.creation_time = self.last_modified_time = now_string
|
|
|
|
@property
|
|
def response_object(self):
|
|
response_object = self.gen_response_object()
|
|
return {
|
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
|
}
|
|
|
|
@property
|
|
def response_create(self):
|
|
return {"TrialComponentArn": self.trial_component_arn}
|
|
|
|
@staticmethod
|
|
def arn_formatter(trial_component_name, account_id, region_name):
|
|
return arn_formatter(
|
|
"experiment-trial-component", trial_component_name, account_id, region_name
|
|
)
|
|
|
|
|
|
sagemaker_backends = BackendDict(SageMakerModelBackend, "sagemaker")
|