From c1bbae3604e0d09990fe7a9b4178e7733b828672 Mon Sep 17 00:00:00 2001 From: HALLOUARD <57447861+YHallouard@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:43:27 +0200 Subject: [PATCH] feat: Add sagemaker list_model_package_groups and fix versionned model packages (#6847) --- IMPLEMENTATION_COVERAGE.md | 2 +- moto/sagemaker/models.py | 74 +++++++- moto/sagemaker/responses.py | 30 ++++ .../test_sagemaker_model_package_groups.py | 165 ++++++++++++++++++ .../test_sagemaker_model_packages.py | 51 ++++-- 5 files changed, 304 insertions(+), 18 deletions(-) create mode 100644 tests/test_sagemaker/test_sagemaker_model_package_groups.py diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d6801d036..b9475906f 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6413,7 +6413,7 @@ - [ ] list_model_cards - [ ] list_model_explainability_job_definitions - [ ] list_model_metadata -- [ ] list_model_package_groups +- [X] list_model_package_groups - [X] list_model_packages - [ ] list_model_quality_job_definitions - [X] list_models diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 7d1419587..c1e94056e 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -48,6 +48,12 @@ PAGINATION_MODEL = { "limit_default": 50, "unique_attribute": "Key", }, + "list_model_package_groups": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "ModelPackageGroupArn", + }, "list_model_packages": { "input_token": "next_token", "limit_key": "max_results", @@ -943,10 +949,11 @@ class ModelPackageGroup(BaseObject): account_id=account_id, region_name=region_name, ) + datetime_now = datetime.now(tzutc()) self.model_package_group_name = model_package_group_name self.model_package_group_arn = model_package_group_arn self.model_package_group_description = model_package_group_description - self.creation_time = datetime.now() + self.creation_time = datetime_now self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, @@ -955,6 +962,20 @@ class ModelPackageGroup(BaseObject): self.model_package_group_status = "Completed" self.tags = tags + def gen_response_object(self) -> Dict[str, Any]: + response_object = super().gen_response_object() + for k, v in response_object.items(): + if isinstance(v, datetime): + response_object[k] = v.isoformat() + response_values = [ + "ModelPackageGroupName", + "ModelPackageGroupArn", + "ModelPackageGroupDescription", + "CreationTime", + "ModelPackageGroupStatus", + ] + return {k: v for k, v in response_object.items() if k in response_values} + class ModelPackage(BaseObject): def __init__( @@ -994,7 +1015,9 @@ class ModelPackage(BaseObject): region_name=region_name, account_id=account_id, _type="model-package", - _id=model_package_name, + _id=f"{model_package_name}/{model_package_version}" + if model_package_version + else model_package_name, ) datetime_now = datetime.now(tzutc()) self.model_package_name = model_package_name @@ -2912,6 +2935,53 @@ class SageMakerModelBackend(BaseBackend): return True raise ValueError(f"Invalid model package type: {model_package_type}") + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] + def list_model_package_groups( # type: ignore[misc] + self, + creation_time_after: Optional[int], + creation_time_before: Optional[int], + name_contains: Optional[str], + sort_by: Optional[str], + sort_order: Optional[str], + ) -> List[ModelPackageGroup]: + if isinstance(creation_time_before, int): + creation_time_before_datetime = datetime.fromtimestamp( + creation_time_before, tz=tzutc() + ) + if isinstance(creation_time_after, int): + creation_time_after_datetime = datetime.fromtimestamp( + creation_time_after, tz=tzutc() + ) + model_package_group_summary_list = list( + filter( + lambda x: ( + creation_time_after is None + or x.creation_time > creation_time_after_datetime + ) + and ( + creation_time_before is None + or x.creation_time < creation_time_before_datetime + ) + and ( + name_contains is None + or x.model_package_group_name.find(name_contains) != -1 + ), + self.model_package_groups.values(), + ) + ) + model_package_group_summary_list = list( + sorted( + model_package_group_summary_list, + key={ + "Name": lambda x: x.model_package_group_name, + "CreationTime": lambda x: x.creation_time, + None: lambda x: x.creation_time, + }[sort_by], + reverse=sort_order == "Descending", + ) + ) + return model_package_group_summary_list + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] def list_model_packages( # type: ignore[misc] self, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index b40f8b537..ec83c10e2 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -799,6 +799,36 @@ class SageMakerResponse(BaseResponse): ) return 200, {}, json.dumps({"EndpointArn": endpoint_arn}) + def list_model_package_groups(self) -> str: + creation_time_after = self._get_param("CreationTimeAfter") + creation_time_before = self._get_param("CreationTimeBefore") + max_results = self._get_param("MaxResults") + name_contains = self._get_param("NameContains") + next_token = self._get_param("NextToken") + sort_by = self._get_param("SortBy") + sort_order = self._get_param("SortOrder") + ( + model_package_group_summary_list, + next_token, + ) = self.sagemaker_backend.list_model_package_groups( + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + max_results=max_results, + name_contains=name_contains, + next_token=next_token, + sort_by=sort_by, + sort_order=sort_order, + ) + model_package_group_summary_list_response_object = [ + x.gen_response_object() for x in model_package_group_summary_list + ] + return json.dumps( + dict( + ModelPackageGroupSummaryList=model_package_group_summary_list_response_object, + NextToken=next_token, + ) + ) + def list_model_packages(self) -> str: creation_time_after = self._get_param("CreationTimeAfter") creation_time_before = self._get_param("CreationTimeBefore") diff --git a/tests/test_sagemaker/test_sagemaker_model_package_groups.py b/tests/test_sagemaker/test_sagemaker_model_package_groups.py new file mode 100644 index 000000000..49a1d99fb --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_model_package_groups.py @@ -0,0 +1,165 @@ +"""Unit tests for sagemaker-supported APIs.""" +from unittest import SkipTest + +import boto3 +from freezegun import freeze_time + +from moto import mock_sagemaker, settings + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_sagemaker +def test_create_model_package_group(): + client = boto3.client("sagemaker", region_name="us-east-2") + resp = client.create_model_package_group( + ModelPackageGroupName="test-model-package-group", + ModelPackageGroupDescription="test-model-package-group-description", + Tags=[ + {"Key": "test-key", "Value": "test-value"}, + ], + ) + assert ( + resp["ModelPackageGroupArn"] + == "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group" + ) + + +@mock_sagemaker +def test_list_model_package_groups(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + resp = client.list_model_package_groups() + + assert ( + resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"] + == "test-model-package-group-1" + ) + assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0] + assert ( + resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"] + == "test-model-package-group-description-1" + ) + assert ( + resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] + == "test-model-package-group-2" + ) + assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1] + assert ( + resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"] + == "test-model-package-group-description-2" + ) + + +@mock_sagemaker +def test_list_model_package_groups_creation_time_before(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + with freeze_time("2020-01-01 00:00:00"): + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + with freeze_time("2021-01-01 00:00:00"): + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + resp = client.list_model_package_groups(CreationTimeBefore="2020-01-01T02:00:00Z") + + assert len(resp["ModelPackageGroupSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_package_groups_creation_time_after(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + with freeze_time("2020-01-01 00:00:00"): + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + with freeze_time("2021-01-01 00:00:00"): + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + resp = client.list_model_package_groups(CreationTimeAfter="2020-01-02T00:00:00Z") + + assert len(resp["ModelPackageGroupSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_package_groups_name_contains(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + client.create_model_package_group( + ModelPackageGroupName="another-model-package-group", + ModelPackageGroupDescription="another-model-package-group-description", + ) + resp = client.list_model_package_groups(NameContains="test-model-package") + + assert len(resp["ModelPackageGroupSummaryList"]) == 2 + + +@mock_sagemaker +def test_list_model_package_groups_sort_by(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + resp = client.list_model_package_groups(SortBy="CreationTime") + + assert ( + resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"] + == "test-model-package-group-1" + ) + assert ( + resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] + == "test-model-package-group-2" + ) + + +@mock_sagemaker +def test_list_model_package_groups_sort_order(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-1", + ModelPackageGroupDescription="test-model-package-group-description-1", + ) + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group-2", + ModelPackageGroupDescription="test-model-package-group-description-2", + ) + resp = client.list_model_package_groups(SortOrder="Descending") + + assert ( + resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"] + == "test-model-package-group-2" + ) + assert ( + resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] + == "test-model-package-group-1" + ) diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py index 1b9a8d752..c510c2c9c 100644 --- a/tests/test_sagemaker/test_sagemaker_model_packages.py +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -15,11 +15,15 @@ def test_list_model_packages(): client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package( ModelPackageName="test-model-package", - ModelPackageDescription="test-model-package-description", + ModelPackageDescription="test-model-package-description-v1", + ) + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description-v2", ) client.create_model_package( ModelPackageName="test-model-package-2", - ModelPackageDescription="test-model-package-description-2", + ModelPackageDescription="test-model-package-description-v1-2", ) resp = client.list_model_packages() @@ -29,7 +33,7 @@ def test_list_model_packages(): assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0] assert ( resp["ModelPackageSummaryList"][0]["ModelPackageDescription"] - == "test-model-package-description" + == "test-model-package-description-v2" ) assert ( resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2" @@ -37,7 +41,7 @@ def test_list_model_packages(): assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1] assert ( resp["ModelPackageSummaryList"][1]["ModelPackageDescription"] - == "test-model-package-description-2" + == "test-model-package-description-v1-2" ) @@ -128,13 +132,22 @@ def test_list_model_packages_model_package_group_name(): ModelPackageGroupName="test-model-package-group", ) client.create_model_package( - ModelPackageName="test-model-package-2", + ModelPackageName="test-model-package", ModelPackageDescription="test-model-package-description-2", ModelPackageGroupName="test-model-package-group", ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-3", + ModelPackageGroupName="test-model-package-group", + ) + client.create_model_package( + ModelPackageName="test-model-package-without-group", + ModelPackageDescription="test-model-package-description-without-group", + ) resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group") - assert len(resp["ModelPackageSummaryList"]) == 2 + assert len(resp["ModelPackageSummaryList"]) == 3 @mock_sagemaker @@ -222,16 +235,24 @@ def test_create_model_package(): @mock_sagemaker -def test_create_model_package_group(): - client = boto3.client("sagemaker", region_name="us-east-2") - resp = client.create_model_package_group( +def test_create_model_package_in_model_package_group(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + resp_version_1 = client.create_model_package( + ModelPackageName="test-model-package", ModelPackageGroupName="test-model-package-group", - ModelPackageGroupDescription="test-model-package-group-description", - Tags=[ - {"Key": "test-key", "Value": "test-value"}, - ], + ModelPackageDescription="test-model-package-description", + ) + resp_version_2 = client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + ModelPackageDescription="test-model-package-description", ) assert ( - resp["ModelPackageGroupArn"] - == "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group" + resp_version_1["ModelPackageArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1" + ) + assert ( + resp_version_2["ModelPackageArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/2" )