feat: Add sagemaker list_model_package_groups and fix versionned model packages (#6847)
This commit is contained in:
parent
0cbde05795
commit
c1bbae3604
@ -6413,7 +6413,7 @@
|
|||||||
- [ ] list_model_cards
|
- [ ] list_model_cards
|
||||||
- [ ] list_model_explainability_job_definitions
|
- [ ] list_model_explainability_job_definitions
|
||||||
- [ ] list_model_metadata
|
- [ ] list_model_metadata
|
||||||
- [ ] list_model_package_groups
|
- [X] list_model_package_groups
|
||||||
- [X] list_model_packages
|
- [X] list_model_packages
|
||||||
- [ ] list_model_quality_job_definitions
|
- [ ] list_model_quality_job_definitions
|
||||||
- [X] list_models
|
- [X] list_models
|
||||||
|
@ -48,6 +48,12 @@ PAGINATION_MODEL = {
|
|||||||
"limit_default": 50,
|
"limit_default": 50,
|
||||||
"unique_attribute": "Key",
|
"unique_attribute": "Key",
|
||||||
},
|
},
|
||||||
|
"list_model_package_groups": {
|
||||||
|
"input_token": "next_token",
|
||||||
|
"limit_key": "max_results",
|
||||||
|
"limit_default": 100,
|
||||||
|
"unique_attribute": "ModelPackageGroupArn",
|
||||||
|
},
|
||||||
"list_model_packages": {
|
"list_model_packages": {
|
||||||
"input_token": "next_token",
|
"input_token": "next_token",
|
||||||
"limit_key": "max_results",
|
"limit_key": "max_results",
|
||||||
@ -943,10 +949,11 @@ class ModelPackageGroup(BaseObject):
|
|||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
)
|
)
|
||||||
|
datetime_now = datetime.now(tzutc())
|
||||||
self.model_package_group_name = model_package_group_name
|
self.model_package_group_name = model_package_group_name
|
||||||
self.model_package_group_arn = model_package_group_arn
|
self.model_package_group_arn = model_package_group_arn
|
||||||
self.model_package_group_description = model_package_group_description
|
self.model_package_group_description = model_package_group_description
|
||||||
self.creation_time = datetime.now()
|
self.creation_time = datetime_now
|
||||||
self.created_by = {
|
self.created_by = {
|
||||||
"UserProfileArn": fake_user_profile_arn,
|
"UserProfileArn": fake_user_profile_arn,
|
||||||
"UserProfileName": fake_user_profile_name,
|
"UserProfileName": fake_user_profile_name,
|
||||||
@ -955,6 +962,20 @@ class ModelPackageGroup(BaseObject):
|
|||||||
self.model_package_group_status = "Completed"
|
self.model_package_group_status = "Completed"
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
|
|
||||||
|
def gen_response_object(self) -> Dict[str, Any]:
|
||||||
|
response_object = super().gen_response_object()
|
||||||
|
for k, v in response_object.items():
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
response_object[k] = v.isoformat()
|
||||||
|
response_values = [
|
||||||
|
"ModelPackageGroupName",
|
||||||
|
"ModelPackageGroupArn",
|
||||||
|
"ModelPackageGroupDescription",
|
||||||
|
"CreationTime",
|
||||||
|
"ModelPackageGroupStatus",
|
||||||
|
]
|
||||||
|
return {k: v for k, v in response_object.items() if k in response_values}
|
||||||
|
|
||||||
|
|
||||||
class ModelPackage(BaseObject):
|
class ModelPackage(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -994,7 +1015,9 @@ class ModelPackage(BaseObject):
|
|||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
_type="model-package",
|
_type="model-package",
|
||||||
_id=model_package_name,
|
_id=f"{model_package_name}/{model_package_version}"
|
||||||
|
if model_package_version
|
||||||
|
else model_package_name,
|
||||||
)
|
)
|
||||||
datetime_now = datetime.now(tzutc())
|
datetime_now = datetime.now(tzutc())
|
||||||
self.model_package_name = model_package_name
|
self.model_package_name = model_package_name
|
||||||
@ -2912,6 +2935,53 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
return True
|
return True
|
||||||
raise ValueError(f"Invalid model package type: {model_package_type}")
|
raise ValueError(f"Invalid model package type: {model_package_type}")
|
||||||
|
|
||||||
|
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
|
||||||
|
def list_model_package_groups( # type: ignore[misc]
|
||||||
|
self,
|
||||||
|
creation_time_after: Optional[int],
|
||||||
|
creation_time_before: Optional[int],
|
||||||
|
name_contains: Optional[str],
|
||||||
|
sort_by: Optional[str],
|
||||||
|
sort_order: Optional[str],
|
||||||
|
) -> List[ModelPackageGroup]:
|
||||||
|
if isinstance(creation_time_before, int):
|
||||||
|
creation_time_before_datetime = datetime.fromtimestamp(
|
||||||
|
creation_time_before, tz=tzutc()
|
||||||
|
)
|
||||||
|
if isinstance(creation_time_after, int):
|
||||||
|
creation_time_after_datetime = datetime.fromtimestamp(
|
||||||
|
creation_time_after, tz=tzutc()
|
||||||
|
)
|
||||||
|
model_package_group_summary_list = list(
|
||||||
|
filter(
|
||||||
|
lambda x: (
|
||||||
|
creation_time_after is None
|
||||||
|
or x.creation_time > creation_time_after_datetime
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
creation_time_before is None
|
||||||
|
or x.creation_time < creation_time_before_datetime
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
name_contains is None
|
||||||
|
or x.model_package_group_name.find(name_contains) != -1
|
||||||
|
),
|
||||||
|
self.model_package_groups.values(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_package_group_summary_list = list(
|
||||||
|
sorted(
|
||||||
|
model_package_group_summary_list,
|
||||||
|
key={
|
||||||
|
"Name": lambda x: x.model_package_group_name,
|
||||||
|
"CreationTime": lambda x: x.creation_time,
|
||||||
|
None: lambda x: x.creation_time,
|
||||||
|
}[sort_by],
|
||||||
|
reverse=sort_order == "Descending",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model_package_group_summary_list
|
||||||
|
|
||||||
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
|
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
|
||||||
def list_model_packages( # type: ignore[misc]
|
def list_model_packages( # type: ignore[misc]
|
||||||
self,
|
self,
|
||||||
|
@ -799,6 +799,36 @@ class SageMakerResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
|
return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
|
||||||
|
|
||||||
|
def list_model_package_groups(self) -> str:
|
||||||
|
creation_time_after = self._get_param("CreationTimeAfter")
|
||||||
|
creation_time_before = self._get_param("CreationTimeBefore")
|
||||||
|
max_results = self._get_param("MaxResults")
|
||||||
|
name_contains = self._get_param("NameContains")
|
||||||
|
next_token = self._get_param("NextToken")
|
||||||
|
sort_by = self._get_param("SortBy")
|
||||||
|
sort_order = self._get_param("SortOrder")
|
||||||
|
(
|
||||||
|
model_package_group_summary_list,
|
||||||
|
next_token,
|
||||||
|
) = self.sagemaker_backend.list_model_package_groups(
|
||||||
|
creation_time_after=creation_time_after,
|
||||||
|
creation_time_before=creation_time_before,
|
||||||
|
max_results=max_results,
|
||||||
|
name_contains=name_contains,
|
||||||
|
next_token=next_token,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
model_package_group_summary_list_response_object = [
|
||||||
|
x.gen_response_object() for x in model_package_group_summary_list
|
||||||
|
]
|
||||||
|
return json.dumps(
|
||||||
|
dict(
|
||||||
|
ModelPackageGroupSummaryList=model_package_group_summary_list_response_object,
|
||||||
|
NextToken=next_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def list_model_packages(self) -> str:
|
def list_model_packages(self) -> str:
|
||||||
creation_time_after = self._get_param("CreationTimeAfter")
|
creation_time_after = self._get_param("CreationTimeAfter")
|
||||||
creation_time_before = self._get_param("CreationTimeBefore")
|
creation_time_before = self._get_param("CreationTimeBefore")
|
||||||
|
165
tests/test_sagemaker/test_sagemaker_model_package_groups.py
Normal file
165
tests/test_sagemaker/test_sagemaker_model_package_groups.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
"""Unit tests for sagemaker-supported APIs."""
|
||||||
|
from unittest import SkipTest
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from moto import mock_sagemaker, settings
|
||||||
|
|
||||||
|
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||||
|
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_model_package_group():
|
||||||
|
client = boto3.client("sagemaker", region_name="us-east-2")
|
||||||
|
resp = client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description",
|
||||||
|
Tags=[
|
||||||
|
{"Key": "test-key", "Value": "test-value"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupArn"]
|
||||||
|
== "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-1"
|
||||||
|
)
|
||||||
|
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0]
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"]
|
||||||
|
== "test-model-package-group-description-1"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-2"
|
||||||
|
)
|
||||||
|
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1]
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"]
|
||||||
|
== "test-model-package-group-description-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups_creation_time_before():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
with freeze_time("2020-01-01 00:00:00"):
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
with freeze_time("2021-01-01 00:00:00"):
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups(CreationTimeBefore="2020-01-01T02:00:00Z")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageGroupSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups_creation_time_after():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
with freeze_time("2020-01-01 00:00:00"):
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
with freeze_time("2021-01-01 00:00:00"):
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups(CreationTimeAfter="2020-01-02T00:00:00Z")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageGroupSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups_name_contains():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="another-model-package-group",
|
||||||
|
ModelPackageGroupDescription="another-model-package-group-description",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups(NameContains="test-model-package")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageGroupSummaryList"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups_sort_by():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups(SortBy="CreationTime")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-1"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_package_groups_sort_order():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-1",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-1",
|
||||||
|
)
|
||||||
|
client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group-2",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_package_groups(SortOrder="Descending")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-2"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
|
||||||
|
== "test-model-package-group-1"
|
||||||
|
)
|
@ -15,11 +15,15 @@ def test_list_model_packages():
|
|||||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
client.create_model_package(
|
client.create_model_package(
|
||||||
ModelPackageName="test-model-package",
|
ModelPackageName="test-model-package",
|
||||||
ModelPackageDescription="test-model-package-description",
|
ModelPackageDescription="test-model-package-description-v1",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description-v2",
|
||||||
)
|
)
|
||||||
client.create_model_package(
|
client.create_model_package(
|
||||||
ModelPackageName="test-model-package-2",
|
ModelPackageName="test-model-package-2",
|
||||||
ModelPackageDescription="test-model-package-description-2",
|
ModelPackageDescription="test-model-package-description-v1-2",
|
||||||
)
|
)
|
||||||
resp = client.list_model_packages()
|
resp = client.list_model_packages()
|
||||||
|
|
||||||
@ -29,7 +33,7 @@ def test_list_model_packages():
|
|||||||
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0]
|
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0]
|
||||||
assert (
|
assert (
|
||||||
resp["ModelPackageSummaryList"][0]["ModelPackageDescription"]
|
resp["ModelPackageSummaryList"][0]["ModelPackageDescription"]
|
||||||
== "test-model-package-description"
|
== "test-model-package-description-v2"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
|
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
|
||||||
@ -37,7 +41,7 @@ def test_list_model_packages():
|
|||||||
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1]
|
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1]
|
||||||
assert (
|
assert (
|
||||||
resp["ModelPackageSummaryList"][1]["ModelPackageDescription"]
|
resp["ModelPackageSummaryList"][1]["ModelPackageDescription"]
|
||||||
== "test-model-package-description-2"
|
== "test-model-package-description-v1-2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -128,13 +132,22 @@ def test_list_model_packages_model_package_group_name():
|
|||||||
ModelPackageGroupName="test-model-package-group",
|
ModelPackageGroupName="test-model-package-group",
|
||||||
)
|
)
|
||||||
client.create_model_package(
|
client.create_model_package(
|
||||||
ModelPackageName="test-model-package-2",
|
ModelPackageName="test-model-package",
|
||||||
ModelPackageDescription="test-model-package-description-2",
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
ModelPackageGroupName="test-model-package-group",
|
ModelPackageGroupName="test-model-package-group",
|
||||||
)
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-3",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-without-group",
|
||||||
|
ModelPackageDescription="test-model-package-description-without-group",
|
||||||
|
)
|
||||||
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
|
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
|
||||||
|
|
||||||
assert len(resp["ModelPackageSummaryList"]) == 2
|
assert len(resp["ModelPackageSummaryList"]) == 3
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
@ -222,16 +235,24 @@ def test_create_model_package():
|
|||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
def test_create_model_package_group():
|
def test_create_model_package_in_model_package_group():
|
||||||
client = boto3.client("sagemaker", region_name="us-east-2")
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
resp = client.create_model_package_group(
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
|
resp_version_1 = client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
ModelPackageGroupName="test-model-package-group",
|
ModelPackageGroupName="test-model-package-group",
|
||||||
ModelPackageGroupDescription="test-model-package-group-description",
|
ModelPackageDescription="test-model-package-description",
|
||||||
Tags=[
|
)
|
||||||
{"Key": "test-key", "Value": "test-value"},
|
resp_version_2 = client.create_model_package(
|
||||||
],
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
resp["ModelPackageGroupArn"]
|
resp_version_1["ModelPackageArn"]
|
||||||
== "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group"
|
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp_version_2["ModelPackageArn"]
|
||||||
|
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/2"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user