From 3259a3307aaa2272a4bf7a68204c80ad47df6800 Mon Sep 17 00:00:00 2001 From: Bogdan Girman Date: Thu, 28 Oct 2021 22:21:20 +0200 Subject: [PATCH] Add more functionality to sagemaker (#4491) --- moto/sagemaker/models.py | 496 ++++++++++++++++++ moto/sagemaker/responses.py | 141 +++++ .../test_sagemaker_experiment.py | 91 ++++ tests/test_sagemaker/test_sagemaker_search.py | 61 +++ tests/test_sagemaker/test_sagemaker_trial.py | 106 ++++ .../test_sagemaker_trial_component.py | 122 +++++ 6 files changed, 1017 insertions(+) create mode 100644 tests/test_sagemaker/test_sagemaker_experiment.py create mode 100644 tests/test_sagemaker/test_sagemaker_search.py create mode 100644 tests/test_sagemaker/test_sagemaker_trial.py create mode 100644 tests/test_sagemaker/test_sagemaker_trial_component.py diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 1b8b03c07..9bb5b3356 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1,3 +1,4 @@ +import json import os from boto3 import Session from datetime import datetime @@ -14,6 +15,11 @@ class BaseObject(BaseModel): words.append(word.title()) return "".join(words) + def update(self, details_json): + details = json.loads(details_json) + for k in details.keys(): + setattr(self, k, details[k]) + def gen_response_object(self): response_object = dict() for key, value in self.__dict__.items(): @@ -866,6 +872,9 @@ class SageMakerModelBackend(BaseBackend): self.notebook_instances = {} self.endpoint_configs = {} self.endpoints = {} + self.experiments = {} + self.trials = {} + self.trial_components = {} self.training_jobs = {} self.notebook_instance_lifecycle_configurations = {} self.region_name = region_name @@ -961,6 +970,385 @@ class SageMakerModelBackend(BaseBackend): else: raise MissingModel(model=model_name) + def create_experiment(self, experiment_name): + experiment = FakeExperiment( + region_name=self.region_name, experiment_name=experiment_name, tags=[] + ) + self.experiments[experiment_name] = experiment + return experiment.response_create + + def describe_experiment(self, experiment_name): + experiment_data = self.experiments[experiment_name] + return { + "ExperimentName": experiment_data.experiment_name, + "ExperimentArn": experiment_data.experiment_arn, + "CreationTime": experiment_data.creation_time, + "LastModifiedTime": experiment_data.last_modified_time, + } + + def add_tags_to_experiment(self, experiment_arn, tags): + experiment = [ + self.experiments[i] + for i in self.experiments + if self.experiments[i].experiment_arn == experiment_arn + ][0] + experiment.tags.extend(tags) + + def add_tags_to_trial(self, trial_arn, tags): + trial = [ + self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn + ][0] + trial.tags.extend(tags) + + def add_tags_to_trial_component(self, trial_component_arn, tags): + trial_component = [ + self.trial_components[i] + for i in self.trial_components + if self.trial_components[i].trial_component_arn == trial_component_arn + ][0] + trial_component.tags.extend(tags) + + def delete_tags_from_experiment(self, experiment_arn, tag_keys): + experiment = [ + self.experiments[i] + for i in self.experiments + if self.experiments[i].experiment_arn == experiment_arn + ][0] + experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys] + + def delete_tags_from_trial(self, trial_arn, tag_keys): + trial = [ + self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn + ][0] + trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys] + + def delete_tags_from_trial_component(self, trial_component_arn, tag_keys): + trial_component = [ + self.trial_components[i] + for i in self.trial_components + if self.trial_components[i].trial_component_arn == trial_component_arn + ][0] + trial_component.tags = [ + tag for tag in trial_component.tags if tag["Key"] not in tag_keys + ] + + def list_experiments(self): + next_index = None + + experiments_fetched = list(self.experiments.values()) + + experiment_summaries = [ + { + "ExperimentName": experiment_data.experiment_name, + "ExperimentArn": experiment_data.experiment_arn, + "CreationTime": experiment_data.creation_time, + "LastModifiedTime": experiment_data.last_modified_time, + } + for experiment_data in experiments_fetched + ] + + return { + "ExperimentSummaries": experiment_summaries, + "NextToken": str(next_index) if next_index is not None else None, + } + + def search(self, resource=None, search_expression=None): + next_index = None + + valid_resources = [ + "Pipeline", + "ModelPackageGroup", + "TrainingJob", + "ExperimentTrialComponent", + "FeatureGroup", + "Endpoint", + "PipelineExecution", + "Project", + "ExperimentTrial", + "Image", + "ImageVersion", + "ModelPackage", + "Experiment", + ] + + if resource not in valid_resources: + raise AWSValidationException( + f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}" + ) + + def evaluate_search_expression(item): + filters = None + if search_expression is not None: + filters = search_expression.get("Filters") + + if filters is not None: + for f in filters: + if f["Operator"] == "Equals": + if f["Name"].startswith("Tags."): + key = f["Name"][5:] + value = f["Value"] + + if ( + len( + [ + e + for e in item.tags + if e["Key"] == key and e["Value"] == value + ] + ) + == 0 + ): + return False + if f["Name"] == "ExperimentName": + experiment_name = f["Value"] + + if getattr(item, "experiment_name") != experiment_name: + return False + + if f["Name"] == "TrialName": + raise AWSValidationException( + f"An error occurred (ValidationException) when calling the Search operation: Unknown property name: {f['Name']}" + ) + + if f["Name"] == "Parents.TrialName": + trial_name = f["Value"] + + if getattr(item, "trial_name") != trial_name: + return False + + return True + + result = { + "Results": [], + "NextToken": str(next_index) if next_index is not None else None, + } + if resource == "Experiment": + experiments_fetched = list(self.experiments.values()) + + experiment_summaries = [ + { + "ExperimentName": experiment_data.experiment_name, + "ExperimentArn": experiment_data.experiment_arn, + "CreationTime": experiment_data.creation_time, + "LastModifiedTime": experiment_data.last_modified_time, + } + for experiment_data in experiments_fetched + if evaluate_search_expression(experiment_data) + ] + + for experiment_summary in experiment_summaries: + result["Results"].append({"Experiment": experiment_summary}) + + if resource == "ExperimentTrial": + trials_fetched = list(self.trials.values()) + + trial_summaries = [ + { + "TrialName": trial_data.trial_name, + "TrialArn": trial_data.trial_arn, + "CreationTime": trial_data.creation_time, + "LastModifiedTime": trial_data.last_modified_time, + } + for trial_data in trials_fetched + if evaluate_search_expression(trial_data) + ] + + for trial_summary in trial_summaries: + result["Results"].append({"Trial": trial_summary}) + + if resource == "ExperimentTrialComponent": + trial_components_fetched = list(self.trial_components.values()) + + trial_component_summaries = [ + { + "TrialComponentName": trial_component_data.trial_component_name, + "TrialComponentArn": trial_component_data.trial_component_arn, + "CreationTime": trial_component_data.creation_time, + "LastModifiedTime": trial_component_data.last_modified_time, + } + for trial_component_data in trial_components_fetched + if evaluate_search_expression(trial_component_data) + ] + + for trial_component_summary in trial_component_summaries: + result["Results"].append({"TrialComponent": trial_component_summary}) + return result + + def delete_experiment(self, experiment_name): + try: + del self.experiments[experiment_name] + except KeyError: + message = "Could not find experiment configuration '{}'.".format( + FakeTrial.arn_formatter(experiment_name, self.region_name) + ) + raise ValidationError(message=message) + + def get_experiment_by_arn(self, arn): + experiments = [ + experiment + for experiment in self.experiments.values() + if experiment.experiment_arn == arn + ] + if len(experiments) == 0: + message = "RecordNotFound" + raise ValidationError(message=message) + return experiments[0] + + def get_experiment_tags(self, arn): + try: + experiment = self.get_experiment_by_arn(arn) + return experiment.tags or [] + except RESTError: + return [] + + def create_trial( + self, trial_name, experiment_name, + ): + trial = FakeTrial( + region_name=self.region_name, + trial_name=trial_name, + experiment_name=experiment_name, + tags=[], + ) + self.trials[trial_name] = trial + return trial.response_create + + def describe_trial(self, trial_name): + try: + return self.trials[trial_name].response_object + except KeyError: + message = "Could not find trial '{}'.".format( + FakeTrial.arn_formatter(trial_name, self.region_name) + ) + raise ValidationError(message=message) + + def delete_trial(self, trial_name): + try: + del self.trials[trial_name] + except KeyError: + message = "Could not find trial configuration '{}'.".format( + FakeTrial.arn_formatter(trial_name, self.region_name) + ) + raise ValidationError(message=message) + + def get_trial_by_arn(self, arn): + trials = [trial for trial in self.trials.values() if trial.trial_arn == arn] + if len(trials) == 0: + message = "RecordNotFound" + raise ValidationError(message=message) + return trials[0] + + def get_trial_tags(self, arn): + try: + trial = self.get_trial_by_arn(arn) + return trial.tags or [] + except RESTError: + return [] + + def list_trials(self, experiment_name=None): + next_index = None + + trials_fetched = list(self.trials.values()) + + trial_summaries = [ + { + "TrialName": trial_data.trial_name, + "TrialArn": trial_data.trial_arn, + "CreationTime": trial_data.creation_time, + "LastModifiedTime": trial_data.last_modified_time, + } + for trial_data in trials_fetched + if experiment_name is None or trial_data.experiment_name == experiment_name + ] + + return { + "TrialSummaries": trial_summaries, + "NextToken": str(next_index) if next_index is not None else None, + } + + def create_trial_component( + self, trial_component_name, trial_name, + ): + trial_component = FakeTrialComponent( + region_name=self.region_name, + trial_component_name=trial_component_name, + trial_name=trial_name, + tags=[], + ) + self.trial_components[trial_component_name] = trial_component + return trial_component.response_create + + def delete_trial_component(self, trial_component_name): + try: + del self.trial_components[trial_component_name] + except KeyError: + message = "Could not find trial-component configuration '{}'.".format( + FakeTrial.arn_formatter(trial_component_name, self.region_name) + ) + raise ValidationError(message=message) + + def get_trial_component_by_arn(self, arn): + trial_components = [ + trial_component + for trial_component in self.trial_components.values() + if trial_component.trial_component_arn == arn + ] + if len(trial_components) == 0: + message = "RecordNotFound" + raise ValidationError(message=message) + return trial_components[0] + + def get_trial_component_tags(self, arn): + try: + trial_component = self.get_trial_component_by_arn(arn) + return trial_component.tags or [] + except RESTError: + return [] + + def describe_trial_component(self, trial_component_name): + try: + return self.trial_components[trial_component_name].response_object + except KeyError: + message = "Could not find trial component '{}'.".format( + FakeTrialComponent.arn_formatter(trial_component_name, self.region_name) + ) + raise ValidationError(message=message) + + def _update_trial_component_details(self, trial_component_name, details_json): + self.trial_components[trial_component_name].update(details_json) + + def list_trial_components(self, trial_name=None): + next_index = None + + trial_components_fetched = list(self.trial_components.values()) + + trial_component_summaries = [ + { + "TrialComponentName": trial_component_data.trial_component_name, + "TrialComponentArn": trial_component_data.trial_component_arn, + "CreationTime": trial_component_data.creation_time, + "LastModifiedTime": trial_component_data.last_modified_time, + } + for trial_component_data in trial_components_fetched + if trial_name is None or trial_component_data.trial_name == trial_name + ] + + return { + "TrialComponentSummaries": trial_component_summaries, + "NextToken": str(next_index) if next_index is not None else None, + } + + def associate_trial_component(self, params): + trial_name = params["TrialName"] + trial_component_name = params["TrialComponentName"] + + self.trial_components[trial_component_name].trial_name = trial_name + + def disassociate_trial_component(self, params): + trial_component_name = params["TrialComponentName"] + + self.trial_components[trial_component_name].trial_name = None + def create_notebook_instance( self, notebook_instance_name, @@ -1292,6 +1680,9 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + def _update_training_job_details(self, training_job_name, details_json): + self.training_jobs[training_job_name].update(details_json) + def list_training_jobs( self, next_token, @@ -1377,6 +1768,111 @@ class SageMakerModelBackend(BaseBackend): } +class FakeExperiment(BaseObject): + def __init__( + self, region_name, experiment_name, tags, + ): + self.experiment_name = experiment_name + self.experiment_arn = FakeExperiment.arn_formatter(experiment_name, region_name) + self.tags = tags + self.creation_time = self.last_modified_time = datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"ExperimentArn": self.experiment_arn} + + @staticmethod + def arn_formatter(experiment_arn, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":experiment/" + + experiment_arn + ) + + +class FakeTrial(BaseObject): + def __init__( + self, region_name, trial_name, experiment_name, tags, + ): + self.trial_name = trial_name + self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name) + self.tags = tags + self.experiment_name = experiment_name + self.creation_time = self.last_modified_time = datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"TrialArn": self.trial_arn} + + @staticmethod + def arn_formatter(trial_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":experiment-trial/" + + trial_name + ) + + +class FakeTrialComponent(BaseObject): + def __init__( + self, region_name, trial_component_name, trial_name, tags, + ): + self.trial_component_name = trial_component_name + self.trial_component_arn = FakeTrialComponent.arn_formatter( + trial_component_name, region_name + ) + self.tags = tags + self.trial_name = trial_name + now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.creation_time = self.last_modified_time = now_string + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"TrialComponentArn": self.trial_component_arn} + + @staticmethod + def arn_formatter(trial_component_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":experiment-trial-component/" + + trial_component_name + ) + + sagemaker_backends = {} for region in Session().get_available_regions("sagemaker"): sagemaker_backends[region] = SageMakerModelBackend(region) diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 2ef014b2c..867e877ea 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -132,6 +132,12 @@ class SageMakerResponse(BaseResponse): tags = self.sagemaker_backend.get_endpoint_tags(arn) elif ":training-job/" in arn: tags = self.sagemaker_backend.get_training_job_tags(arn) + elif ":experiment/" in arn: + tags = self.sagemaker_backend.get_experiment_tags(arn) + elif ":experiment-trial/" in arn: + tags = self.sagemaker_backend.get_trial_tags(arn) + elif ":experiment-trial-component/" in arn: + tags = self.sagemaker_backend.get_trial_component_tags(arn) else: tags = [] except AWSError: @@ -139,6 +145,30 @@ class SageMakerResponse(BaseResponse): response = {"Tags": tags} return 200, {}, json.dumps(response) + @amzn_request_id + def add_tags(self): + arn = self._get_param("ResourceArn") + tags = self._get_param("Tags") + if ":experiment/" in arn: + self.sagemaker_backend.add_tags_to_experiment(arn, tags) + elif ":experiment-trial/" in arn: + self.sagemaker_backend.add_tags_to_trial(arn, tags) + elif ":experiment-trial-component/" in arn: + self.sagemaker_backend.add_tags_to_trial_component(arn, tags) + return 200, {}, json.dumps({}) + + @amzn_request_id + def delete_tags(self): + arn = self._get_param("ResourceArn") + tag_keys = self._get_param("TagKeys") + if ":experiment/" in arn: + self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys) + elif ":experiment-trial/" in arn: + self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys) + elif ":experiment-trial-component/" in arn: + self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys) + return 200, {}, json.dumps({}) + @amzn_request_id def create_endpoint_config(self): try: @@ -278,6 +308,117 @@ class SageMakerResponse(BaseResponse): ) return 200, {}, json.dumps("{}") + @amzn_request_id + def search(self): + response = self.sagemaker_backend.search( + resource=self._get_param("Resource"), + search_expression=self._get_param("SearchExpression"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_experiments(self): + response = self.sagemaker_backend.list_experiments() + return 200, {}, json.dumps(response) + + @amzn_request_id + def delete_experiment(self): + self.sagemaker_backend.delete_experiment( + experiment_name=self._get_param("ExperimentName"), + ) + return 200, {}, json.dumps({}) + + @amzn_request_id + def create_experiment(self, *args, **kwargs): + response = self.sagemaker_backend.create_experiment( + experiment_name=self._get_param("ExperimentName"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_experiment(self, *args, **kwargs): + response = self.sagemaker_backend.describe_experiment( + experiment_name=self._get_param("ExperimentName"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_trials(self): + response = self.sagemaker_backend.list_trials( + experiment_name=self._get_param("ExperimentName"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def create_trial(self, *args, **kwargs): + try: + response = self.sagemaker_backend.create_trial( + trial_name=self._get_param("TrialName"), + experiment_name=self._get_param("ExperimentName"), + ) + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def list_trial_components(self): + response = self.sagemaker_backend.list_trial_components( + trial_name=self._get_param("TrialName"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def create_trial_component(self, *args, **kwargs): + try: + response = self.sagemaker_backend.create_trial_component( + trial_component_name=self._get_param("TrialComponentName"), + trial_name=self._get_param("TrialName"), + ) + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_trial(self, *args, **kwargs): + trial_name = self._get_param("TrialName") + response = self.sagemaker_backend.describe_trial(trial_name) + return json.dumps(response) + + @amzn_request_id + def delete_trial(self): + trial_name = self._get_param("TrialName") + self.sagemaker_backend.delete_trial(trial_name) + return 200, {}, json.dumps({}) + + @amzn_request_id + def delete_trial_component(self): + trial_component_name = self._get_param("TrialComponentName") + self.sagemaker_backend.delete_trial_component(trial_component_name) + return 200, {}, json.dumps({}) + + @amzn_request_id + def describe_trial_component(self, *args, **kwargs): + trial_component_name = self._get_param("TrialComponentName") + response = self.sagemaker_backend.describe_trial_component(trial_component_name) + return json.dumps(response) + + @amzn_request_id + def associate_trial_component(self, *args, **kwargs): + self.sagemaker_backend.associate_trial_component(self.request_params) + response = {} + return 200, {}, json.dumps(response) + + @amzn_request_id + def disassociate_trial_component(self, *args, **kwargs): + self.sagemaker_backend.disassociate_trial_component(self.request_params) + response = {} + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_associations(self, *args, **kwargs): + response = self.sagemaker_backend.list_associations(self.request_params) + return 200, {}, json.dumps(response) + @amzn_request_id def list_training_jobs(self): max_results_range = range(1, 101) diff --git a/tests/test_sagemaker/test_sagemaker_experiment.py b/tests/test_sagemaker/test_sagemaker_experiment.py new file mode 100644 index 000000000..8447cb900 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_experiment.py @@ -0,0 +1,91 @@ +import boto3 + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID + +TEST_REGION_NAME = "us-east-1" + + +@mock_sagemaker +def test_create_experiment(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_experiments() + + assert len(resp["ExperimentSummaries"]) == 1 + assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name + assert ( + resp["ExperimentSummaries"][0]["ExperimentArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}" + ) + + +@mock_sagemaker +def test_delete_experiment(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + resp = client.delete_experiment(ExperimentName=experiment_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_experiments() + + assert len(resp["ExperimentSummaries"]) == 0 + + +@mock_sagemaker +def test_add_tags_to_experiment(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + resp = client.describe_experiment(ExperimentName=experiment_name) + + arn = resp["ExperimentArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_to_experiment(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + resp = client.describe_experiment(ExperimentName=experiment_name) + + arn = resp["ExperimentArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_search.py b/tests/test_sagemaker/test_sagemaker_search.py new file mode 100644 index 000000000..cbe42ac08 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_search.py @@ -0,0 +1,61 @@ +import boto3 + +from moto import mock_sagemaker + +TEST_REGION_NAME = "us-east-1" + + +@mock_sagemaker +def test_search(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + trial_component_name = "some-trial-component-name" + another_trial_component_name = "another-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + resp = client.create_trial_component( + TrialComponentName=another_trial_component_name + ) + + resp = client.search(Resource="ExperimentTrialComponent") + + assert len(resp["Results"]) == 2 + + resp = client.describe_trial_component(TrialComponentName=trial_component_name) + + trial_component_arn = resp["TrialComponentArn"] + + tags = [{"Key": "key-name", "Value": "some-value"}] + + client.add_tags(ResourceArn=trial_component_arn, Tags=tags) + + resp = client.search( + Resource="ExperimentTrialComponent", + SearchExpression={ + "Filters": [ + {"Name": "Tags.key-name", "Operator": "Equals", "Value": "some-value"} + ] + }, + ) + + assert len(resp["Results"]) == 1 + assert ( + resp["Results"][0]["TrialComponent"]["TrialComponentName"] + == trial_component_name + ) + + resp = client.search(Resource="Experiment") + assert len(resp["Results"]) == 1 + assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name + + resp = client.search(Resource="ExperimentTrial") + assert len(resp["Results"]) == 1 + assert resp["Results"][0]["Trial"]["TrialName"] == trial_name diff --git a/tests/test_sagemaker/test_sagemaker_trial.py b/tests/test_sagemaker/test_sagemaker_trial.py new file mode 100644 index 000000000..38e376062 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_trial.py @@ -0,0 +1,106 @@ +import boto3 + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID + +TEST_REGION_NAME = "us-east-1" + + +@mock_sagemaker +def test_create_trial(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_trials() + + assert len(resp["TrialSummaries"]) == 1 + assert resp["TrialSummaries"][0]["TrialName"] == trial_name + assert ( + resp["TrialSummaries"][0]["TrialArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}" + ) + + +@mock_sagemaker +def test_delete_trial(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + resp = client.delete_trial(TrialName=trial_name) + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_trials() + + assert len(resp["TrialSummaries"]) == 0 + + +@mock_sagemaker +def test_add_tags_to_trial(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + resp = client.describe_trial(TrialName=trial_name) + + arn = resp["TrialArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_to_trial(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + resp = client.describe_trial(TrialName=trial_name) + + arn = resp["TrialArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py new file mode 100644 index 000000000..53220ae5f --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -0,0 +1,122 @@ +import boto3 + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID + +TEST_REGION_NAME = "us-east-1" + + +@mock_sagemaker +def test_create__trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_trial_components() + + assert len(resp["TrialComponentSummaries"]) == 1 + assert ( + resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name + ) + assert ( + resp["TrialComponentSummaries"][0]["TrialComponentArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + ) + + +@mock_sagemaker +def test_delete__trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + resp = client.delete_trial_component(TrialComponentName=trial_component_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_trial_components() + + assert len(resp["TrialComponentSummaries"]) == 0 + + +@mock_sagemaker +def test_add_tags_to_trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + resp = client.describe_trial_component(TrialComponentName=trial_component_name) + + arn = resp["TrialComponentArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_to_trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + resp = client.describe_trial_component(TrialComponentName=trial_component_name) + + arn = resp["TrialComponentArn"] + + tags = [{"Key": "name", "Value": "value"}] + + client.add_tags(ResourceArn=arn, Tags=tags) + + client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_tags(ResourceArn=arn) + + assert resp["Tags"] == [] + + +@mock_sagemaker +def test_associate_trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + client.associate_trial_component( + TrialComponentName=trial_component_name, TrialName=trial_name + ) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_trial_components(TrialName=trial_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert ( + resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name + )