From 7764a94491f1e13e0bc5358cb7966c3bfe841895 Mon Sep 17 00:00:00 2001 From: Bogdan Girman Date: Mon, 1 Nov 2021 23:30:07 +0100 Subject: [PATCH] Improve sagemaker (#4517) --- moto/sagemaker/exceptions.py | 5 ++ moto/sagemaker/models.py | 69 +++++++++++++-- moto/sagemaker/responses.py | 9 +- tests/test_sagemaker/test_sagemaker_search.py | 56 ++++++++++++ tests/test_sagemaker/test_sagemaker_trial.py | 22 +++++ .../test_sagemaker_trial_component.py | 88 ++++++++++++++++++- 6 files changed, 236 insertions(+), 13 deletions(-) diff --git a/moto/sagemaker/exceptions.py b/moto/sagemaker/exceptions.py index f157214c0..68a03f86d 100644 --- a/moto/sagemaker/exceptions.py +++ b/moto/sagemaker/exceptions.py @@ -35,3 +35,8 @@ class ValidationError(JsonRESTError): class AWSValidationException(AWSError): TYPE = "ValidationException" + + +class ResourceNotFound(JsonRESTError): + def __init__(self, message, **kwargs): + super(ResourceNotFound, self).__init__(__class__.__name__, message, **kwargs) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 9bb5b3356..4593aebe1 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -5,7 +5,12 @@ from datetime import datetime from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import RESTError from moto.sagemaker import validators -from .exceptions import MissingModel, ValidationError, AWSValidationException +from .exceptions import ( + MissingModel, + ValidationError, + AWSValidationException, + ResourceNotFound, +) class BaseObject(BaseModel): @@ -1102,8 +1107,13 @@ class SageMakerModelBackend(BaseBackend): if f["Name"] == "ExperimentName": experiment_name = f["Value"] - if getattr(item, "experiment_name") != experiment_name: - return False + if hasattr(item, "experiment_name"): + if getattr(item, "experiment_name") != experiment_name: + return False + else: + raise ValidationError( + message="Unknown property name: ExperimentName" + ) if f["Name"] == "TrialName": raise AWSValidationException( @@ -1209,6 +1219,7 @@ class SageMakerModelBackend(BaseBackend): trial_name=trial_name, experiment_name=experiment_name, tags=[], + trial_components=[], ) self.trials[trial_name] = trial return trial.response_create @@ -1245,11 +1256,22 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] - def list_trials(self, experiment_name=None): + def list_trials(self, experiment_name=None, trial_component_name=None): next_index = None trials_fetched = list(self.trials.values()) + def evaluate_filter_expression(trial_data): + if experiment_name is not None: + if trial_data.experiment_name != experiment_name: + return False + + if trial_component_name is not None: + if trial_component_name not in trial_data.trial_components: + return False + + return True + trial_summaries = [ { "TrialName": trial_data.trial_name, @@ -1258,7 +1280,7 @@ class SageMakerModelBackend(BaseBackend): "LastModifiedTime": trial_data.last_modified_time, } for trial_data in trials_fetched - if experiment_name is None or trial_data.experiment_name == experiment_name + if evaluate_filter_expression(trial_data) ] return { @@ -1342,12 +1364,42 @@ class SageMakerModelBackend(BaseBackend): trial_name = params["TrialName"] trial_component_name = params["TrialComponentName"] - self.trial_components[trial_component_name].trial_name = trial_name + if trial_name in self.trials.keys(): + self.trials[trial_name].trial_components.extend([trial_component_name]) + else: + raise ResourceNotFound( + message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}' does not exist." + ) + + if trial_component_name in self.trial_components.keys(): + self.trial_components[trial_component_name].trial_name = trial_name + + return { + "TrialComponentArn": self.trial_components[ + trial_component_name + ].trial_component_arn, + "TrialArn": self.trials[trial_name].trial_arn, + } def disassociate_trial_component(self, params): trial_component_name = params["TrialComponentName"] + trial_name = params["TrialName"] - self.trial_components[trial_component_name].trial_name = None + if trial_component_name in self.trial_components.keys(): + self.trial_components[trial_component_name].trial_name = None + + if trial_name in self.trials.keys(): + self.trials[trial_name].trial_components = list( + filter( + lambda x: x != trial_component_name, + self.trials[trial_name].trial_components, + ) + ) + + return { + "TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}", + "TrialArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}", + } def create_notebook_instance( self, @@ -1804,11 +1856,12 @@ class FakeExperiment(BaseObject): class FakeTrial(BaseObject): def __init__( - self, region_name, trial_name, experiment_name, tags, + self, region_name, trial_name, experiment_name, tags, trial_components, ): self.trial_name = trial_name self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name) self.tags = tags + self.trial_components = trial_components self.experiment_name = experiment_name self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 867e877ea..8ba597618 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -346,6 +346,7 @@ class SageMakerResponse(BaseResponse): def list_trials(self): response = self.sagemaker_backend.list_trials( experiment_name=self._get_param("ExperimentName"), + trial_component_name=self._get_param("TrialComponentName"), ) return 200, {}, json.dumps(response) @@ -404,14 +405,14 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def associate_trial_component(self, *args, **kwargs): - self.sagemaker_backend.associate_trial_component(self.request_params) - response = {} + response = self.sagemaker_backend.associate_trial_component(self.request_params) return 200, {}, json.dumps(response) @amzn_request_id def disassociate_trial_component(self, *args, **kwargs): - self.sagemaker_backend.disassociate_trial_component(self.request_params) - response = {} + response = self.sagemaker_backend.disassociate_trial_component( + self.request_params + ) return 200, {}, json.dumps(response) @amzn_request_id diff --git a/tests/test_sagemaker/test_sagemaker_search.py b/tests/test_sagemaker/test_sagemaker_search.py index cbe42ac08..a5d3628f2 100644 --- a/tests/test_sagemaker/test_sagemaker_search.py +++ b/tests/test_sagemaker/test_sagemaker_search.py @@ -1,4 +1,7 @@ import boto3 +import pytest + +from botocore.exceptions import ClientError from moto import mock_sagemaker @@ -59,3 +62,56 @@ def test_search(): resp = client.search(Resource="ExperimentTrial") assert len(resp["Results"]) == 1 assert resp["Results"][0]["Trial"]["TrialName"] == trial_name + + +@mock_sagemaker +def test_search_trial_component_with_experiment_name(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + trial_component_name = "some-trial-component-name" + another_trial_component_name = "another-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + resp = client.create_trial_component( + TrialComponentName=another_trial_component_name + ) + + resp = client.search(Resource="ExperimentTrialComponent") + + assert len(resp["Results"]) == 2 + + resp = client.describe_trial_component(TrialComponentName=trial_component_name) + + trial_component_arn = resp["TrialComponentArn"] + + tags = [{"Key": "key-name", "Value": "some-value"}] + + client.add_tags(ResourceArn=trial_component_arn, Tags=tags) + + with pytest.raises(ClientError) as ex: + client.search( + Resource="ExperimentTrialComponent", + SearchExpression={ + "Filters": [ + { + "Name": "ExperimentName", + "Operator": "Equals", + "Value": experiment_name, + } + ] + }, + ) + + ex.value.response["Error"]["Code"].should.equal("ValidationException") + ex.value.response["Error"]["Message"].should.equal( + "Unknown property name: ExperimentName" + ) + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) diff --git a/tests/test_sagemaker/test_sagemaker_trial.py b/tests/test_sagemaker/test_sagemaker_trial.py index 38e376062..2f5007f77 100644 --- a/tests/test_sagemaker/test_sagemaker_trial.py +++ b/tests/test_sagemaker/test_sagemaker_trial.py @@ -30,6 +30,28 @@ def test_create_trial(): ) +@mock_sagemaker +def test_list_trials_by_trial_component_name(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + resp = client.list_trials(TrialComponentName=trial_component_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert len(resp["TrialSummaries"]) == 0 + + @mock_sagemaker def test_delete_trial(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py index 53220ae5f..1bf15d008 100644 --- a/tests/test_sagemaker/test_sagemaker_trial_component.py +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -1,4 +1,7 @@ import boto3 +import pytest + +from botocore.exceptions import ClientError from moto import mock_sagemaker from moto.sts.models import ACCOUNT_ID @@ -108,11 +111,19 @@ def test_associate_trial_component(): resp = client.create_trial_component(TrialComponentName=trial_component_name) - client.associate_trial_component( + resp = client.associate_trial_component( TrialComponentName=trial_component_name, TrialName=trial_name ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert ( + resp["TrialComponentArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + ) + assert ( + resp["TrialArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}" + ) resp = client.list_trial_components(TrialName=trial_name) @@ -120,3 +131,78 @@ def test_associate_trial_component(): assert ( resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name ) + + resp = client.list_trials(TrialComponentName=trial_component_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert resp["TrialSummaries"][0]["TrialName"] == trial_name + + with pytest.raises(ClientError) as ex: + resp = client.associate_trial_component( + TrialComponentName="does-not-exist", TrialName="does-not-exist" + ) + + ex.value.response["Error"]["Code"].should.equal("ResourceNotFound") + ex.value.response["Error"]["Message"].should.equal( + f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist." + ) + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + + +@mock_sagemaker +def test_disassociate_trial_component(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + trial_component_name = "some-trial-component-name" + + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + client.associate_trial_component( + TrialComponentName=trial_component_name, TrialName=trial_name + ) + + resp = client.disassociate_trial_component( + TrialComponentName=trial_component_name, TrialName=trial_name + ) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert ( + resp["TrialComponentArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + ) + assert ( + resp["TrialArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}" + ) + + resp = client.list_trial_components(TrialName=trial_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert len(resp["TrialComponentSummaries"]) == 0 + + resp = client.list_trials(TrialComponentName=trial_component_name) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert len(resp["TrialSummaries"]) == 0 + + resp = client.disassociate_trial_component( + TrialComponentName="does-not-exist", TrialName="does-not-exist" + ) + + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert ( + resp["TrialComponentArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist" + ) + assert ( + resp["TrialArn"] + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist" + )