From 3ae1b62590b15c632bf33c3d1b3bbd1cbe50b1cc Mon Sep 17 00:00:00 2001 From: Guilherme de Amorim Date: Thu, 4 Jan 2024 16:13:30 -0300 Subject: [PATCH] SageMaker: model-package-group supports list_tags, add_tags and remove_tags (#7183) Co-authored-by: Guilherme --- moto/sagemaker/models.py | 3 +- .../test_sagemaker_model_package_groups.py | 61 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 896dd3b66..06f5896cc 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1896,6 +1896,7 @@ class SageMakerModelBackend(BaseBackend): "experiment-trial-component": self.trial_components, "processing-job": self.processing_jobs, "pipeline": self.pipelines, + "model-package-group": self.model_package_groups, } target_resource, target_name = arn.split(":")[-1].split("/") try: @@ -3205,7 +3206,7 @@ class SageMakerModelBackend(BaseBackend): model_package_group_description=model_package_group_description, account_id=self.account_id, region_name=self.region_name, - tags=tags, + tags=tags or [], ) return self.model_package_groups[ model_package_group_name diff --git a/tests/test_sagemaker/test_sagemaker_model_package_groups.py b/tests/test_sagemaker/test_sagemaker_model_package_groups.py index 9092ed9a5..15ba43a8a 100644 --- a/tests/test_sagemaker/test_sagemaker_model_package_groups.py +++ b/tests/test_sagemaker/test_sagemaker_model_package_groups.py @@ -1,4 +1,5 @@ """Unit tests for sagemaker-supported APIs.""" +import uuid from datetime import datetime from unittest import SkipTest @@ -190,3 +191,63 @@ def test_describe_model_package_group(): ) assert resp["ModelPackageGroupStatus"] == "Completed" assert resp["CreationTime"] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzutc()) + + +@mock_sagemaker +def test_list_tags_model_package_group(): + region_name = "eu-west-1" + model_package_group_name = "test-model-package-group" + client = boto3.client("sagemaker", region_name=region_name) + client.create_model_package_group( + ModelPackageGroupName=model_package_group_name, + ModelPackageGroupDescription="test-model-package-group-description", + ) + + tags = [] + for _ in range(80): + tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) + + resource_arn = ( + f"arn:aws:sagemaker:{region_name}:123456789012" + f":model-package-group/{model_package_group_name}" + ) + _ = client.add_tags(ResourceArn=resource_arn, Tags=tags) + + paginator = client.get_paginator("list_tags") + response_iterator = paginator.paginate(ResourceArn=resource_arn) + tags_from_paginator = [] + for response in response_iterator: + tags_from_paginator.extend(response["Tags"]) + + assert tags_from_paginator == tags + + +@mock_sagemaker +def test_delete_tags_model_package_group(): + region_name = "eu-west-1" + model_package_group_name = "test-model-package-group" + client = boto3.client("sagemaker", region_name=region_name) + client.create_model_package_group( + ModelPackageGroupName=model_package_group_name, + ModelPackageGroupDescription="test-model-package-group-description", + ) + + tags = [] + for _ in range(80): + tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) + + resource_arn = ( + f"arn:aws:sagemaker:{region_name}:123456789012" + f":model-package-group/{model_package_group_name}" + ) + _ = client.add_tags(ResourceArn=resource_arn, Tags=tags) + + delete_tag_keys = [tag["Key"] for tag in tags[:20]] + _ = client.delete_tags(ResourceArn=resource_arn, TagKeys=delete_tag_keys) + + paginator = client.get_paginator("list_tags") + response_iterator = paginator.paginate(ResourceArn=resource_arn) + remaining_tags = [] + for response in response_iterator: + remaining_tags.extend(response["Tags"]) + assert remaining_tags == tags[20:]