Sagemaker: Add update_model_package (#6891)

This commit is contained in:
HALLOUARD 2023-10-08 22:16:50 +02:00 committed by GitHub
parent e8c92a7e65
commit 34bc540f70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 588 additions and 41 deletions

View File

@ -6491,7 +6491,7 @@
- [ ] update_image_version
- [ ] update_inference_experiment
- [ ] update_model_card
- [ ] update_model_package
- [X] update_model_package
- [ ] update_monitoring_alert
- [ ] update_monitoring_schedule
- [ ] update_notebook_instance

View File

@ -19,10 +19,10 @@ from .utils import (
get_pipeline_from_name,
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
validate_model_approval_status,
)
from .utils import load_pipeline_definition_from_s3, arn_formatter
PAGINATION_MODEL = {
"list_experiments": {
"input_token": "NextToken",
@ -988,10 +988,10 @@ class ModelPackage(BaseObject):
source_algorithm_specification: Any,
validation_specification: Any,
certify_for_marketplace: bool,
model_approval_status: str,
model_approval_status: Optional[str],
metadata_properties: Any,
model_metrics: Any,
approval_description: str,
approval_description: Optional[str],
customer_metadata_properties: Any,
drift_check_baselines: Any,
domain: str,
@ -1029,24 +1029,23 @@ class ModelPackage(BaseObject):
self.inference_specification = inference_specification
self.source_algorithm_specification = source_algorithm_specification
self.validation_specification = validation_specification
self.model_package_status_details = (
{
"ValidationStatuses": [
{
"Name": model_package_arn,
"Status": "Completed",
}
],
"ImageScanStatuses": [
{
"Name": model_package_arn,
"Status": "Completed",
}
],
},
)
self.model_package_status_details = {
"ValidationStatuses": [
{
"Name": model_package_arn,
"Status": "Completed",
}
],
"ImageScanStatuses": [
{
"Name": model_package_arn,
"Status": "Completed",
}
],
}
self.certify_for_marketplace = certify_for_marketplace
self.model_approval_status = model_approval_status
self.model_approval_status: Optional[str] = None
self.set_model_approval_status(model_approval_status)
self.created_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
@ -1054,21 +1053,20 @@ class ModelPackage(BaseObject):
}
self.metadata_properties = metadata_properties
self.model_metrics = model_metrics
self.last_modified_time = datetime_now
self.last_modified_time: Optional[datetime] = None
self.approval_description = approval_description
self.customer_metadata_properties = customer_metadata_properties
self.drift_check_baselines = drift_check_baselines
self.domain = domain
self.task = task
self.sample_payload_url = sample_payload_url
self.additional_inference_specifications = additional_inference_specifications
self.additional_inference_specifications: Optional[List[Any]] = None
self.add_additional_inference_specifications(
additional_inference_specifications
)
self.tags = tags
self.model_package_status = "Completed"
self.last_modified_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
self.last_modified_by: Optional[Dict[str, str]] = None
self.client_token = client_token
def gen_response_object(self) -> Dict[str, Any]:
@ -1083,10 +1081,290 @@ class ModelPackage(BaseObject):
"ModelPackageArn",
"ModelPackageDescription",
"CreationTime",
"InferenceSpecification",
"SourceAlgorithmSpecification",
"ValidationSpecification",
"ModelPackageStatus",
"ModelPackageStatusDetails",
"CertifyForMarketplace",
"ModelApprovalStatus",
"CreatedBy",
"MetadataProperties",
"ModelMetrics",
"LastModifiedTime",
"LastModifiedBy",
"ApprovalDescription",
"CustomerMetadataProperties",
"DriftCheckBaselines",
"Domain",
"Task",
"SamplePayloadUrl",
"AdditionalInferenceSpecifications",
"SkipModelValidation",
]
return {k: v for k, v in response_object.items() if k in response_values}
return {
k: v
for k, v in response_object.items()
if k in response_values
if v is not None
}
def modifications_done(self) -> None:
self.last_modified_time = datetime.now(tzutc())
self.last_modified_by = self.created_by
def set_model_approval_status(self, model_approval_status: Optional[str]) -> None:
if model_approval_status is not None:
validate_model_approval_status(model_approval_status)
self.model_approval_status = model_approval_status
def remove_customer_metadata_property(
self, customer_metadata_properties_to_remove: List[str]
) -> None:
if customer_metadata_properties_to_remove is not None:
for customer_metadata_property in customer_metadata_properties_to_remove:
self.customer_metadata_properties.pop(customer_metadata_property, None)
def add_additional_inference_specifications(
self, additional_inference_specifications_to_add: Optional[List[Any]]
) -> None:
self.validate_additional_inference_specifications(
additional_inference_specifications_to_add
)
if (
self.additional_inference_specifications is not None
and additional_inference_specifications_to_add is not None
):
self.additional_inference_specifications.extend(
additional_inference_specifications_to_add
)
else:
self.additional_inference_specifications = (
additional_inference_specifications_to_add
)
def validate_additional_inference_specifications(
self, additional_inference_specifications: Optional[List[Dict[str, Any]]]
) -> None:
specifications_to_validate = additional_inference_specifications or []
for additional_inference_specification in specifications_to_validate:
if "SupportedTransformInstanceTypes" in additional_inference_specification:
self.validate_supported_transform_instance_types(
additional_inference_specification[
"SupportedTransformInstanceTypes"
]
)
if (
"SupportedRealtimeInferenceInstanceTypes"
in additional_inference_specification
):
self.validate_supported_realtime_inference_instance_types(
additional_inference_specification[
"SupportedRealtimeInferenceInstanceTypes"
]
)
@staticmethod
def validate_supported_transform_instance_types(instance_types: List[str]) -> None:
VALID_TRANSFORM_INSTANCE_TYPES = [
"ml.m4.xlarge",
"ml.m4.2xlarge",
"ml.m4.4xlarge",
"ml.m4.10xlarge",
"ml.m4.16xlarge",
"ml.c4.xlarge",
"ml.c4.2xlarge",
"ml.c4.4xlarge",
"ml.c4.8xlarge",
"ml.p2.xlarge",
"ml.p2.8xlarge",
"ml.p2.16xlarge",
"ml.p3.2xlarge",
"ml.p3.8xlarge",
"ml.p3.16xlarge",
"ml.c5.xlarge",
"ml.c5.2xlarge",
"ml.c5.4xlarge",
"ml.c5.9xlarge",
"ml.c5.18xlarge",
"ml.m5.large",
"ml.m5.xlarge",
"ml.m5.2xlarge",
"ml.m5.4xlarge",
"ml.m5.12xlarge",
"ml.m5.24xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.2xlarge",
"ml.g4dn.4xlarge",
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g4dn.16xlarge",
]
for instance_type in instance_types:
if not validators.is_one_of(instance_type, VALID_TRANSFORM_INSTANCE_TYPES):
message = f"Value '{instance_type}' at 'SupportedTransformInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_TRANSFORM_INSTANCE_TYPES}"
raise ValidationError(message=message)
@staticmethod
def validate_supported_realtime_inference_instance_types(
instance_types: List[str],
) -> None:
VALID_REALTIME_INFERENCE_INSTANCE_TYPES = [
"ml.t2.medium",
"ml.t2.large",
"ml.t2.xlarge",
"ml.t2.2xlarge",
"ml.m4.xlarge",
"ml.m4.2xlarge",
"ml.m4.4xlarge",
"ml.m4.10xlarge",
"ml.m4.16xlarge",
"ml.m5.large",
"ml.m5.xlarge",
"ml.m5.2xlarge",
"ml.m5.4xlarge",
"ml.m5.12xlarge",
"ml.m5.24xlarge",
"ml.m5d.large",
"ml.m5d.xlarge",
"ml.m5d.2xlarge",
"ml.m5d.4xlarge",
"ml.m5d.12xlarge",
"ml.m5d.24xlarge",
"ml.c4.large",
"ml.c4.xlarge",
"ml.c4.2xlarge",
"ml.c4.4xlarge",
"ml.c4.8xlarge",
"ml.p2.xlarge",
"ml.p2.8xlarge",
"ml.p2.16xlarge",
"ml.p3.2xlarge",
"ml.p3.8xlarge",
"ml.p3.16xlarge",
"ml.c5.large",
"ml.c5.xlarge",
"ml.c5.2xlarge",
"ml.c5.4xlarge",
"ml.c5.9xlarge",
"ml.c5.18xlarge",
"ml.c5d.large",
"ml.c5d.xlarge",
"ml.c5d.2xlarge",
"ml.c5d.4xlarge",
"ml.c5d.9xlarge",
"ml.c5d.18xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.2xlarge",
"ml.g4dn.4xlarge",
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g4dn.16xlarge",
"ml.r5.large",
"ml.r5.xlarge",
"ml.r5.2xlarge",
"ml.r5.4xlarge",
"ml.r5.12xlarge",
"ml.r5.24xlarge",
"ml.r5d.large",
"ml.r5d.xlarge",
"ml.r5d.2xlarge",
"ml.r5d.4xlarge",
"ml.r5d.12xlarge",
"ml.r5d.24xlarge",
"ml.inf1.xlarge",
"ml.inf1.2xlarge",
"ml.inf1.6xlarge",
"ml.inf1.24xlarge",
"ml.c6i.large",
"ml.c6i.xlarge",
"ml.c6i.2xlarge",
"ml.c6i.4xlarge",
"ml.c6i.8xlarge",
"ml.c6i.12xlarge",
"ml.c6i.16xlarge",
"ml.c6i.24xlarge",
"ml.c6i.32xlarge",
"ml.g5.xlarge",
"ml.g5.2xlarge",
"ml.g5.4xlarge",
"ml.g5.8xlarge",
"ml.g5.12xlarge",
"ml.g5.16xlarge",
"ml.g5.24xlarge",
"ml.g5.48xlarge",
"ml.p4d.24xlarge",
"ml.c7g.large",
"ml.c7g.xlarge",
"ml.c7g.2xlarge",
"ml.c7g.4xlarge",
"ml.c7g.8xlarge",
"ml.c7g.12xlarge",
"ml.c7g.16xlarge",
"ml.m6g.large",
"ml.m6g.xlarge",
"ml.m6g.2xlarge",
"ml.m6g.4xlarge",
"ml.m6g.8xlarge",
"ml.m6g.12xlarge",
"ml.m6g.16xlarge",
"ml.m6gd.large",
"ml.m6gd.xlarge",
"ml.m6gd.2xlarge",
"ml.m6gd.4xlarge",
"ml.m6gd.8xlarge",
"ml.m6gd.12xlarge",
"ml.m6gd.16xlarge",
"ml.c6g.large",
"ml.c6g.xlarge",
"ml.c6g.2xlarge",
"ml.c6g.4xlarge",
"ml.c6g.8xlarge",
"ml.c6g.12xlarge",
"ml.c6g.16xlarge",
"ml.c6gd.large",
"ml.c6gd.xlarge",
"ml.c6gd.2xlarge",
"ml.c6gd.4xlarge",
"ml.c6gd.8xlarge",
"ml.c6gd.12xlarge",
"ml.c6gd.16xlarge",
"ml.c6gn.large",
"ml.c6gn.xlarge",
"ml.c6gn.2xlarge",
"ml.c6gn.4xlarge",
"ml.c6gn.8xlarge",
"ml.c6gn.12xlarge",
"ml.c6gn.16xlarge",
"ml.r6g.large",
"ml.r6g.xlarge",
"ml.r6g.2xlarge",
"ml.r6g.4xlarge",
"ml.r6g.8xlarge",
"ml.r6g.12xlarge",
"ml.r6g.16xlarge",
"ml.r6gd.large",
"ml.r6gd.xlarge",
"ml.r6gd.2xlarge",
"ml.r6gd.4xlarge",
"ml.r6gd.8xlarge",
"ml.r6gd.12xlarge",
"ml.r6gd.16xlarge",
"ml.p4de.24xlarge",
"ml.trn1.2xlarge",
"ml.trn1.32xlarge",
"ml.inf2.xlarge",
"ml.inf2.8xlarge",
"ml.inf2.24xlarge",
"ml.inf2.48xlarge",
"ml.p5.48xlarge",
]
for instance_type in instance_types:
if not validators.is_one_of(
instance_type, VALID_REALTIME_INFERENCE_INSTANCE_TYPES
):
message = f"Value '{instance_type}' at 'SupportedRealtimeInferenceInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_REALTIME_INFERENCE_INSTANCE_TYPES}"
raise ValidationError(message=message)
class VpcConfig(BaseObject):
@ -3054,6 +3332,36 @@ class SageMakerModelBackend(BaseBackend):
raise ValidationError(f"Model package {model_package_name} not found")
return model_package
def update_model_package(
self,
model_package_arn: str,
model_approval_status: Optional[str],
approval_description: Optional[str],
customer_metadata_properties: Optional[Dict[str, str]],
customer_metadata_properties_to_remove: List[str],
additional_inference_specifications_to_add: Optional[List[Any]],
) -> str:
model_package_name_mapped = self.model_package_name_mapping.get(
model_package_arn, model_package_arn
)
model_package = self.model_packages.get(model_package_name_mapped)
if model_package is None:
raise ValidationError(f"Model package {model_package_arn} not found")
model_package.set_model_approval_status(model_approval_status)
model_package.approval_description = approval_description
model_package.customer_metadata_properties = customer_metadata_properties
model_package.remove_customer_metadata_property(
customer_metadata_properties_to_remove
)
model_package.add_additional_inference_specifications(
additional_inference_specifications_to_add
)
model_package.modifications_done()
return model_package.model_package_arn
def create_model_package(
self,
model_package_name: str,
@ -3062,9 +3370,9 @@ class SageMakerModelBackend(BaseBackend):
inference_specification: Any,
validation_specification: Any,
source_algorithm_specification: Any,
certify_for_marketplace: Any,
certify_for_marketplace: bool,
tags: Any,
model_approval_status: str,
model_approval_status: Optional[str],
metadata_properties: Any,
model_metrics: Any,
client_token: Any,
@ -3102,7 +3410,7 @@ class SageMakerModelBackend(BaseBackend):
sample_payload_url=sample_payload_url,
additional_inference_specifications=additional_inference_specifications,
model_package_version=model_package_version,
approval_description=model_approval_status,
approval_description=None,
region_name=self.region_name,
account_id=self.account_id,
client_token=client_token,

View File

@ -874,6 +874,27 @@ class SageMakerResponse(BaseResponse):
model_package.gen_response_object(),
)
def update_model_package(self) -> str:
model_package_arn = self._get_param("ModelPackageArn")
model_approval_status = self._get_param("ModelApprovalStatus")
approval_dexcription = self._get_param("ApprovalDescription")
customer_metadata_properties = self._get_param("CustomerMetadataProperties")
customer_metadata_properties_to_remove = self._get_param(
"CustomerMetadataPropertiesToRemove", []
)
additional_inference_specification_to_add = self._get_param(
"AdditionalInferenceSpecificationsToAdd"
)
updated_model_package_arn = self.sagemaker_backend.update_model_package(
model_package_arn=model_package_arn,
model_approval_status=model_approval_status,
approval_description=approval_dexcription,
customer_metadata_properties=customer_metadata_properties,
customer_metadata_properties_to_remove=customer_metadata_properties_to_remove,
additional_inference_specifications_to_add=additional_inference_specification_to_add,
)
return json.dumps(dict(ModelPackageArn=updated_model_package_arn))
def create_model_package(self) -> str:
model_package_name = self._get_param("ModelPackageName")
model_package_group_name = self._get_param("ModelPackageGroupName")
@ -881,7 +902,7 @@ class SageMakerResponse(BaseResponse):
inference_specification = self._get_param("InferenceSpecification")
validation_specification = self._get_param("ValidationSpecification")
source_algorithm_specification = self._get_param("SourceAlgorithmSpecification")
certify_for_marketplace = self._get_param("CertifyForMarketplace")
certify_for_marketplace = self._get_param("CertifyForMarketplace", False)
tags = self._get_param("Tags")
model_approval_status = self._get_param("ModelApprovalStatus")
metadata_properties = self._get_param("MetadataProperties")
@ -893,7 +914,7 @@ class SageMakerResponse(BaseResponse):
task = self._get_param("Task")
sample_payload_url = self._get_param("SamplePayloadUrl")
additional_inference_specifications = self._get_param(
"AdditionalInferenceSpecifications"
"AdditionalInferenceSpecifications", None
)
model_package_arn = self.sagemaker_backend.create_model_package(
model_package_name=model_package_name,

View File

@ -5,7 +5,6 @@ import json
from typing import Any, Dict
from .exceptions import ValidationError
if typing.TYPE_CHECKING:
from .models import FakePipeline, FakePipelineExecution
@ -51,3 +50,15 @@ def load_pipeline_definition_from_s3(
def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
def validate_model_approval_status(model_approval_status: typing.Optional[str]) -> None:
if model_approval_status is not None and model_approval_status not in [
"Approved",
"Rejected",
"PendingManualApproval",
]:
raise ValidationError(
f"Value '{model_approval_status}' at 'modelApprovalStatus' failed to satisfy constraint: "
"Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
)

View File

@ -1,13 +1,20 @@
"""Unit tests for sagemaker-supported APIs."""
from unittest import SkipTest
from datetime import datetime
from unittest import SkipTest, TestCase
import boto3
from freezegun import freeze_time
from dateutil.tz import tzutc # type: ignore
from moto import mock_sagemaker, settings
import pytest
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
from moto.sagemaker.exceptions import ValidationError
from moto.sagemaker.models import ModelPackage
from moto.sagemaker.utils import validate_model_approval_status
@mock_sagemaker
@ -210,15 +217,194 @@ def test_list_model_packages_sort_order():
@mock_sagemaker
def test_describe_model_package():
def test_describe_model_package_default():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
with freeze_time("2015-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["ModelPackageName"] == "test-model-package"
assert resp["ModelPackageGroupName"] == "test-model-package-group"
assert resp["ModelPackageDescription"] == "test-model-package-description"
assert (
resp["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
)
assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc())
assert (
resp["CreatedBy"]["UserProfileArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
)
assert resp["CreatedBy"]["UserProfileName"] == "fake-user-profile-name"
assert resp["CreatedBy"]["DomainId"] == "fake-domain-id"
assert resp["ModelPackageStatus"] == "Completed"
assert resp.get("ModelPackageStatusDetails") is not None
assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [
{
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
"Status": "Completed",
}
]
assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [
{
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
"Status": "Completed",
}
]
assert resp["CertifyForMarketplace"] is False
@mock_sagemaker
def test_describe_model_package_with_create_model_package_arguments():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
ModelApprovalStatus="PendingManualApproval",
MetadataProperties={
"CommitId": "test-commit-id",
"GeneratedBy": "test-user",
"ProjectId": "test-project-id",
"Repository": "test-repo",
},
CertifyForMarketplace=True,
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["ModelApprovalStatus"] == "PendingManualApproval"
assert resp.get("ApprovalDescription") is None
assert resp["CertifyForMarketplace"] is True
assert resp["MetadataProperties"] is not None
assert resp["MetadataProperties"]["CommitId"] == "test-commit-id"
assert resp["MetadataProperties"]["GeneratedBy"] == "test-user"
assert resp["MetadataProperties"]["ProjectId"] == "test-project-id"
assert resp["MetadataProperties"]["Repository"] == "test-repo"
@mock_sagemaker
def test_update_model_package():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
CustomerMetadataProperties={
"test-key-to-remove": "test-value-to-remove",
},
)["ModelPackageArn"]
client.update_model_package(
ModelPackageArn=model_package_arn,
ModelApprovalStatus="Approved",
ApprovalDescription="test-approval-description",
CustomerMetadataProperties={"test-key": "test-value"},
CustomerMetadataPropertiesToRemove=["test-key-to-remove"],
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["ModelApprovalStatus"] == "Approved"
assert resp["ApprovalDescription"] == "test-approval-description"
assert resp["CustomerMetadataProperties"] is not None
assert resp["CustomerMetadataProperties"]["test-key"] == "test-value"
assert resp["CustomerMetadataProperties"].get("test-key-to-remove") is None
@mock_sagemaker
def test_update_model_package_given_additional_inference_specifications_to_add():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)["ModelPackageArn"]
additional_inference_specifications_to_add = {
"Name": "test-inference-specification-name",
"Description": "test-inference-specification-description",
"Containers": [
{
"ContainerHostname": "test-container-hostname",
"Image": "test-image",
"ImageDigest": "test-image-digest",
"ModelDataUrl": "test-model-data-url",
"ProductId": "test-product-id",
"Environment": {"test-key": "test-value"},
"ModelInput": {"DataInputConfig": "test-data-input-config"},
"Framework": "test-framework",
"FrameworkVersion": "test-framework-version",
"NearestModelName": "test-nearest-model-name",
},
],
"SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
"SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.t2.large"],
"SupportedContentTypes": [
"test-content-type-1",
],
"SupportedResponseMIMETypes": [
"test-response-mime-type-1",
],
}
client.update_model_package(
ModelPackageArn=model_package_arn,
AdditionalInferenceSpecificationsToAdd=[
additional_inference_specifications_to_add
],
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["AdditionalInferenceSpecifications"] is not None
TestCase().assertDictEqual(
resp["AdditionalInferenceSpecifications"][0],
additional_inference_specifications_to_add,
)
@mock_sagemaker
def test_update_model_package_shoudl_update_last_modified_information():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
model_package_arn = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
)["ModelPackageArn"]
with freeze_time("2020-01-01 12:00:00"):
client.update_model_package(
ModelPackageArn=model_package_arn,
ModelApprovalStatus="Approved",
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp.get("LastModifiedTime") is not None
assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
assert resp.get("LastModifiedBy") is not None
assert (
resp["LastModifiedBy"]["UserProfileArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
)
assert resp["LastModifiedBy"]["UserProfileName"] == "fake-user-profile-name"
assert resp["LastModifiedBy"]["DomainId"] == "fake-domain-id"
def test_validate_supported_transform_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
with pytest.raises(ValidationError) as exc:
ModelPackage.validate_supported_transform_instance_types(
["ml.m4.2xlarge", "not-a-supported-transform-instances-types"]
)
assert "not-a-supported-transform-instances-types" in str(exc.value)
def test_validate_supported_realtime_inference_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
with pytest.raises(ValidationError) as exc:
ModelPackage.validate_supported_realtime_inference_instance_types(
["ml.m4.2xlarge", "not-a-supported-realtime-inference-instances-types"]
)
assert "not-a-supported-realtime-inference-instances-types" in str(exc.value)
@mock_sagemaker
@ -234,6 +420,27 @@ def test_create_model_package():
)
@pytest.mark.parametrize(
"model_approval_status",
["Approved", "Rejected", "PendingManualApproval"],
)
def test_utils_validate_model_approval_status_should_not_raise_error_if_model_approval_status_is_correct(
model_approval_status: str,
):
validate_model_approval_status(model_approval_status)
def test_utils_validate_model_approval_status_should_raise_error_if_model_approval_status_is_incorrect():
model_approval_status = "IncorrectStatus"
with pytest.raises(ValidationError) as exc:
validate_model_approval_status(model_approval_status)
assert exc.value.code == 400
assert (
exc.value.message
== "Value 'IncorrectStatus' at 'modelApprovalStatus' failed to satisfy constraint: Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
)
@mock_sagemaker
def test_create_model_package_in_model_package_group():
client = boto3.client("sagemaker", region_name="eu-west-1")