diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 34e5703b0..c9302c490 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6380,7 +6380,7 @@ - [ ] describe_model_card_export_job - [ ] describe_model_explainability_job_definition - [X] describe_model_package -- [ ] describe_model_package_group +- [X] describe_model_package_group - [ ] describe_model_quality_job_definition - [ ] describe_monitoring_schedule - [ ] describe_notebook_instance diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 8a36e33ee..208c0f09c 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -3260,6 +3260,16 @@ class SageMakerModelBackend(BaseBackend): ) return model_package_group_summary_list + def describe_model_package_group( + self, model_package_group_name: str + ) -> ModelPackageGroup: + model_package_group = self.model_package_groups.get(model_package_group_name) + if model_package_group is None: + raise ValidationError( + f"Model package group {model_package_group_name} not found" + ) + return model_package_group + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] def list_model_packages( # type: ignore[misc] self, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index acddec2c8..f490296d7 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -874,6 +874,15 @@ class SageMakerResponse(BaseResponse): model_package.gen_response_object(), ) + def describe_model_package_group(self) -> str: + model_package_group_name = self._get_param("ModelPackageGroupName") + model_package_group = self.sagemaker_backend.describe_model_package_group( + model_package_group_name=model_package_group_name, + ) + return json.dumps( + model_package_group.gen_response_object(), + ) + def update_model_package(self) -> str: model_package_arn = self._get_param("ModelPackageArn") model_approval_status = self._get_param("ModelApprovalStatus") diff --git a/tests/test_sagemaker/test_sagemaker_model_package_groups.py b/tests/test_sagemaker/test_sagemaker_model_package_groups.py index 49a1d99fb..147975064 100644 --- a/tests/test_sagemaker/test_sagemaker_model_package_groups.py +++ b/tests/test_sagemaker/test_sagemaker_model_package_groups.py @@ -1,8 +1,10 @@ """Unit tests for sagemaker-supported APIs.""" from unittest import SkipTest +from datetime import datetime import boto3 from freezegun import freeze_time +from dateutil.tz import tzutc # type: ignore from moto import mock_sagemaker, settings @@ -163,3 +165,28 @@ def test_list_model_package_groups_sort_order(): resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] == "test-model-package-group-1" ) + + +@mock_sagemaker +def test_describe_model_package_group(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + with freeze_time("2020-01-01 00:00:00"): + client.create_model_package_group( + ModelPackageGroupName="test-model-package-group", + ModelPackageGroupDescription="test-model-package-group-description", + ) + resp = client.describe_model_package_group( + ModelPackageGroupName="test-model-package-group" + ) + assert resp["ModelPackageGroupName"] == "test-model-package-group" + assert ( + resp["ModelPackageGroupDescription"] == "test-model-package-group-description" + ) + assert ( + resp["ModelPackageGroupArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package-group/test-model-package-group" + ) + assert resp["ModelPackageGroupStatus"] == "Completed" + assert resp["CreationTime"] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzutc())