Sagemaker: Add describe_model_package_group (#6899)

This commit is contained in:
Michael French 2023-10-11 05:34:33 -05:00 committed by GitHub
parent f39bb45f2c
commit d4b8e07be8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 1 deletions

View File

@ -6380,7 +6380,7 @@
- [ ] describe_model_card_export_job - [ ] describe_model_card_export_job
- [ ] describe_model_explainability_job_definition - [ ] describe_model_explainability_job_definition
- [X] describe_model_package - [X] describe_model_package
- [ ] describe_model_package_group - [X] describe_model_package_group
- [ ] describe_model_quality_job_definition - [ ] describe_model_quality_job_definition
- [ ] describe_monitoring_schedule - [ ] describe_monitoring_schedule
- [ ] describe_notebook_instance - [ ] describe_notebook_instance

View File

@ -3260,6 +3260,16 @@ class SageMakerModelBackend(BaseBackend):
) )
return model_package_group_summary_list return model_package_group_summary_list
def describe_model_package_group(
self, model_package_group_name: str
) -> ModelPackageGroup:
model_package_group = self.model_package_groups.get(model_package_group_name)
if model_package_group is None:
raise ValidationError(
f"Model package group {model_package_group_name} not found"
)
return model_package_group
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_model_packages( # type: ignore[misc] def list_model_packages( # type: ignore[misc]
self, self,

View File

@ -874,6 +874,15 @@ class SageMakerResponse(BaseResponse):
model_package.gen_response_object(), model_package.gen_response_object(),
) )
def describe_model_package_group(self) -> str:
model_package_group_name = self._get_param("ModelPackageGroupName")
model_package_group = self.sagemaker_backend.describe_model_package_group(
model_package_group_name=model_package_group_name,
)
return json.dumps(
model_package_group.gen_response_object(),
)
def update_model_package(self) -> str: def update_model_package(self) -> str:
model_package_arn = self._get_param("ModelPackageArn") model_package_arn = self._get_param("ModelPackageArn")
model_approval_status = self._get_param("ModelApprovalStatus") model_approval_status = self._get_param("ModelApprovalStatus")

View File

@ -1,8 +1,10 @@
"""Unit tests for sagemaker-supported APIs.""" """Unit tests for sagemaker-supported APIs."""
from unittest import SkipTest from unittest import SkipTest
from datetime import datetime
import boto3 import boto3
from freezegun import freeze_time from freezegun import freeze_time
from dateutil.tz import tzutc # type: ignore
from moto import mock_sagemaker, settings from moto import mock_sagemaker, settings
@ -163,3 +165,28 @@ def test_list_model_package_groups_sort_order():
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"] resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-1" == "test-model-package-group-1"
) )
@mock_sagemaker
def test_describe_model_package_group():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
with freeze_time("2020-01-01 00:00:00"):
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupDescription="test-model-package-group-description",
)
resp = client.describe_model_package_group(
ModelPackageGroupName="test-model-package-group"
)
assert resp["ModelPackageGroupName"] == "test-model-package-group"
assert (
resp["ModelPackageGroupDescription"] == "test-model-package-group-description"
)
assert (
resp["ModelPackageGroupArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package-group/test-model-package-group"
)
assert resp["ModelPackageGroupStatus"] == "Completed"
assert resp["CreationTime"] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzutc())