Sagemaker: Fix pagination for ModelPackages(Groups) (#6972)

This commit is contained in:
Bert Blommers 2023-10-31 07:56:42 -01:00 committed by GitHub
parent aa3770086b
commit d390aa673c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 31 deletions

View File

@ -52,13 +52,13 @@ PAGINATION_MODEL = {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageGroupArn",
"unique_attribute": "model_package_group_arn",
},
"list_model_packages": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageArn",
"unique_attribute": "model_package_arn",
},
"list_notebook_instances": {
"input_token": "next_token",

View File

@ -31,35 +31,35 @@ def test_create_model_package_group():
@mock_sagemaker
def test_list_model_package_groups():
client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group-1"
desc1 = "test-model-package-group-description-1"
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
ModelPackageGroupName=group1, ModelPackageGroupDescription=desc1
)
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups()
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
== "test-model-package-group-1"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0]
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"]
== "test-model-package-group-description-1"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1]
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"]
== "test-model-package-group-description-2"
group2 = "test-model-package-group-2"
desc2 = "test-model-package-group-description-2"
client.create_model_package_group(
ModelPackageGroupName=group2,
ModelPackageGroupDescription=desc2,
)
summary = client.list_model_package_groups()["ModelPackageGroupSummaryList"]
assert summary[0]["ModelPackageGroupName"] == group1
assert summary[0]["ModelPackageGroupDescription"] == desc1
assert summary[1]["ModelPackageGroupName"] == group2
assert summary[1]["ModelPackageGroupDescription"] == desc2
# Pagination
resp = client.list_model_package_groups(MaxResults=1)
assert len(resp["ModelPackageGroupSummaryList"]) == 1
resp = client.list_model_package_groups(MaxResults=1, NextToken=resp["NextToken"])
assert len(resp["ModelPackageGroupSummaryList"]) == 1
assert "NextToken" not in resp
@mock_sagemaker
def test_list_model_package_groups_creation_time_before():

View File

@ -133,29 +133,40 @@ def test_list_model_packages_approval_status():
@mock_sagemaker
def test_list_model_packages_model_package_group_name():
client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group"
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package-without-group",
ModelPackageDescription="test-model-package-description-without-group",
ModelPackageDescription="diff_group",
)
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
resp = client.list_model_packages(ModelPackageGroupName=group1)
assert len(resp["ModelPackageSummaryList"]) == 3
# Pagination
resp = client.list_model_packages(ModelPackageGroupName=group1, MaxResults=2)
assert len(resp["ModelPackageSummaryList"]) == 2
resp = client.list_model_packages(
ModelPackageGroupName=group1, MaxResults=2, NextToken=resp["NextToken"]
)
assert len(resp["ModelPackageSummaryList"]) == 1
assert "NextToken" not in resp
@mock_sagemaker
def test_list_model_packages_model_package_type():