feat: Add sagemaker list_model_package_groups and fix versionned model packages (#6847)

This commit is contained in:
HALLOUARD 2023-09-26 10:43:27 +02:00 committed by GitHub
parent 0cbde05795
commit c1bbae3604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 304 additions and 18 deletions

View File

@ -6413,7 +6413,7 @@
- [ ] list_model_cards - [ ] list_model_cards
- [ ] list_model_explainability_job_definitions - [ ] list_model_explainability_job_definitions
- [ ] list_model_metadata - [ ] list_model_metadata
- [ ] list_model_package_groups - [X] list_model_package_groups
- [X] list_model_packages - [X] list_model_packages
- [ ] list_model_quality_job_definitions - [ ] list_model_quality_job_definitions
- [X] list_models - [X] list_models

View File

@ -48,6 +48,12 @@ PAGINATION_MODEL = {
"limit_default": 50, "limit_default": 50,
"unique_attribute": "Key", "unique_attribute": "Key",
}, },
"list_model_package_groups": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageGroupArn",
},
"list_model_packages": { "list_model_packages": {
"input_token": "next_token", "input_token": "next_token",
"limit_key": "max_results", "limit_key": "max_results",
@ -943,10 +949,11 @@ class ModelPackageGroup(BaseObject):
account_id=account_id, account_id=account_id,
region_name=region_name, region_name=region_name,
) )
datetime_now = datetime.now(tzutc())
self.model_package_group_name = model_package_group_name self.model_package_group_name = model_package_group_name
self.model_package_group_arn = model_package_group_arn self.model_package_group_arn = model_package_group_arn
self.model_package_group_description = model_package_group_description self.model_package_group_description = model_package_group_description
self.creation_time = datetime.now() self.creation_time = datetime_now
self.created_by = { self.created_by = {
"UserProfileArn": fake_user_profile_arn, "UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name, "UserProfileName": fake_user_profile_name,
@ -955,6 +962,20 @@ class ModelPackageGroup(BaseObject):
self.model_package_group_status = "Completed" self.model_package_group_status = "Completed"
self.tags = tags self.tags = tags
def gen_response_object(self) -> Dict[str, Any]:
response_object = super().gen_response_object()
for k, v in response_object.items():
if isinstance(v, datetime):
response_object[k] = v.isoformat()
response_values = [
"ModelPackageGroupName",
"ModelPackageGroupArn",
"ModelPackageGroupDescription",
"CreationTime",
"ModelPackageGroupStatus",
]
return {k: v for k, v in response_object.items() if k in response_values}
class ModelPackage(BaseObject): class ModelPackage(BaseObject):
def __init__( def __init__(
@ -994,7 +1015,9 @@ class ModelPackage(BaseObject):
region_name=region_name, region_name=region_name,
account_id=account_id, account_id=account_id,
_type="model-package", _type="model-package",
_id=model_package_name, _id=f"{model_package_name}/{model_package_version}"
if model_package_version
else model_package_name,
) )
datetime_now = datetime.now(tzutc()) datetime_now = datetime.now(tzutc())
self.model_package_name = model_package_name self.model_package_name = model_package_name
@ -2912,6 +2935,53 @@ class SageMakerModelBackend(BaseBackend):
return True return True
raise ValueError(f"Invalid model package type: {model_package_type}") raise ValueError(f"Invalid model package type: {model_package_type}")
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_model_package_groups( # type: ignore[misc]
self,
creation_time_after: Optional[int],
creation_time_before: Optional[int],
name_contains: Optional[str],
sort_by: Optional[str],
sort_order: Optional[str],
) -> List[ModelPackageGroup]:
if isinstance(creation_time_before, int):
creation_time_before_datetime = datetime.fromtimestamp(
creation_time_before, tz=tzutc()
)
if isinstance(creation_time_after, int):
creation_time_after_datetime = datetime.fromtimestamp(
creation_time_after, tz=tzutc()
)
model_package_group_summary_list = list(
filter(
lambda x: (
creation_time_after is None
or x.creation_time > creation_time_after_datetime
)
and (
creation_time_before is None
or x.creation_time < creation_time_before_datetime
)
and (
name_contains is None
or x.model_package_group_name.find(name_contains) != -1
),
self.model_package_groups.values(),
)
)
model_package_group_summary_list = list(
sorted(
model_package_group_summary_list,
key={
"Name": lambda x: x.model_package_group_name,
"CreationTime": lambda x: x.creation_time,
None: lambda x: x.creation_time,
}[sort_by],
reverse=sort_order == "Descending",
)
)
return model_package_group_summary_list
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_model_packages( # type: ignore[misc] def list_model_packages( # type: ignore[misc]
self, self,

View File

@ -799,6 +799,36 @@ class SageMakerResponse(BaseResponse):
) )
return 200, {}, json.dumps({"EndpointArn": endpoint_arn}) return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
def list_model_package_groups(self) -> str:
creation_time_after = self._get_param("CreationTimeAfter")
creation_time_before = self._get_param("CreationTimeBefore")
max_results = self._get_param("MaxResults")
name_contains = self._get_param("NameContains")
next_token = self._get_param("NextToken")
sort_by = self._get_param("SortBy")
sort_order = self._get_param("SortOrder")
(
model_package_group_summary_list,
next_token,
) = self.sagemaker_backend.list_model_package_groups(
creation_time_after=creation_time_after,
creation_time_before=creation_time_before,
max_results=max_results,
name_contains=name_contains,
next_token=next_token,
sort_by=sort_by,
sort_order=sort_order,
)
model_package_group_summary_list_response_object = [
x.gen_response_object() for x in model_package_group_summary_list
]
return json.dumps(
dict(
ModelPackageGroupSummaryList=model_package_group_summary_list_response_object,
NextToken=next_token,
)
)
def list_model_packages(self) -> str: def list_model_packages(self) -> str:
creation_time_after = self._get_param("CreationTimeAfter") creation_time_after = self._get_param("CreationTimeAfter")
creation_time_before = self._get_param("CreationTimeBefore") creation_time_before = self._get_param("CreationTimeBefore")

View File

@ -0,0 +1,165 @@
"""Unit tests for sagemaker-supported APIs."""
from unittest import SkipTest
import boto3
from freezegun import freeze_time
from moto import mock_sagemaker, settings
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@mock_sagemaker
def test_create_model_package_group():
client = boto3.client("sagemaker", region_name="us-east-2")
resp = client.create_model_package_group(
ModelPackageGroupName="test-model-package-group",
ModelPackageGroupDescription="test-model-package-group-description",
Tags=[
{"Key": "test-key", "Value": "test-value"},
],
)
assert (
resp["ModelPackageGroupArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group"
)
@mock_sagemaker
def test_list_model_package_groups():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups()
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
== "test-model-package-group-1"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][0]
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupDescription"]
== "test-model-package-group-description-1"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
assert "ModelPackageGroupDescription" in resp["ModelPackageGroupSummaryList"][1]
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupDescription"]
== "test-model-package-group-description-2"
)
@mock_sagemaker
def test_list_model_package_groups_creation_time_before():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
with freeze_time("2020-01-01 00:00:00"):
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
with freeze_time("2021-01-01 00:00:00"):
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups(CreationTimeBefore="2020-01-01T02:00:00Z")
assert len(resp["ModelPackageGroupSummaryList"]) == 1
@mock_sagemaker
def test_list_model_package_groups_creation_time_after():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
with freeze_time("2020-01-01 00:00:00"):
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
with freeze_time("2021-01-01 00:00:00"):
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups(CreationTimeAfter="2020-01-02T00:00:00Z")
assert len(resp["ModelPackageGroupSummaryList"]) == 1
@mock_sagemaker
def test_list_model_package_groups_name_contains():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
client.create_model_package_group(
ModelPackageGroupName="another-model-package-group",
ModelPackageGroupDescription="another-model-package-group-description",
)
resp = client.list_model_package_groups(NameContains="test-model-package")
assert len(resp["ModelPackageGroupSummaryList"]) == 2
@mock_sagemaker
def test_list_model_package_groups_sort_by():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups(SortBy="CreationTime")
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
== "test-model-package-group-1"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
@mock_sagemaker
def test_list_model_package_groups_sort_order():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-1",
ModelPackageGroupDescription="test-model-package-group-description-1",
)
client.create_model_package_group(
ModelPackageGroupName="test-model-package-group-2",
ModelPackageGroupDescription="test-model-package-group-description-2",
)
resp = client.list_model_package_groups(SortOrder="Descending")
assert (
resp["ModelPackageGroupSummaryList"][0]["ModelPackageGroupName"]
== "test-model-package-group-2"
)
assert (
resp["ModelPackageGroupSummaryList"][1]["ModelPackageGroupName"]
== "test-model-package-group-1"
)

