Add more functionality to sagemaker (#4491)

This commit is contained in:
Bogdan Girman 2021-10-28 22:21:20 +02:00 committed by GitHub
parent 6f5cae98ad
commit 3259a3307a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1017 additions and 0 deletions

View File

@ -1,3 +1,4 @@
import json
import os
from boto3 import Session
from datetime import datetime
@ -14,6 +15,11 @@ class BaseObject(BaseModel):
words.append(word.title())
return "".join(words)
def update(self, details_json):
details = json.loads(details_json)
for k in details.keys():
setattr(self, k, details[k])
def gen_response_object(self):
response_object = dict()
for key, value in self.__dict__.items():
@ -866,6 +872,9 @@ class SageMakerModelBackend(BaseBackend):
self.notebook_instances = {}
self.endpoint_configs = {}
self.endpoints = {}
self.experiments = {}
self.trials = {}
self.trial_components = {}
self.training_jobs = {}
self.notebook_instance_lifecycle_configurations = {}
self.region_name = region_name
@ -961,6 +970,385 @@ class SageMakerModelBackend(BaseBackend):
else:
raise MissingModel(model=model_name)
def create_experiment(self, experiment_name):
experiment = FakeExperiment(
region_name=self.region_name, experiment_name=experiment_name, tags=[]
)
self.experiments[experiment_name] = experiment
return experiment.response_create
def describe_experiment(self, experiment_name):
experiment_data = self.experiments[experiment_name]
return {
"ExperimentName": experiment_data.experiment_name,
"ExperimentArn": experiment_data.experiment_arn,
"CreationTime": experiment_data.creation_time,
"LastModifiedTime": experiment_data.last_modified_time,
}
def add_tags_to_experiment(self, experiment_arn, tags):
experiment = [
self.experiments[i]
for i in self.experiments
if self.experiments[i].experiment_arn == experiment_arn
][0]
experiment.tags.extend(tags)
def add_tags_to_trial(self, trial_arn, tags):
trial = [
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
][0]
trial.tags.extend(tags)
def add_tags_to_trial_component(self, trial_component_arn, tags):
trial_component = [
self.trial_components[i]
for i in self.trial_components
if self.trial_components[i].trial_component_arn == trial_component_arn
][0]
trial_component.tags.extend(tags)
def delete_tags_from_experiment(self, experiment_arn, tag_keys):
experiment = [
self.experiments[i]
for i in self.experiments
if self.experiments[i].experiment_arn == experiment_arn
][0]
experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys]
def delete_tags_from_trial(self, trial_arn, tag_keys):
trial = [
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
][0]
trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys]
def delete_tags_from_trial_component(self, trial_component_arn, tag_keys):
trial_component = [
self.trial_components[i]
for i in self.trial_components
if self.trial_components[i].trial_component_arn == trial_component_arn
][0]
trial_component.tags = [
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
]
def list_experiments(self):
next_index = None
experiments_fetched = list(self.experiments.values())
experiment_summaries = [
{
"ExperimentName": experiment_data.experiment_name,
"ExperimentArn": experiment_data.experiment_arn,
"CreationTime": experiment_data.creation_time,
"LastModifiedTime": experiment_data.last_modified_time,
}
for experiment_data in experiments_fetched
]
return {
"ExperimentSummaries": experiment_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def search(self, resource=None, search_expression=None):
next_index = None
valid_resources = [
"Pipeline",
"ModelPackageGroup",
"TrainingJob",
"ExperimentTrialComponent",
"FeatureGroup",
"Endpoint",
"PipelineExecution",
"Project",
"ExperimentTrial",
"Image",
"ImageVersion",
"ModelPackage",
"Experiment",
]
if resource not in valid_resources:
raise AWSValidationException(
f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}"
)
def evaluate_search_expression(item):
filters = None
if search_expression is not None:
filters = search_expression.get("Filters")
if filters is not None:
for f in filters:
if f["Operator"] == "Equals":
if f["Name"].startswith("Tags."):
key = f["Name"][5:]
value = f["Value"]
if (
len(
[
e
for e in item.tags
if e["Key"] == key and e["Value"] == value
]
)
== 0
):
return False
if f["Name"] == "ExperimentName":
experiment_name = f["Value"]
if getattr(item, "experiment_name") != experiment_name:
return False
if f["Name"] == "TrialName":
raise AWSValidationException(
f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}"
)
if f["Name"] == "Parents.TrialName":
trial_name = f["Value"]
if getattr(item, "trial_name") != trial_name:
return False
return True
result = {
"Results": [],
"NextToken": str(next_index) if next_index is not None else None,
}
if resource == "Experiment":
experiments_fetched = list(self.experiments.values())
experiment_summaries = [
{
"ExperimentName": experiment_data.experiment_name,
"ExperimentArn": experiment_data.experiment_arn,
"CreationTime": experiment_data.creation_time,
"LastModifiedTime": experiment_data.last_modified_time,
}
for experiment_data in experiments_fetched
if evaluate_search_expression(experiment_data)
]
for experiment_summary in experiment_summaries:
result["Results"].append({"Experiment": experiment_summary})
if resource == "ExperimentTrial":
trials_fetched = list(self.trials.values())
trial_summaries = [
{
"TrialName": trial_data.trial_name,
"TrialArn": trial_data.trial_arn,
"CreationTime": trial_data.creation_time,
"LastModifiedTime": trial_data.last_modified_time,
}
for trial_data in trials_fetched
if evaluate_search_expression(trial_data)
]
for trial_summary in trial_summaries:
result["Results"].append({"Trial": trial_summary})
if resource == "ExperimentTrialComponent":
trial_components_fetched = list(self.trial_components.values())
trial_component_summaries = [
{
"TrialComponentName": trial_component_data.trial_component_name,
"TrialComponentArn": trial_component_data.trial_component_arn,
"CreationTime": trial_component_data.creation_time,
"LastModifiedTime": trial_component_data.last_modified_time,
}
for trial_component_data in trial_components_fetched
if evaluate_search_expression(trial_component_data)
]
for trial_component_summary in trial_component_summaries:
result["Results"].append({"TrialComponent": trial_component_summary})
return result
def delete_experiment(self, experiment_name):
try:
del self.experiments[experiment_name]
except KeyError:
message = "Could not find experiment configuration '{}'.".format(
FakeTrial.arn_formatter(experiment_name, self.region_name)
)
raise ValidationError(message=message)
def get_experiment_by_arn(self, arn):
experiments = [
experiment
for experiment in self.experiments.values()
if experiment.experiment_arn == arn
]
if len(experiments) == 0:
message = "RecordNotFound"
raise ValidationError(message=message)
return experiments[0]
def get_experiment_tags(self, arn):
try:
experiment = self.get_experiment_by_arn(arn)
return experiment.tags or []
except RESTError:
return []
def create_trial(
self, trial_name, experiment_name,
):
trial = FakeTrial(
region_name=self.region_name,
trial_name=trial_name,
experiment_name=experiment_name,
tags=[],
)
self.trials[trial_name] = trial
return trial.response_create
def describe_trial(self, trial_name):
try:
return self.trials[trial_name].response_object
except KeyError:
message = "Could not find trial '{}'.".format(
FakeTrial.arn_formatter(trial_name, self.region_name)
)
raise ValidationError(message=message)
def delete_trial(self, trial_name):
try:
del self.trials[trial_name]
except KeyError:
message = "Could not find trial configuration '{}'.".format(
FakeTrial.arn_formatter(trial_name, self.region_name)
)
raise ValidationError(message=message)
def get_trial_by_arn(self, arn):
trials = [trial for trial in self.trials.values() if trial.trial_arn == arn]
if len(trials) == 0:
message = "RecordNotFound"
raise ValidationError(message=message)
return trials[0]
def get_trial_tags(self, arn):
try:
trial = self.get_trial_by_arn(arn)
return trial.tags or []
except RESTError:
return []
def list_trials(self, experiment_name=None):
next_index = None
trials_fetched = list(self.trials.values())
trial_summaries = [
{
"TrialName": trial_data.trial_name,
"TrialArn": trial_data.trial_arn,
"CreationTime": trial_data.creation_time,
"LastModifiedTime": trial_data.last_modified_time,
}
for trial_data in trials_fetched
if experiment_name is None or trial_data.experiment_name == experiment_name
]
return {
"TrialSummaries": trial_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def create_trial_component(
self, trial_component_name, trial_name,
):
trial_component = FakeTrialComponent(
region_name=self.region_name,
trial_component_name=trial_component_name,
trial_name=trial_name,
tags=[],
)
self.trial_components[trial_component_name] = trial_component
return trial_component.response_create
def delete_trial_component(self, trial_component_name):
try:
del self.trial_components[trial_component_name]
except KeyError:
message = "Could not find trial-component configuration '{}'.".format(
FakeTrial.arn_formatter(trial_component_name, self.region_name)
)
raise ValidationError(message=message)
def get_trial_component_by_arn(self, arn):
trial_components = [
trial_component
for trial_component in self.trial_components.values()
if trial_component.trial_component_arn == arn
]
if len(trial_components) == 0:
message = "RecordNotFound"
raise ValidationError(message=message)
return trial_components[0]
def get_trial_component_tags(self, arn):
try:
trial_component = self.get_trial_component_by_arn(arn)
return trial_component.tags or []
except RESTError:
return []
def describe_trial_component(self, trial_component_name):
try:
return self.trial_components[trial_component_name].response_object
except KeyError:
message = "Could not find trial component '{}'.".format(
FakeTrialComponent.arn_formatter(trial_component_name, self.region_name)
)
raise ValidationError(message=message)
def _update_trial_component_details(self, trial_component_name, details_json):
self.trial_components[trial_component_name].update(details_json)
def list_trial_components(self, trial_name=None):
next_index = None
trial_components_fetched = list(self.trial_components.values())
trial_component_summaries = [
{
"TrialComponentName": trial_component_data.trial_component_name,
"TrialComponentArn": trial_component_data.trial_component_arn,
"CreationTime": trial_component_data.creation_time,
"LastModifiedTime": trial_component_data.last_modified_time,
}
for trial_component_data in trial_components_fetched
if trial_name is None or trial_component_data.trial_name == trial_name
]
return {
"TrialComponentSummaries": trial_component_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def associate_trial_component(self, params):
trial_name = params["TrialName"]
trial_component_name = params["TrialComponentName"]
self.trial_components[trial_component_name].trial_name = trial_name
def disassociate_trial_component(self, params):
trial_component_name = params["TrialComponentName"]
self.trial_components[trial_component_name].trial_name = None
def create_notebook_instance(
self,
notebook_instance_name,
@ -1292,6 +1680,9 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
def _update_training_job_details(self, training_job_name, details_json):
self.training_jobs[training_job_name].update(details_json)
def list_training_jobs(
self,
next_token,
@ -1377,6 +1768,111 @@ class SageMakerModelBackend(BaseBackend):
}
class FakeExperiment(BaseObject):
def __init__(
self, region_name, experiment_name, tags,
):
self.experiment_name = experiment_name
self.experiment_arn = FakeExperiment.arn_formatter(experiment_name, region_name)
self.tags = tags
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"ExperimentArn": self.experiment_arn}
@staticmethod
def arn_formatter(experiment_arn, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":experiment/"
+ experiment_arn
)
class FakeTrial(BaseObject):
def __init__(
self, region_name, trial_name, experiment_name, tags,
):
self.trial_name = trial_name
self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name)
self.tags = tags
self.experiment_name = experiment_name
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"TrialArn": self.trial_arn}
@staticmethod
def arn_formatter(trial_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":experiment-trial/"
+ trial_name
)
class FakeTrialComponent(BaseObject):
def __init__(
self, region_name, trial_component_name, trial_name, tags,
):
self.trial_component_name = trial_component_name
self.trial_component_arn = FakeTrialComponent.arn_formatter(
trial_component_name, region_name
)
self.tags = tags
self.trial_name = trial_name
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = self.last_modified_time = now_string
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"TrialComponentArn": self.trial_component_arn}
@staticmethod
def arn_formatter(trial_component_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":experiment-trial-component/"
+ trial_component_name
)
sagemaker_backends = {}
for region in Session().get_available_regions("sagemaker"):
sagemaker_backends[region] = SageMakerModelBackend(region)

