Sagemaker model: support versioned models (#7165)
This commit is contained in:
parent
c60b3f03ab
commit
44dcd14088
@ -3,7 +3,7 @@ import os
|
||||
import random
|
||||
import string
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
||||
|
||||
from dateutil.tz import tzutc
|
||||
|
||||
@ -1004,6 +1004,7 @@ class ModelPackage(BaseObject):
|
||||
client_token: str,
|
||||
region_name: str,
|
||||
account_id: str,
|
||||
model_package_type: str,
|
||||
tags: Optional[List[Dict[str, str]]] = None,
|
||||
) -> None:
|
||||
fake_user_profile_name = "fake-user-profile-name"
|
||||
@ -1032,6 +1033,7 @@ class ModelPackage(BaseObject):
|
||||
self.inference_specification = inference_specification
|
||||
self.source_algorithm_specification = source_algorithm_specification
|
||||
self.validation_specification = validation_specification
|
||||
self.model_package_type = model_package_type
|
||||
self.model_package_status_details = {
|
||||
"ValidationStatuses": [
|
||||
{
|
||||
@ -1105,6 +1107,10 @@ class ModelPackage(BaseObject):
|
||||
"AdditionalInferenceSpecifications",
|
||||
"SkipModelValidation",
|
||||
]
|
||||
if self.model_package_type == "Versioned":
|
||||
del response_object["ModelPackageName"]
|
||||
elif self.model_package_type == "Unversioned":
|
||||
del response_object["ModelPackageGroupName"]
|
||||
return {
|
||||
k: v
|
||||
for k, v in response_object.items()
|
||||
@ -3295,6 +3301,8 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
if model_package_group_name is not None:
|
||||
model_package_type = "Versioned"
|
||||
if model_package_group_name.startswith("arn:aws"):
|
||||
model_package_group_name = model_package_group_name.split("/")[-1]
|
||||
model_package_summary_list = list(
|
||||
filter(
|
||||
lambda x: (
|
||||
@ -3377,7 +3385,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
|
||||
def create_model_package(
|
||||
self,
|
||||
model_package_name: str,
|
||||
model_package_name: Optional[str],
|
||||
model_package_group_name: Optional[str],
|
||||
model_package_description: Optional[str],
|
||||
inference_specification: Any,
|
||||
@ -3397,15 +3405,32 @@ class SageMakerModelBackend(BaseBackend):
|
||||
additional_inference_specifications: Any,
|
||||
) -> str:
|
||||
model_package_version = None
|
||||
if model_package_group_name is not None:
|
||||
if model_package_group_name and model_package_name:
|
||||
raise AWSValidationException(
|
||||
"An error occurred (ValidationException) when calling the CreateModelPackage operation: Both ModelPackageName and ModelPackageGroupName are provided in the input. Cannot determine which one to use."
|
||||
)
|
||||
elif not model_package_group_name and not model_package_name:
|
||||
raise AWSValidationException(
|
||||
"An error ocurred (ValidationException) when calling the CreateModelPackag operation: Missing ARN."
|
||||
)
|
||||
elif model_package_group_name:
|
||||
model_package_type = "Versioned"
|
||||
model_package_name = model_package_group_name
|
||||
model_packages_for_group = [
|
||||
x
|
||||
for x in self.model_packages.values()
|
||||
if x.model_package_group_name == model_package_group_name
|
||||
]
|
||||
if model_package_group_name not in self.model_package_groups:
|
||||
raise AWSValidationException(
|
||||
"An error ocurred (ValidationException) when calling the CreateModelPackage operation: Model Package Group does not exist."
|
||||
)
|
||||
model_package_version = len(model_packages_for_group) + 1
|
||||
else:
|
||||
model_package_type = "Unversioned"
|
||||
|
||||
model_package = ModelPackage(
|
||||
model_package_name=model_package_name,
|
||||
model_package_name=cast(str, model_package_name),
|
||||
model_package_group_name=model_package_group_name,
|
||||
model_package_description=model_package_description,
|
||||
inference_specification=inference_specification,
|
||||
@ -3427,6 +3452,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
region_name=self.region_name,
|
||||
account_id=self.account_id,
|
||||
client_token=client_token,
|
||||
model_package_type=model_package_type,
|
||||
)
|
||||
self.model_package_name_mapping[
|
||||
model_package.model_package_name
|
||||
|
@ -4,6 +4,7 @@ from unittest import SkipTest, TestCase
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from dateutil.tz import tzutc # type: ignore
|
||||
from freezegun import freeze_time
|
||||
|
||||
@ -133,19 +134,17 @@ def test_list_model_packages_approval_status():
|
||||
def test_list_model_packages_model_package_group_name():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
group1 = "test-model-package-group"
|
||||
client.create_model_package_group(
|
||||
ModelPackageGroupName=group1,
|
||||
ModelPackageGroupDescription="test-model-package-description",
|
||||
)
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
ModelPackageGroupName=group1,
|
||||
)
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageDescription="test-model-package-description-2",
|
||||
ModelPackageGroupName=group1,
|
||||
)
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package-2",
|
||||
ModelPackageDescription="test-model-package-description-3",
|
||||
ModelPackageGroupName=group1,
|
||||
)
|
||||
client.create_model_package(
|
||||
@ -170,8 +169,8 @@ def test_list_model_packages_model_package_group_name():
|
||||
@mock_sagemaker
|
||||
def test_list_model_packages_model_package_type():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
)
|
||||
@ -233,18 +232,16 @@ def test_describe_model_package_default():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
with freeze_time("2015-01-01 00:00:00"):
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp["ModelPackageName"] == "test-model-package"
|
||||
)["ModelPackageArn"]
|
||||
resp = client.describe_model_package(ModelPackageName=model_package_arn)
|
||||
assert resp["ModelPackageGroupName"] == "test-model-package-group"
|
||||
assert resp["ModelPackageDescription"] == "test-model-package-description"
|
||||
assert (
|
||||
resp["ModelPackageArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1"
|
||||
)
|
||||
assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc())
|
||||
assert (
|
||||
@ -257,13 +254,13 @@ def test_describe_model_package_default():
|
||||
assert resp.get("ModelPackageStatusDetails") is not None
|
||||
assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [
|
||||
{
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1",
|
||||
"Status": "Completed",
|
||||
}
|
||||
]
|
||||
assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [
|
||||
{
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1",
|
||||
"Status": "Completed",
|
||||
}
|
||||
]
|
||||
@ -273,10 +270,8 @@ def test_describe_model_package_default():
|
||||
@mock_sagemaker
|
||||
def test_describe_model_package_with_create_model_package_arguments():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
ModelApprovalStatus="PendingManualApproval",
|
||||
MetadataProperties={
|
||||
@ -303,7 +298,6 @@ def test_update_model_package():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
CustomerMetadataProperties={
|
||||
@ -317,7 +311,7 @@ def test_update_model_package():
|
||||
CustomerMetadataProperties={"test-key": "test-value"},
|
||||
CustomerMetadataPropertiesToRemove=["test-key-to-remove"],
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
resp = client.describe_model_package(ModelPackageName=model_package_arn)
|
||||
assert resp["ModelApprovalStatus"] == "Approved"
|
||||
assert resp["ApprovalDescription"] == "test-approval-description"
|
||||
assert resp["CustomerMetadataProperties"] is not None
|
||||
@ -330,7 +324,6 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)["ModelPackageArn"]
|
||||
@ -366,7 +359,7 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
|
||||
additional_inference_specifications_to_add
|
||||
],
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
resp = client.describe_model_package(ModelPackageName=model_package_arn)
|
||||
assert resp["AdditionalInferenceSpecifications"] is not None
|
||||
TestCase().assertDictEqual(
|
||||
resp["AdditionalInferenceSpecifications"][0],
|
||||
@ -375,13 +368,12 @@ def test_update_model_package_given_additional_inference_specifications_to_add()
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_model_package_shoudl_update_last_modified_information():
|
||||
def test_update_model_package_should_update_last_modified_information():
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest("Can't freeze time in ServerMode")
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
)["ModelPackageArn"]
|
||||
with freeze_time("2020-01-01 12:00:00"):
|
||||
@ -389,7 +381,7 @@ def test_update_model_package_shoudl_update_last_modified_information():
|
||||
ModelPackageArn=model_package_arn,
|
||||
ModelApprovalStatus="Approved",
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
resp = client.describe_model_package(ModelPackageName=model_package_arn)
|
||||
assert resp.get("LastModifiedTime") is not None
|
||||
assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
|
||||
assert resp.get("LastModifiedBy") is not None
|
||||
@ -430,6 +422,42 @@ def test_create_model_package():
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model_package_should_raise_error_when_model_package_group_name_and_model_package_group_name_are_not_provided():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
_ = client.create_model_package(
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
assert "Missing ARN." in str(exc.value)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model_package_should_raise_error_when_package_name_and_group_are_provided():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
_ = client.create_model_package(
|
||||
ModelPackageGroupName="TestModelPackageGroup",
|
||||
ModelPackageName="TestModelPackage",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
assert (
|
||||
"Both ModelPackageName and ModelPackageGroupName are provided in the input."
|
||||
in str(exc.value)
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model_package_should_raise_error_when_model_package_group_provided_not_exist():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
_ = client.create_model_package(
|
||||
ModelPackageGroupName="TestModelPackageGroupNotExist",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
assert "Model Package Group does not exist." in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_approval_status",
|
||||
["Approved", "Rejected", "PendingManualApproval"],
|
||||
@ -456,20 +484,18 @@ def test_create_model_package_in_model_package_group():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
resp_version_1 = client.create_model_package(
|
||||
ModelPackageName="TestModelPackage",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
resp_version_2 = client.create_model_package(
|
||||
ModelPackageName="TestModelPackage",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
assert (
|
||||
resp_version_1["ModelPackageArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/1"
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1"
|
||||
)
|
||||
assert (
|
||||
resp_version_2["ModelPackageArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/2"
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/2"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user