From 34bc540f70a99a2870a0c9984b9c5dc9546ec74e Mon Sep 17 00:00:00 2001 From: HALLOUARD <57447861+YHallouard@users.noreply.github.com> Date: Sun, 8 Oct 2023 22:16:50 +0200 Subject: [PATCH] Sagemaker: Add update_model_package (#6891) --- IMPLEMENTATION_COVERAGE.md | 2 +- moto/sagemaker/models.py | 370 ++++++++++++++++-- moto/sagemaker/responses.py | 25 +- moto/sagemaker/utils.py | 13 +- .../test_sagemaker_model_packages.py | 219 ++++++++++- 5 files changed, 588 insertions(+), 41 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 053a5dc86..086121c49 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6491,7 +6491,7 @@ - [ ] update_image_version - [ ] update_inference_experiment - [ ] update_model_card -- [ ] update_model_package +- [X] update_model_package - [ ] update_monitoring_alert - [ ] update_monitoring_schedule - [ ] update_notebook_instance diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index c1e94056e..8a36e33ee 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -19,10 +19,10 @@ from .utils import ( get_pipeline_from_name, get_pipeline_execution_from_arn, get_pipeline_name_from_execution_arn, + validate_model_approval_status, ) from .utils import load_pipeline_definition_from_s3, arn_formatter - PAGINATION_MODEL = { "list_experiments": { "input_token": "NextToken", @@ -988,10 +988,10 @@ class ModelPackage(BaseObject): source_algorithm_specification: Any, validation_specification: Any, certify_for_marketplace: bool, - model_approval_status: str, + model_approval_status: Optional[str], metadata_properties: Any, model_metrics: Any, - approval_description: str, + approval_description: Optional[str], customer_metadata_properties: Any, drift_check_baselines: Any, domain: str, @@ -1029,24 +1029,23 @@ class ModelPackage(BaseObject): self.inference_specification = inference_specification self.source_algorithm_specification = source_algorithm_specification self.validation_specification = validation_specification - self.model_package_status_details = ( - { - "ValidationStatuses": [ - { - "Name": model_package_arn, - "Status": "Completed", - } - ], - "ImageScanStatuses": [ - { - "Name": model_package_arn, - "Status": "Completed", - } - ], - }, - ) + self.model_package_status_details = { + "ValidationStatuses": [ + { + "Name": model_package_arn, + "Status": "Completed", + } + ], + "ImageScanStatuses": [ + { + "Name": model_package_arn, + "Status": "Completed", + } + ], + } self.certify_for_marketplace = certify_for_marketplace - self.model_approval_status = model_approval_status + self.model_approval_status: Optional[str] = None + self.set_model_approval_status(model_approval_status) self.created_by = { "UserProfileArn": fake_user_profile_arn, "UserProfileName": fake_user_profile_name, @@ -1054,21 +1053,20 @@ class ModelPackage(BaseObject): } self.metadata_properties = metadata_properties self.model_metrics = model_metrics - self.last_modified_time = datetime_now + self.last_modified_time: Optional[datetime] = None self.approval_description = approval_description self.customer_metadata_properties = customer_metadata_properties self.drift_check_baselines = drift_check_baselines self.domain = domain self.task = task self.sample_payload_url = sample_payload_url - self.additional_inference_specifications = additional_inference_specifications + self.additional_inference_specifications: Optional[List[Any]] = None + self.add_additional_inference_specifications( + additional_inference_specifications + ) self.tags = tags self.model_package_status = "Completed" - self.last_modified_by = { - "UserProfileArn": fake_user_profile_arn, - "UserProfileName": fake_user_profile_name, - "DomainId": fake_domain_id, - } + self.last_modified_by: Optional[Dict[str, str]] = None self.client_token = client_token def gen_response_object(self) -> Dict[str, Any]: @@ -1083,10 +1081,290 @@ class ModelPackage(BaseObject): "ModelPackageArn", "ModelPackageDescription", "CreationTime", + "InferenceSpecification", + "SourceAlgorithmSpecification", + "ValidationSpecification", "ModelPackageStatus", + "ModelPackageStatusDetails", + "CertifyForMarketplace", "ModelApprovalStatus", + "CreatedBy", + "MetadataProperties", + "ModelMetrics", + "LastModifiedTime", + "LastModifiedBy", + "ApprovalDescription", + "CustomerMetadataProperties", + "DriftCheckBaselines", + "Domain", + "Task", + "SamplePayloadUrl", + "AdditionalInferenceSpecifications", + "SkipModelValidation", ] - return {k: v for k, v in response_object.items() if k in response_values} + return { + k: v + for k, v in response_object.items() + if k in response_values + if v is not None + } + + def modifications_done(self) -> None: + self.last_modified_time = datetime.now(tzutc()) + self.last_modified_by = self.created_by + + def set_model_approval_status(self, model_approval_status: Optional[str]) -> None: + if model_approval_status is not None: + validate_model_approval_status(model_approval_status) + self.model_approval_status = model_approval_status + + def remove_customer_metadata_property( + self, customer_metadata_properties_to_remove: List[str] + ) -> None: + if customer_metadata_properties_to_remove is not None: + for customer_metadata_property in customer_metadata_properties_to_remove: + self.customer_metadata_properties.pop(customer_metadata_property, None) + + def add_additional_inference_specifications( + self, additional_inference_specifications_to_add: Optional[List[Any]] + ) -> None: + self.validate_additional_inference_specifications( + additional_inference_specifications_to_add + ) + if ( + self.additional_inference_specifications is not None + and additional_inference_specifications_to_add is not None + ): + self.additional_inference_specifications.extend( + additional_inference_specifications_to_add + ) + else: + self.additional_inference_specifications = ( + additional_inference_specifications_to_add + ) + + def validate_additional_inference_specifications( + self, additional_inference_specifications: Optional[List[Dict[str, Any]]] + ) -> None: + specifications_to_validate = additional_inference_specifications or [] + for additional_inference_specification in specifications_to_validate: + if "SupportedTransformInstanceTypes" in additional_inference_specification: + self.validate_supported_transform_instance_types( + additional_inference_specification[ + "SupportedTransformInstanceTypes" + ] + ) + if ( + "SupportedRealtimeInferenceInstanceTypes" + in additional_inference_specification + ): + self.validate_supported_realtime_inference_instance_types( + additional_inference_specification[ + "SupportedRealtimeInferenceInstanceTypes" + ] + ) + + @staticmethod + def validate_supported_transform_instance_types(instance_types: List[str]) -> None: + VALID_TRANSFORM_INSTANCE_TYPES = [ + "ml.m4.xlarge", + "ml.m4.2xlarge", + "ml.m4.4xlarge", + "ml.m4.10xlarge", + "ml.m4.16xlarge", + "ml.c4.xlarge", + "ml.c4.2xlarge", + "ml.c4.4xlarge", + "ml.c4.8xlarge", + "ml.p2.xlarge", + "ml.p2.8xlarge", + "ml.p2.16xlarge", + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g4dn.16xlarge", + ] + for instance_type in instance_types: + if not validators.is_one_of(instance_type, VALID_TRANSFORM_INSTANCE_TYPES): + message = f"Value '{instance_type}' at 'SupportedTransformInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_TRANSFORM_INSTANCE_TYPES}" + raise ValidationError(message=message) + + @staticmethod + def validate_supported_realtime_inference_instance_types( + instance_types: List[str], + ) -> None: + VALID_REALTIME_INFERENCE_INSTANCE_TYPES = [ + "ml.t2.medium", + "ml.t2.large", + "ml.t2.xlarge", + "ml.t2.2xlarge", + "ml.m4.xlarge", + "ml.m4.2xlarge", + "ml.m4.4xlarge", + "ml.m4.10xlarge", + "ml.m4.16xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.m5d.large", + "ml.m5d.xlarge", + "ml.m5d.2xlarge", + "ml.m5d.4xlarge", + "ml.m5d.12xlarge", + "ml.m5d.24xlarge", + "ml.c4.large", + "ml.c4.xlarge", + "ml.c4.2xlarge", + "ml.c4.4xlarge", + "ml.c4.8xlarge", + "ml.p2.xlarge", + "ml.p2.8xlarge", + "ml.p2.16xlarge", + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.c5.large", + "ml.c5.xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.c5d.large", + "ml.c5d.xlarge", + "ml.c5d.2xlarge", + "ml.c5d.4xlarge", + "ml.c5d.9xlarge", + "ml.c5d.18xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g4dn.16xlarge", + "ml.r5.large", + "ml.r5.xlarge", + "ml.r5.2xlarge", + "ml.r5.4xlarge", + "ml.r5.12xlarge", + "ml.r5.24xlarge", + "ml.r5d.large", + "ml.r5d.xlarge", + "ml.r5d.2xlarge", + "ml.r5d.4xlarge", + "ml.r5d.12xlarge", + "ml.r5d.24xlarge", + "ml.inf1.xlarge", + "ml.inf1.2xlarge", + "ml.inf1.6xlarge", + "ml.inf1.24xlarge", + "ml.c6i.large", + "ml.c6i.xlarge", + "ml.c6i.2xlarge", + "ml.c6i.4xlarge", + "ml.c6i.8xlarge", + "ml.c6i.12xlarge", + "ml.c6i.16xlarge", + "ml.c6i.24xlarge", + "ml.c6i.32xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.16xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + "ml.c7g.large", + "ml.c7g.xlarge", + "ml.c7g.2xlarge", + "ml.c7g.4xlarge", + "ml.c7g.8xlarge", + "ml.c7g.12xlarge", + "ml.c7g.16xlarge", + "ml.m6g.large", + "ml.m6g.xlarge", + "ml.m6g.2xlarge", + "ml.m6g.4xlarge", + "ml.m6g.8xlarge", + "ml.m6g.12xlarge", + "ml.m6g.16xlarge", + "ml.m6gd.large", + "ml.m6gd.xlarge", + "ml.m6gd.2xlarge", + "ml.m6gd.4xlarge", + "ml.m6gd.8xlarge", + "ml.m6gd.12xlarge", + "ml.m6gd.16xlarge", + "ml.c6g.large", + "ml.c6g.xlarge", + "ml.c6g.2xlarge", + "ml.c6g.4xlarge", + "ml.c6g.8xlarge", + "ml.c6g.12xlarge", + "ml.c6g.16xlarge", + "ml.c6gd.large", + "ml.c6gd.xlarge", + "ml.c6gd.2xlarge", + "ml.c6gd.4xlarge", + "ml.c6gd.8xlarge", + "ml.c6gd.12xlarge", + "ml.c6gd.16xlarge", + "ml.c6gn.large", + "ml.c6gn.xlarge", + "ml.c6gn.2xlarge", + "ml.c6gn.4xlarge", + "ml.c6gn.8xlarge", + "ml.c6gn.12xlarge", + "ml.c6gn.16xlarge", + "ml.r6g.large", + "ml.r6g.xlarge", + "ml.r6g.2xlarge", + "ml.r6g.4xlarge", + "ml.r6g.8xlarge", + "ml.r6g.12xlarge", + "ml.r6g.16xlarge", + "ml.r6gd.large", + "ml.r6gd.xlarge", + "ml.r6gd.2xlarge", + "ml.r6gd.4xlarge", + "ml.r6gd.8xlarge", + "ml.r6gd.12xlarge", + "ml.r6gd.16xlarge", + "ml.p4de.24xlarge", + "ml.trn1.2xlarge", + "ml.trn1.32xlarge", + "ml.inf2.xlarge", + "ml.inf2.8xlarge", + "ml.inf2.24xlarge", + "ml.inf2.48xlarge", + "ml.p5.48xlarge", + ] + for instance_type in instance_types: + if not validators.is_one_of( + instance_type, VALID_REALTIME_INFERENCE_INSTANCE_TYPES + ): + message = f"Value '{instance_type}' at 'SupportedRealtimeInferenceInstanceTypes' failed to satisfy constraint: Member must satisfy enum value set: {VALID_REALTIME_INFERENCE_INSTANCE_TYPES}" + raise ValidationError(message=message) class VpcConfig(BaseObject): @@ -3054,6 +3332,36 @@ class SageMakerModelBackend(BaseBackend): raise ValidationError(f"Model package {model_package_name} not found") return model_package + def update_model_package( + self, + model_package_arn: str, + model_approval_status: Optional[str], + approval_description: Optional[str], + customer_metadata_properties: Optional[Dict[str, str]], + customer_metadata_properties_to_remove: List[str], + additional_inference_specifications_to_add: Optional[List[Any]], + ) -> str: + model_package_name_mapped = self.model_package_name_mapping.get( + model_package_arn, model_package_arn + ) + model_package = self.model_packages.get(model_package_name_mapped) + + if model_package is None: + raise ValidationError(f"Model package {model_package_arn} not found") + + model_package.set_model_approval_status(model_approval_status) + model_package.approval_description = approval_description + model_package.customer_metadata_properties = customer_metadata_properties + model_package.remove_customer_metadata_property( + customer_metadata_properties_to_remove + ) + model_package.add_additional_inference_specifications( + additional_inference_specifications_to_add + ) + model_package.modifications_done() + + return model_package.model_package_arn + def create_model_package( self, model_package_name: str, @@ -3062,9 +3370,9 @@ class SageMakerModelBackend(BaseBackend): inference_specification: Any, validation_specification: Any, source_algorithm_specification: Any, - certify_for_marketplace: Any, + certify_for_marketplace: bool, tags: Any, - model_approval_status: str, + model_approval_status: Optional[str], metadata_properties: Any, model_metrics: Any, client_token: Any, @@ -3102,7 +3410,7 @@ class SageMakerModelBackend(BaseBackend): sample_payload_url=sample_payload_url, additional_inference_specifications=additional_inference_specifications, model_package_version=model_package_version, - approval_description=model_approval_status, + approval_description=None, region_name=self.region_name, account_id=self.account_id, client_token=client_token, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index ec83c10e2..acddec2c8 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -874,6 +874,27 @@ class SageMakerResponse(BaseResponse): model_package.gen_response_object(), ) + def update_model_package(self) -> str: + model_package_arn = self._get_param("ModelPackageArn") + model_approval_status = self._get_param("ModelApprovalStatus") + approval_dexcription = self._get_param("ApprovalDescription") + customer_metadata_properties = self._get_param("CustomerMetadataProperties") + customer_metadata_properties_to_remove = self._get_param( + "CustomerMetadataPropertiesToRemove", [] + ) + additional_inference_specification_to_add = self._get_param( + "AdditionalInferenceSpecificationsToAdd" + ) + updated_model_package_arn = self.sagemaker_backend.update_model_package( + model_package_arn=model_package_arn, + model_approval_status=model_approval_status, + approval_description=approval_dexcription, + customer_metadata_properties=customer_metadata_properties, + customer_metadata_properties_to_remove=customer_metadata_properties_to_remove, + additional_inference_specifications_to_add=additional_inference_specification_to_add, + ) + return json.dumps(dict(ModelPackageArn=updated_model_package_arn)) + def create_model_package(self) -> str: model_package_name = self._get_param("ModelPackageName") model_package_group_name = self._get_param("ModelPackageGroupName") @@ -881,7 +902,7 @@ class SageMakerResponse(BaseResponse): inference_specification = self._get_param("InferenceSpecification") validation_specification = self._get_param("ValidationSpecification") source_algorithm_specification = self._get_param("SourceAlgorithmSpecification") - certify_for_marketplace = self._get_param("CertifyForMarketplace") + certify_for_marketplace = self._get_param("CertifyForMarketplace", False) tags = self._get_param("Tags") model_approval_status = self._get_param("ModelApprovalStatus") metadata_properties = self._get_param("MetadataProperties") @@ -893,7 +914,7 @@ class SageMakerResponse(BaseResponse): task = self._get_param("Task") sample_payload_url = self._get_param("SamplePayloadUrl") additional_inference_specifications = self._get_param( - "AdditionalInferenceSpecifications" + "AdditionalInferenceSpecifications", None ) model_package_arn = self.sagemaker_backend.create_model_package( model_package_name=model_package_name, diff --git a/moto/sagemaker/utils.py b/moto/sagemaker/utils.py index 6b6b9175e..f10f85108 100644 --- a/moto/sagemaker/utils.py +++ b/moto/sagemaker/utils.py @@ -5,7 +5,6 @@ import json from typing import Any, Dict from .exceptions import ValidationError - if typing.TYPE_CHECKING: from .models import FakePipeline, FakePipelineExecution @@ -51,3 +50,15 @@ def load_pipeline_definition_from_s3( def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str: return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}" + + +def validate_model_approval_status(model_approval_status: typing.Optional[str]) -> None: + if model_approval_status is not None and model_approval_status not in [ + "Approved", + "Rejected", + "PendingManualApproval", + ]: + raise ValidationError( + f"Value '{model_approval_status}' at 'modelApprovalStatus' failed to satisfy constraint: " + "Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]" + ) diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py index c510c2c9c..617d5a27b 100644 --- a/tests/test_sagemaker/test_sagemaker_model_packages.py +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -1,13 +1,20 @@ """Unit tests for sagemaker-supported APIs.""" -from unittest import SkipTest +from datetime import datetime +from unittest import SkipTest, TestCase import boto3 from freezegun import freeze_time +from dateutil.tz import tzutc # type: ignore from moto import mock_sagemaker, settings +import pytest + # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html +from moto.sagemaker.exceptions import ValidationError +from moto.sagemaker.models import ModelPackage +from moto.sagemaker.utils import validate_model_approval_status @mock_sagemaker @@ -210,15 +217,194 @@ def test_list_model_packages_sort_order(): @mock_sagemaker -def test_describe_model_package(): +def test_describe_model_package_default(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") client = boto3.client("sagemaker", region_name="eu-west-1") - client.create_model_package( - ModelPackageName="test-model-package", - ModelPackageDescription="test-model-package-description", - ) + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + with freeze_time("2015-01-01 00:00:00"): + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + ModelPackageDescription="test-model-package-description", + ) resp = client.describe_model_package(ModelPackageName="test-model-package") assert resp["ModelPackageName"] == "test-model-package" + assert resp["ModelPackageGroupName"] == "test-model-package-group" assert resp["ModelPackageDescription"] == "test-model-package-description" + assert ( + resp["ModelPackageArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1" + ) + assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc()) + assert ( + resp["CreatedBy"]["UserProfileArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name" + ) + assert resp["CreatedBy"]["UserProfileName"] == "fake-user-profile-name" + assert resp["CreatedBy"]["DomainId"] == "fake-domain-id" + assert resp["ModelPackageStatus"] == "Completed" + assert resp.get("ModelPackageStatusDetails") is not None + assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [ + { + "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1", + "Status": "Completed", + } + ] + assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [ + { + "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1", + "Status": "Completed", + } + ] + assert resp["CertifyForMarketplace"] is False + + +@mock_sagemaker +def test_describe_model_package_with_create_model_package_arguments(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + ModelPackageDescription="test-model-package-description", + ModelApprovalStatus="PendingManualApproval", + MetadataProperties={ + "CommitId": "test-commit-id", + "GeneratedBy": "test-user", + "ProjectId": "test-project-id", + "Repository": "test-repo", + }, + CertifyForMarketplace=True, + ) + resp = client.describe_model_package(ModelPackageName="test-model-package") + assert resp["ModelApprovalStatus"] == "PendingManualApproval" + assert resp.get("ApprovalDescription") is None + assert resp["CertifyForMarketplace"] is True + assert resp["MetadataProperties"] is not None + assert resp["MetadataProperties"]["CommitId"] == "test-commit-id" + assert resp["MetadataProperties"]["GeneratedBy"] == "test-user" + assert resp["MetadataProperties"]["ProjectId"] == "test-project-id" + assert resp["MetadataProperties"]["Repository"] == "test-repo" + + +@mock_sagemaker +def test_update_model_package(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + model_package_arn = client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + ModelPackageDescription="test-model-package-description", + CustomerMetadataProperties={ + "test-key-to-remove": "test-value-to-remove", + }, + )["ModelPackageArn"] + client.update_model_package( + ModelPackageArn=model_package_arn, + ModelApprovalStatus="Approved", + ApprovalDescription="test-approval-description", + CustomerMetadataProperties={"test-key": "test-value"}, + CustomerMetadataPropertiesToRemove=["test-key-to-remove"], + ) + resp = client.describe_model_package(ModelPackageName="test-model-package") + assert resp["ModelApprovalStatus"] == "Approved" + assert resp["ApprovalDescription"] == "test-approval-description" + assert resp["CustomerMetadataProperties"] is not None + assert resp["CustomerMetadataProperties"]["test-key"] == "test-value" + assert resp["CustomerMetadataProperties"].get("test-key-to-remove") is None + + +@mock_sagemaker +def test_update_model_package_given_additional_inference_specifications_to_add(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + model_package_arn = client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + ModelPackageDescription="test-model-package-description", + )["ModelPackageArn"] + additional_inference_specifications_to_add = { + "Name": "test-inference-specification-name", + "Description": "test-inference-specification-description", + "Containers": [ + { + "ContainerHostname": "test-container-hostname", + "Image": "test-image", + "ImageDigest": "test-image-digest", + "ModelDataUrl": "test-model-data-url", + "ProductId": "test-product-id", + "Environment": {"test-key": "test-value"}, + "ModelInput": {"DataInputConfig": "test-data-input-config"}, + "Framework": "test-framework", + "FrameworkVersion": "test-framework-version", + "NearestModelName": "test-nearest-model-name", + }, + ], + "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + "SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.t2.large"], + "SupportedContentTypes": [ + "test-content-type-1", + ], + "SupportedResponseMIMETypes": [ + "test-response-mime-type-1", + ], + } + client.update_model_package( + ModelPackageArn=model_package_arn, + AdditionalInferenceSpecificationsToAdd=[ + additional_inference_specifications_to_add + ], + ) + resp = client.describe_model_package(ModelPackageName="test-model-package") + assert resp["AdditionalInferenceSpecifications"] is not None + TestCase().assertDictEqual( + resp["AdditionalInferenceSpecifications"][0], + additional_inference_specifications_to_add, + ) + + +@mock_sagemaker +def test_update_model_package_shoudl_update_last_modified_information(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") + model_package_arn = client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageGroupName="test-model-package-group", + )["ModelPackageArn"] + with freeze_time("2020-01-01 12:00:00"): + client.update_model_package( + ModelPackageArn=model_package_arn, + ModelApprovalStatus="Approved", + ) + resp = client.describe_model_package(ModelPackageName="test-model-package") + assert resp.get("LastModifiedTime") is not None + assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc()) + assert resp.get("LastModifiedBy") is not None + assert ( + resp["LastModifiedBy"]["UserProfileArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:user-profile/fake-domain-id/fake-user-profile-name" + ) + assert resp["LastModifiedBy"]["UserProfileName"] == "fake-user-profile-name" + assert resp["LastModifiedBy"]["DomainId"] == "fake-domain-id" + + +def test_validate_supported_transform_instance_types_should_raise_error_for_wrong_supported_transform_instance_types(): + with pytest.raises(ValidationError) as exc: + ModelPackage.validate_supported_transform_instance_types( + ["ml.m4.2xlarge", "not-a-supported-transform-instances-types"] + ) + assert "not-a-supported-transform-instances-types" in str(exc.value) + + +def test_validate_supported_realtime_inference_instance_types_should_raise_error_for_wrong_supported_transform_instance_types(): + with pytest.raises(ValidationError) as exc: + ModelPackage.validate_supported_realtime_inference_instance_types( + ["ml.m4.2xlarge", "not-a-supported-realtime-inference-instances-types"] + ) + assert "not-a-supported-realtime-inference-instances-types" in str(exc.value) @mock_sagemaker @@ -234,6 +420,27 @@ def test_create_model_package(): ) +@pytest.mark.parametrize( + "model_approval_status", + ["Approved", "Rejected", "PendingManualApproval"], +) +def test_utils_validate_model_approval_status_should_not_raise_error_if_model_approval_status_is_correct( + model_approval_status: str, +): + validate_model_approval_status(model_approval_status) + + +def test_utils_validate_model_approval_status_should_raise_error_if_model_approval_status_is_incorrect(): + model_approval_status = "IncorrectStatus" + with pytest.raises(ValidationError) as exc: + validate_model_approval_status(model_approval_status) + assert exc.value.code == 400 + assert ( + exc.value.message + == "Value 'IncorrectStatus' at 'modelApprovalStatus' failed to satisfy constraint: Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]" + ) + + @mock_sagemaker def test_create_model_package_in_model_package_group(): client = boto3.client("sagemaker", region_name="eu-west-1")