Sagemaker: Add describe_model_package_group (#6899)
This commit is contained in:
parent
f39bb45f2c
commit
d4b8e07be8
@ -6380,7 +6380,7 @@
|
||||
- [ ] describe_model_card_export_job
|
||||
- [ ] describe_model_explainability_job_definition
|
||||
- [X] describe_model_package
|
||||
- [ ] describe_model_package_group
|
||||
- [X] describe_model_package_group
|
||||
- [ ] describe_model_quality_job_definition
|
||||
- [ ] describe_monitoring_schedule
|
||||
- [ ] describe_notebook_instance
|
||||
|
@ -3260,6 +3260,16 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
return model_package_group_summary_list
|
||||
|
||||
def describe_model_package_group(
|
||||
self, model_package_group_name: str
|
||||
) -> ModelPackageGroup:
|
||||
model_package_group = self.model_package_groups.get(model_package_group_name)
|
||||
if model_package_group is None:
|
||||
raise ValidationError(
|
||||
f"Model package group {model_package_group_name} not found"
|
||||
)
|
||||
return model_package_group
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
|
||||
def list_model_packages( # type: ignore[misc]
|
||||
self,
|
||||
|
@ -874,6 +874,15 @@ class SageMakerResponse(BaseResponse):
|
||||
model_package.gen_response_object(),
|
||||
)
|
||||
|
||||
def describe_model_package_group(self) -> str:
|
||||
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||
model_package_group = self.sagemaker_backend.describe_model_package_group(
|
||||
model_package_group_name=model_package_group_name,
|
||||
)
|
||||
return json.dumps(
|
||||
model_package_group.gen_response_object(),
|
||||
)
|
||||
|
||||
def update_model_package(self) -> str:
|
||||
model_package_arn = self._get_param("ModelPackageArn")
|
||||
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""Unit tests for sagemaker-supported APIs."""
|
||||
from unittest import SkipTest
|
||||
from datetime import datetime
|
||||
|
||||
import boto3
|
||||
from freezegun import freeze_time
|
||||
from dateutil.tz import tzutc # type: ignore
|
||||
|
||||
from moto import mock_sagemaker, settings
|
||||
|
||||
@ -163,3 +165,28 @@ def test_list_model_package_groups_sort_order():
|
||||
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
|
||||
== "test-model-package-group-1"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model_package_group():
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest("Can't freeze time in ServerMode")
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
with freeze_time("2020-01-01 00:00:00"):
|
||||
client.create_model_package_group(
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageGroupDescription="test-model-package-group-description",
|
||||
)
|
||||
resp = client.describe_model_package_group(
|
||||
ModelPackageGroupName="test-model-package-group"
|
||||
)
|
||||
assert resp["ModelPackageGroupName"] == "test-model-package-group"
|
||||
assert (
|
||||
resp["ModelPackageGroupDescription"] == "test-model-package-group-description"
|
||||
)
|
||||
assert (
|
||||
resp["ModelPackageGroupArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package-group/test-model-package-group"
|
||||
)
|
||||
assert resp["ModelPackageGroupStatus"] == "Completed"
|
||||
assert resp["CreationTime"] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=tzutc())
|
||||
|
Loading…
Reference in New Issue
Block a user