Sagemaker: Add update_model_package (#6891)
This commit is contained in:
parent
e8c92a7e65
commit
34bc540f70
@ -6491,7 +6491,7 @@
|
|||||||
- [ ] update_image_version
|
- [ ] update_image_version
|
||||||
- [ ] update_inference_experiment
|
- [ ] update_inference_experiment
|
||||||
- [ ] update_model_card
|
- [ ] update_model_card
|
||||||
- [ ] update_model_package
|
- [X] update_model_package
|
||||||
- [ ] update_monitoring_alert
|
- [ ] update_monitoring_alert
|
||||||
- [ ] update_monitoring_schedule
|
- [ ] update_monitoring_schedule
|
||||||
- [ ] update_notebook_instance
|
- [ ] update_notebook_instance
|
||||||
|
@ -19,10 +19,10 @@ from .utils import (
|
|||||||
get_pipeline_from_name,
|
get_pipeline_from_name,
|
||||||
get_pipeline_execution_from_arn,
|
get_pipeline_execution_from_arn,
|
||||||
get_pipeline_name_from_execution_arn,
|
get_pipeline_name_from_execution_arn,
|
||||||
|
validate_model_approval_status,
|
||||||
)
|
)
|
||||||
from .utils import load_pipeline_definition_from_s3, arn_formatter
|
from .utils import load_pipeline_definition_from_s3, arn_formatter
|
||||||
|
|
||||||
|
|
||||||
PAGINATION_MODEL = {
|
PAGINATION_MODEL = {
|
||||||
"list_experiments": {
|
"list_experiments": {
|
||||||
"input_token": "NextToken",
|
"input_token": "NextToken",
|
||||||
@ -988,10 +988,10 @@ class ModelPackage(BaseObject):
|
|||||||
source_algorithm_specification: Any,
|
source_algorithm_specification: Any,
|
||||||
validation_specification: Any,
|
validation_specification: Any,
|
||||||
certify_for_marketplace: bool,
|
certify_for_marketplace: bool,
|
||||||
model_approval_status: str,
|
model_approval_status: Optional[str],
|
||||||
metadata_properties: Any,
|
metadata_properties: Any,
|
||||||
model_metrics: Any,
|
model_metrics: Any,
|
||||||
approval_description: str,
|
approval_description: Optional[str],
|
||||||
customer_metadata_properties: Any,
|
customer_metadata_properties: Any,
|
||||||
drift_check_baselines: Any,
|
drift_check_baselines: Any,
|
||||||
domain: str,
|
domain: str,
|
||||||
@ -1029,24 +1029,23 @@ class ModelPackage(BaseObject):
|
|||||||
self.inference_specification = inference_specification
|
self.inference_specification = inference_specification
|
||||||
self.source_algorithm_specification = source_algorithm_specification
|
self.source_algorithm_specification = source_algorithm_specification
|
||||||
self.validation_specification = validation_specification
|
self.validation_specification = validation_specification
|
||||||
self.model_package_status_details = (
|
self.model_package_status_details = {
|
||||||
{
|
"ValidationStatuses": [
|
||||||
"ValidationStatuses": [
|
{
|
||||||
{
|
"Name": model_package_arn,
|
||||||
"Name": model_package_arn,
|
"Status": "Completed",
|
||||||
"Status": "Completed",
|
}
|
||||||
}
|
],
|
||||||
],
|
"ImageScanStatuses": [
|
||||||
"ImageScanStatuses": [
|
{
|
||||||
{
|
"Name": model_package_arn,
|
||||||
"Name": model_package_arn,
|
"Status": "Completed",
|
||||||
"Status": "Completed",
|
}
|
||||||
}
|
],
|
||||||
],
|
}
|
||||||
},
|
|
||||||
)
|
|
||||||
self.certify_for_marketplace = certify_for_marketplace
|
self.certify_for_marketplace = certify_for_marketplace
|
||||||
self.model_approval_status = model_approval_status
|
self.model_approval_status: Optional[str] = None
|
||||||
|
self.set_model_approval_status(model_approval_status)
|
||||||
self.created_by = {
|
self.created_by = {
|
||||||
"UserProfileArn": fake_user_profile_arn,
|
"UserProfileArn": fake_user_profile_arn,
|
||||||
"UserProfileName": fake_user_profile_name,
|
"UserProfileName": fake_user_profile_name,
|
||||||
@ -1054,21 +1053,20 @@ class ModelPackage(BaseObject):
|
|||||||
}
|
}
|
||||||
self.metadata_properties = metadata_properties
|
self.metadata_properties = metadata_properties
|
||||||
self.model_metrics = model_metrics
|
self.model_metrics = model_metrics
|
||||||
self.last_modified_time = datetime_now
|
self.last_modified_time: Optional[datetime] = None
|
||||||
self.approval_description = approval_description
|
self.approval_description = approval_description
|
||||||
self.customer_metadata_properties = customer_metadata_properties
|
self.customer_metadata_properties = customer_metadata_properties
|
||||||
self.drift_check_baselines = drift_check_baselines
|
self.drift_check_baselines = drift_check_baselines
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.task = task
|
self.task = task
|
||||||
self.sample_payload_url = sample_payload_url
|
self.sample_payload_url = sample_payload_url
|
||||||
self.additional_inference_specifications = additional_inference_specifications
|
self.additional_inference_specifications: Optional[List[Any]] = None
|
||||||
|
self.add_additional_inference_specifications(
|
||||||
|
additional_inference_specifications
|
||||||
|
)
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.model_package_status = "Completed"
|
self.model_package_status = "Completed"
|
||||||
self.last_modified_by = {
|
self.last_modified_by: Optional[Dict[str, str]] = None
|
||||||
"UserProfileArn": fake_user_profile_arn,
|
|
||||||
"UserProfileName": fake_user_profile_name,
|
|
||||||
"DomainId": fake_domain_id,
|
|
||||||
}
|
|
||||||
self.client_token = client_token
|
self.client_token = client_token
|
||||||
|
|
||||||
def gen_response_object(self) -> Dict[str, Any]:
|
def gen_response_object(self) -> Dict[str, Any]:
|
||||||
@ -1083,10 +1081,290 @@ class ModelPackage(BaseObject):
|
|||||||
"ModelPackageArn",
|
"ModelPackageArn",
|
||||||
"ModelPackageDescription",
|
"ModelPackageDescription",
|
||||||
"CreationTime",
|
"CreationTime",
|
||||||
|
"InferenceSpecification",
|
||||||
|
"SourceAlgorithmSpecification",
|
||||||
|
"ValidationSpecification",
|
||||||
"ModelPackageStatus",
|
"ModelPackageStatus",
|
||||||
|
"ModelPackageStatusDetails",
|
||||||
|
"CertifyForMarketplace",
|
||||||
"ModelApprovalStatus",
|
"ModelApprovalStatus",
|
||||||
|
"CreatedBy",
|
||||||
|
"MetadataProperties",
|
||||||
|
"ModelMetrics",
|
||||||
|
"LastModifiedTime",
|
||||||
|
"LastModifiedBy",
|
||||||
|
"ApprovalDescription",
|
||||||
|
"CustomerMetadataProperties",
|
||||||
|
"DriftCheckBaselines",
|
||||||
|
"Domain",
|
||||||
|
"Task",
|
||||||
|
"SamplePayloadUrl",
|
||||||
|
"AdditionalInferenceSpecifications",
|
||||||
|
"SkipModelValidation",
|
||||||
]
|
]
|
||||||
return {k: v for k, v in response_object.items() if k in response_values}
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in response_object.items()
|
||||||
|
if k in response_values
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def modifications_done(self) -> None:
|
||||||
|
self.last_modified_time = datetime.now(tzutc())
|
||||||
|
self.last_modified_by = self.created_by
|
||||||
|
|
||||||
|
def set_model_approval_status(self, model_approval_status: Optional[str]) -> None:
|
||||||
|
if model_approval_status is not None:
|
||||||
|
validate_model_approval_status(model_approval_status)
|
||||||
|
self.model_approval_status = model_approval_status
|
||||||
|
|
||||||
|
def remove_customer_metadata_property(
|
||||||
|
self, customer_metadata_properties_to_remove: List[str]
|
||||||
|
) -> None:
|
||||||
|
if customer_metadata_properties_to_remove is not None:
|
||||||
|
for customer_metadata_property in customer_metadata_properties_to_remove:
|
||||||
|
self.customer_metadata_properties.pop(customer_metadata_property, None)
|
||||||
|
|
||||||
|
def add_additional_inference_specifications(
|
||||||
|
self, additional_inference_specifications_to_add: Optional[List[Any]]
|
||||||
|
) -> None:
|
||||||
|
self.validate_additional_inference_specifications(
|
||||||
|
additional_inference_specifications_to_add
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.additional_inference_specifications is not None
|
||||||
|
and additional_inference_specifications_to_add is not None
|
||||||
|
):
|
||||||
|
self.additional_inference_specifications.extend(
|
||||||
|
additional_inference_specifications_to_add
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.additional_inference_specifications = (
|
||||||
|
additional_inference_specifications_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_additional_inference_specifications(
|
||||||
|
self, additional_inference_specifications: Optional[List[Dict[str, Any]]]
|
||||||
|
) -> None:
|
||||||
|
specifications_to_validate = additional_inference_specifications or []
|
||||||
|
for additional_inference_specification in specifications_to_validate:
|
||||||
|
if "SupportedTransformInstanceTypes" in additional_inference_specification:
|
||||||
|
self.validate_supported_transform_instance_types(
|
||||||
|
additional_inference_specification[
|
||||||
|
"SupportedTransformInstanceTypes"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"SupportedRealtimeInferenceInstanceTypes"
|
||||||
|
in additional_inference_specification
|
||||||
|
):
|
||||||
|
self.validate_supported_realtime_inference_instance_types(
|
||||||
|
additional_inference_specification[
|
||||||
|
"SupportedRealtimeInferenceInstanceTypes"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_supported_transform_instance_types(instance_types: List[str]) -> None:
|
||||||
|
VALID_TRANSFORM_INSTANCE_TYPES = [
|
||||||
|
"ml.m4.xlarge",
|
||||||
|
"ml.m4.2xlarge",
|
||||||
|
"ml.m4.4xlarge",
|
||||||
|
"ml.m4.10xlarge",
|
||||||
|
"ml.m4.16xlarge",
|
||||||
|
"ml.c4.xlarge",
|
||||||
|
"ml.c4.2xlarge",
|
||||||
|
"ml.c4.4xlarge",
|
||||||
|
"ml.c4.8xlarge",
|
||||||
|
"ml.p2.xlarge",
|
||||||
|
"ml.p2.8xlarge",
|
||||||
|
"ml.p2.16xlarge",
|
||||||
|
"ml.p3.2xlarge",
|
||||||
|
"ml.p3.8xlarge",
|
||||||
|
"ml.p3.16xlarge",
|
||||||
|
"ml.c5.xlarge",
|
||||||
|
"ml.c5.2xlarge",
|
||||||
|
"ml.c5.4xlarge",
|
||||||
|
"ml.c5.9xlarge",
|
||||||
|
"ml.c5.18xlarge",
|
||||||
|
"ml.m5.large",
|
||||||
|
"ml.m5.xlarge",
|
||||||
|
"ml.m5.2xlarge",
|
||||||
|
"ml.m5.4xlarge",
|
||||||
|
"ml.m5.12xlarge",
|
||||||
|
"ml.m5.24xlarge",
|
||||||
|
"ml.g4dn.xlarge",
|
||||||
|
"ml.g4dn.2xlarge",
|
||||||
|
"ml.g4dn.4xlarge",
|
||||||
|
"ml.g4dn.8xlarge",
|
||||||
|
"ml.g4dn.12xlarge",
|
||||||
|
"ml.g4dn.16xlarge",
|
||||||
|
]
|
||||||
|
for instance_type in instance_types:
|
||||||
|
if not validators.is_one_of(instance_type, VALID_TRANSFORM_INSTANCE_TYPES):
|
||||||
|
message = f"Value '{instance_type}' at 'SupportedTransformInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_TRANSFORM_INSTANCE_TYPES}"
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_supported_realtime_inference_instance_types(
|
||||||
|
instance_types: List[str],
|
||||||
|
) -> None:
|
||||||
|
VALID_REALTIME_INFERENCE_INSTANCE_TYPES = [
|
||||||
|
"ml.t2.medium",
|
||||||
|
"ml.t2.large",
|
||||||
|
"ml.t2.xlarge",
|
||||||
|
"ml.t2.2xlarge",
|
||||||
|
"ml.m4.xlarge",
|
||||||
|
"ml.m4.2xlarge",
|
||||||
|
"ml.m4.4xlarge",
|
||||||
|
"ml.m4.10xlarge",
|
||||||
|
"ml.m4.16xlarge",
|
||||||
|
"ml.m5.large",
|
||||||
|
"ml.m5.xlarge",
|
||||||
|
"ml.m5.2xlarge",
|
||||||
|
"ml.m5.4xlarge",
|
||||||
|
"ml.m5.12xlarge",
|
||||||
|
"ml.m5.24xlarge",
|
||||||
|
"ml.m5d.large",
|
||||||
|
"ml.m5d.xlarge",
|
||||||
|
"ml.m5d.2xlarge",
|
||||||
|
"ml.m5d.4xlarge",
|
||||||
|
"ml.m5d.12xlarge",
|
||||||
|
"ml.m5d.24xlarge",
|
||||||
|
"ml.c4.large",
|
||||||
|
"ml.c4.xlarge",
|
||||||
|
"ml.c4.2xlarge",
|
||||||
|
"ml.c4.4xlarge",
|
||||||
|
"ml.c4.8xlarge",
|
||||||
|
"ml.p2.xlarge",
|
||||||
|
"ml.p2.8xlarge",
|
||||||
|
"ml.p2.16xlarge",
|
||||||
|
"ml.p3.2xlarge",
|
||||||
|
"ml.p3.8xlarge",
|
||||||
|
"ml.p3.16xlarge",
|
||||||
|
"ml.c5.large",
|
||||||
|
"ml.c5.xlarge",
|
||||||
|
"ml.c5.2xlarge",
|
||||||
|
"ml.c5.4xlarge",
|
||||||
|
"ml.c5.9xlarge",
|
||||||
|
"ml.c5.18xlarge",
|
||||||
|
"ml.c5d.large",
|
||||||
|
"ml.c5d.xlarge",
|
||||||
|
"ml.c5d.2xlarge",
|
||||||
|
"ml.c5d.4xlarge",
|
||||||
|
"ml.c5d.9xlarge",
|
||||||
|
"ml.c5d.18xlarge",
|
||||||
|
"ml.g4dn.xlarge",
|
||||||
|
"ml.g4dn.2xlarge",
|
||||||
|
"ml.g4dn.4xlarge",
|
||||||
|
"ml.g4dn.8xlarge",
|
||||||
|
"ml.g4dn.12xlarge",
|
||||||
|
"ml.g4dn.16xlarge",
|
||||||
|
"ml.r5.large",
|
||||||
|
"ml.r5.xlarge",
|
||||||
|
"ml.r5.2xlarge",
|
||||||
|
"ml.r5.4xlarge",
|
||||||
|
"ml.r5.12xlarge",
|
||||||
|
"ml.r5.24xlarge",
|
||||||
|
"ml.r5d.large",
|
||||||
|
"ml.r5d.xlarge",
|
||||||
|
"ml.r5d.2xlarge",
|
||||||
|
"ml.r5d.4xlarge",
|
||||||
|
"ml.r5d.12xlarge",
|
||||||
|
"ml.r5d.24xlarge",
|
||||||
|
"ml.inf1.xlarge",
|
||||||
|
"ml.inf1.2xlarge",
|
||||||
|
"ml.inf1.6xlarge",
|
||||||
|
"ml.inf1.24xlarge",
|
||||||
|
"ml.c6i.large",
|
||||||
|
"ml.c6i.xlarge",
|
||||||
|
"ml.c6i.2xlarge",
|
||||||
|
"ml.c6i.4xlarge",
|
||||||
|
"ml.c6i.8xlarge",
|
||||||
|
"ml.c6i.12xlarge",
|
||||||
|
"ml.c6i.16xlarge",
|
||||||
|
"ml.c6i.24xlarge",
|
||||||
|
"ml.c6i.32xlarge",
|
||||||
|
"ml.g5.xlarge",
|
||||||
|
"ml.g5.2xlarge",
|
||||||
|
"ml.g5.4xlarge",
|
||||||
|
"ml.g5.8xlarge",
|
||||||
|
"ml.g5.12xlarge",
|
||||||
|
"ml.g5.16xlarge",
|
||||||
|
"ml.g5.24xlarge",
|
||||||
|
"ml.g5.48xlarge",
|
||||||
|
"ml.p4d.24xlarge",
|
||||||
|
"ml.c7g.large",
|
||||||
|
"ml.c7g.xlarge",
|
||||||
|
"ml.c7g.2xlarge",
|
||||||
|
"ml.c7g.4xlarge",
|
||||||
|
"ml.c7g.8xlarge",
|
||||||
|
"ml.c7g.12xlarge",
|
||||||
|
"ml.c7g.16xlarge",
|
||||||
|
"ml.m6g.large",
|
||||||
|
"ml.m6g.xlarge",
|
||||||
|
"ml.m6g.2xlarge",
|
||||||
|
"ml.m6g.4xlarge",
|
||||||
|
"ml.m6g.8xlarge",
|
||||||
|
"ml.m6g.12xlarge",
|
||||||
|
"ml.m6g.16xlarge",
|
||||||
|
"ml.m6gd.large",
|
||||||
|
"ml.m6gd.xlarge",
|
||||||
|
"ml.m6gd.2xlarge",
|
||||||
|
"ml.m6gd.4xlarge",
|
||||||
|
"ml.m6gd.8xlarge",
|
||||||
|
"ml.m6gd.12xlarge",
|
||||||
|
"ml.m6gd.16xlarge",
|
||||||
|
"ml.c6g.large",
|
||||||
|
"ml.c6g.xlarge",
|
||||||
|
"ml.c6g.2xlarge",
|
||||||
|
"ml.c6g.4xlarge",
|
||||||
|
"ml.c6g.8xlarge",
|
||||||
|
"ml.c6g.12xlarge",
|
||||||
|
"ml.c6g.16xlarge",
|
||||||
|
"ml.c6gd.large",
|
||||||
|
"ml.c6gd.xlarge",
|
||||||
|
"ml.c6gd.2xlarge",
|
||||||
|
"ml.c6gd.4xlarge",
|
||||||
|
"ml.c6gd.8xlarge",
|
||||||
|
"ml.c6gd.12xlarge",
|
||||||
|
"ml.c6gd.16xlarge",
|
||||||
|
"ml.c6gn.large",
|
||||||
|
"ml.c6gn.xlarge",
|
||||||
|
"ml.c6gn.2xlarge",
|
||||||
|
"ml.c6gn.4xlarge",
|
||||||
|
"ml.c6gn.8xlarge",
|
||||||
|
"ml.c6gn.12xlarge",
|
||||||
|
"ml.c6gn.16xlarge",
|
||||||
|
"ml.r6g.large",
|
||||||
|
"ml.r6g.xlarge",
|
||||||
|
"ml.r6g.2xlarge",
|
||||||
|
"ml.r6g.4xlarge",
|
||||||
|
"ml.r6g.8xlarge",
|
||||||
|
"ml.r6g.12xlarge",
|
||||||
|
"ml.r6g.16xlarge",
|
||||||
|
"ml.r6gd.large",
|
||||||
|
"ml.r6gd.xlarge",
|
||||||
|
"ml.r6gd.2xlarge",
|
||||||
|
"ml.r6gd.4xlarge",
|
||||||
|
"ml.r6gd.8xlarge",
|
||||||
|
"ml.r6gd.12xlarge",
|
||||||
|
"ml.r6gd.16xlarge",
|
||||||
|
"ml.p4de.24xlarge",
|
||||||
|
"ml.trn1.2xlarge",
|
||||||
|
"ml.trn1.32xlarge",
|
||||||
|
"ml.inf2.xlarge",
|
||||||
|
"ml.inf2.8xlarge",
|
||||||
|
"ml.inf2.24xlarge",
|
||||||
|
"ml.inf2.48xlarge",
|
||||||
|
"ml.p5.48xlarge",
|
||||||
|
]
|
||||||
|
for instance_type in instance_types:
|
||||||
|
if not validators.is_one_of(
|
||||||
|
instance_type, VALID_REALTIME_INFERENCE_INSTANCE_TYPES
|
||||||
|
):
|
||||||
|
message = f"Value '{instance_type}' at 'SupportedRealtimeInferenceInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_REALTIME_INFERENCE_INSTANCE_TYPES}"
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
|
||||||
class VpcConfig(BaseObject):
|
class VpcConfig(BaseObject):
|
||||||
@ -3054,6 +3332,36 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
raise ValidationError(f"Model package {model_package_name} not found")
|
raise ValidationError(f"Model package {model_package_name} not found")
|
||||||
return model_package
|
return model_package
|
||||||
|
|
||||||
|
def update_model_package(
|
||||||
|
self,
|
||||||
|
model_package_arn: str,
|
||||||
|
model_approval_status: Optional[str],
|
||||||
|
approval_description: Optional[str],
|
||||||
|
customer_metadata_properties: Optional[Dict[str, str]],
|
||||||
|
customer_metadata_properties_to_remove: List[str],
|
||||||
|
additional_inference_specifications_to_add: Optional[List[Any]],
|
||||||
|
) -> str:
|
||||||
|
model_package_name_mapped = self.model_package_name_mapping.get(
|
||||||
|
model_package_arn, model_package_arn
|
||||||
|
)
|
||||||
|
model_package = self.model_packages.get(model_package_name_mapped)
|
||||||
|
|
||||||
|
if model_package is None:
|
||||||
|
raise ValidationError(f"Model package {model_package_arn} not found")
|
||||||
|
|
||||||
|
model_package.set_model_approval_status(model_approval_status)
|
||||||
|
model_package.approval_description = approval_description
|
||||||
|
model_package.customer_metadata_properties = customer_metadata_properties
|
||||||
|
model_package.remove_customer_metadata_property(
|
||||||
|
customer_metadata_properties_to_remove
|
||||||
|
)
|
||||||
|
model_package.add_additional_inference_specifications(
|
||||||
|
additional_inference_specifications_to_add
|
||||||
|
)
|
||||||
|
model_package.modifications_done()
|
||||||
|
|
||||||
|
return model_package.model_package_arn
|
||||||
|
|
||||||
def create_model_package(
|
def create_model_package(
|
||||||
self,
|
self,
|
||||||
model_package_name: str,
|
model_package_name: str,
|
||||||
@ -3062,9 +3370,9 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
inference_specification: Any,
|
inference_specification: Any,
|
||||||
validation_specification: Any,
|
validation_specification: Any,
|
||||||
source_algorithm_specification: Any,
|
source_algorithm_specification: Any,
|
||||||
certify_for_marketplace: Any,
|
certify_for_marketplace: bool,
|
||||||
tags: Any,
|
tags: Any,
|
||||||
model_approval_status: str,
|
model_approval_status: Optional[str],
|
||||||
metadata_properties: Any,
|
metadata_properties: Any,
|
||||||
model_metrics: Any,
|
model_metrics: Any,
|
||||||
client_token: Any,
|
client_token: Any,
|
||||||
@ -3102,7 +3410,7 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
sample_payload_url=sample_payload_url,
|
sample_payload_url=sample_payload_url,
|
||||||
additional_inference_specifications=additional_inference_specifications,
|
additional_inference_specifications=additional_inference_specifications,
|
||||||
model_package_version=model_package_version,
|
model_package_version=model_package_version,
|
||||||
approval_description=model_approval_status,
|
approval_description=None,
|
||||||
region_name=self.region_name,
|
region_name=self.region_name,
|
||||||
account_id=self.account_id,
|
account_id=self.account_id,
|
||||||
client_token=client_token,
|
client_token=client_token,
|
||||||
|
@ -874,6 +874,27 @@ class SageMakerResponse(BaseResponse):
|
|||||||
model_package.gen_response_object(),
|
model_package.gen_response_object(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_model_package(self) -> str:
|
||||||
|
model_package_arn = self._get_param("ModelPackageArn")
|
||||||
|
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||||
|
approval_dexcription = self._get_param("ApprovalDescription")
|
||||||
|
customer_metadata_properties = self._get_param("CustomerMetadataProperties")
|
||||||
|
customer_metadata_properties_to_remove = self._get_param(
|
||||||
|
"CustomerMetadataPropertiesToRemove", []
|
||||||
|
)
|
||||||
|
additional_inference_specification_to_add = self._get_param(
|
||||||
|
"AdditionalInferenceSpecificationsToAdd"
|
||||||
|
)
|
||||||
|
updated_model_package_arn = self.sagemaker_backend.update_model_package(
|
||||||
|
model_package_arn=model_package_arn,
|
||||||
|
model_approval_status=model_approval_status,
|
||||||
|
approval_description=approval_dexcription,
|
||||||
|
customer_metadata_properties=customer_metadata_properties,
|
||||||
|
customer_metadata_properties_to_remove=customer_metadata_properties_to_remove,
|
||||||
|
additional_inference_specifications_to_add=additional_inference_specification_to_add,
|
||||||
|
)
|
||||||
|
return json.dumps(dict(ModelPackageArn=updated_model_package_arn))
|
||||||
|
|
||||||
def create_model_package(self) -> str:
|
def create_model_package(self) -> str:
|
||||||
model_package_name = self._get_param("ModelPackageName")
|
model_package_name = self._get_param("ModelPackageName")
|
||||||
model_package_group_name = self._get_param("ModelPackageGroupName")
|
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||||
@ -881,7 +902,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
inference_specification = self._get_param("InferenceSpecification")
|
inference_specification = self._get_param("InferenceSpecification")
|
||||||
validation_specification = self._get_param("ValidationSpecification")
|
validation_specification = self._get_param("ValidationSpecification")
|
||||||
source_algorithm_specification = self._get_param("SourceAlgorithmSpecification")
|
source_algorithm_specification = self._get_param("SourceAlgorithmSpecification")
|
||||||
certify_for_marketplace = self._get_param("CertifyForMarketplace")
|
certify_for_marketplace = self._get_param("CertifyForMarketplace", False)
|
||||||
tags = self._get_param("Tags")
|
tags = self._get_param("Tags")
|
||||||
model_approval_status = self._get_param("ModelApprovalStatus")
|
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||||
metadata_properties = self._get_param("MetadataProperties")
|
metadata_properties = self._get_param("MetadataProperties")
|
||||||
@ -893,7 +914,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
task = self._get_param("Task")
|
task = self._get_param("Task")
|
||||||
sample_payload_url = self._get_param("SamplePayloadUrl")
|
sample_payload_url = self._get_param("SamplePayloadUrl")
|
||||||
additional_inference_specifications = self._get_param(
|
additional_inference_specifications = self._get_param(
|
||||||
"AdditionalInferenceSpecifications"
|
"AdditionalInferenceSpecifications", None
|
||||||
)
|
)
|
||||||
model_package_arn = self.sagemaker_backend.create_model_package(
|
model_package_arn = self.sagemaker_backend.create_model_package(
|
||||||
model_package_name=model_package_name,
|
model_package_name=model_package_name,
|
||||||
|
@ -5,7 +5,6 @@ import json
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from .exceptions import ValidationError
|
from .exceptions import ValidationError
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from .models import FakePipeline, FakePipelineExecution
|
from .models import FakePipeline, FakePipelineExecution
|
||||||
|
|
||||||
@ -51,3 +50,15 @@ def load_pipeline_definition_from_s3(
|
|||||||
|
|
||||||
def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
|
def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
|
||||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_approval_status(model_approval_status: typing.Optional[str]) -> None:
|
||||||
|
if model_approval_status is not None and model_approval_status not in [
|
||||||
|
"Approved",
|
||||||
|
"Rejected",
|
||||||
|
"PendingManualApproval",
|
||||||
|
]:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Value '{model_approval_status}' at 'modelApprovalStatus' failed to satisfy constraint: "
|
||||||
|
"Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
|
||||||
|
)
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
"""Unit tests for sagemaker-supported APIs."""
|
"""Unit tests for sagemaker-supported APIs."""
|
||||||
from unittest import SkipTest
|
from datetime import datetime
|
||||||
|
from unittest import SkipTest, TestCase
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from dateutil.tz import tzutc # type: ignore
|
||||||
|
|
||||||
from moto import mock_sagemaker, settings
|
from moto import mock_sagemaker, settings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||||
|
from moto.sagemaker.exceptions import ValidationError
|
||||||
|
from moto.sagemaker.models import ModelPackage
|
||||||
|
from moto.sagemaker.utils import validate_model_approval_status
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
@ -210,15 +217,194 @@ def test_list_model_packages_sort_order():
|
|||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
def test_describe_model_package():
|
def test_describe_model_package_default():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
client.create_model_package(
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
ModelPackageName="test-model-package",
|
with freeze_time("2015-01-01 00:00:00"):
|
||||||
ModelPackageDescription="test-model-package-description",
|
client.create_model_package(
|
||||||
)
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
assert resp["ModelPackageName"] == "test-model-package"
|
assert resp["ModelPackageName"] == "test-model-package"
|
||||||
|
assert resp["ModelPackageGroupName"] == "test-model-package-group"
|
||||||
assert resp["ModelPackageDescription"] == "test-model-package-description"
|
assert resp["ModelPackageDescription"] == "test-model-package-description"
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageArn"]
|
||||||
|
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
|
||||||
|
)
|
||||||
|
assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc())
|
||||||
|
assert (
|
||||||
|
resp["CreatedBy"]["UserProfileArn"]
|
||||||
|
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
|
||||||
|
)
|
||||||
|
assert resp["CreatedBy"]["UserProfileName"] == "fake-user-profile-name"
|
||||||
|
assert resp["CreatedBy"]["DomainId"] == "fake-domain-id"
|
||||||
|
assert resp["ModelPackageStatus"] == "Completed"
|
||||||
|
assert resp.get("ModelPackageStatusDetails") is not None
|
||||||
|
assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [
|
||||||
|
{
|
||||||
|
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||||
|
"Status": "Completed",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [
|
||||||
|
{
|
||||||
|
"Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1",
|
||||||
|
"Status": "Completed",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert resp["CertifyForMarketplace"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_describe_model_package_with_create_model_package_arguments():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
ModelApprovalStatus="PendingManualApproval",
|
||||||
|
MetadataProperties={
|
||||||
|
"CommitId": "test-commit-id",
|
||||||
|
"GeneratedBy": "test-user",
|
||||||
|
"ProjectId": "test-project-id",
|
||||||
|
"Repository": "test-repo",
|
||||||
|
},
|
||||||
|
CertifyForMarketplace=True,
|
||||||
|
)
|
||||||
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
|
assert resp["ModelApprovalStatus"] == "PendingManualApproval"
|
||||||
|
assert resp.get("ApprovalDescription") is None
|
||||||
|
assert resp["CertifyForMarketplace"] is True
|
||||||
|
assert resp["MetadataProperties"] is not None
|
||||||
|
assert resp["MetadataProperties"]["CommitId"] == "test-commit-id"
|
||||||
|
assert resp["MetadataProperties"]["GeneratedBy"] == "test-user"
|
||||||
|
assert resp["MetadataProperties"]["ProjectId"] == "test-project-id"
|
||||||
|
assert resp["MetadataProperties"]["Repository"] == "test-repo"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_update_model_package():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
|
model_package_arn = client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
CustomerMetadataProperties={
|
||||||
|
"test-key-to-remove": "test-value-to-remove",
|
||||||
|
},
|
||||||
|
)["ModelPackageArn"]
|
||||||
|
client.update_model_package(
|
||||||
|
ModelPackageArn=model_package_arn,
|
||||||
|
ModelApprovalStatus="Approved",
|
||||||
|
ApprovalDescription="test-approval-description",
|
||||||
|
CustomerMetadataProperties={"test-key": "test-value"},
|
||||||
|
CustomerMetadataPropertiesToRemove=["test-key-to-remove"],
|
||||||
|
)
|
||||||
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
|
assert resp["ModelApprovalStatus"] == "Approved"
|
||||||
|
assert resp["ApprovalDescription"] == "test-approval-description"
|
||||||
|
assert resp["CustomerMetadataProperties"] is not None
|
||||||
|
assert resp["CustomerMetadataProperties"]["test-key"] == "test-value"
|
||||||
|
assert resp["CustomerMetadataProperties"].get("test-key-to-remove") is None
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_update_model_package_given_additional_inference_specifications_to_add():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
|
model_package_arn = client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)["ModelPackageArn"]
|
||||||
|
additional_inference_specifications_to_add = {
|
||||||
|
"Name": "test-inference-specification-name",
|
||||||
|
"Description": "test-inference-specification-description",
|
||||||
|
"Containers": [
|
||||||
|
{
|
||||||
|
"ContainerHostname": "test-container-hostname",
|
||||||
|
"Image": "test-image",
|
||||||
|
"ImageDigest": "test-image-digest",
|
||||||
|
"ModelDataUrl": "test-model-data-url",
|
||||||
|
"ProductId": "test-product-id",
|
||||||
|
"Environment": {"test-key": "test-value"},
|
||||||
|
"ModelInput": {"DataInputConfig": "test-data-input-config"},
|
||||||
|
"Framework": "test-framework",
|
||||||
|
"FrameworkVersion": "test-framework-version",
|
||||||
|
"NearestModelName": "test-nearest-model-name",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
|
||||||
|
"SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.t2.large"],
|
||||||
|
"SupportedContentTypes": [
|
||||||
|
"test-content-type-1",
|
||||||
|
],
|
||||||
|
"SupportedResponseMIMETypes": [
|
||||||
|
"test-response-mime-type-1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
client.update_model_package(
|
||||||
|
ModelPackageArn=model_package_arn,
|
||||||
|
AdditionalInferenceSpecificationsToAdd=[
|
||||||
|
additional_inference_specifications_to_add
|
||||||
|
],
|
||||||
|
)
|
||||||
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
|
assert resp["AdditionalInferenceSpecifications"] is not None
|
||||||
|
TestCase().assertDictEqual(
|
||||||
|
resp["AdditionalInferenceSpecifications"][0],
|
||||||
|
additional_inference_specifications_to_add,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_update_model_package_shoudl_update_last_modified_information():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
|
||||||
|
model_package_arn = client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
)["ModelPackageArn"]
|
||||||
|
with freeze_time("2020-01-01 12:00:00"):
|
||||||
|
client.update_model_package(
|
||||||
|
ModelPackageArn=model_package_arn,
|
||||||
|
ModelApprovalStatus="Approved",
|
||||||
|
)
|
||||||
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
|
assert resp.get("LastModifiedTime") is not None
|
||||||
|
assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
|
||||||
|
assert resp.get("LastModifiedBy") is not None
|
||||||
|
assert (
|
||||||
|
resp["LastModifiedBy"]["UserProfileArn"]
|
||||||
|
== "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name"
|
||||||
|
)
|
||||||
|
assert resp["LastModifiedBy"]["UserProfileName"] == "fake-user-profile-name"
|
||||||
|
assert resp["LastModifiedBy"]["DomainId"] == "fake-domain-id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_supported_transform_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ModelPackage.validate_supported_transform_instance_types(
|
||||||
|
["ml.m4.2xlarge", "not-a-supported-transform-instances-types"]
|
||||||
|
)
|
||||||
|
assert "not-a-supported-transform-instances-types" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_supported_realtime_inference_instance_types_should_raise_error_for_wrong_supported_transform_instance_types():
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ModelPackage.validate_supported_realtime_inference_instance_types(
|
||||||
|
["ml.m4.2xlarge", "not-a-supported-realtime-inference-instances-types"]
|
||||||
|
)
|
||||||
|
assert "not-a-supported-realtime-inference-instances-types" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
@ -234,6 +420,27 @@ def test_create_model_package():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_approval_status",
|
||||||
|
["Approved", "Rejected", "PendingManualApproval"],
|
||||||
|
)
|
||||||
|
def test_utils_validate_model_approval_status_should_not_raise_error_if_model_approval_status_is_correct(
|
||||||
|
model_approval_status: str,
|
||||||
|
):
|
||||||
|
validate_model_approval_status(model_approval_status)
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_validate_model_approval_status_should_raise_error_if_model_approval_status_is_incorrect():
|
||||||
|
model_approval_status = "IncorrectStatus"
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
validate_model_approval_status(model_approval_status)
|
||||||
|
assert exc.value.code == 400
|
||||||
|
assert (
|
||||||
|
exc.value.message
|
||||||
|
== "Value 'IncorrectStatus' at 'modelApprovalStatus' failed to satisfy constraint: Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
def test_create_model_package_in_model_package_group():
|
def test_create_model_package_in_model_package_group():
|
||||||
client = boto3.client("sagemaker", region_name="eu-west-1")
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
Loading…
Reference in New Issue
Block a user