Add pagination for list experiments/trials/components (#4686)
This commit is contained in:
parent
ecc00606c4
commit
de9aa9a8e3
@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.sagemaker import validators
|
||||
from moto.utilities.paginator import paginate
|
||||
from .exceptions import (
|
||||
MissingModel,
|
||||
ValidationError,
|
||||
@ -13,6 +14,31 @@ from .exceptions import (
|
||||
)
|
||||
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_experiments": {
|
||||
"input_token": "NextToken",
|
||||
"limit_key": "MaxResults",
|
||||
"limit_default": 100,
|
||||
"unique_attribute": "experiment_arn",
|
||||
"fail_on_invalid_token": True,
|
||||
},
|
||||
"list_trials": {
|
||||
"input_token": "NextToken",
|
||||
"limit_key": "MaxResults",
|
||||
"limit_default": 100,
|
||||
"unique_attribute": "trial_arn",
|
||||
"fail_on_invalid_token": True,
|
||||
},
|
||||
"list_trial_components": {
|
||||
"input_token": "NextToken",
|
||||
"limit_key": "MaxResults",
|
||||
"limit_default": 100,
|
||||
"unique_attribute": "trial_component_arn",
|
||||
"fail_on_invalid_token": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
def camelCase(self, key):
|
||||
words = []
|
||||
@ -1114,25 +1140,9 @@ class SageMakerModelBackend(BaseBackend):
|
||||
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
|
||||
]
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_experiments(self):
|
||||
next_index = None
|
||||
|
||||
experiments_fetched = list(self.experiments.values())
|
||||
|
||||
experiment_summaries = [
|
||||
{
|
||||
"ExperimentName": experiment_data.experiment_name,
|
||||
"ExperimentArn": experiment_data.experiment_arn,
|
||||
"CreationTime": experiment_data.creation_time,
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
for experiment_data in experiments_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"ExperimentSummaries": experiment_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
return list(self.experiments.values())
|
||||
|
||||
def search(self, resource=None, search_expression=None):
|
||||
next_index = None
|
||||
@ -1333,9 +1343,8 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_trials(self, experiment_name=None, trial_component_name=None):
|
||||
next_index = None
|
||||
|
||||
trials_fetched = list(self.trials.values())
|
||||
|
||||
def evaluate_filter_expression(trial_data):
|
||||
@ -1349,22 +1358,12 @@ class SageMakerModelBackend(BaseBackend):
|
||||
|
||||
return True
|
||||
|
||||
trial_summaries = [
|
||||
{
|
||||
"TrialName": trial_data.trial_name,
|
||||
"TrialArn": trial_data.trial_arn,
|
||||
"CreationTime": trial_data.creation_time,
|
||||
"LastModifiedTime": trial_data.last_modified_time,
|
||||
}
|
||||
return [
|
||||
trial_data
|
||||
for trial_data in trials_fetched
|
||||
if evaluate_filter_expression(trial_data)
|
||||
]
|
||||
|
||||
return {
|
||||
"TrialSummaries": trial_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def create_trial_component(
|
||||
self, trial_component_name, trial_name,
|
||||
):
|
||||
@ -1416,27 +1415,16 @@ class SageMakerModelBackend(BaseBackend):
|
||||
def _update_trial_component_details(self, trial_component_name, details_json):
|
||||
self.trial_components[trial_component_name].update(details_json)
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_trial_components(self, trial_name=None):
|
||||
next_index = None
|
||||
|
||||
trial_components_fetched = list(self.trial_components.values())
|
||||
|
||||
trial_component_summaries = [
|
||||
{
|
||||
"TrialComponentName": trial_component_data.trial_component_name,
|
||||
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||
"CreationTime": trial_component_data.creation_time,
|
||||
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||
}
|
||||
return [
|
||||
trial_component_data
|
||||
for trial_component_data in trial_components_fetched
|
||||
if trial_name is None or trial_component_data.trial_name == trial_name
|
||||
]
|
||||
|
||||
return {
|
||||
"TrialComponentSummaries": trial_component_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def associate_trial_component(self, params):
|
||||
trial_name = params["TrialName"]
|
||||
trial_component_name = params["TrialComponentName"]
|
||||
|
@ -345,7 +345,30 @@ class SageMakerResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def list_experiments(self):
|
||||
response = self.sagemaker_backend.list_experiments()
|
||||
MaxResults = self._get_param("MaxResults")
|
||||
NextToken = self._get_param("NextToken")
|
||||
|
||||
paged_results, next_token = self.sagemaker_backend.list_experiments(
|
||||
MaxResults=MaxResults, NextToken=NextToken,
|
||||
)
|
||||
|
||||
experiment_summaries = [
|
||||
{
|
||||
"ExperimentName": experiment_data.experiment_name,
|
||||
"ExperimentArn": experiment_data.experiment_arn,
|
||||
"CreationTime": experiment_data.creation_time,
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
for experiment_data in paged_results
|
||||
]
|
||||
|
||||
response = {
|
||||
"ExperimentSummaries": experiment_summaries,
|
||||
}
|
||||
|
||||
if next_token:
|
||||
response["NextToken"] = next_token
|
||||
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
@ -371,10 +394,33 @@ class SageMakerResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def list_trials(self):
|
||||
response = self.sagemaker_backend.list_trials(
|
||||
MaxResults = self._get_param("MaxResults")
|
||||
NextToken = self._get_param("NextToken")
|
||||
|
||||
paged_results, next_token = self.sagemaker_backend.list_trials(
|
||||
NextToken=NextToken,
|
||||
MaxResults=MaxResults,
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
trial_component_name=self._get_param("TrialComponentName"),
|
||||
)
|
||||
|
||||
trial_summaries = [
|
||||
{
|
||||
"TrialName": trial_data.trial_name,
|
||||
"TrialArn": trial_data.trial_arn,
|
||||
"CreationTime": trial_data.creation_time,
|
||||
"LastModifiedTime": trial_data.last_modified_time,
|
||||
}
|
||||
for trial_data in paged_results
|
||||
]
|
||||
|
||||
response = {
|
||||
"TrialSummaries": trial_summaries,
|
||||
}
|
||||
|
||||
if next_token:
|
||||
response["NextToken"] = next_token
|
||||
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
@ -390,9 +436,32 @@ class SageMakerResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def list_trial_components(self):
|
||||
response = self.sagemaker_backend.list_trial_components(
|
||||
MaxResults = self._get_param("MaxResults")
|
||||
NextToken = self._get_param("NextToken")
|
||||
|
||||
paged_results, next_token = self.sagemaker_backend.list_trial_components(
|
||||
NextToken=NextToken,
|
||||
MaxResults=MaxResults,
|
||||
trial_name=self._get_param("TrialName"),
|
||||
)
|
||||
|
||||
trial_component_summaries = [
|
||||
{
|
||||
"TrialComponentName": trial_component_data.trial_component_name,
|
||||
"TrialComponentArn": trial_component_data.trial_component_arn,
|
||||
"CreationTime": trial_component_data.creation_time,
|
||||
"LastModifiedTime": trial_component_data.last_modified_time,
|
||||
}
|
||||
for trial_component_data in paged_results
|
||||
]
|
||||
|
||||
response = {
|
||||
"TrialComponentSummaries": trial_component_summaries,
|
||||
}
|
||||
|
||||
if next_token:
|
||||
response["NextToken"] = next_token
|
||||
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
|
@ -26,6 +26,35 @@ def test_create_experiment():
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_experiments():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
|
||||
|
||||
for experiment_name in experiment_names:
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments(MaxResults=1)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 1
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_experiments(MaxResults=2, NextToken=next_token)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 2
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_experiments(NextToken=next_token)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 7
|
||||
|
||||
assert resp.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
@ -30,6 +30,38 @@ def test_create_trial():
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trials():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_names = [f"some-trial-name-{i}" for i in range(10)]
|
||||
|
||||
for trial_name in trial_names:
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
resp = client.list_trials(MaxResults=1)
|
||||
|
||||
assert len(resp["TrialSummaries"]) == 1
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_trials(MaxResults=2, NextToken=next_token)
|
||||
|
||||
assert len(resp["TrialSummaries"]) == 2
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_trials(NextToken=next_token)
|
||||
|
||||
assert len(resp["TrialSummaries"]) == 7
|
||||
|
||||
assert resp.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trials_by_trial_component_name():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
@ -31,6 +31,34 @@ def test_create__trial_component():
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trial_components():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_names = [f"some-trial-component-name-{i}" for i in range(10)]
|
||||
|
||||
for trial_component_name in trial_component_names:
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = client.list_trial_components(MaxResults=1)
|
||||
|
||||
assert len(resp["TrialComponentSummaries"]) == 1
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_trial_components(MaxResults=2, NextToken=next_token)
|
||||
|
||||
assert len(resp["TrialComponentSummaries"]) == 2
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_trial_components(NextToken=next_token)
|
||||
|
||||
assert len(resp["TrialComponentSummaries"]) == 7
|
||||
|
||||
assert resp.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete__trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user