Sagemaker model: support versioned models (#7165)

This commit is contained in:
Guilherme de Amorim 2024-01-01 11:25:37 -03:00 committed by GitHub
parent c60b3f03ab
commit 44dcd14088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 32 deletions

View File

@ -3,7 +3,7 @@ import os
import random
import string
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union, cast
from dateutil.tz import tzutc
@ -1004,6 +1004,7 @@ class ModelPackage(BaseObject):
client_token: str,
region_name: str,
account_id: str,
model_package_type: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> None:
fake_user_profile_name = "fake-user-profile-name"
@ -1032,6 +1033,7 @@ class ModelPackage(BaseObject):
self.inference_specification = inference_specification
self.source_algorithm_specification = source_algorithm_specification
self.validation_specification = validation_specification
self.model_package_type = model_package_type
self.model_package_status_details = {
"ValidationStatuses": [
{
@ -1105,6 +1107,10 @@ class ModelPackage(BaseObject):
"AdditionalInferenceSpecifications",
"SkipModelValidation",
]
if self.model_package_type == "Versioned":
del response_object["ModelPackageName"]
elif self.model_package_type == "Unversioned":
del response_object["ModelPackageGroupName"]
return {
k: v
for k, v in response_object.items()
@ -3295,6 +3301,8 @@ class SageMakerModelBackend(BaseBackend):
)
if model_package_group_name is not None:
model_package_type = "Versioned"
if model_package_group_name.startswith("arn:aws"):
model_package_group_name = model_package_group_name.split("/")[-1]
model_package_summary_list = list(
filter(
lambda x: (
@ -3377,7 +3385,7 @@ class SageMakerModelBackend(BaseBackend):
def create_model_package(
self,
model_package_name: str,
model_package_name: Optional[str],
model_package_group_name: Optional[str],
model_package_description: Optional[str],
inference_specification: Any,
@ -3397,15 +3405,32 @@ class SageMakerModelBackend(BaseBackend):
additional_inference_specifications: Any,
) -> str:
model_package_version = None
if model_package_group_name is not None:
if model_package_group_name and model_package_name:
raise AWSValidationException(
"An error occurred (ValidationException) when calling the CreateModelPackage operation: Both ModelPackageName and ModelPackageGroupName are provided in the input. Cannot determine which one to use."
)
elif not model_package_group_name and not model_package_name:
raise AWSValidationException(
"An error ocurred (ValidationException) when calling the CreateModelPackag operation: Missing ARN."
)
elif model_package_group_name:
model_package_type = "Versioned"
model_package_name = model_package_group_name
model_packages_for_group = [
x
for x in self.model_packages.values()
if x.model_package_group_name == model_package_group_name
]
if model_package_group_name not in self.model_package_groups:
raise AWSValidationException(
"An error ocurred (ValidationException) when calling the CreateModelPackage operation: Model Package Group does not exist."
)
model_package_version = len(model_packages_for_group) + 1
else:
model_package_type = "Unversioned"
model_package = ModelPackage(
model_package_name=model_package_name,
model_package_name=cast(str, model_package_name),
model_package_group_name=model_package_group_name,
model_package_description=model_package_description,
inference_specification=inference_specification,
@ -3427,6 +3452,7 @@ class SageMakerModelBackend(BaseBackend):
region_name=self.region_name,
account_id=self.account_id,
client_token=client_token,
model_package_type=model_package_type,
)
self.model_package_name_mapping[
model_package.model_package_name

View File

@ -4,6 +4,7 @@ from unittest import SkipTest, TestCase
import boto3
import pytest
from botocore.exceptions import ClientError
from dateutil.tz import tzutc # type: ignore
from freezegun import freeze_time
@ -133,19 +134,17 @@ def test_list_model_packages_approval_status():
def test_list_model_packages_model_package_group_name():
client = boto3.client("sagemaker", region_name="eu-west-1")
group1 = "test-model-package-group"
client.create_model_package_group(
ModelPackageGroupName=group1,
ModelPackageGroupDescription="test-model-package-description",
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName=group1,
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName=group1,
)
client.create_model_package(
@ -170,8 +169,8 @@ def test_list_model_packages_model_package_group_name():
@mock_sagemaker
def test_list_model_packages_model_package_type():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group",
)
@ -233,18 +232,16 @@ def test_describe_model_package_default():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
with freeze_time("2015-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package",
model_package_arn = client.create_model_package(
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["ModelPackageName"] == "test-model-package"
)["ModelPackageArn"]
resp = client.describe_model_package(ModelPackageName=model_package_arn)
assert resp["ModelPackageGroupName"] == "test-model-package-group"
assert resp["ModelPackageDescription"] == "test-model-package-description"
assert (
resp["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1"
)
assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc())
assert (
@ -257,13 +254,13 @@ def test_describe_model_package_default():
assert resp.get("ModelPackageStatusDetails") is not None
assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [
{
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1",
"Status": "Completed",
}
]
assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [
{
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1",
"Status": "Completed",
}
]
@ -273,10 +270,8 @@ def test_describe_model_package_default():
@mock_sagemaker
def test_describe_model_package_with_create_model_package_arguments():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
ModelApprovalStatus="PendingManualApproval",
MetadataProperties={
@ -303,7 +298,6 @@ def test_update_model_package():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
CustomerMetadataProperties={
@ -317,7 +311,7 @@ def test_update_model_package():
CustomerMetadataProperties={"test-key": "test-value"},
CustomerMetadataPropertiesToRemove=["test-key-to-remove"],
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
resp = client.describe_model_package(ModelPackageName=model_package_arn)
assert resp["ModelApprovalStatus"] == "Approved"
assert resp["ApprovalDescription"] == "test-approval-description"
assert resp["CustomerMetadataProperties"] is not None
@ -330,7 +324,6 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)["ModelPackageArn"]
@ -366,7 +359,7 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
additional_inference_specifications_to_add
],
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
resp = client.describe_model_package(ModelPackageName=model_package_arn)
assert resp["AdditionalInferenceSpecifications"] is not None
TestCase().assertDictEqual(
resp["AdditionalInferenceSpecifications"][0],
@ -375,13 +368,12 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
@mock_sagemaker
def test_update_model_package_shoudl_update_last_modified_information():
def test_update_model_package_should_update_last_modified_information():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
)["ModelPackageArn"]
with freeze_time("2020-01-01 12:00:00"):
@ -389,7 +381,7 @@ def test_update_model_package_shoudl_update_last_modified_information():
ModelPackageArn=model_package_arn,
ModelApprovalStatus="Approved",
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
resp = client.describe_model_package(ModelPackageName=model_package_arn)
assert resp.get("LastModifiedTime") is not None
assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
assert resp.get("LastModifiedBy") is not None
@ -430,6 +422,42 @@ def test_create_model_package():
)
@mock_sagemaker
def test_create_model_package_should_raise_error_when_model_package_group_name_and_model_package_group_name_are_not_provided():
client = boto3.client("sagemaker", region_name="eu-west-1")
with pytest.raises(ClientError) as exc:
_ = client.create_model_package(
ModelPackageDescription="test-model-package-description",
)
assert "Missing ARN." in str(exc.value)
@mock_sagemaker
def test_create_model_package_should_raise_error_when_package_name_and_group_are_provided():
client = boto3.client("sagemaker", region_name="eu-west-1")
with pytest.raises(ClientError) as exc:
_ = client.create_model_package(
ModelPackageGroupName="TestModelPackageGroup",
ModelPackageName="TestModelPackage",
ModelPackageDescription="test-model-package-description",
)
assert (
"Both ModelPackageName and ModelPackageGroupName are provided in the input."
in str(exc.value)
)
@mock_sagemaker
def test_create_model_package_should_raise_error_when_model_package_group_provided_not_exist():
client = boto3.client("sagemaker", region_name="eu-west-1")
with pytest.raises(ClientError) as exc:
_ = client.create_model_package(
ModelPackageGroupName="TestModelPackageGroupNotExist",
ModelPackageDescription="test-model-package-description",
)
assert "Model Package Group does not exist." in str(exc.value)
@pytest.mark.parametrize(
"model_approval_status",
["Approved", "Rejected", "PendingManualApproval"],
@ -456,20 +484,18 @@ def test_create_model_package_in_model_package_group():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
resp_version_1 = client.create_model_package(
ModelPackageName="TestModelPackage",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
resp_version_2 = client.create_model_package(
ModelPackageName="TestModelPackage",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
assert (
resp_version_1["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/1"
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1"
)
assert (
resp_version_2["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/2"
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/2"
)