Add more functionality to sagemaker (#4491)
This commit is contained in:
parent
6f5cae98ad
commit
3259a3307a
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -14,6 +15,11 @@ class BaseObject(BaseModel):
|
|||||||
words.append(word.title())
|
words.append(word.title())
|
||||||
return "".join(words)
|
return "".join(words)
|
||||||
|
|
||||||
|
def update(self, details_json):
|
||||||
|
details = json.loads(details_json)
|
||||||
|
for k in details.keys():
|
||||||
|
setattr(self, k, details[k])
|
||||||
|
|
||||||
def gen_response_object(self):
|
def gen_response_object(self):
|
||||||
response_object = dict()
|
response_object = dict()
|
||||||
for key, value in self.__dict__.items():
|
for key, value in self.__dict__.items():
|
||||||
@ -866,6 +872,9 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
self.notebook_instances = {}
|
self.notebook_instances = {}
|
||||||
self.endpoint_configs = {}
|
self.endpoint_configs = {}
|
||||||
self.endpoints = {}
|
self.endpoints = {}
|
||||||
|
self.experiments = {}
|
||||||
|
self.trials = {}
|
||||||
|
self.trial_components = {}
|
||||||
self.training_jobs = {}
|
self.training_jobs = {}
|
||||||
self.notebook_instance_lifecycle_configurations = {}
|
self.notebook_instance_lifecycle_configurations = {}
|
||||||
self.region_name = region_name
|
self.region_name = region_name
|
||||||
@ -961,6 +970,385 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
raise MissingModel(model=model_name)
|
raise MissingModel(model=model_name)
|
||||||
|
|
||||||
|
def create_experiment(self, experiment_name):
|
||||||
|
experiment = FakeExperiment(
|
||||||
|
region_name=self.region_name, experiment_name=experiment_name, tags=[]
|
||||||
|
)
|
||||||
|
self.experiments[experiment_name] = experiment
|
||||||
|
return experiment.response_create
|
||||||
|
|
||||||
|
def describe_experiment(self, experiment_name):
|
||||||
|
experiment_data = self.experiments[experiment_name]
|
||||||
|
return {
|
||||||
|
"ExperimentName": experiment_data.experiment_name,
|
||||||
|
"ExperimentArn": experiment_data.experiment_arn,
|
||||||
|
"CreationTime": experiment_data.creation_time,
|
||||||
|
"LastModifiedTime": experiment_data.last_modified_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_tags_to_experiment(self, experiment_arn, tags):
|
||||||
|
experiment = [
|
||||||
|
self.experiments[i]
|
||||||
|
for i in self.experiments
|
||||||
|
if self.experiments[i].experiment_arn == experiment_arn
|
||||||
|
][0]
|
||||||
|
experiment.tags.extend(tags)
|
||||||
|
|
||||||
|
def add_tags_to_trial(self, trial_arn, tags):
|
||||||
|
trial = [
|
||||||
|
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||||
|
][0]
|
||||||
|
trial.tags.extend(tags)
|
||||||
|
|
||||||
|
def add_tags_to_trial_component(self, trial_component_arn, tags):
|
||||||
|
trial_component = [
|
||||||
|
self.trial_components[i]
|
||||||
|
for i in self.trial_components
|
||||||
|
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||||
|
][0]
|
||||||
|
trial_component.tags.extend(tags)
|
||||||
|
|
||||||
|
def delete_tags_from_experiment(self, experiment_arn, tag_keys):
|
||||||
|
experiment = [
|
||||||
|
self.experiments[i]
|
||||||
|
for i in self.experiments
|
||||||
|
if self.experiments[i].experiment_arn == experiment_arn
|
||||||
|
][0]
|
||||||
|
experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys]
|
||||||
|
|
||||||
|
def delete_tags_from_trial(self, trial_arn, tag_keys):
|
||||||
|
trial = [
|
||||||
|
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||||
|
][0]
|
||||||
|
trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys]
|
||||||
|
|
||||||
|
def delete_tags_from_trial_component(self, trial_component_arn, tag_keys):
|
||||||
|
trial_component = [
|
||||||
|
self.trial_components[i]
|
||||||
|
for i in self.trial_components
|
||||||
|
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||||
|
][0]
|
||||||
|
trial_component.tags = [
|
||||||
|
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
|
||||||
|
]
|
||||||
|
|
||||||
|
def list_experiments(self):
|
||||||
|
next_index = None
|
||||||
|
|
||||||
|
experiments_fetched = list(self.experiments.values())
|
||||||
|
|
||||||
|
experiment_summaries = [
|
||||||
|
{
|
||||||
|
"ExperimentName": experiment_data.experiment_name,
|
||||||
|
"ExperimentArn": experiment_data.experiment_arn,
|
||||||
|
"CreationTime": experiment_data.creation_time,
|
||||||
|
"LastModifiedTime": experiment_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for experiment_data in experiments_fetched
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ExperimentSummaries": experiment_summaries,
|
||||||
|
"NextToken": str(next_index) if next_index is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def search(self, resource=None, search_expression=None):
|
||||||
|
next_index = None
|
||||||
|
|
||||||
|
valid_resources = [
|
||||||
|
"Pipeline",
|
||||||
|
"ModelPackageGroup",
|
||||||
|
"TrainingJob",
|
||||||
|
"ExperimentTrialComponent",
|
||||||
|
"FeatureGroup",
|
||||||
|
"Endpoint",
|
||||||
|
"PipelineExecution",
|
||||||
|
"Project",
|
||||||
|
"ExperimentTrial",
|
||||||
|
"Image",
|
||||||
|
"ImageVersion",
|
||||||
|
"ModelPackage",
|
||||||
|
"Experiment",
|
||||||
|
]
|
||||||
|
|
||||||
|
if resource not in valid_resources:
|
||||||
|
raise AWSValidationException(
|
||||||
|
f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def evaluate_search_expression(item):
|
||||||
|
filters = None
|
||||||
|
if search_expression is not None:
|
||||||
|
filters = search_expression.get("Filters")
|
||||||
|
|
||||||
|
if filters is not None:
|
||||||
|
for f in filters:
|
||||||
|
if f["Operator"] == "Equals":
|
||||||
|
if f["Name"].startswith("Tags."):
|
||||||
|
key = f["Name"][5:]
|
||||||
|
value = f["Value"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(
|
||||||
|
[
|
||||||
|
e
|
||||||
|
for e in item.tags
|
||||||
|
if e["Key"] == key and e["Value"] == value
|
||||||
|
]
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if f["Name"] == "ExperimentName":
|
||||||
|
experiment_name = f["Value"]
|
||||||
|
|
||||||
|
if getattr(item, "experiment_name") != experiment_name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if f["Name"] == "TrialName":
|
||||||
|
raise AWSValidationException(
|
||||||
|
f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if f["Name"] == "Parents.TrialName":
|
||||||
|
trial_name = f["Value"]
|
||||||
|
|
||||||
|
if getattr(item, "trial_name") != trial_name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"Results": [],
|
||||||
|
"NextToken": str(next_index) if next_index is not None else None,
|
||||||
|
}
|
||||||
|
if resource == "Experiment":
|
||||||
|
experiments_fetched = list(self.experiments.values())
|
||||||
|
|
||||||
|
experiment_summaries = [
|
||||||
|
{
|
||||||
|
"ExperimentName": experiment_data.experiment_name,
|
||||||
|
"ExperimentArn": experiment_data.experiment_arn,
|
||||||
|
"CreationTime": experiment_data.creation_time,
|
||||||
|
"LastModifiedTime": experiment_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for experiment_data in experiments_fetched
|
||||||
|
if evaluate_search_expression(experiment_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
for experiment_summary in experiment_summaries:
|
||||||
|
result["Results"].append({"Experiment": experiment_summary})
|
||||||
|
|
||||||
|
if resource == "ExperimentTrial":
|
||||||
|
trials_fetched = list(self.trials.values())
|
||||||
|
|
||||||
|
trial_summaries = [
|
||||||
|
{
|
||||||
|
"TrialName": trial_data.trial_name,
|
||||||
|
"TrialArn": trial_data.trial_arn,
|
||||||
|
"CreationTime": trial_data.creation_time,
|
||||||
|
"LastModifiedTime": trial_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for trial_data in trials_fetched
|
||||||
|
if evaluate_search_expression(trial_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
for trial_summary in trial_summaries:
|
||||||
|
result["Results"].append({"Trial": trial_summary})
|
||||||
|
|
||||||
|
if resource == "ExperimentTrialComponent":
|
||||||
|
trial_components_fetched = list(self.trial_components.values())
|
||||||
|
|
||||||
|
trial_component_summaries = [
|
||||||
|
{
|
||||||
|
"TrialComponentName": trial_component_data.trial_component_name,
|
||||||
|
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||||
|
"CreationTime": trial_component_data.creation_time,
|
||||||
|
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for trial_component_data in trial_components_fetched
|
||||||
|
if evaluate_search_expression(trial_component_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
for trial_component_summary in trial_component_summaries:
|
||||||
|
result["Results"].append({"TrialComponent": trial_component_summary})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete_experiment(self, experiment_name):
|
||||||
|
try:
|
||||||
|
del self.experiments[experiment_name]
|
||||||
|
except KeyError:
|
||||||
|
message = "Could not find experiment configuration '{}'.".format(
|
||||||
|
FakeTrial.arn_formatter(experiment_name, self.region_name)
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def get_experiment_by_arn(self, arn):
|
||||||
|
experiments = [
|
||||||
|
experiment
|
||||||
|
for experiment in self.experiments.values()
|
||||||
|
if experiment.experiment_arn == arn
|
||||||
|
]
|
||||||
|
if len(experiments) == 0:
|
||||||
|
message = "RecordNotFound"
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
return experiments[0]
|
||||||
|
|
||||||
|
def get_experiment_tags(self, arn):
|
||||||
|
try:
|
||||||
|
experiment = self.get_experiment_by_arn(arn)
|
||||||
|
return experiment.tags or []
|
||||||
|
except RESTError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_trial(
|
||||||
|
self, trial_name, experiment_name,
|
||||||
|
):
|
||||||
|
trial = FakeTrial(
|
||||||
|
region_name=self.region_name,
|
||||||
|
trial_name=trial_name,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
self.trials[trial_name] = trial
|
||||||
|
return trial.response_create
|
||||||
|
|
||||||
|
def describe_trial(self, trial_name):
|
||||||
|
try:
|
||||||
|
return self.trials[trial_name].response_object
|
||||||
|
except KeyError:
|
||||||
|
message = "Could not find trial '{}'.".format(
|
||||||
|
FakeTrial.arn_formatter(trial_name, self.region_name)
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def delete_trial(self, trial_name):
|
||||||
|
try:
|
||||||
|
del self.trials[trial_name]
|
||||||
|
except KeyError:
|
||||||
|
message = "Could not find trial configuration '{}'.".format(
|
||||||
|
FakeTrial.arn_formatter(trial_name, self.region_name)
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def get_trial_by_arn(self, arn):
|
||||||
|
trials = [trial for trial in self.trials.values() if trial.trial_arn == arn]
|
||||||
|
if len(trials) == 0:
|
||||||
|
message = "RecordNotFound"
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
return trials[0]
|
||||||
|
|
||||||
|
def get_trial_tags(self, arn):
|
||||||
|
try:
|
||||||
|
trial = self.get_trial_by_arn(arn)
|
||||||
|
return trial.tags or []
|
||||||
|
except RESTError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def list_trials(self, experiment_name=None):
|
||||||
|
next_index = None
|
||||||
|
|
||||||
|
trials_fetched = list(self.trials.values())
|
||||||
|
|
||||||
|
trial_summaries = [
|
||||||
|
{
|
||||||
|
"TrialName": trial_data.trial_name,
|
||||||
|
"TrialArn": trial_data.trial_arn,
|
||||||
|
"CreationTime": trial_data.creation_time,
|
||||||
|
"LastModifiedTime": trial_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for trial_data in trials_fetched
|
||||||
|
if experiment_name is None or trial_data.experiment_name == experiment_name
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"TrialSummaries": trial_summaries,
|
||||||
|
"NextToken": str(next_index) if next_index is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_trial_component(
|
||||||
|
self, trial_component_name, trial_name,
|
||||||
|
):
|
||||||
|
trial_component = FakeTrialComponent(
|
||||||
|
region_name=self.region_name,
|
||||||
|
trial_component_name=trial_component_name,
|
||||||
|
trial_name=trial_name,
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
self.trial_components[trial_component_name] = trial_component
|
||||||
|
return trial_component.response_create
|
||||||
|
|
||||||
|
def delete_trial_component(self, trial_component_name):
|
||||||
|
try:
|
||||||
|
del self.trial_components[trial_component_name]
|
||||||
|
except KeyError:
|
||||||
|
message = "Could not find trial-component configuration '{}'.".format(
|
||||||
|
FakeTrial.arn_formatter(trial_component_name, self.region_name)
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def get_trial_component_by_arn(self, arn):
|
||||||
|
trial_components = [
|
||||||
|
trial_component
|
||||||
|
for trial_component in self.trial_components.values()
|
||||||
|
if trial_component.trial_component_arn == arn
|
||||||
|
]
|
||||||
|
if len(trial_components) == 0:
|
||||||
|
message = "RecordNotFound"
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
return trial_components[0]
|
||||||
|
|
||||||
|
def get_trial_component_tags(self, arn):
|
||||||
|
try:
|
||||||
|
trial_component = self.get_trial_component_by_arn(arn)
|
||||||
|
return trial_component.tags or []
|
||||||
|
except RESTError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def describe_trial_component(self, trial_component_name):
|
||||||
|
try:
|
||||||
|
return self.trial_components[trial_component_name].response_object
|
||||||
|
except KeyError:
|
||||||
|
message = "Could not find trial component '{}'.".format(
|
||||||
|
FakeTrialComponent.arn_formatter(trial_component_name, self.region_name)
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def _update_trial_component_details(self, trial_component_name, details_json):
|
||||||
|
self.trial_components[trial_component_name].update(details_json)
|
||||||
|
|
||||||
|
def list_trial_components(self, trial_name=None):
|
||||||
|
next_index = None
|
||||||
|
|
||||||
|
trial_components_fetched = list(self.trial_components.values())
|
||||||
|
|
||||||
|
trial_component_summaries = [
|
||||||
|
{
|
||||||
|
"TrialComponentName": trial_component_data.trial_component_name,
|
||||||
|
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||||
|
"CreationTime": trial_component_data.creation_time,
|
||||||
|
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||||
|
}
|
||||||
|
for trial_component_data in trial_components_fetched
|
||||||
|
if trial_name is None or trial_component_data.trial_name == trial_name
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"TrialComponentSummaries": trial_component_summaries,
|
||||||
|
"NextToken": str(next_index) if next_index is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def associate_trial_component(self, params):
|
||||||
|
trial_name = params["TrialName"]
|
||||||
|
trial_component_name = params["TrialComponentName"]
|
||||||
|
|
||||||
|
self.trial_components[trial_component_name].trial_name = trial_name
|
||||||
|
|
||||||
|
def disassociate_trial_component(self, params):
|
||||||
|
trial_component_name = params["TrialComponentName"]
|
||||||
|
|
||||||
|
self.trial_components[trial_component_name].trial_name = None
|
||||||
|
|
||||||
def create_notebook_instance(
|
def create_notebook_instance(
|
||||||
self,
|
self,
|
||||||
notebook_instance_name,
|
notebook_instance_name,
|
||||||
@ -1292,6 +1680,9 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
except RESTError:
|
except RESTError:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _update_training_job_details(self, training_job_name, details_json):
|
||||||
|
self.training_jobs[training_job_name].update(details_json)
|
||||||
|
|
||||||
def list_training_jobs(
|
def list_training_jobs(
|
||||||
self,
|
self,
|
||||||
next_token,
|
next_token,
|
||||||
@ -1377,6 +1768,111 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeExperiment(BaseObject):
|
||||||
|
def __init__(
|
||||||
|
self, region_name, experiment_name, tags,
|
||||||
|
):
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.experiment_arn = FakeExperiment.arn_formatter(experiment_name, region_name)
|
||||||
|
self.tags = tags
|
||||||
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_object(self):
|
||||||
|
response_object = self.gen_response_object()
|
||||||
|
return {
|
||||||
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_create(self):
|
||||||
|
return {"ExperimentArn": self.experiment_arn}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arn_formatter(experiment_arn, region_name):
|
||||||
|
return (
|
||||||
|
"arn:aws:sagemaker:"
|
||||||
|
+ region_name
|
||||||
|
+ ":"
|
||||||
|
+ str(ACCOUNT_ID)
|
||||||
|
+ ":experiment/"
|
||||||
|
+ experiment_arn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTrial(BaseObject):
|
||||||
|
def __init__(
|
||||||
|
self, region_name, trial_name, experiment_name, tags,
|
||||||
|
):
|
||||||
|
self.trial_name = trial_name
|
||||||
|
self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name)
|
||||||
|
self.tags = tags
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_object(self):
|
||||||
|
response_object = self.gen_response_object()
|
||||||
|
return {
|
||||||
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_create(self):
|
||||||
|
return {"TrialArn": self.trial_arn}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arn_formatter(trial_name, region_name):
|
||||||
|
return (
|
||||||
|
"arn:aws:sagemaker:"
|
||||||
|
+ region_name
|
||||||
|
+ ":"
|
||||||
|
+ str(ACCOUNT_ID)
|
||||||
|
+ ":experiment-trial/"
|
||||||
|
+ trial_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTrialComponent(BaseObject):
|
||||||
|
def __init__(
|
||||||
|
self, region_name, trial_component_name, trial_name, tags,
|
||||||
|
):
|
||||||
|
self.trial_component_name = trial_component_name
|
||||||
|
self.trial_component_arn = FakeTrialComponent.arn_formatter(
|
||||||
|
trial_component_name, region_name
|
||||||
|
)
|
||||||
|
self.tags = tags
|
||||||
|
self.trial_name = trial_name
|
||||||
|
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
self.creation_time = self.last_modified_time = now_string
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_object(self):
|
||||||
|
response_object = self.gen_response_object()
|
||||||
|
return {
|
||||||
|
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_create(self):
|
||||||
|
return {"TrialComponentArn": self.trial_component_arn}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arn_formatter(trial_component_name, region_name):
|
||||||
|
return (
|
||||||
|
"arn:aws:sagemaker:"
|
||||||
|
+ region_name
|
||||||
|
+ ":"
|
||||||
|
+ str(ACCOUNT_ID)
|
||||||
|
+ ":experiment-trial-component/"
|
||||||
|
+ trial_component_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
sagemaker_backends = {}
|
sagemaker_backends = {}
|
||||||
for region in Session().get_available_regions("sagemaker"):
|
for region in Session().get_available_regions("sagemaker"):
|
||||||
sagemaker_backends[region] = SageMakerModelBackend(region)
|
sagemaker_backends[region] = SageMakerModelBackend(region)
|
||||||
|
@ -132,6 +132,12 @@ class SageMakerResponse(BaseResponse):
|
|||||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||||
elif ":training-job/" in arn:
|
elif ":training-job/" in arn:
|
||||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||||
|
elif ":experiment/" in arn:
|
||||||
|
tags = self.sagemaker_backend.get_experiment_tags(arn)
|
||||||
|
elif ":experiment-trial/" in arn:
|
||||||
|
tags = self.sagemaker_backend.get_trial_tags(arn)
|
||||||
|
elif ":experiment-trial-component/" in arn:
|
||||||
|
tags = self.sagemaker_backend.get_trial_component_tags(arn)
|
||||||
else:
|
else:
|
||||||
tags = []
|
tags = []
|
||||||
except AWSError:
|
except AWSError:
|
||||||
@ -139,6 +145,30 @@ class SageMakerResponse(BaseResponse):
|
|||||||
response = {"Tags": tags}
|
response = {"Tags": tags}
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def add_tags(self):
|
||||||
|
arn = self._get_param("ResourceArn")
|
||||||
|
tags = self._get_param("Tags")
|
||||||
|
if ":experiment/" in arn:
|
||||||
|
self.sagemaker_backend.add_tags_to_experiment(arn, tags)
|
||||||
|
elif ":experiment-trial/" in arn:
|
||||||
|
self.sagemaker_backend.add_tags_to_trial(arn, tags)
|
||||||
|
elif ":experiment-trial-component/" in arn:
|
||||||
|
self.sagemaker_backend.add_tags_to_trial_component(arn, tags)
|
||||||
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def delete_tags(self):
|
||||||
|
arn = self._get_param("ResourceArn")
|
||||||
|
tag_keys = self._get_param("TagKeys")
|
||||||
|
if ":experiment/" in arn:
|
||||||
|
self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys)
|
||||||
|
elif ":experiment-trial/" in arn:
|
||||||
|
self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys)
|
||||||
|
elif ":experiment-trial-component/" in arn:
|
||||||
|
self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys)
|
||||||
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def create_endpoint_config(self):
|
def create_endpoint_config(self):
|
||||||
try:
|
try:
|
||||||
@ -278,6 +308,117 @@ class SageMakerResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return 200, {}, json.dumps("{}")
|
return 200, {}, json.dumps("{}")
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def search(self):
|
||||||
|
response = self.sagemaker_backend.search(
|
||||||
|
resource=self._get_param("Resource"),
|
||||||
|
search_expression=self._get_param("SearchExpression"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def list_experiments(self):
|
||||||
|
response = self.sagemaker_backend.list_experiments()
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def delete_experiment(self):
|
||||||
|
self.sagemaker_backend.delete_experiment(
|
||||||
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def create_experiment(self, *args, **kwargs):
|
||||||
|
response = self.sagemaker_backend.create_experiment(
|
||||||
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def describe_experiment(self, *args, **kwargs):
|
||||||
|
response = self.sagemaker_backend.describe_experiment(
|
||||||
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def list_trials(self):
|
||||||
|
response = self.sagemaker_backend.list_trials(
|
||||||
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def create_trial(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
response = self.sagemaker_backend.create_trial(
|
||||||
|
trial_name=self._get_param("TrialName"),
|
||||||
|
experiment_name=self._get_param("ExperimentName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
except AWSError as err:
|
||||||
|
return err.response()
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def list_trial_components(self):
|
||||||
|
response = self.sagemaker_backend.list_trial_components(
|
||||||
|
trial_name=self._get_param("TrialName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def create_trial_component(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
response = self.sagemaker_backend.create_trial_component(
|
||||||
|
trial_component_name=self._get_param("TrialComponentName"),
|
||||||
|
trial_name=self._get_param("TrialName"),
|
||||||
|
)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
except AWSError as err:
|
||||||
|
return err.response()
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def describe_trial(self, *args, **kwargs):
|
||||||
|
trial_name = self._get_param("TrialName")
|
||||||
|
response = self.sagemaker_backend.describe_trial(trial_name)
|
||||||
|
return json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def delete_trial(self):
|
||||||
|
trial_name = self._get_param("TrialName")
|
||||||
|
self.sagemaker_backend.delete_trial(trial_name)
|
||||||
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def delete_trial_component(self):
|
||||||
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
|
self.sagemaker_backend.delete_trial_component(trial_component_name)
|
||||||
|
return 200, {}, json.dumps({})
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def describe_trial_component(self, *args, **kwargs):
|
||||||
|
trial_component_name = self._get_param("TrialComponentName")
|
||||||
|
response = self.sagemaker_backend.describe_trial_component(trial_component_name)
|
||||||
|
return json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def associate_trial_component(self, *args, **kwargs):
|
||||||
|
self.sagemaker_backend.associate_trial_component(self.request_params)
|
||||||
|
response = {}
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def disassociate_trial_component(self, *args, **kwargs):
|
||||||
|
self.sagemaker_backend.disassociate_trial_component(self.request_params)
|
||||||
|
response = {}
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def list_associations(self, *args, **kwargs):
|
||||||
|
response = self.sagemaker_backend.list_associations(self.request_params)
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_training_jobs(self):
|
def list_training_jobs(self):
|
||||||
max_results_range = range(1, 101)
|
max_results_range = range(1, 101)
|
||||||
|
91
tests/test_sagemaker/test_sagemaker_experiment.py
Normal file
91
tests/test_sagemaker/test_sagemaker_experiment.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import boto3
|
||||||
|
|
||||||
|
from moto import mock_sagemaker
|
||||||
|
from moto.sts.models import ACCOUNT_ID
|
||||||
|
|
||||||
|
TEST_REGION_NAME = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_experiment():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_experiments()
|
||||||
|
|
||||||
|
assert len(resp["ExperimentSummaries"]) == 1
|
||||||
|
assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name
|
||||||
|
assert (
|
||||||
|
resp["ExperimentSummaries"][0]["ExperimentArn"]
|
||||||
|
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete_experiment():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
resp = client.delete_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_experiments()
|
||||||
|
|
||||||
|
assert len(resp["ExperimentSummaries"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_add_tags_to_experiment():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
arn = resp["ExperimentArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == tags
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete_tags_to_experiment():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
arn = resp["ExperimentArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == []
|
61
tests/test_sagemaker/test_sagemaker_search.py
Normal file
61
tests/test_sagemaker/test_sagemaker_search.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import boto3
|
||||||
|
|
||||||
|
from moto import mock_sagemaker
|
||||||
|
|
||||||
|
TEST_REGION_NAME = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_search():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
another_trial_component_name = "another-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
resp = client.create_trial_component(
|
||||||
|
TrialComponentName=another_trial_component_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.search(Resource="ExperimentTrialComponent")
|
||||||
|
|
||||||
|
assert len(resp["Results"]) == 2
|
||||||
|
|
||||||
|
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
trial_component_arn = resp["TrialComponentArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "key-name", "Value": "some-value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||||
|
|
||||||
|
resp = client.search(
|
||||||
|
Resource="ExperimentTrialComponent",
|
||||||
|
SearchExpression={
|
||||||
|
"Filters": [
|
||||||
|
{"Name": "Tags.key-name", "Operator": "Equals", "Value": "some-value"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(resp["Results"]) == 1
|
||||||
|
assert (
|
||||||
|
resp["Results"][0]["TrialComponent"]["TrialComponentName"]
|
||||||
|
== trial_component_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.search(Resource="Experiment")
|
||||||
|
assert len(resp["Results"]) == 1
|
||||||
|
assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name
|
||||||
|
|
||||||
|
resp = client.search(Resource="ExperimentTrial")
|
||||||
|
assert len(resp["Results"]) == 1
|
||||||
|
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
|
106
tests/test_sagemaker/test_sagemaker_trial.py
Normal file
106
tests/test_sagemaker/test_sagemaker_trial.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import boto3
|
||||||
|
|
||||||
|
from moto import mock_sagemaker
|
||||||
|
from moto.sts.models import ACCOUNT_ID
|
||||||
|
|
||||||
|
TEST_REGION_NAME = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_trial():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_trials()
|
||||||
|
|
||||||
|
assert len(resp["TrialSummaries"]) == 1
|
||||||
|
assert resp["TrialSummaries"][0]["TrialName"] == trial_name
|
||||||
|
assert (
|
||||||
|
resp["TrialSummaries"][0]["TrialArn"]
|
||||||
|
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete_trial():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
resp = client.delete_trial(TrialName=trial_name)
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_trials()
|
||||||
|
|
||||||
|
assert len(resp["TrialSummaries"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_add_tags_to_trial():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
resp = client.describe_trial(TrialName=trial_name)
|
||||||
|
|
||||||
|
arn = resp["TrialArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == tags
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete_tags_to_trial():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
resp = client.describe_trial(TrialName=trial_name)
|
||||||
|
|
||||||
|
arn = resp["TrialArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == []
|
122
tests/test_sagemaker/test_sagemaker_trial_component.py
Normal file
122
tests/test_sagemaker/test_sagemaker_trial_component.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import boto3
|
||||||
|
|
||||||
|
from moto import mock_sagemaker
|
||||||
|
from moto.sts.models import ACCOUNT_ID
|
||||||
|
|
||||||
|
TEST_REGION_NAME = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create__trial_component():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_trial_components()
|
||||||
|
|
||||||
|
assert len(resp["TrialComponentSummaries"]) == 1
|
||||||
|
assert (
|
||||||
|
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["TrialComponentSummaries"][0]["TrialComponentArn"]
|
||||||
|
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete__trial_component():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
resp = client.delete_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_trial_components()
|
||||||
|
|
||||||
|
assert len(resp["TrialComponentSummaries"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_add_tags_to_trial_component():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
arn = resp["TrialComponentArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == tags
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_delete_tags_to_trial_component():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
arn = resp["TrialComponentArn"]
|
||||||
|
|
||||||
|
tags = [{"Key": "name", "Value": "value"}]
|
||||||
|
|
||||||
|
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||||
|
|
||||||
|
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_tags(ResourceArn=arn)
|
||||||
|
|
||||||
|
assert resp["Tags"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_associate_trial_component():
|
||||||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
experiment_name = "some-experiment-name"
|
||||||
|
|
||||||
|
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||||
|
|
||||||
|
trial_name = "some-trial-name"
|
||||||
|
|
||||||
|
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||||
|
|
||||||
|
trial_component_name = "some-trial-component-name"
|
||||||
|
|
||||||
|
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||||
|
|
||||||
|
client.associate_trial_component(
|
||||||
|
TrialComponentName=trial_component_name, TrialName=trial_name
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
resp = client.list_trial_components(TrialName=trial_name)
|
||||||
|
|
||||||
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
assert (
|
||||||
|
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user