View File

@ -15,11 +15,15 @@ def test_list_model_packages():
client = boto3.client("sagemaker", region_name="eu-west-1") client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package", ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description", ModelPackageDescription="test-model-package-description-v1",
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-v2",
) )
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package-2", ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2", ModelPackageDescription="test-model-package-description-v1-2",
) )
resp = client.list_model_packages() resp = client.list_model_packages()
@ -29,7 +33,7 @@ def test_list_model_packages():
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0] assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0]
assert ( assert (
resp["ModelPackageSummaryList"][0]["ModelPackageDescription"] resp["ModelPackageSummaryList"][0]["ModelPackageDescription"]
== "test-model-package-description" == "test-model-package-description-v2"
) )
assert ( assert (
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2" resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
@ -37,7 +41,7 @@ def test_list_model_packages():
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1] assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1]
assert ( assert (
resp["ModelPackageSummaryList"][1]["ModelPackageDescription"] resp["ModelPackageSummaryList"][1]["ModelPackageDescription"]
== "test-model-package-description-2" == "test-model-package-description-v1-2"
) )
@ -128,13 +132,22 @@ def test_list_model_packages_model_package_group_name():
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName="test-model-package-group",
) )
client.create_model_package( client.create_model_package(
ModelPackageName="test-model-package-2", ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2", ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName="test-model-package-group",
) )
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName="test-model-package-group",
)
client.create_model_package(
ModelPackageName="test-model-package-without-group",
ModelPackageDescription="test-model-package-description-without-group",
)
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group") resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
assert len(resp["ModelPackageSummaryList"]) == 2 assert len(resp["ModelPackageSummaryList"]) == 3
@mock_sagemaker @mock_sagemaker
@ -222,16 +235,24 @@ def test_create_model_package():
@mock_sagemaker @mock_sagemaker
def test_create_model_package_group(): def test_create_model_package_in_model_package_group():
client = boto3.client("sagemaker", region_name="us-east-2") client = boto3.client("sagemaker", region_name="eu-west-1")
resp = client.create_model_package_group( client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
resp_version_1 = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group", ModelPackageGroupName="test-model-package-group",
ModelPackageGroupDescription="test-model-package-group-description", ModelPackageDescription="test-model-package-description",
Tags=[ )
{"Key": "test-key", "Value": "test-value"}, resp_version_2 = client.create_model_package(
], ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
) )
assert ( assert (
resp["ModelPackageGroupArn"] resp_version_1["ModelPackageArn"]
== "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group" == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
)
assert (
resp_version_2["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/2"
) )