Sagemaker: Add update_model_package (#6891)
This commit is contained in:
parent
e8c92a7e65
commit
34bc540f70
@ -6491,7 +6491,7 @@
|
||||
- [ ] update_image_version
|
||||
- [ ] update_inference_experiment
|
||||
- [ ] update_model_card
|
||||
- [ ] update_model_package
|
||||
- [X] update_model_package
|
||||
- [ ] update_monitoring_alert
|
||||
- [ ] update_monitoring_schedule
|
||||
- [ ] update_notebook_instance
|
||||
|
@ -19,10 +19,10 @@ from .utils import (
|
||||
get_pipeline_from_name,
|
||||
get_pipeline_execution_from_arn,
|
||||
get_pipeline_name_from_execution_arn,
|
||||
validate_model_approval_status,
|
||||
)
|
||||
from .utils import load_pipeline_definition_from_s3, arn_formatter
|
||||
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_experiments": {
|
||||
"input_token": "NextToken",
|
||||
@ -988,10 +988,10 @@ class ModelPackage(BaseObject):
|
||||
source_algorithm_specification: Any,
|
||||
validation_specification: Any,
|
||||
certify_for_marketplace: bool,
|
||||
model_approval_status: str,
|
||||
model_approval_status: Optional[str],
|
||||
metadata_properties: Any,
|
||||
model_metrics: Any,
|
||||
approval_description: str,
|
||||
approval_description: Optional[str],
|
||||
customer_metadata_properties: Any,
|
||||
drift_check_baselines: Any,
|
||||
domain: str,
|
||||
@ -1029,24 +1029,23 @@ class ModelPackage(BaseObject):
|
||||
self.inference_specification = inference_specification
|
||||
self.source_algorithm_specification = source_algorithm_specification
|
||||
self.validation_specification = validation_specification
|
||||
self.model_package_status_details = (
|
||||
{
|
||||
"ValidationStatuses": [
|
||||
{
|
||||
"Name": model_package_arn,
|
||||
"Status": "Completed",
|
||||
}
|
||||
],
|
||||
"ImageScanStatuses": [
|
||||
{
|
||||
"Name": model_package_arn,
|
||||
"Status": "Completed",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
self.model_package_status_details = {
|
||||
"ValidationStatuses": [
|
||||
{
|
||||
"Name": model_package_arn,
|
||||
"Status": "Completed",
|
||||
}
|
||||
],
|
||||
"ImageScanStatuses": [
|
||||
{
|
||||
"Name": model_package_arn,
|
||||
"Status": "Completed",
|
||||
}
|
||||
],
|
||||
}
|
||||
self.certify_for_marketplace = certify_for_marketplace
|
||||
self.model_approval_status = model_approval_status
|
||||
self.model_approval_status: Optional[str] = None
|
||||
self.set_model_approval_status(model_approval_status)
|
||||
self.created_by = {
|
||||
"UserProfileArn": fake_user_profile_arn,
|
||||
"UserProfileName": fake_user_profile_name,
|
||||
@ -1054,21 +1053,20 @@ class ModelPackage(BaseObject):
|
||||
}
|
||||
self.metadata_properties = metadata_properties
|
||||
self.model_metrics = model_metrics
|
||||
self.last_modified_time = datetime_now
|
||||
self.last_modified_time: Optional[datetime] = None
|
||||
self.approval_description = approval_description
|
||||
self.customer_metadata_properties = customer_metadata_properties
|
||||
self.drift_check_baselines = drift_check_baselines
|
||||
self.domain = domain
|
||||
self.task = task
|
||||
self.sample_payload_url = sample_payload_url
|
||||
self.additional_inference_specifications = additional_inference_specifications
|
||||
self.additional_inference_specifications: Optional[List[Any]] = None
|
||||
self.add_additional_inference_specifications(
|
||||
additional_inference_specifications
|
||||
)
|
||||
self.tags = tags
|
||||
self.model_package_status = "Completed"
|
||||
self.last_modified_by = {
|
||||
"UserProfileArn": fake_user_profile_arn,
|
||||
"UserProfileName": fake_user_profile_name,
|
||||
"DomainId": fake_domain_id,
|
||||
}
|
||||
self.last_modified_by: Optional[Dict[str, str]] = None
|
||||
self.client_token = client_token
|
||||
|
||||
def gen_response_object(self) -> Dict[str, Any]:
|
||||
@ -1083,10 +1081,290 @@ class ModelPackage(BaseObject):
|
||||
"ModelPackageArn",
|
||||
"ModelPackageDescription",
|
||||
"CreationTime",
|
||||
"InferenceSpecification",
|
||||
"SourceAlgorithmSpecification",
|
||||
"ValidationSpecification",
|
||||
"ModelPackageStatus",
|
||||
"ModelPackageStatusDetails",
|
||||
"CertifyForMarketplace",
|
||||
"ModelApprovalStatus",
|
||||
"CreatedBy",
|
||||
"MetadataProperties",
|
||||
"ModelMetrics",
|
||||
"LastModifiedTime",
|
||||
"LastModifiedBy",
|
||||
"ApprovalDescription",
|
||||
"CustomerMetadataProperties",
|
||||
"DriftCheckBaselines",
|
||||
"Domain",
|
||||
"Task",
|
||||
"SamplePayloadUrl",
|
||||
"AdditionalInferenceSpecifications",
|
||||
"SkipModelValidation",
|
||||
]
|
||||
return {k: v for k, v in response_object.items() if k in response_values}
|
||||
return {
|
||||
k: v
|
||||
for k, v in response_object.items()
|
||||
if k in response_values
|
||||
if v is not None
|
||||
}
|
||||
|
||||
def modifications_done(self) -> None:
|
||||
self.last_modified_time = datetime.now(tzutc())
|
||||
self.last_modified_by = self.created_by
|
||||
|
||||
def set_model_approval_status(self, model_approval_status: Optional[str]) -> None:
|
||||
if model_approval_status is not None:
|
||||
validate_model_approval_status(model_approval_status)
|
||||
self.model_approval_status = model_approval_status
|
||||
|
||||
def remove_customer_metadata_property(
|
||||
self, customer_metadata_properties_to_remove: List[str]
|
||||
) -> None:
|
||||
if customer_metadata_properties_to_remove is not None:
|
||||
for customer_metadata_property in customer_metadata_properties_to_remove:
|
||||
self.customer_metadata_properties.pop(customer_metadata_property, None)
|
||||
|
||||
def add_additional_inference_specifications(
|
||||
self, additional_inference_specifications_to_add: Optional[List[Any]]
|
||||
) -> None:
|
||||
self.validate_additional_inference_specifications(
|
||||
additional_inference_specifications_to_add
|
||||
)
|
||||
if (
|
||||
self.additional_inference_specifications is not None
|
||||
and additional_inference_specifications_to_add is not None
|
||||
):
|
||||
self.additional_inference_specifications.extend(
|
||||
additional_inference_specifications_to_add
|
||||
)
|
||||
else:
|
||||
self.additional_inference_specifications = (
|
||||
additional_inference_specifications_to_add
|
||||
)
|
||||
|
||||
def validate_additional_inference_specifications(
|
||||
self, additional_inference_specifications: Optional[List[Dict[str, Any]]]
|
||||
) -> None:
|
||||
specifications_to_validate = additional_inference_specifications or []
|
||||
for additional_inference_specification in specifications_to_validate:
|
||||
if "SupportedTransformInstanceTypes" in additional_inference_specification:
|
||||
self.validate_supported_transform_instance_types(
|
||||
additional_inference_specification[
|
||||
"SupportedTransformInstanceTypes"
|
||||
]
|
||||
)
|
||||
if (
|
||||
"SupportedRealtimeInferenceInstanceTypes"
|
||||
in additional_inference_specification
|
||||
):
|
||||
self.validate_supported_realtime_inference_instance_types(
|
||||
additional_inference_specification[
|
||||
"SupportedRealtimeInferenceInstanceTypes"
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_supported_transform_instance_types(instance_types: List[str]) -> None:
|
||||
VALID_TRANSFORM_INSTANCE_TYPES = [
|
||||
"ml.m4.xlarge",
|
||||
"ml.m4.2xlarge",
|
||||
"ml.m4.4xlarge",
|
||||
"ml.m4.10xlarge",
|
||||
"ml.m4.16xlarge",
|
||||
"ml.c4.xlarge",
|
||||
"ml.c4.2xlarge",
|
||||
"ml.c4.4xlarge",
|
||||
"ml.c4.8xlarge",
|
||||
"ml.p2.xlarge",
|
||||
"ml.p2.8xlarge",
|
||||
"ml.p2.16xlarge",
|
||||
"ml.p3.2xlarge",
|
||||
"ml.p3.8xlarge",
|
||||
"ml.p3.16xlarge",
|
||||
"ml.c5.xlarge",
|
||||
"ml.c5.2xlarge",
|
||||
"ml.c5.4xlarge",
|
||||
"ml.c5.9xlarge",
|
||||
"ml.c5.18xlarge",
|
||||
"ml.m5.large",
|
||||
"ml.m5.xlarge",
|
||||
"ml.m5.2xlarge",
|
||||
"ml.m5.4xlarge",
|
||||
"ml.m5.12xlarge",
|
||||
"ml.m5.24xlarge",
|
||||
"ml.g4dn.xlarge",
|
||||
"ml.g4dn.2xlarge",
|
||||
"ml.g4dn.4xlarge",
|
||||
"ml.g4dn.8xlarge",
|
||||
"ml.g4dn.12xlarge",
|
||||
"ml.g4dn.16xlarge",
|
||||
]
|
||||
for instance_type in instance_types:
|
||||
if not validators.is_one_of(instance_type, VALID_TRANSFORM_INSTANCE_TYPES):
|
||||
message = f"Value '{instance_type}' at 'SupportedTransformInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_TRANSFORM_INSTANCE_TYPES}"
|
||||
raise ValidationError(message=message)
|
||||
|
||||
@staticmethod
|
||||
def validate_supported_realtime_inference_instance_types(
|
||||
instance_types: List[str],
|
||||
) -> None:
|
||||
VALID_REALTIME_INFERENCE_INSTANCE_TYPES = [
|
||||
"ml.t2.medium",
|
||||
"ml.t2.large",
|
||||
"ml.t2.xlarge",
|
||||
"ml.t2.2xlarge",
|
||||
"ml.m4.xlarge",
|
||||
"ml.m4.2xlarge",
|
||||
"ml.m4.4xlarge",
|
||||
"ml.m4.10xlarge",
|
||||
"ml.m4.16xlarge",
|
||||
"ml.m5.large",
|
||||
"ml.m5.xlarge",
|
||||
"ml.m5.2xlarge",
|
||||
"ml.m5.4xlarge",
|
||||
"ml.m5.12xlarge",
|
||||
"ml.m5.24xlarge",
|
||||
"ml.m5d.large",
|
||||
"ml.m5d.xlarge",
|
||||
"ml.m5d.2xlarge",
|
||||
"ml.m5d.4xlarge",
|
||||
"ml.m5d.12xlarge",
|
||||
"ml.m5d.24xlarge",
|
||||
"ml.c4.large",
|
||||
"ml.c4.xlarge",
|
||||
"ml.c4.2xlarge",
|
||||
"ml.c4.4xlarge",
|
||||
"ml.c4.8xlarge",
|
||||
"ml.p2.xlarge",
|
||||
"ml.p2.8xlarge",
|
||||
"ml.p2.16xlarge",
|
||||
"ml.p3.2xlarge",
|
||||
"ml.p3.8xlarge",
|
||||
"ml.p3.16xlarge",
|
||||
"ml.c5.large",
|
||||
"ml.c5.xlarge",
|
||||
"ml.c5.2xlarge",
|
||||
"ml.c5.4xlarge",
|
||||
"ml.c5.9xlarge",
|
||||
"ml.c5.18xlarge",
|
||||
"ml.c5d.large",
|
||||
"ml.c5d.xlarge",
|
||||
"ml.c5d.2xlarge",
|
||||
"ml.c5d.4xlarge",
|
||||
"ml.c5d.9xlarge",
|
||||
"ml.c5d.18xlarge",
|
||||
"ml.g4dn.xlarge",
|
||||
"ml.g4dn.2xlarge",
|
||||
"ml.g4dn.4xlarge",
|
||||
"ml.g4dn.8xlarge",
|
||||
"ml.g4dn.12xlarge",
|
||||
"ml.g4dn.16xlarge",
|
||||
"ml.r5.large",
|
||||
"ml.r5.xlarge",
|
||||
"ml.r5.2xlarge",
|
||||
"ml.r5.4xlarge",
|
||||
"ml.r5.12xlarge",
|
||||
"ml.r5.24xlarge",
|
||||
"ml.r5d.large",
|
||||
"ml.r5d.xlarge",
|
||||
"ml.r5d.2xlarge",
|
||||
"ml.r5d.4xlarge",
|
||||
"ml.r5d.12xlarge",
|
||||
"ml.r5d.24xlarge",
|
||||
"ml.inf1.xlarge",
|
||||
"ml.inf1.2xlarge",
|
||||
"ml.inf1.6xlarge",
|
||||
"ml.inf1.24xlarge",
|
||||
"ml.c6i.large",
|
||||
"ml.c6i.xlarge",
|
||||
"ml.c6i.2xlarge",
|
||||
"ml.c6i.4xlarge",
|
||||
"ml.c6i.8xlarge",
|
||||
"ml.c6i.12xlarge",
|
||||
"ml.c6i.16xlarge",
|
||||
"ml.c6i.24xlarge",
|
||||
"ml.c6i.32xlarge",
|
||||
"ml.g5.xlarge",
|
||||
"ml.g5.2xlarge",
|
||||
"ml.g5.4xlarge",
|
||||
"ml.g5.8xlarge",
|
||||
"ml.g5.12xlarge",
|
||||
"ml.g5.16xlarge",
|
||||
"ml.g5.24xlarge",
|
||||
"ml.g5.48xlarge",
|
||||
"ml.p4d.24xlarge",
|
||||
"ml.c7g.large",
|
||||
"ml.c7g.xlarge",
|
||||
"ml.c7g.2xlarge",
|
||||
"ml.c7g.4xlarge",
|
||||
"ml.c7g.8xlarge",
|
||||
"ml.c7g.12xlarge",
|
||||
"ml.c7g.16xlarge",
|
||||
"ml.m6g.large",
|
||||
"ml.m6g.xlarge",
|
||||
"ml.m6g.2xlarge",
|
||||
"ml.m6g.4xlarge",
|
||||
"ml.m6g.8xlarge",
|
||||
"ml.m6g.12xlarge",
|
||||
"ml.m6g.16xlarge",
|
||||
"ml.m6gd.large",
|
||||
"ml.m6gd.xlarge",
|
||||
"ml.m6gd.2xlarge",
|
||||
"ml.m6gd.4xlarge",
|
||||
"ml.m6gd.8xlarge",
|
||||
"ml.m6gd.12xlarge",
|
||||
"ml.m6gd.16xlarge",
|
||||
"ml.c6g.large",
|
||||
"ml.c6g.xlarge",
|
||||
"ml.c6g.2xlarge",
|
||||
"ml.c6g.4xlarge",
|
||||
"ml.c6g.8xlarge",
|
||||
"ml.c6g.12xlarge",
|
||||
"ml.c6g.16xlarge",
|
||||
"ml.c6gd.large",
|
||||
"ml.c6gd.xlarge",
|
||||
"ml.c6gd.2xlarge",
|
||||
"ml.c6gd.4xlarge",
|
||||
"ml.c6gd.8xlarge",
|
||||
"ml.c6gd.12xlarge",
|
||||
"ml.c6gd.16xlarge",
|
||||
"ml.c6gn.large",
|
||||
"ml.c6gn.xlarge",
|
||||
"ml.c6gn.2xlarge",
|
||||
"ml.c6gn.4xlarge",
|
||||
"ml.c6gn.8xlarge",
|
||||
"ml.c6gn.12xlarge",
|
||||
"ml.c6gn.16xlarge",
|
||||
"ml.r6g.large",
|
||||
"ml.r6g.xlarge",
|
||||
"ml.r6g.2xlarge",
|
||||
"ml.r6g.4xlarge",
|
||||
"ml.r6g.8xlarge",
|
||||
"ml.r6g.12xlarge",
|
||||
"ml.r6g.16xlarge",
|
||||
"ml.r6gd.large",
|
||||
"ml.r6gd.xlarge",
|
||||
"ml.r6gd.2xlarge",
|
||||
"ml.r6gd.4xlarge",
|
||||
"ml.r6gd.8xlarge",
|
||||
"ml.r6gd.12xlarge",
|
||||
"ml.r6gd.16xlarge",
|
||||
"ml.p4de.24xlarge",
|
||||
"ml.trn1.2xlarge",
|
||||
"ml.trn1.32xlarge",
|
||||
"ml.inf2.xlarge",
|
||||
"ml.inf2.8xlarge",
|
||||
"ml.inf2.24xlarge",
|
||||
"ml.inf2.48xlarge",
|
||||
"ml.p5.48xlarge",
|
||||
]
|
||||
for instance_type in instance_types:
|
||||
if not validators.is_one_of(
|
||||
instance_type, VALID_REALTIME_INFERENCE_INSTANCE_TYPES
|
||||
):
|
||||
message = f"Value '{instance_type}' at 'SupportedRealtimeInferenceInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_REALTIME_INFERENCE_INSTANCE_TYPES}"
|
||||
raise ValidationError(message=message)
|
||||
|
||||
|
||||
class VpcConfig(BaseObject):
|
||||
@ -3054,6 +3332,36 @@ class SageMakerModelBackend(BaseBackend):
|
||||
raise ValidationError(f"Model package {model_package_name} not found")
|
||||
return model_package
|
||||
|
||||
def update_model_package(
|
||||
self,
|
||||
model_package_arn: str,
|
||||
model_approval_status: Optional[str],
|
||||
approval_description: Optional[str],
|
||||
customer_metadata_properties: Optional[Dict[str, str]],
|
||||
customer_metadata_properties_to_remove: List[str],
|
||||
additional_inference_specifications_to_add: Optional[List[Any]],
|
||||
) -> str:
|
||||
model_package_name_mapped = self.model_package_name_mapping.get(
|
||||
model_package_arn, model_package_arn
|
||||
)
|
||||
model_package = self.model_packages.get(model_package_name_mapped)
|
||||
|
||||
if model_package is None:
|
||||
raise ValidationError(f"Model package {model_package_arn} not found")
|
||||
|
||||
model_package.set_model_approval_status(model_approval_status)
|
||||
model_package.approval_description = approval_description
|
||||
model_package.customer_metadata_properties = customer_metadata_properties
|
||||
model_package.remove_customer_metadata_property(
|
||||
customer_metadata_properties_to_remove
|
||||
)
|
||||
model_package.add_additional_inference_specifications(
|
||||
additional_inference_specifications_to_add
|
||||
)
|
||||
model_package.modifications_done()
|
||||
|
||||
return model_package.model_package_arn
|
||||
|
||||
def create_model_package(
|
||||
self,
|
||||
model_package_name: str,
|
||||
@ -3062,9 +3370,9 @@ class SageMakerModelBackend(BaseBackend):
|
||||
inference_specification: Any,
|
||||
validation_specification: Any,
|
||||
source_algorithm_specification: Any,
|
||||
certify_for_marketplace: Any,
|
||||
certify_for_marketplace: bool,
|
||||
tags: Any,
|
||||
model_approval_status: str,
|
||||
model_approval_status: Optional[str],
|
||||
metadata_properties: Any,
|
||||
model_metrics: Any,
|
||||
client_token: Any,
|
||||
@ -3102,7 +3410,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
sample_payload_url=sample_payload_url,
|
||||
additional_inference_specifications=additional_inference_specifications,
|
||||
model_package_version=model_package_version,
|
||||
approval_description=model_approval_status,
|
||||
approval_description=None,
|
||||
region_name=self.region_name,
|
||||
account_id=self.account_id,
|
||||
client_token=client_token,
|
||||
|
@ -874,6 +874,27 @@ class SageMakerResponse(BaseResponse):
|
||||
model_package.gen_response_object(),
|
||||
)
|
||||
|
||||
def update_model_package(self) -> str:
|
||||
model_package_arn = self._get_param("ModelPackageArn")
|
||||
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||
approval_dexcription = self._get_param("ApprovalDescription")
|
||||
customer_metadata_properties = self._get_param("CustomerMetadataProperties")
|
||||
customer_metadata_properties_to_remove = self._get_param(
|
||||
"CustomerMetadataPropertiesToRemove", []
|
||||
)
|
||||
additional_inference_specification_to_add = self._get_param(
|
||||
"AdditionalInferenceSpecificationsToAdd"
|
||||
)
|
||||
updated_model_package_arn = self.sagemaker_backend.update_model_package(
|
||||
model_package_arn=model_package_arn,
|
||||
model_approval_status=model_approval_status,
|
||||
approval_description=approval_dexcription,
|
||||
customer_metadata_properties=customer_metadata_properties,
|
||||
customer_metadata_properties_to_remove=customer_metadata_properties_to_remove,
|
||||
additional_inference_specifications_to_add=additional_inference_specification_to_add,
|
||||
)
|
||||
return json.dumps(dict(ModelPackageArn=updated_model_package_arn))
|
||||
|
||||
def create_model_package(self) -> str:
|
||||
model_package_name = self._get_param("ModelPackageName")
|
||||
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||
@ -881,7 +902,7 @@ class SageMakerResponse(BaseResponse):
|
||||
inference_specification = self._get_param("InferenceSpecification")
|
||||
validation_specification = self._get_param("ValidationSpecification")
|
||||
source_algorithm_specification = self._get_param("SourceAlgorithmSpecification")
|
||||
certify_for_marketplace = self._get_param("CertifyForMarketplace")
|
||||
certify_for_marketplace = self._get_param("CertifyForMarketplace", False)
|
||||
tags = self._get_param("Tags")
|
||||
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||
metadata_properties = self._get_param("MetadataProperties")
|
||||
@ -893,7 +914,7 @@ class SageMakerResponse(BaseResponse):
|
||||
task = self._get_param("Task")
|
||||
sample_payload_url = self._get_param("SamplePayloadUrl")
|
||||
additional_inference_specifications = self._get_param(
|
||||
"AdditionalInferenceSpecifications"
|
||||
"AdditionalInferenceSpecifications", None
|
||||
)
|
||||
model_package_arn = self.sagemaker_backend.create_model_package(
|
||||
model_package_name=model_package_name,
|
||||
|
@ -5,7 +5,6 @@ import json
|
||||
from typing import Any, Dict
|
||||
from .exceptions import ValidationError
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .models import FakePipeline, FakePipelineExecution
|
||||
|
||||
@ -51,3 +50,15 @@ def load_pipeline_definition_from_s3(
|
||||
|
||||
def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
|
||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||
|
||||
|
||||
def validate_model_approval_status(model_approval_status: typing.Optional[str]) -> None:
|
||||
if model_approval_status is not None and model_approval_status not in [
|
||||
"Approved",
|
||||
"Rejected",
|
||||
"PendingManualApproval",
|
||||
]:
|
||||
raise ValidationError(
|
||||
f"Value '{model_approval_status}' at 'modelApprovalStatus' failed to satisfy constraint: "
|
||||
"Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
|
||||
)
|
||||
|
@ -1,13 +1,20 @@
|
||||
"""Unit tests for sagemaker-supported APIs."""
|
||||
from unittest import SkipTest
|
||||
from datetime import datetime
|
||||
from unittest import SkipTest, TestCase
|
||||
|
||||
import boto3
|
||||
from freezegun import freeze_time
|
||||
from dateutil.tz import tzutc # type: ignore
|
||||
|
||||
from moto import mock_sagemaker, settings
|
||||
|
||||
import pytest
|
||||
|
||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||
from moto.sagemaker.exceptions import ValidationError
|
||||
from moto.sagemaker.models import ModelPackage
|
||||
from moto.sagemaker.utils import validate_model_approval_status
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -210,15 +217,194 @@ def test_list_model_packages_sort_order():
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model_package():
|
||||
def test_describe_model_package_default():
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest("Can't freeze time in ServerMode")
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
with freeze_time("2015-01-01 00:00:00"):
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp["ModelPackageName"] == "test-model-package"
|
||||
assert resp["ModelPackageGroupName"] == "test-model-package-group"
|
||||
assert resp["ModelPackageDescription"] == "test-model-package-description"
|
||||
assert (
|
||||
resp["ModelPackageArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
|
||||
)
|
||||
assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc())
|
||||
assert (
|
||||
resp["CreatedBy"]["UserProfileArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
|
||||
)
|
||||
assert resp["CreatedBy"]["UserProfileName"] == "fake-user-profile-name"
|
||||
assert resp["CreatedBy"]["DomainId"] == "fake-domain-id"
|
||||
assert resp["ModelPackageStatus"] == "Completed"
|
||||
assert resp.get("ModelPackageStatusDetails") is not None
|
||||
assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [
|
||||
{
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||
"Status": "Completed",
|
||||
}
|
||||
]
|
||||
assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [
|
||||
{
|
||||
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||
"Status": "Completed",
|
||||
}
|
||||
]
|
||||
assert resp["CertifyForMarketplace"] is False
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model_package_with_create_model_package_arguments():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
ModelApprovalStatus="PendingManualApproval",
|
||||
MetadataProperties={
|
||||
"CommitId": "test-commit-id",
|
||||
"GeneratedBy": "test-user",
|
||||
"ProjectId": "test-project-id",
|
||||
"Repository": "test-repo",
|
||||
},
|
||||
CertifyForMarketplace=True,
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp["ModelApprovalStatus"] == "PendingManualApproval"
|
||||
assert resp.get("ApprovalDescription") is None
|
||||
assert resp["CertifyForMarketplace"] is True
|
||||
assert resp["MetadataProperties"] is not None
|
||||
assert resp["MetadataProperties"]["CommitId"] == "test-commit-id"
|
||||
assert resp["MetadataProperties"]["GeneratedBy"] == "test-user"
|
||||
assert resp["MetadataProperties"]["ProjectId"] == "test-project-id"
|
||||
assert resp["MetadataProperties"]["Repository"] == "test-repo"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_model_package():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
CustomerMetadataProperties={
|
||||
"test-key-to-remove": "test-value-to-remove",
|
||||
},
|
||||
)["ModelPackageArn"]
|
||||
client.update_model_package(
|
||||
ModelPackageArn=model_package_arn,
|
||||
ModelApprovalStatus="Approved",
|
||||
ApprovalDescription="test-approval-description",
|
||||
CustomerMetadataProperties={"test-key": "test-value"},
|
||||
CustomerMetadataPropertiesToRemove=["test-key-to-remove"],
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp["ModelApprovalStatus"] == "Approved"
|
||||
assert resp["ApprovalDescription"] == "test-approval-description"
|
||||
assert resp["CustomerMetadataProperties"] is not None
|
||||
assert resp["CustomerMetadataProperties"]["test-key"] == "test-value"
|
||||
assert resp["CustomerMetadataProperties"].get("test-key-to-remove") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_model_package_given_additional_inference_specifications_to_add():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
ModelPackageDescription="test-model-package-description",
|
||||
)["ModelPackageArn"]
|
||||
additional_inference_specifications_to_add = {
|
||||
"Name": "test-inference-specification-name",
|
||||
"Description": "test-inference-specification-description",
|
||||
"Containers": [
|
||||
{
|
||||
"ContainerHostname": "test-container-hostname",
|
||||
"Image": "test-image",
|
||||
"ImageDigest": "test-image-digest",
|
||||
"ModelDataUrl": "test-model-data-url",
|
||||
"ProductId": "test-product-id",
|
||||
"Environment": {"test-key": "test-value"},
|
||||
"ModelInput": {"DataInputConfig": "test-data-input-config"},
|
||||
"Framework": "test-framework",
|
||||
"FrameworkVersion": "test-framework-version",
|
||||
"NearestModelName": "test-nearest-model-name",
|
||||
},
|
||||
],
|
||||
"SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
|
||||
"SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.t2.large"],
|
||||
"SupportedContentTypes": [
|
||||
"test-content-type-1",
|
||||
],
|
||||
"SupportedResponseMIMETypes": [
|
||||
"test-response-mime-type-1",
|
||||
],
|
||||
}
|
||||
client.update_model_package(
|
||||
ModelPackageArn=model_package_arn,
|
||||
AdditionalInferenceSpecificationsToAdd=[
|
||||
additional_inference_specifications_to_add
|
||||
],
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp["AdditionalInferenceSpecifications"] is not None
|
||||
TestCase().assertDictEqual(
|
||||
resp["AdditionalInferenceSpecifications"][0],
|
||||
additional_inference_specifications_to_add,
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_model_package_shoudl_update_last_modified_information():
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest("Can't freeze time in ServerMode")
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||
model_package_arn = client.create_model_package(
|
||||
ModelPackageName="test-model-package",
|
||||
ModelPackageGroupName="test-model-package-group",
|
||||
)["ModelPackageArn"]
|
||||
with freeze_time("2020-01-01 12:00:00"):
|
||||
client.update_model_package(
|
||||
ModelPackageArn=model_package_arn,
|
||||
ModelApprovalStatus="Approved",
|
||||
)
|
||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||
assert resp.get("LastModifiedTime") is not None
|
||||
assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
|
||||
assert resp.get("LastModifiedBy") is not None
|
||||
assert (
|
||||
resp["LastModifiedBy"]["UserProfileArn"]
|
||||
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
|
||||
)
|
||||
assert resp["LastModifiedBy"]["UserProfileName"] == "fake-user-profile-name"
|
||||
assert resp["LastModifiedBy"]["DomainId"] == "fake-domain-id"
|
||||
|
||||
|
||||
def test_validate_supported_transform_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ModelPackage.validate_supported_transform_instance_types(
|
||||
["ml.m4.2xlarge", "not-a-supported-transform-instances-types"]
|
||||
)
|
||||
assert "not-a-supported-transform-instances-types" in str(exc.value)
|
||||
|
||||
|
||||
def test_validate_supported_realtime_inference_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ModelPackage.validate_supported_realtime_inference_instance_types(
|
||||
["ml.m4.2xlarge", "not-a-supported-realtime-inference-instances-types"]
|
||||
)
|
||||
assert "not-a-supported-realtime-inference-instances-types" in str(exc.value)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -234,6 +420,27 @@ def test_create_model_package():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_approval_status",
|
||||
["Approved", "Rejected", "PendingManualApproval"],
|
||||
)
|
||||
def test_utils_validate_model_approval_status_should_not_raise_error_if_model_approval_status_is_correct(
|
||||
model_approval_status: str,
|
||||
):
|
||||
validate_model_approval_status(model_approval_status)
|
||||
|
||||
|
||||
def test_utils_validate_model_approval_status_should_raise_error_if_model_approval_status_is_incorrect():
|
||||
model_approval_status = "IncorrectStatus"
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
validate_model_approval_status(model_approval_status)
|
||||
assert exc.value.code == 400
|
||||
assert (
|
||||
exc.value.message
|
||||
== "Value 'IncorrectStatus' at 'modelApprovalStatus' failed to satisfy constraint: Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model_package_in_model_package_group():
|
||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||
|
Loading…
Reference in New Issue
Block a user