Sagemaker: Fix pagination for ModelPackages(Groups) (#6972)

This commit is contained in:
Bert Blommers 2023-10-31 07:56:42 -01:00 committed by GitHub
parent aa3770086b
commit d390aa673c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 31 deletions

View File

@ -52,13 +52,13 @@ PAGINATION_MODEL = {
"input_token": "next_token", "input_token": "next_token",
"limit_key": "max_results", "limit_key": "max_results",
"limit_default": 100, "limit_default": 100,
"unique_attribute": "ModelPackageGroupArn", "unique_attribute": "model_package_group_arn",
}, },
"list_model_packages": { "list_model_packages": {
"input_token": "next_token", "input_token": "next_token",
"limit_key": "max_results", "limit_key": "max_results",
"limit_default": 100, "limit_default": 100,
"unique_attribute": "ModelPackageArn", "unique_attribute": "model_package_arn",
}, },
"list_notebook_instances": { "list_notebook_instances": {
"input_token": "next_token", "input_token": "next_token",

View File

@ -31,35 +31,35 @@ def test_create_model_package_group():
@mock_sagemaker @mock_sagemaker
def test_list_model_package_groups(): def test_list_model_package_groups():
client = boto3.client("sagemaker", region_name="eu-west-1") client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group-1"
desc1 = "test-model-package-group-description-1"
client.create_model_package_group( client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1", ModelPackageGroupName=group1, ModelPackageGroupDescription=desc1
ModelPackageGroupDescription="test-model-package-group-description-1",
) )
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups()
assert ( group2 = "test-model-package-group-2"
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"] desc2 = "test-model-package-group-description-2"
== "test-model-package-group-1" client.create_model_package_group(
) ModelPackageGroupName=group2,
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0] ModelPackageGroupDescription=desc2,
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"]
== "test-model-package-group-description-1"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1]
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"]
== "test-model-package-group-description-2"
) )
summary = client.list_model_package_groups()["ModelPackageGroupSummaryList"]
assert summary[0]["ModelPackageGroupName"] == group1
assert summary[0]["ModelPackageGroupDescription"] == desc1
assert summary[1]["ModelPackageGroupName"] == group2
assert summary[1]["ModelPackageGroupDescription"] == desc2
# Pagination
resp = client.list_model_package_groups(MaxResults=1)
assert len(resp["ModelPackageGroupSummaryList"]) == 1
resp = client.list_model_package_groups(MaxResults=1, NextToken=resp["NextToken"])
assert len(resp["ModelPackageGroupSummaryList"]) == 1
assert "NextToken" not in resp
@mock_sagemaker @mock_sagemaker
def test_list_model_package_groups_creation_time_before(): def test_list_model_package_groups_creation_time_before():

View File

@ -133,29 +133,40 @@ def test_list_model_packages_approval_status():
@mock_sagemaker @mock_sagemaker
def test_list_model_packages_model_package_group_name(): def test_list_model_packages_model_package_group_name():
client = boto3.client("sagemaker", region_name="eu-west-1") client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group"
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package", ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description", ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName=group1,
) )
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package", ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2", ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName=group1,
) )
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package-2", ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3", ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName=group1,
) )
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package-without-group", ModelPackageName="test-model-package-without-group",
ModelPackageDescription="test-model-package-description-without-group", ModelPackageDescription="diff_group",
) )
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group") resp = client.list_model_packages(ModelPackageGroupName=group1)
assert len(resp["ModelPackageSummaryList"]) == 3 assert len(resp["ModelPackageSummaryList"]) == 3
# Pagination
resp = client.list_model_packages(ModelPackageGroupName=group1, MaxResults=2)
assert len(resp["ModelPackageSummaryList"]) == 2
resp = client.list_model_packages(
ModelPackageGroupName=group1, MaxResults=2, NextToken=resp["NextToken"]
)
assert len(resp["ModelPackageSummaryList"]) == 1
assert "NextToken" not in resp
@mock_sagemaker @mock_sagemaker
def test_list_model_packages_model_package_type(): def test_list_model_packages_model_package_type():