From d390aa673c68279bd63f9d725a0f00ecaf8e426d Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Tue, 31 Oct 2023 07:56:42 -0100 Subject: [PATCH] Sagemaker: Fix pagination for ModelPackages(Groups) (#6972) --- moto/sagemaker/models.py | 4 +- .../test_sagemaker_model_package_groups.py | 48 +++++++++---------- .../test_sagemaker_model_packages.py | 21 ++++++-- 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index a61d87f8a..aa855b321 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -52,13 +52,13 @@ PAGINATION_MODEL = { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, - "unique_attribute": "ModelPackageGroupArn", + "unique_attribute": "model_package_group_arn", }, "list_model_packages": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, - "unique_attribute": "ModelPackageArn", + "unique_attribute": "model_package_arn", }, "list_notebook_instances": { "input_token": "next_token", diff --git a/tests/test_sagemaker/test_sagemaker_model_package_groups.py b/tests/test_sagemaker/test_sagemaker_model_package_groups.py index 147975064..966286ed0 100644 --- a/tests/test_sagemaker/test_sagemaker_model_package_groups.py +++ b/tests/test_sagemaker/test_sagemaker_model_package_groups.py @@ -31,35 +31,35 @@ def test_create_model_package_group(): @mock_sagemaker def test_list_model_package_groups(): client = boto3.client("sagemaker", region_name="eu-west-1") + group1 = "test-model-package-group-1" + desc1 = "test-model-package-group-description-1" client.create_model_package_group( - ModelPackageGroupName="test-model-package-group-1", - ModelPackageGroupDescription="test-model-package-group-description-1", + ModelPackageGroupName=group1, ModelPackageGroupDescription=desc1 ) - client.create_model_package_group( - ModelPackageGroupName="test-model-package-group-2", - ModelPackageGroupDescription="test-model-package-group-description-2", - ) - resp = client.list_model_package_groups() - assert ( - resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"] - == "test-model-package-group-1" - ) - assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0] - assert ( - resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"] - == "test-model-package-group-description-1" - ) - assert ( - resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] - == "test-model-package-group-2" - ) - assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1] - assert ( - resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"] - == "test-model-package-group-description-2" + group2 = "test-model-package-group-2" + desc2 = "test-model-package-group-description-2" + client.create_model_package_group( + ModelPackageGroupName=group2, + ModelPackageGroupDescription=desc2, ) + summary = client.list_model_package_groups()["ModelPackageGroupSummaryList"] + + assert summary[0]["ModelPackageGroupName"] == group1 + assert summary[0]["ModelPackageGroupDescription"] == desc1 + + assert summary[1]["ModelPackageGroupName"] == group2 + assert summary[1]["ModelPackageGroupDescription"] == desc2 + + # Pagination + resp = client.list_model_package_groups(MaxResults=1) + assert len(resp["ModelPackageGroupSummaryList"]) == 1 + + resp = client.list_model_package_groups(MaxResults=1, NextToken=resp["NextToken"]) + assert len(resp["ModelPackageGroupSummaryList"]) == 1 + assert "NextToken" not in resp + @mock_sagemaker def test_list_model_package_groups_creation_time_before(): diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py index c54e6e8a2..4d9f326ae 100644 --- a/tests/test_sagemaker/test_sagemaker_model_packages.py +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -133,29 +133,40 @@ def test_list_model_packages_approval_status(): @mock_sagemaker def test_list_model_packages_model_package_group_name(): client = boto3.client("sagemaker", region_name="eu-west-1") + group1 = "test-model-package-group" client.create_model_package( ModelPackageName="test-model-package", ModelPackageDescription="test-model-package-description", - ModelPackageGroupName="test-model-package-group", + ModelPackageGroupName=group1, ) client.create_model_package( ModelPackageName="test-model-package", ModelPackageDescription="test-model-package-description-2", - ModelPackageGroupName="test-model-package-group", + ModelPackageGroupName=group1, ) client.create_model_package( ModelPackageName="test-model-package-2", ModelPackageDescription="test-model-package-description-3", - ModelPackageGroupName="test-model-package-group", + ModelPackageGroupName=group1, ) client.create_model_package( ModelPackageName="test-model-package-without-group", - ModelPackageDescription="test-model-package-description-without-group", + ModelPackageDescription="diff_group", ) - resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group") + resp = client.list_model_packages(ModelPackageGroupName=group1) assert len(resp["ModelPackageSummaryList"]) == 3 + # Pagination + resp = client.list_model_packages(ModelPackageGroupName=group1, MaxResults=2) + assert len(resp["ModelPackageSummaryList"]) == 2 + + resp = client.list_model_packages( + ModelPackageGroupName=group1, MaxResults=2, NextToken=resp["NextToken"] + ) + assert len(resp["ModelPackageSummaryList"]) == 1 + assert "NextToken" not in resp + @mock_sagemaker def test_list_model_packages_model_package_type():