diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index d4c932d3e..896dd3b66 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -3,7 +3,7 @@ import os import random import string from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union, cast from dateutil.tz import tzutc @@ -1004,6 +1004,7 @@ class ModelPackage(BaseObject): client_token: str, region_name: str, account_id: str, + model_package_type: str, tags: Optional[List[Dict[str, str]]] = None, ) -> None: fake_user_profile_name = "fake-user-profile-name" @@ -1032,6 +1033,7 @@ class ModelPackage(BaseObject): self.inference_specification = inference_specification self.source_algorithm_specification = source_algorithm_specification self.validation_specification = validation_specification + self.model_package_type = model_package_type self.model_package_status_details = { "ValidationStatuses": [ { @@ -1105,6 +1107,10 @@ class ModelPackage(BaseObject): "AdditionalInferenceSpecifications", "SkipModelValidation", ] + if self.model_package_type == "Versioned": + del response_object["ModelPackageName"] + elif self.model_package_type == "Unversioned": + del response_object["ModelPackageGroupName"] return { k: v for k, v in response_object.items() @@ -3295,6 +3301,8 @@ class SageMakerModelBackend(BaseBackend): ) if model_package_group_name is not None: model_package_type = "Versioned" + if model_package_group_name.startswith("arn:aws"): + model_package_group_name = model_package_group_name.split("/")[-1] model_package_summary_list = list( filter( lambda x: ( @@ -3377,7 +3385,7 @@ class SageMakerModelBackend(BaseBackend): def create_model_package( self, - model_package_name: str, + model_package_name: Optional[str], model_package_group_name: Optional[str], model_package_description: Optional[str], inference_specification: Any, @@ -3397,15 +3405,32 @@ class SageMakerModelBackend(BaseBackend): additional_inference_specifications: Any, ) -> str: model_package_version = None - if model_package_group_name is not None: + if model_package_group_name and model_package_name: + raise AWSValidationException( + "An error occurred (ValidationException) when calling the CreateModelPackage operation: Both ModelPackageName and ModelPackageGroupName are provided in the input. Cannot determine which one to use." + ) + elif not model_package_group_name and not model_package_name: + raise AWSValidationException( + "An error ocurred (ValidationException) when calling the CreateModelPackag operation: Missing ARN." + ) + elif model_package_group_name: + model_package_type = "Versioned" + model_package_name = model_package_group_name model_packages_for_group = [ x for x in self.model_packages.values() if x.model_package_group_name == model_package_group_name ] + if model_package_group_name not in self.model_package_groups: + raise AWSValidationException( + "An error ocurred (ValidationException) when calling the CreateModelPackage operation: Model Package Group does not exist." + ) model_package_version = len(model_packages_for_group) + 1 + else: + model_package_type = "Unversioned" + model_package = ModelPackage( - model_package_name=model_package_name, + model_package_name=cast(str, model_package_name), model_package_group_name=model_package_group_name, model_package_description=model_package_description, inference_specification=inference_specification, @@ -3427,6 +3452,7 @@ class SageMakerModelBackend(BaseBackend): region_name=self.region_name, account_id=self.account_id, client_token=client_token, + model_package_type=model_package_type, ) self.model_package_name_mapping[ model_package.model_package_name diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py index 1a8f303b6..5d6b0d240 100644 --- a/tests/test_sagemaker/test_sagemaker_model_packages.py +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -4,6 +4,7 @@ from unittest import SkipTest, TestCase import boto3 import pytest +from botocore.exceptions import ClientError from dateutil.tz import tzutc # type: ignore from freezegun import freeze_time @@ -133,19 +134,17 @@ def test_list_model_packages_approval_status(): def test_list_model_packages_model_package_group_name(): client = boto3.client("sagemaker", region_name="eu-west-1") group1 = "test-model-package-group" + client.create_model_package_group( + ModelPackageGroupName=group1, + ModelPackageGroupDescription="test-model-package-description", + ) client.create_model_package( - ModelPackageName="test-model-package", - ModelPackageDescription="test-model-package-description", ModelPackageGroupName=group1, ) client.create_model_package( - ModelPackageName="test-model-package", - ModelPackageDescription="test-model-package-description-2", ModelPackageGroupName=group1, ) client.create_model_package( - ModelPackageName="test-model-package-2", - ModelPackageDescription="test-model-package-description-3", ModelPackageGroupName=group1, ) client.create_model_package( @@ -170,8 +169,8 @@ def test_list_model_packages_model_package_group_name(): @mock_sagemaker def test_list_model_packages_model_package_type(): client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package_group(ModelPackageGroupName="test-model-package-group") client.create_model_package( - ModelPackageName="test-model-package", ModelPackageDescription="test-model-package-description", ModelPackageGroupName="test-model-package-group", ) @@ -233,18 +232,16 @@ def test_describe_model_package_default(): client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package_group(ModelPackageGroupName="test-model-package-group") with freeze_time("2015-01-01 00:00:00"): - client.create_model_package( - ModelPackageName="test-model-package", + model_package_arn = client.create_model_package( ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", - ) - resp = client.describe_model_package(ModelPackageName="test-model-package") - assert resp["ModelPackageName"] == "test-model-package" + )["ModelPackageArn"] + resp = client.describe_model_package(ModelPackageName=model_package_arn) assert resp["ModelPackageGroupName"] == "test-model-package-group" assert resp["ModelPackageDescription"] == "test-model-package-description" assert ( resp["ModelPackageArn"] - == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1" + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1" ) assert resp["CreationTime"] == datetime(2015, 1, 1, 0, 0, 0, tzinfo=tzutc()) assert ( @@ -257,13 +254,13 @@ def test_describe_model_package_default(): assert resp.get("ModelPackageStatusDetails") is not None assert resp["ModelPackageStatusDetails"]["ValidationStatuses"] == [ { - "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1", + "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1", "Status": "Completed", } ] assert resp["ModelPackageStatusDetails"]["ImageScanStatuses"] == [ { - "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1", + "Name": "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1", "Status": "Completed", } ] @@ -273,10 +270,8 @@ def test_describe_model_package_default(): @mock_sagemaker def test_describe_model_package_with_create_model_package_arguments(): client = boto3.client("sagemaker", region_name="eu-west-1") - client.create_model_package_group(ModelPackageGroupName="test-model-package-group") client.create_model_package( ModelPackageName="test-model-package", - ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", ModelApprovalStatus="PendingManualApproval", MetadataProperties={ @@ -303,7 +298,6 @@ def test_update_model_package(): client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package_group(ModelPackageGroupName="test-model-package-group") model_package_arn = client.create_model_package( - ModelPackageName="test-model-package", ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", CustomerMetadataProperties={ @@ -317,7 +311,7 @@ def test_update_model_package(): CustomerMetadataProperties={"test-key": "test-value"}, CustomerMetadataPropertiesToRemove=["test-key-to-remove"], ) - resp = client.describe_model_package(ModelPackageName="test-model-package") + resp = client.describe_model_package(ModelPackageName=model_package_arn) assert resp["ModelApprovalStatus"] == "Approved" assert resp["ApprovalDescription"] == "test-approval-description" assert resp["CustomerMetadataProperties"] is not None @@ -330,7 +324,6 @@ def test_update_model_package_given_additional_inference_specifications_to_add() client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package_group(ModelPackageGroupName="test-model-package-group") model_package_arn = client.create_model_package( - ModelPackageName="test-model-package", ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", )["ModelPackageArn"] @@ -366,7 +359,7 @@ def test_update_model_package_given_additional_inference_specifications_to_add() additional_inference_specifications_to_add ], ) - resp = client.describe_model_package(ModelPackageName="test-model-package") + resp = client.describe_model_package(ModelPackageName=model_package_arn) assert resp["AdditionalInferenceSpecifications"] is not None TestCase().assertDictEqual( resp["AdditionalInferenceSpecifications"][0], @@ -375,13 +368,12 @@ def test_update_model_package_given_additional_inference_specifications_to_add() @mock_sagemaker -def test_update_model_package_shoudl_update_last_modified_information(): +def test_update_model_package_should_update_last_modified_information(): if settings.TEST_SERVER_MODE: raise SkipTest("Can't freeze time in ServerMode") client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package_group(ModelPackageGroupName="test-model-package-group") model_package_arn = client.create_model_package( - ModelPackageName="test-model-package", ModelPackageGroupName="test-model-package-group", )["ModelPackageArn"] with freeze_time("2020-01-01 12:00:00"): @@ -389,7 +381,7 @@ def test_update_model_package_shoudl_update_last_modified_information(): ModelPackageArn=model_package_arn, ModelApprovalStatus="Approved", ) - resp = client.describe_model_package(ModelPackageName="test-model-package") + resp = client.describe_model_package(ModelPackageName=model_package_arn) assert resp.get("LastModifiedTime") is not None assert resp["LastModifiedTime"] == datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc()) assert resp.get("LastModifiedBy") is not None @@ -430,6 +422,42 @@ def test_create_model_package(): ) +@mock_sagemaker +def test_create_model_package_should_raise_error_when_model_package_group_name_and_model_package_group_name_are_not_provided(): + client = boto3.client("sagemaker", region_name="eu-west-1") + with pytest.raises(ClientError) as exc: + _ = client.create_model_package( + ModelPackageDescription="test-model-package-description", + ) + assert "Missing ARN." in str(exc.value) + + +@mock_sagemaker +def test_create_model_package_should_raise_error_when_package_name_and_group_are_provided(): + client = boto3.client("sagemaker", region_name="eu-west-1") + with pytest.raises(ClientError) as exc: + _ = client.create_model_package( + ModelPackageGroupName="TestModelPackageGroup", + ModelPackageName="TestModelPackage", + ModelPackageDescription="test-model-package-description", + ) + assert ( + "Both ModelPackageName and ModelPackageGroupName are provided in the input." + in str(exc.value) + ) + + +@mock_sagemaker +def test_create_model_package_should_raise_error_when_model_package_group_provided_not_exist(): + client = boto3.client("sagemaker", region_name="eu-west-1") + with pytest.raises(ClientError) as exc: + _ = client.create_model_package( + ModelPackageGroupName="TestModelPackageGroupNotExist", + ModelPackageDescription="test-model-package-description", + ) + assert "Model Package Group does not exist." in str(exc.value) + + @pytest.mark.parametrize( "model_approval_status", ["Approved", "Rejected", "PendingManualApproval"], @@ -456,20 +484,18 @@ def test_create_model_package_in_model_package_group(): client = boto3.client("sagemaker", region_name="eu-west-1") client.create_model_package_group(ModelPackageGroupName="test-model-package-group") resp_version_1 = client.create_model_package( - ModelPackageName="TestModelPackage", ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", ) resp_version_2 = client.create_model_package( - ModelPackageName="TestModelPackage", ModelPackageGroupName="test-model-package-group", ModelPackageDescription="test-model-package-description", ) assert ( resp_version_1["ModelPackageArn"] - == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/1" + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/1" ) assert ( resp_version_2["ModelPackageArn"] - == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/testmodelpackage/2" + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package-group/2" )