Add more functionality to sagemaker (#4491)
This commit is contained in:
parent
6f5cae98ad
commit
3259a3307a
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from boto3 import Session
|
||||
from datetime import datetime
|
||||
@ -14,6 +15,11 @@ class BaseObject(BaseModel):
|
||||
words.append(word.title())
|
||||
return "".join(words)
|
||||
|
||||
def update(self, details_json):
|
||||
details = json.loads(details_json)
|
||||
for k in details.keys():
|
||||
setattr(self, k, details[k])
|
||||
|
||||
def gen_response_object(self):
|
||||
response_object = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
@ -866,6 +872,9 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.notebook_instances = {}
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.experiments = {}
|
||||
self.trials = {}
|
||||
self.trial_components = {}
|
||||
self.training_jobs = {}
|
||||
self.notebook_instance_lifecycle_configurations = {}
|
||||
self.region_name = region_name
|
||||
@ -961,6 +970,385 @@ class SageMakerModelBackend(BaseBackend):
|
||||
else:
|
||||
raise MissingModel(model=model_name)
|
||||
|
||||
def create_experiment(self, experiment_name):
|
||||
experiment = FakeExperiment(
|
||||
region_name=self.region_name, experiment_name=experiment_name, tags=[]
|
||||
)
|
||||
self.experiments[experiment_name] = experiment
|
||||
return experiment.response_create
|
||||
|
||||
def describe_experiment(self, experiment_name):
|
||||
experiment_data = self.experiments[experiment_name]
|
||||
return {
|
||||
"ExperimentName": experiment_data.experiment_name,
|
||||
"ExperimentArn": experiment_data.experiment_arn,
|
||||
"CreationTime": experiment_data.creation_time,
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
|
||||
def add_tags_to_experiment(self, experiment_arn, tags):
|
||||
experiment = [
|
||||
self.experiments[i]
|
||||
for i in self.experiments
|
||||
if self.experiments[i].experiment_arn == experiment_arn
|
||||
][0]
|
||||
experiment.tags.extend(tags)
|
||||
|
||||
def add_tags_to_trial(self, trial_arn, tags):
|
||||
trial = [
|
||||
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||
][0]
|
||||
trial.tags.extend(tags)
|
||||
|
||||
def add_tags_to_trial_component(self, trial_component_arn, tags):
|
||||
trial_component = [
|
||||
self.trial_components[i]
|
||||
for i in self.trial_components
|
||||
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||
][0]
|
||||
trial_component.tags.extend(tags)
|
||||
|
||||
def delete_tags_from_experiment(self, experiment_arn, tag_keys):
|
||||
experiment = [
|
||||
self.experiments[i]
|
||||
for i in self.experiments
|
||||
if self.experiments[i].experiment_arn == experiment_arn
|
||||
][0]
|
||||
experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys]
|
||||
|
||||
def delete_tags_from_trial(self, trial_arn, tag_keys):
|
||||
trial = [
|
||||
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||
][0]
|
||||
trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys]
|
||||
|
||||
def delete_tags_from_trial_component(self, trial_component_arn, tag_keys):
|
||||
trial_component = [
|
||||
self.trial_components[i]
|
||||
for i in self.trial_components
|
||||
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||
][0]
|
||||
trial_component.tags = [
|
||||
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
|
||||
]
|
||||
|
||||
def list_experiments(self):
|
||||
next_index = None
|
||||
|
||||
experiments_fetched = list(self.experiments.values())
|
||||
|
||||
experiment_summaries = [
|
||||
{
|
||||
"ExperimentName": experiment_data.experiment_name,
|
||||
"ExperimentArn": experiment_data.experiment_arn,
|
||||
"CreationTime": experiment_data.creation_time,
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
for experiment_data in experiments_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"ExperimentSummaries": experiment_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def search(self, resource=None, search_expression=None):
|
||||
next_index = None
|
||||
|
||||
valid_resources = [
|
||||
"Pipeline",
|
||||
"ModelPackageGroup",
|
||||
"TrainingJob",
|
||||
"ExperimentTrialComponent",
|
||||
"FeatureGroup",
|
||||
"Endpoint",
|
||||
"PipelineExecution",
|
||||
"Project",
|
||||
"ExperimentTrial",
|
||||
"Image",
|
||||
"ImageVersion",
|
||||
"ModelPackage",
|
||||
"Experiment",
|
||||
]
|
||||
|
||||
if resource not in valid_resources:
|
||||
raise AWSValidationException(
|
||||
f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}"
|
||||
)
|
||||
|
||||
def evaluate_search_expression(item):
|
||||
filters = None
|
||||
if search_expression is not None:
|
||||
filters = search_expression.get("Filters")
|
||||
|
||||
if filters is not None:
|
||||
for f in filters:
|
||||
if f["Operator"] == "Equals":
|
||||
if f["Name"].startswith("Tags."):
|
||||
key = f["Name"][5:]
|
||||
value = f["Value"]
|
||||
|
||||
if (
|
||||
len(
|
||||
[
|
||||
e
|
||||
for e in item.tags
|
||||
if e["Key"] == key and e["Value"] == value
|
||||
]
|
||||
)
|
||||
== 0
|
||||
):
|
||||
return False
|
||||
if f["Name"] == "ExperimentName":
|
||||
experiment_name = f["Value"]
|
||||
|
||||
if getattr(item, "experiment_name") != experiment_name:
|
||||
return False
|
||||
|
||||
if f["Name"] == "TrialName":
|
||||
raise AWSValidationException(
|
||||
f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}"
|
||||
)
|
||||
|
||||
if f["Name"] == "Parents.TrialName":
|
||||
trial_name = f["Value"]
|
||||
|
||||
if getattr(item, "trial_name") != trial_name:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
result = {
|
||||
"Results": [],
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
if resource == "Experiment":
|
||||
experiments_fetched = list(self.experiments.values())
|
||||
|
||||
experiment_summaries = [
|
||||
{
|
||||
"ExperimentName": experiment_data.experiment_name,
|
||||
"ExperimentArn": experiment_data.experiment_arn,
|
||||
"CreationTime": experiment_data.creation_time,
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
for experiment_data in experiments_fetched
|
||||
if evaluate_search_expression(experiment_data)
|
||||
]
|
||||
|
||||
for experiment_summary in experiment_summaries:
|
||||
result["Results"].append({"Experiment": experiment_summary})
|
||||
|
||||
if resource == "ExperimentTrial":
|
||||
trials_fetched = list(self.trials.values())
|
||||
|
||||
trial_summaries = [
|
||||
{
|
||||
"TrialName": trial_data.trial_name,
|
||||
"TrialArn": trial_data.trial_arn,
|
||||
"CreationTime": trial_data.creation_time,
|
||||
"LastModifiedTime": trial_data.last_modified_time,
|
||||
}
|
||||
for trial_data in trials_fetched
|
||||
if evaluate_search_expression(trial_data)
|
||||
]
|
||||
|
||||
for trial_summary in trial_summaries:
|
||||
result["Results"].append({"Trial": trial_summary})
|
||||
|
||||
if resource == "ExperimentTrialComponent":
|
||||
trial_components_fetched = list(self.trial_components.values())
|
||||
|
||||
trial_component_summaries = [
|
||||
{
|
||||
"TrialComponentName": trial_component_data.trial_component_name,
|
||||
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||
"CreationTime": trial_component_data.creation_time,
|
||||
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||
}
|
||||
for trial_component_data in trial_components_fetched
|
||||
if evaluate_search_expression(trial_component_data)
|
||||
]
|
||||
|
||||
for trial_component_summary in trial_component_summaries:
|
||||
result["Results"].append({"TrialComponent": trial_component_summary})
|
||||
return result
|
||||
|
||||
def delete_experiment(self, experiment_name):
|
||||
try:
|
||||
del self.experiments[experiment_name]
|
||||
except KeyError:
|
||||
message = "Could not find experiment configuration '{}'.".format(
|
||||
FakeTrial.arn_formatter(experiment_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_experiment_by_arn(self, arn):
|
||||
experiments = [
|
||||
experiment
|
||||
for experiment in self.experiments.values()
|
||||
if experiment.experiment_arn == arn
|
||||
]
|
||||
if len(experiments) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return experiments[0]
|
||||
|
||||
def get_experiment_tags(self, arn):
|
||||
try:
|
||||
experiment = self.get_experiment_by_arn(arn)
|
||||
return experiment.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_trial(
|
||||
self, trial_name, experiment_name,
|
||||
):
|
||||
trial = FakeTrial(
|
||||
region_name=self.region_name,
|
||||
trial_name=trial_name,
|
||||
experiment_name=experiment_name,
|
||||
tags=[],
|
||||
)
|
||||
self.trials[trial_name] = trial
|
||||
return trial.response_create
|
||||
|
||||
def describe_trial(self, trial_name):
|
||||
try:
|
||||
return self.trials[trial_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find trial '{}'.".format(
|
||||
FakeTrial.arn_formatter(trial_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def delete_trial(self, trial_name):
|
||||
try:
|
||||
del self.trials[trial_name]
|
||||
except KeyError:
|
||||
message = "Could not find trial configuration '{}'.".format(
|
||||
FakeTrial.arn_formatter(trial_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_trial_by_arn(self, arn):
|
||||
trials = [trial for trial in self.trials.values() if trial.trial_arn == arn]
|
||||
if len(trials) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return trials[0]
|
||||
|
||||
def get_trial_tags(self, arn):
|
||||
try:
|
||||
trial = self.get_trial_by_arn(arn)
|
||||
return trial.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def list_trials(self, experiment_name=None):
|
||||
next_index = None
|
||||
|
||||
trials_fetched = list(self.trials.values())
|
||||
|
||||
trial_summaries = [
|
||||
{
|
||||
"TrialName": trial_data.trial_name,
|
||||
"TrialArn": trial_data.trial_arn,
|
||||
"CreationTime": trial_data.creation_time,
|
||||
"LastModifiedTime": trial_data.last_modified_time,
|
||||
}
|
||||
for trial_data in trials_fetched
|
||||
if experiment_name is None or trial_data.experiment_name == experiment_name
|
||||
]
|
||||
|
||||
return {
|
||||
"TrialSummaries": trial_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def create_trial_component(
|
||||
self, trial_component_name, trial_name,
|
||||
):
|
||||
trial_component = FakeTrialComponent(
|
||||
region_name=self.region_name,
|
||||
trial_component_name=trial_component_name,
|
||||
trial_name=trial_name,
|
||||
tags=[],
|
||||
)
|
||||
self.trial_components[trial_component_name] = trial_component
|
||||
return trial_component.response_create
|
||||
|
||||
def delete_trial_component(self, trial_component_name):
|
||||
try:
|
||||
del self.trial_components[trial_component_name]
|
||||
except KeyError:
|
||||
message = "Could not find trial-component configuration '{}'.".format(
|
||||
FakeTrial.arn_formatter(trial_component_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_trial_component_by_arn(self, arn):
|
||||
trial_components = [
|
||||
trial_component
|
||||
for trial_component in self.trial_components.values()
|
||||
if trial_component.trial_component_arn == arn
|
||||
]
|
||||
if len(trial_components) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return trial_components[0]
|
||||
|
||||
def get_trial_component_tags(self, arn):
|
||||
try:
|
||||
trial_component = self.get_trial_component_by_arn(arn)
|
||||
return trial_component.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def describe_trial_component(self, trial_component_name):
|
||||
try:
|
||||
return self.trial_components[trial_component_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find trial component '{}'.".format(
|
||||
FakeTrialComponent.arn_formatter(trial_component_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def _update_trial_component_details(self, trial_component_name, details_json):
|
||||
self.trial_components[trial_component_name].update(details_json)
|
||||
|
||||
def list_trial_components(self, trial_name=None):
|
||||
next_index = None
|
||||
|
||||
trial_components_fetched = list(self.trial_components.values())
|
||||
|
||||
trial_component_summaries = [
|
||||
{
|
||||
"TrialComponentName": trial_component_data.trial_component_name,
|
||||
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||
"CreationTime": trial_component_data.creation_time,
|
||||
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||
}
|
||||
for trial_component_data in trial_components_fetched
|
||||
if trial_name is None or trial_component_data.trial_name == trial_name
|
||||
]
|
||||
|
||||
return {
|
||||
"TrialComponentSummaries": trial_component_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def associate_trial_component(self, params):
|
||||
trial_name = params["TrialName"]
|
||||
trial_component_name = params["TrialComponentName"]
|
||||
|
||||
self.trial_components[trial_component_name].trial_name = trial_name
|
||||
|
||||
def disassociate_trial_component(self, params):
|
||||
trial_component_name = params["TrialComponentName"]
|
||||
|
||||
self.trial_components[trial_component_name].trial_name = None
|
||||
|
||||
def create_notebook_instance(
|
||||
self,
|
||||
notebook_instance_name,
|
||||
@ -1292,6 +1680,9 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def _update_training_job_details(self, training_job_name, details_json):
|
||||
self.training_jobs[training_job_name].update(details_json)
|
||||
|
||||
def list_training_jobs(
|
||||
self,
|
||||
next_token,
|
||||
@ -1377,6 +1768,111 @@ class SageMakerModelBackend(BaseBackend):
|
||||
}
|
||||
|
||||
|
||||
class FakeExperiment(BaseObject):
|
||||
def __init__(
|
||||
self, region_name, experiment_name, tags,
|
||||
):
|
||||
self.experiment_name = experiment_name
|
||||
self.experiment_arn = FakeExperiment.arn_formatter(experiment_name, region_name)
|
||||
self.tags = tags
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"ExperimentArn": self.experiment_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(experiment_arn, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":experiment/"
|
||||
+ experiment_arn
|
||||
)
|
||||
|
||||
|
||||
class FakeTrial(BaseObject):
|
||||
def __init__(
|
||||
self, region_name, trial_name, experiment_name, tags,
|
||||
):
|
||||
self.trial_name = trial_name
|
||||
self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name)
|
||||
self.tags = tags
|
||||
self.experiment_name = experiment_name
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"TrialArn": self.trial_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(trial_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":experiment-trial/"
|
||||
+ trial_name
|
||||
)
|
||||
|
||||
|
||||
class FakeTrialComponent(BaseObject):
|
||||
def __init__(
|
||||
self, region_name, trial_component_name, trial_name, tags,
|
||||
):
|
||||
self.trial_component_name = trial_component_name
|
||||
self.trial_component_arn = FakeTrialComponent.arn_formatter(
|
||||
trial_component_name, region_name
|
||||
)
|
||||
self.tags = tags
|
||||
self.trial_name = trial_name
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = self.last_modified_time = now_string
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"TrialComponentArn": self.trial_component_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(trial_component_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":experiment-trial-component/"
|
||||
+ trial_component_name
|
||||
)
|
||||
|
||||
|
||||
sagemaker_backends = {}
|
||||
for region in Session().get_available_regions("sagemaker"):
|
||||
sagemaker_backends[region] = SageMakerModelBackend(region)
|
||||
|
@ -132,6 +132,12 @@ class SageMakerResponse(BaseResponse):
|
||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||
elif ":training-job/" in arn:
|
||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||
elif ":experiment/" in arn:
|
||||
tags = self.sagemaker_backend.get_experiment_tags(arn)
|
||||
elif ":experiment-trial/" in arn:
|
||||
tags = self.sagemaker_backend.get_trial_tags(arn)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
tags = self.sagemaker_backend.get_trial_component_tags(arn)
|
||||
else:
|
||||
tags = []
|
||||
except AWSError:
|
||||
@ -139,6 +145,30 @@ class SageMakerResponse(BaseResponse):
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def add_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self._get_param("Tags")
|
||||
if ":experiment/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_experiment(arn, tags)
|
||||
elif ":experiment-trial/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_trial(arn, tags)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_trial_component(arn, tags)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
def delete_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tag_keys = self._get_param("TagKeys")
|
||||
if ":experiment/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys)
|
||||
elif ":experiment-trial/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint_config(self):
|
||||
try:
|
||||
@ -278,6 +308,117 @@ class SageMakerResponse(BaseResponse):
|
||||
)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def search(self):
|
||||
response = self.sagemaker_backend.search(
|
||||
resource=self._get_param("Resource"),
|
||||
search_expression=self._get_param("SearchExpression"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_experiments(self):
|
||||
response = self.sagemaker_backend.list_experiments()
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_experiment(self):
|
||||
self.sagemaker_backend.delete_experiment(
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
def create_experiment(self, *args, **kwargs):
|
||||
response = self.sagemaker_backend.create_experiment(
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def describe_experiment(self, *args, **kwargs):
|
||||
response = self.sagemaker_backend.describe_experiment(
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_trials(self):
|
||||
response = self.sagemaker_backend.list_trials(
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_trial(self, *args, **kwargs):
|
||||
try:
|
||||
response = self.sagemaker_backend.create_trial(
|
||||
trial_name=self._get_param("TrialName"),
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def list_trial_components(self):
|
||||
response = self.sagemaker_backend.list_trial_components(
|
||||
trial_name=self._get_param("TrialName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_trial_component(self, *args, **kwargs):
|
||||
try:
|
||||
response = self.sagemaker_backend.create_trial_component(
|
||||
trial_component_name=self._get_param("TrialComponentName"),
|
||||
trial_name=self._get_param("TrialName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_trial(self, *args, **kwargs):
|
||||
trial_name = self._get_param("TrialName")
|
||||
response = self.sagemaker_backend.describe_trial(trial_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_trial(self):
|
||||
trial_name = self._get_param("TrialName")
|
||||
self.sagemaker_backend.delete_trial(trial_name)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
def delete_trial_component(self):
|
||||
trial_component_name = self._get_param("TrialComponentName")
|
||||
self.sagemaker_backend.delete_trial_component(trial_component_name)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
def describe_trial_component(self, *args, **kwargs):
|
||||
trial_component_name = self._get_param("TrialComponentName")
|
||||
response = self.sagemaker_backend.describe_trial_component(trial_component_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def associate_trial_component(self, *args, **kwargs):
|
||||
self.sagemaker_backend.associate_trial_component(self.request_params)
|
||||
response = {}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def disassociate_trial_component(self, *args, **kwargs):
|
||||
self.sagemaker_backend.disassociate_trial_component(self.request_params)
|
||||
response = {}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_associations(self, *args, **kwargs):
|
||||
response = self.sagemaker_backend.list_associations(self.request_params)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_training_jobs(self):
|
||||
max_results_range = range(1, 101)
|
||||
|
91
tests/test_sagemaker/test_sagemaker_experiment.py
Normal file
91
tests/test_sagemaker/test_sagemaker_experiment.py
Normal file
@ -0,0 +1,91 @@
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments()
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 1
|
||||
assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name
|
||||
assert (
|
||||
resp["ExperimentSummaries"][0]["ExperimentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.delete_experiment(ExperimentName=experiment_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments()
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||
|
||||
arn = resp["ExperimentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_to_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||
|
||||
arn = resp["ExperimentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == []
|
61
tests/test_sagemaker/test_sagemaker_search.py
Normal file
61
tests/test_sagemaker/test_sagemaker_search.py
Normal file
@ -0,0 +1,61 @@
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_search():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
another_trial_component_name = "another-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.create_trial_component(
|
||||
TrialComponentName=another_trial_component_name
|
||||
)
|
||||
|
||||
resp = client.search(Resource="ExperimentTrialComponent")
|
||||
|
||||
assert len(resp["Results"]) == 2
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
trial_component_arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "key-name", "Value": "some-value"}]
|
||||
|
||||
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
|
||||
resp = client.search(
|
||||
Resource="ExperimentTrialComponent",
|
||||
SearchExpression={
|
||||
"Filters": [
|
||||
{"Name": "Tags.key-name", "Operator": "Equals", "Value": "some-value"}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert len(resp["Results"]) == 1
|
||||
assert (
|
||||
resp["Results"][0]["TrialComponent"]["TrialComponentName"]
|
||||
== trial_component_name
|
||||
)
|
||||
|
||||
resp = client.search(Resource="Experiment")
|
||||
assert len(resp["Results"]) == 1
|
||||
assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name
|
||||
|
||||
resp = client.search(Resource="ExperimentTrial")
|
||||
assert len(resp["Results"]) == 1
|
||||
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
|
106
tests/test_sagemaker/test_sagemaker_trial.py
Normal file
106
tests/test_sagemaker/test_sagemaker_trial.py
Normal file
@ -0,0 +1,106 @@
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_trial():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_trials()
|
||||
|
||||
assert len(resp["TrialSummaries"]) == 1
|
||||
assert resp["TrialSummaries"][0]["TrialName"] == trial_name
|
||||
assert (
|
||||
resp["TrialSummaries"][0]["TrialArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_trial():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
resp = client.delete_trial(TrialName=trial_name)
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_trials()
|
||||
|
||||
assert len(resp["TrialSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_trial():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
resp = client.describe_trial(TrialName=trial_name)
|
||||
|
||||
arn = resp["TrialArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_to_trial():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
resp = client.describe_trial(TrialName=trial_name)
|
||||
|
||||
arn = resp["TrialArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == []
|
122
tests/test_sagemaker/test_sagemaker_trial_component.py
Normal file
122
tests/test_sagemaker/test_sagemaker_trial_component.py
Normal file
@ -0,0 +1,122 @@
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create__trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_trial_components()
|
||||
|
||||
assert len(resp["TrialComponentSummaries"]) == 1
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||
)
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete__trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.delete_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_trial_components()
|
||||
|
||||
assert len(resp["TrialComponentSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_to_trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == []
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_associate_trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
client.associate_trial_component(
|
||||
TrialComponentName=trial_component_name, TrialName=trial_name
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_trial_components(TrialName=trial_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||
)
|
Loading…
Reference in New Issue
Block a user