View File

@ -132,6 +132,12 @@ class SageMakerResponse(BaseResponse):
tags = self.sagemaker_backend.get_endpoint_tags(arn)
elif ":training-job/" in arn:
tags = self.sagemaker_backend.get_training_job_tags(arn)
elif ":experiment/" in arn:
tags = self.sagemaker_backend.get_experiment_tags(arn)
elif ":experiment-trial/" in arn:
tags = self.sagemaker_backend.get_trial_tags(arn)
elif ":experiment-trial-component/" in arn:
tags = self.sagemaker_backend.get_trial_component_tags(arn)
else:
tags = []
except AWSError:
@ -139,6 +145,30 @@ class SageMakerResponse(BaseResponse):
response = {"Tags": tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def add_tags(self):
arn = self._get_param("ResourceArn")
tags = self._get_param("Tags")
if ":experiment/" in arn:
self.sagemaker_backend.add_tags_to_experiment(arn, tags)
elif ":experiment-trial/" in arn:
self.sagemaker_backend.add_tags_to_trial(arn, tags)
elif ":experiment-trial-component/" in arn:
self.sagemaker_backend.add_tags_to_trial_component(arn, tags)
return 200, {}, json.dumps({})
@amzn_request_id
def delete_tags(self):
arn = self._get_param("ResourceArn")
tag_keys = self._get_param("TagKeys")
if ":experiment/" in arn:
self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys)
elif ":experiment-trial/" in arn:
self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys)
elif ":experiment-trial-component/" in arn:
self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys)
return 200, {}, json.dumps({})
@amzn_request_id
def create_endpoint_config(self):
try:
@ -278,6 +308,117 @@ class SageMakerResponse(BaseResponse):
)
return 200, {}, json.dumps("{}")
@amzn_request_id
def search(self):
response = self.sagemaker_backend.search(
resource=self._get_param("Resource"),
search_expression=self._get_param("SearchExpression"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_experiments(self):
response = self.sagemaker_backend.list_experiments()
return 200, {}, json.dumps(response)
@amzn_request_id
def delete_experiment(self):
self.sagemaker_backend.delete_experiment(
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps({})
@amzn_request_id
def create_experiment(self, *args, **kwargs):
response = self.sagemaker_backend.create_experiment(
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_experiment(self, *args, **kwargs):
response = self.sagemaker_backend.describe_experiment(
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_trials(self):
response = self.sagemaker_backend.list_trials(
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def create_trial(self, *args, **kwargs):
try:
response = self.sagemaker_backend.create_trial(
trial_name=self._get_param("TrialName"),
experiment_name=self._get_param("ExperimentName"),
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def list_trial_components(self):
response = self.sagemaker_backend.list_trial_components(
trial_name=self._get_param("TrialName"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def create_trial_component(self, *args, **kwargs):
try:
response = self.sagemaker_backend.create_trial_component(
trial_component_name=self._get_param("TrialComponentName"),
trial_name=self._get_param("TrialName"),
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_trial(self, *args, **kwargs):
trial_name = self._get_param("TrialName")
response = self.sagemaker_backend.describe_trial(trial_name)
return json.dumps(response)
@amzn_request_id
def delete_trial(self):
trial_name = self._get_param("TrialName")
self.sagemaker_backend.delete_trial(trial_name)
return 200, {}, json.dumps({})
@amzn_request_id
def delete_trial_component(self):
trial_component_name = self._get_param("TrialComponentName")
self.sagemaker_backend.delete_trial_component(trial_component_name)
return 200, {}, json.dumps({})
@amzn_request_id
def describe_trial_component(self, *args, **kwargs):
trial_component_name = self._get_param("TrialComponentName")
response = self.sagemaker_backend.describe_trial_component(trial_component_name)
return json.dumps(response)
@amzn_request_id
def associate_trial_component(self, *args, **kwargs):
self.sagemaker_backend.associate_trial_component(self.request_params)
response = {}
return 200, {}, json.dumps(response)
@amzn_request_id
def disassociate_trial_component(self, *args, **kwargs):
self.sagemaker_backend.disassociate_trial_component(self.request_params)
response = {}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_associations(self, *args, **kwargs):
response = self.sagemaker_backend.list_associations(self.request_params)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_training_jobs(self):
max_results_range = range(1, 101)

View File

@ -0,0 +1,91 @@
import boto3
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
TEST_REGION_NAME = "us-east-1"
@mock_sagemaker
def test_create_experiment():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_experiments()
assert len(resp["ExperimentSummaries"]) == 1
assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name
assert (
resp["ExperimentSummaries"][0]["ExperimentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}"
)
@mock_sagemaker
def test_delete_experiment():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
resp = client.delete_experiment(ExperimentName=experiment_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_experiments()
assert len(resp["ExperimentSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_experiment():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
resp = client.describe_experiment(ExperimentName=experiment_name)
arn = resp["ExperimentArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_experiment():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
resp = client.describe_experiment(ExperimentName=experiment_name)
arn = resp["ExperimentArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == []

View File

@ -0,0 +1,61 @@
import boto3
from moto import mock_sagemaker
TEST_REGION_NAME = "us-east-1"
@mock_sagemaker
def test_search():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
trial_component_name = "some-trial-component-name"
another_trial_component_name = "another-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.create_trial_component(
TrialComponentName=another_trial_component_name
)
resp = client.search(Resource="ExperimentTrialComponent")
assert len(resp["Results"]) == 2
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
trial_component_arn = resp["TrialComponentArn"]
tags = [{"Key": "key-name", "Value": "some-value"}]
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
resp = client.search(
Resource="ExperimentTrialComponent",
SearchExpression={
"Filters": [
{"Name": "Tags.key-name", "Operator": "Equals", "Value": "some-value"}
]
},
)
assert len(resp["Results"]) == 1
assert (
resp["Results"][0]["TrialComponent"]["TrialComponentName"]
== trial_component_name
)
resp = client.search(Resource="Experiment")
assert len(resp["Results"]) == 1
assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name
resp = client.search(Resource="ExperimentTrial")
assert len(resp["Results"]) == 1
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name

View File

@ -0,0 +1,106 @@
import boto3
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
TEST_REGION_NAME = "us-east-1"
@mock_sagemaker
def test_create_trial():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_trials()
assert len(resp["TrialSummaries"]) == 1
assert resp["TrialSummaries"][0]["TrialName"] == trial_name
assert (
resp["TrialSummaries"][0]["TrialArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
)
@mock_sagemaker
def test_delete_trial():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
resp = client.delete_trial(TrialName=trial_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_trials()
assert len(resp["TrialSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_trial():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
resp = client.describe_trial(TrialName=trial_name)
arn = resp["TrialArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_trial():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
resp = client.describe_trial(TrialName=trial_name)
arn = resp["TrialArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == []

View File

@ -0,0 +1,122 @@
import boto3
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
TEST_REGION_NAME = "us-east-1"
@mock_sagemaker
def test_create__trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_trial_components()
assert len(resp["TrialComponentSummaries"]) == 1
assert (
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
)
assert (
resp["TrialComponentSummaries"][0]["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
)
@mock_sagemaker
def test_delete__trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.delete_trial_component(TrialComponentName=trial_component_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_trial_components()
assert len(resp["TrialComponentSummaries"]) == 0
@mock_sagemaker
def test_add_tags_to_trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
arn = resp["TrialComponentArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == tags
@mock_sagemaker
def test_delete_tags_to_trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
arn = resp["TrialComponentArn"]
tags = [{"Key": "name", "Value": "value"}]
client.add_tags(ResourceArn=arn, Tags=tags)
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_tags(ResourceArn=arn)
assert resp["Tags"] == []
@mock_sagemaker
def test_associate_trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
client.associate_trial_component(
TrialComponentName=trial_component_name, TrialName=trial_name
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_trial_components(TrialName=trial_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
)