From de9aa9a8e3b998902381a565e25b5da17960c7f7 Mon Sep 17 00:00:00 2001 From: Bogdan Girman Date: Wed, 15 Dec 2021 11:32:19 +0100 Subject: [PATCH] Add pagination for list experiments/trials/components (#4686) --- moto/sagemaker/models.py | 80 ++++++++----------- moto/sagemaker/responses.py | 75 ++++++++++++++++- .../test_sagemaker_experiment.py | 29 +++++++ tests/test_sagemaker/test_sagemaker_trial.py | 32 ++++++++ .../test_sagemaker_trial_component.py | 28 +++++++ 5 files changed, 195 insertions(+), 49 deletions(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index a999630e5..3a0fca030 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -5,6 +5,7 @@ from datetime import datetime from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import RESTError from moto.sagemaker import validators +from moto.utilities.paginator import paginate from .exceptions import ( MissingModel, ValidationError, @@ -13,6 +14,31 @@ from .exceptions import ( ) +PAGINATION_MODEL = { + "list_experiments": { + "input_token": "NextToken", + "limit_key": "MaxResults", + "limit_default": 100, + "unique_attribute": "experiment_arn", + "fail_on_invalid_token": True, + }, + "list_trials": { + "input_token": "NextToken", + "limit_key": "MaxResults", + "limit_default": 100, + "unique_attribute": "trial_arn", + "fail_on_invalid_token": True, + }, + "list_trial_components": { + "input_token": "NextToken", + "limit_key": "MaxResults", + "limit_default": 100, + "unique_attribute": "trial_component_arn", + "fail_on_invalid_token": True, + }, +} + + class BaseObject(BaseModel): def camelCase(self, key): words = [] @@ -1114,25 +1140,9 @@ class SageMakerModelBackend(BaseBackend): tag for tag in trial_component.tags if tag["Key"] not in tag_keys ] + @paginate(pagination_model=PAGINATION_MODEL) def list_experiments(self): - next_index = None - - experiments_fetched = list(self.experiments.values()) - - experiment_summaries = [ - { - "ExperimentName": experiment_data.experiment_name, - "ExperimentArn": experiment_data.experiment_arn, - "CreationTime": experiment_data.creation_time, - "LastModifiedTime": experiment_data.last_modified_time, - } - for experiment_data in experiments_fetched - ] - - return { - "ExperimentSummaries": experiment_summaries, - "NextToken": str(next_index) if next_index is not None else None, - } + return list(self.experiments.values()) def search(self, resource=None, search_expression=None): next_index = None @@ -1333,9 +1343,8 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + @paginate(pagination_model=PAGINATION_MODEL) def list_trials(self, experiment_name=None, trial_component_name=None): - next_index = None - trials_fetched = list(self.trials.values()) def evaluate_filter_expression(trial_data): @@ -1349,22 +1358,12 @@ class SageMakerModelBackend(BaseBackend): return True - trial_summaries = [ - { - "TrialName": trial_data.trial_name, - "TrialArn": trial_data.trial_arn, - "CreationTime": trial_data.creation_time, - "LastModifiedTime": trial_data.last_modified_time, - } + return [ + trial_data for trial_data in trials_fetched if evaluate_filter_expression(trial_data) ] - return { - "TrialSummaries": trial_summaries, - "NextToken": str(next_index) if next_index is not None else None, - } - def create_trial_component( self, trial_component_name, trial_name, ): @@ -1416,27 +1415,16 @@ class SageMakerModelBackend(BaseBackend): def _update_trial_component_details(self, trial_component_name, details_json): self.trial_components[trial_component_name].update(details_json) + @paginate(pagination_model=PAGINATION_MODEL) def list_trial_components(self, trial_name=None): - next_index = None - trial_components_fetched = list(self.trial_components.values()) - trial_component_summaries = [ - { - "TrialComponentName": trial_component_data.trial_component_name, - "TrialComponentArn": trial_component_data.trial_component_arn, - "CreationTime": trial_component_data.creation_time, - "LastModifiedTime": trial_component_data.last_modified_time, - } + return [ + trial_component_data for trial_component_data in trial_components_fetched if trial_name is None or trial_component_data.trial_name == trial_name ] - return { - "TrialComponentSummaries": trial_component_summaries, - "NextToken": str(next_index) if next_index is not None else None, - } - def associate_trial_component(self, params): trial_name = params["TrialName"] trial_component_name = params["TrialComponentName"] diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 46d92b077..0aac52ce3 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -345,7 +345,30 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def list_experiments(self): - response = self.sagemaker_backend.list_experiments() + MaxResults = self._get_param("MaxResults") + NextToken = self._get_param("NextToken") + + paged_results, next_token = self.sagemaker_backend.list_experiments( + MaxResults=MaxResults, NextToken=NextToken, + ) + + experiment_summaries = [ + { + "ExperimentName": experiment_data.experiment_name, + "ExperimentArn": experiment_data.experiment_arn, + "CreationTime": experiment_data.creation_time, + "LastModifiedTime": experiment_data.last_modified_time, + } + for experiment_data in paged_results + ] + + response = { + "ExperimentSummaries": experiment_summaries, + } + + if next_token: + response["NextToken"] = next_token + return 200, {}, json.dumps(response) @amzn_request_id @@ -371,10 +394,33 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def list_trials(self): - response = self.sagemaker_backend.list_trials( + MaxResults = self._get_param("MaxResults") + NextToken = self._get_param("NextToken") + + paged_results, next_token = self.sagemaker_backend.list_trials( + NextToken=NextToken, + MaxResults=MaxResults, experiment_name=self._get_param("ExperimentName"), trial_component_name=self._get_param("TrialComponentName"), ) + + trial_summaries = [ + { + "TrialName": trial_data.trial_name, + "TrialArn": trial_data.trial_arn, + "CreationTime": trial_data.creation_time, + "LastModifiedTime": trial_data.last_modified_time, + } + for trial_data in paged_results + ] + + response = { + "TrialSummaries": trial_summaries, + } + + if next_token: + response["NextToken"] = next_token + return 200, {}, json.dumps(response) @amzn_request_id @@ -390,9 +436,32 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def list_trial_components(self): - response = self.sagemaker_backend.list_trial_components( + MaxResults = self._get_param("MaxResults") + NextToken = self._get_param("NextToken") + + paged_results, next_token = self.sagemaker_backend.list_trial_components( + NextToken=NextToken, + MaxResults=MaxResults, trial_name=self._get_param("TrialName"), ) + + trial_component_summaries = [ + { + "TrialComponentName": trial_component_data.trial_component_name, + "TrialComponentArn": trial_component_data.trial_component_arn, + "CreationTime": trial_component_data.creation_time, + "LastModifiedTime": trial_component_data.last_modified_time, + } + for trial_component_data in paged_results + ] + + response = { + "TrialComponentSummaries": trial_component_summaries, + } + + if next_token: + response["NextToken"] = next_token + return 200, {}, json.dumps(response) @amzn_request_id diff --git a/tests/test_sagemaker/test_sagemaker_experiment.py b/tests/test_sagemaker/test_sagemaker_experiment.py index 8447cb900..810a82ee3 100644 --- a/tests/test_sagemaker/test_sagemaker_experiment.py +++ b/tests/test_sagemaker/test_sagemaker_experiment.py @@ -26,6 +26,35 @@ def test_create_experiment(): ) +@mock_sagemaker +def test_list_experiments(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_names = [f"some-experiment-name-{i}" for i in range(10)] + + for experiment_name in experiment_names: + resp = client.create_experiment(ExperimentName=experiment_name) + assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 + + resp = client.list_experiments(MaxResults=1) + + assert len(resp["ExperimentSummaries"]) == 1 + + next_token = resp["NextToken"] + + resp = client.list_experiments(MaxResults=2, NextToken=next_token) + + assert len(resp["ExperimentSummaries"]) == 2 + + next_token = resp["NextToken"] + + resp = client.list_experiments(NextToken=next_token) + + assert len(resp["ExperimentSummaries"]) == 7 + + assert resp.get("NextToken") is None + + @mock_sagemaker def test_delete_experiment(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) diff --git a/tests/test_sagemaker/test_sagemaker_trial.py b/tests/test_sagemaker/test_sagemaker_trial.py index 2f5007f77..cf02b26bd 100644 --- a/tests/test_sagemaker/test_sagemaker_trial.py +++ b/tests/test_sagemaker/test_sagemaker_trial.py @@ -30,6 +30,38 @@ def test_create_trial(): ) +@mock_sagemaker +def test_list_trials(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + + resp = client.create_experiment(ExperimentName=experiment_name) + + trial_names = [f"some-trial-name-{i}" for i in range(10)] + + for trial_name in trial_names: + resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + + resp = client.list_trials(MaxResults=1) + + assert len(resp["TrialSummaries"]) == 1 + + next_token = resp["NextToken"] + + resp = client.list_trials(MaxResults=2, NextToken=next_token) + + assert len(resp["TrialSummaries"]) == 2 + + next_token = resp["NextToken"] + + resp = client.list_trials(NextToken=next_token) + + assert len(resp["TrialSummaries"]) == 7 + + assert resp.get("NextToken") is None + + @mock_sagemaker def test_list_trials_by_trial_component_name(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py index 1bf15d008..e120f7ccd 100644 --- a/tests/test_sagemaker/test_sagemaker_trial_component.py +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -31,6 +31,34 @@ def test_create__trial_component(): ) +@mock_sagemaker +def test_list_trial_components(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_names = [f"some-trial-component-name-{i}" for i in range(10)] + + for trial_component_name in trial_component_names: + resp = client.create_trial_component(TrialComponentName=trial_component_name) + + resp = client.list_trial_components(MaxResults=1) + + assert len(resp["TrialComponentSummaries"]) == 1 + + next_token = resp["NextToken"] + + resp = client.list_trial_components(MaxResults=2, NextToken=next_token) + + assert len(resp["TrialComponentSummaries"]) == 2 + + next_token = resp["NextToken"] + + resp = client.list_trial_components(NextToken=next_token) + + assert len(resp["TrialComponentSummaries"]) == 7 + + assert resp.get("NextToken") is None + + @mock_sagemaker def test_delete__trial_component(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)