Add pagination for list experiments/trials/components (#4686)

This commit is contained in:
Bogdan Girman 2021-12-15 11:32:19 +01:00 committed by GitHub
parent ecc00606c4
commit de9aa9a8e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 195 additions and 49 deletions

View File

@ -5,6 +5,7 @@ from datetime import datetime
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.sagemaker import validators
from moto.utilities.paginator import paginate
from .exceptions import (
MissingModel,
ValidationError,
@ -13,6 +14,31 @@ from .exceptions import (
)
PAGINATION_MODEL = {
"list_experiments": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "experiment_arn",
"fail_on_invalid_token": True,
},
"list_trials": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "trial_arn",
"fail_on_invalid_token": True,
},
"list_trial_components": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "trial_component_arn",
"fail_on_invalid_token": True,
},
}
class BaseObject(BaseModel):
def camelCase(self, key):
words = []
@ -1114,25 +1140,9 @@ class SageMakerModelBackend(BaseBackend):
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
]
@paginate(pagination_model=PAGINATION_MODEL)
def list_experiments(self):
next_index = None
experiments_fetched = list(self.experiments.values())
experiment_summaries = [
{
"ExperimentName": experiment_data.experiment_name,
"ExperimentArn": experiment_data.experiment_arn,
"CreationTime": experiment_data.creation_time,
"LastModifiedTime": experiment_data.last_modified_time,
}
for experiment_data in experiments_fetched
]
return {
"ExperimentSummaries": experiment_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
return list(self.experiments.values())
def search(self, resource=None, search_expression=None):
next_index = None
@ -1333,9 +1343,8 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
@paginate(pagination_model=PAGINATION_MODEL)
def list_trials(self, experiment_name=None, trial_component_name=None):
next_index = None
trials_fetched = list(self.trials.values())
def evaluate_filter_expression(trial_data):
@ -1349,22 +1358,12 @@ class SageMakerModelBackend(BaseBackend):
return True
trial_summaries = [
{
"TrialName": trial_data.trial_name,
"TrialArn": trial_data.trial_arn,
"CreationTime": trial_data.creation_time,
"LastModifiedTime": trial_data.last_modified_time,
}
return [
trial_data
for trial_data in trials_fetched
if evaluate_filter_expression(trial_data)
]
return {
"TrialSummaries": trial_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def create_trial_component(
self, trial_component_name, trial_name,
):
@ -1416,27 +1415,16 @@ class SageMakerModelBackend(BaseBackend):
def _update_trial_component_details(self, trial_component_name, details_json):
self.trial_components[trial_component_name].update(details_json)
@paginate(pagination_model=PAGINATION_MODEL)
def list_trial_components(self, trial_name=None):
next_index = None
trial_components_fetched = list(self.trial_components.values())
trial_component_summaries = [
{
"TrialComponentName": trial_component_data.trial_component_name,
"TrialComponentArn": trial_component_data.trial_component_arn,
"CreationTime": trial_component_data.creation_time,
"LastModifiedTime": trial_component_data.last_modified_time,
}
return [
trial_component_data
for trial_component_data in trial_components_fetched
if trial_name is None or trial_component_data.trial_name == trial_name
]
return {
"TrialComponentSummaries": trial_component_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def associate_trial_component(self, params):
trial_name = params["TrialName"]
trial_component_name = params["TrialComponentName"]

View File

@ -345,7 +345,30 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def list_experiments(self):
response = self.sagemaker_backend.list_experiments()
MaxResults = self._get_param("MaxResults")
NextToken = self._get_param("NextToken")
paged_results, next_token = self.sagemaker_backend.list_experiments(
MaxResults=MaxResults, NextToken=NextToken,
)
experiment_summaries = [
{
"ExperimentName": experiment_data.experiment_name,
"ExperimentArn": experiment_data.experiment_arn,
"CreationTime": experiment_data.creation_time,
"LastModifiedTime": experiment_data.last_modified_time,
}
for experiment_data in paged_results
]
response = {
"ExperimentSummaries": experiment_summaries,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response)
@amzn_request_id
@ -371,10 +394,33 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def list_trials(self):
response = self.sagemaker_backend.list_trials(
MaxResults = self._get_param("MaxResults")
NextToken = self._get_param("NextToken")
paged_results, next_token = self.sagemaker_backend.list_trials(
NextToken=NextToken,
MaxResults=MaxResults,
experiment_name=self._get_param("ExperimentName"),
trial_component_name=self._get_param("TrialComponentName"),
)
trial_summaries = [
{
"TrialName": trial_data.trial_name,
"TrialArn": trial_data.trial_arn,
"CreationTime": trial_data.creation_time,
"LastModifiedTime": trial_data.last_modified_time,
}
for trial_data in paged_results
]
response = {
"TrialSummaries": trial_summaries,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response)
@amzn_request_id
@ -390,9 +436,32 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def list_trial_components(self):
response = self.sagemaker_backend.list_trial_components(
MaxResults = self._get_param("MaxResults")
NextToken = self._get_param("NextToken")
paged_results, next_token = self.sagemaker_backend.list_trial_components(
NextToken=NextToken,
MaxResults=MaxResults,
trial_name=self._get_param("TrialName"),
)
trial_component_summaries = [
{
"TrialComponentName": trial_component_data.trial_component_name,
"TrialComponentArn": trial_component_data.trial_component_arn,
"CreationTime": trial_component_data.creation_time,
"LastModifiedTime": trial_component_data.last_modified_time,
}
for trial_component_data in paged_results
]
response = {
"TrialComponentSummaries": trial_component_summaries,
}
if next_token:
response["NextToken"] = next_token
return 200, {}, json.dumps(response)
@amzn_request_id

View File

@ -26,6 +26,35 @@ def test_create_experiment():
)
@mock_sagemaker
def test_list_experiments():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
for experiment_name in experiment_names:
resp = client.create_experiment(ExperimentName=experiment_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
resp = client.list_experiments(MaxResults=1)
assert len(resp["ExperimentSummaries"]) == 1
next_token = resp["NextToken"]
resp = client.list_experiments(MaxResults=2, NextToken=next_token)
assert len(resp["ExperimentSummaries"]) == 2
next_token = resp["NextToken"]
resp = client.list_experiments(NextToken=next_token)
assert len(resp["ExperimentSummaries"]) == 7
assert resp.get("NextToken") is None
@mock_sagemaker
def test_delete_experiment():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)

View File

@ -30,6 +30,38 @@ def test_create_trial():
)
@mock_sagemaker
def test_list_trials():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_names = [f"some-trial-name-{i}" for i in range(10)]
for trial_name in trial_names:
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
resp = client.list_trials(MaxResults=1)
assert len(resp["TrialSummaries"]) == 1
next_token = resp["NextToken"]
resp = client.list_trials(MaxResults=2, NextToken=next_token)
assert len(resp["TrialSummaries"]) == 2
next_token = resp["NextToken"]
resp = client.list_trials(NextToken=next_token)
assert len(resp["TrialSummaries"]) == 7
assert resp.get("NextToken") is None
@mock_sagemaker
def test_list_trials_by_trial_component_name():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)

View File

@ -31,6 +31,34 @@ def test_create__trial_component():
)
@mock_sagemaker
def test_list_trial_components():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
trial_component_names = [f"some-trial-component-name-{i}" for i in range(10)]
for trial_component_name in trial_component_names:
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.list_trial_components(MaxResults=1)
assert len(resp["TrialComponentSummaries"]) == 1
next_token = resp["NextToken"]
resp = client.list_trial_components(MaxResults=2, NextToken=next_token)
assert len(resp["TrialComponentSummaries"]) == 2
next_token = resp["NextToken"]
resp = client.list_trial_components(NextToken=next_token)
assert len(resp["TrialComponentSummaries"]) == 7
assert resp.get("NextToken") is None
@mock_sagemaker
def test_delete__trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)