diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 5a9da6260..a5bf3c6da 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -51,6 +51,13 @@ PAGINATION_MODEL = { "unique_attribute": "Key", "fail_on_invalid_token": True, }, + "list_model_packages": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "ModelPackageArn", + "fail_on_invalid_token": True, + }, } @@ -911,6 +918,152 @@ class Model(BaseObject, CloudFormationModel): sagemaker_backends[account_id][region_name].delete_model(model_name) +class ModelPackageGroup(BaseObject): + def __init__( + self, + model_package_group_name: str, + model_package_group_description: str, + account_id: str, + region_name: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + model_package_group_arn = arn_formatter( + region_name=region_name, + account_id=account_id, + _type="model-package-group", + _id=model_package_group_name, + ) + fake_user_profile_name = "fake-user-profile-name" + fake_domain_id = "fake-domain-id" + fake_user_profile_arn = arn_formatter( + _type="user-profile", + _id=f"{fake_domain_id}/{fake_user_profile_name}", + account_id=account_id, + region_name=region_name, + ) + self.model_package_group_name = model_package_group_name + self.model_package_group_arn = model_package_group_arn + self.model_package_group_description = model_package_group_description + self.creation_time = datetime.now() + self.created_by = { + "UserProfileArn": fake_user_profile_arn, + "UserProfileName": fake_user_profile_name, + "DomainId": fake_domain_id, + } + self.model_package_group_status = "Completed" + self.tags = tags + + +class ModelPackage(BaseObject): + def __init__( + self, + model_package_name: str, + model_package_group_name: Optional[str], + model_package_version: Optional[int], + model_package_description: Optional[str], + inference_specification: Any, + source_algorithm_specification: Any, + validation_specification: Any, + certify_for_marketplace: bool, + model_approval_status: str, + metadata_properties: Any, + model_metrics: Any, + approval_description: str, + customer_metadata_properties: Any, + drift_check_baselines: Any, + domain: str, + task: str, + sample_payload_url: str, + additional_inference_specifications: List[Any], + client_token: str, + region_name: str, + account_id: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + fake_user_profile_name = "fake-user-profile-name" + fake_domain_id = "fake-domain-id" + fake_user_profile_arn = arn_formatter( + _type="user-profile", + _id=f"{fake_domain_id}/{fake_user_profile_name}", + account_id=account_id, + region_name=region_name, + ) + model_package_arn = arn_formatter( + region_name=region_name, + account_id=account_id, + _type="model-package", + _id=model_package_name, + ) + datetime_now = datetime.utcnow() + self.model_package_name = model_package_name + self.model_package_group_name = model_package_group_name + self.model_package_version = model_package_version + self.model_package_arn = model_package_arn + self.model_package_description = model_package_description + self.creation_time = datetime_now + self.inference_specification = inference_specification + self.source_algorithm_specification = source_algorithm_specification + self.validation_specification = validation_specification + self.model_package_status_details = ( + { + "ValidationStatuses": [ + { + "Name": model_package_arn, + "Status": "Completed", + } + ], + "ImageScanStatuses": [ + { + "Name": model_package_arn, + "Status": "Completed", + } + ], + }, + ) + self.certify_for_marketplace = certify_for_marketplace + self.model_approval_status = model_approval_status + self.created_by = { + "UserProfileArn": fake_user_profile_arn, + "UserProfileName": fake_user_profile_name, + "DomainId": fake_domain_id, + } + self.metadata_properties = metadata_properties + self.model_metrics = model_metrics + self.last_modified_time = datetime_now + self.approval_description = approval_description + self.customer_metadata_properties = customer_metadata_properties + self.drift_check_baselines = drift_check_baselines + self.domain = domain + self.task = task + self.sample_payload_url = sample_payload_url + self.additional_inference_specifications = additional_inference_specifications + self.tags = tags + self.model_package_status = "Completed" + self.last_modified_by = { + "UserProfileArn": fake_user_profile_arn, + "UserProfileName": fake_user_profile_name, + "DomainId": fake_domain_id, + } + self.client_token = client_token + + def gen_response_object(self) -> Dict[str, Any]: + response_object = super().gen_response_object() + for k, v in response_object.items(): + if isinstance(v, datetime): + response_object[k] = v.isoformat() + response_values = [ + "ModelPackageName", + "ModelPackageGroupName", + "ModelPackageVersion", + "ModelPackageArn", + "ModelPackageDescription", + "CreationTime", + "ModelPackageStatus", + "ModelApprovalStatus", + ] + return {k: v for k, v in response_object.items() if k in response_values} + + class VpcConfig(BaseObject): def __init__(self, security_group_ids: List[str], subnets: List[str]): self.security_group_ids = security_group_ids @@ -1277,6 +1430,9 @@ class SageMakerModelBackend(BaseBackend): self.notebook_instance_lifecycle_configurations: Dict[ str, FakeSageMakerNotebookInstanceLifecycleConfig ] = {} + self.model_package_groups: Dict[str, ModelPackageGroup] = {} + self.model_packages: Dict[str, ModelPackage] = {} + self.model_package_name_mapping: Dict[str, str] = {} @staticmethod def default_vpc_endpoint_service( @@ -2671,6 +2827,164 @@ class SageMakerModelBackend(BaseBackend): endpoint.endpoint_status = "InService" return endpoint.endpoint_arn + def create_model_package_group( + self, + model_package_group_name: str, + model_package_group_description: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> str: + self.model_package_groups[model_package_group_name] = ModelPackageGroup( + model_package_group_name=model_package_group_name, + model_package_group_description=model_package_group_description, + account_id=self.account_id, + region_name=self.region_name, + tags=tags, + ) + return self.model_package_groups[ + model_package_group_name + ].model_package_group_arn + + def _get_versioned_or_not( + self, model_package_type: Optional[str], model_package_version: Optional[int] + ) -> bool: + if model_package_type == "Versioned": + return model_package_version is not None + elif model_package_type == "Unversioned" or model_package_type is None: + return model_package_version is None + elif model_package_type == "Both": + return True + raise ValueError(f"Invalid model package type: {model_package_type}") + + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] + def list_model_packages( # type: ignore[misc] + self, + creation_time_after: Optional[int], + creation_time_before: Optional[int], + name_contains: Optional[str], + model_approval_status: Optional[str], + model_package_group_name: Optional[str], + model_package_type: Optional[str], + sort_by: Optional[str], + sort_order: Optional[str], + ) -> List[ModelPackage]: + if isinstance(creation_time_before, int): + creation_time_before_datetime = datetime.fromtimestamp(creation_time_before) + if isinstance(creation_time_after, int): + creation_time_after_datetime = datetime.fromtimestamp(creation_time_after) + if model_package_group_name is not None: + model_package_type = "Versioned" + model_package_summary_list = list( + filter( + lambda x: ( + creation_time_after is None + or x.creation_time > creation_time_after_datetime + ) + and ( + creation_time_before is None + or x.creation_time < creation_time_before_datetime + ) + and ( + name_contains is None + or x.model_package_name.find(name_contains) != -1 + ) + and ( + model_approval_status is None + or x.model_approval_status == model_approval_status + ) + and ( + model_package_group_name is None + or x.model_package_group_name == model_package_group_name + ) + and self._get_versioned_or_not( + model_package_type, x.model_package_version + ), + self.model_packages.values(), + ) + ) + model_package_summary_list = list( + sorted( + model_package_summary_list, + key={ + "Name": lambda x: x.model_package_name, + "CreationTime": lambda x: x.creation_time, + None: lambda x: x.creation_time, + }[sort_by], + reverse=sort_order == "Descending", + ) + ) + return model_package_summary_list + + def describe_model_package(self, model_package_name: str) -> ModelPackage: + model_package_name_mapped = self.model_package_name_mapping.get( + model_package_name, model_package_name + ) + model_package = self.model_packages.get(model_package_name_mapped) + if model_package is None: + raise ValidationError(f"Model package {model_package_name} not found") + return model_package + + def create_model_package( + self, + model_package_name: str, + model_package_group_name: Optional[str], + model_package_description: Optional[str], + inference_specification: Any, + validation_specification: Any, + source_algorithm_specification: Any, + certify_for_marketplace: Any, + tags: Any, + model_approval_status: str, + metadata_properties: Any, + model_metrics: Any, + client_token: Any, + customer_metadata_properties: Any, + drift_check_baselines: Any, + domain: Any, + task: Any, + sample_payload_url: Any, + additional_inference_specifications: Any, + ) -> str: + model_package_version = None + if model_package_group_name is not None: + model_packages_for_group = [ + x + for x in self.model_packages.values() + if x.model_package_group_name == model_package_group_name + ] + model_package_version = len(model_packages_for_group) + 1 + model_package = ModelPackage( + model_package_name=model_package_name, + model_package_group_name=model_package_group_name, + model_package_description=model_package_description, + inference_specification=inference_specification, + validation_specification=validation_specification, + source_algorithm_specification=source_algorithm_specification, + certify_for_marketplace=certify_for_marketplace, + tags=tags, + model_approval_status=model_approval_status, + metadata_properties=metadata_properties, + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + drift_check_baselines=drift_check_baselines, + domain=domain, + task=task, + sample_payload_url=sample_payload_url, + additional_inference_specifications=additional_inference_specifications, + model_package_version=model_package_version, + approval_description=model_approval_status, + region_name=self.region_name, + account_id=self.account_id, + client_token=client_token, + ) + self.model_package_name_mapping[ + model_package.model_package_name + ] = model_package.model_package_arn + self.model_package_name_mapping[ + model_package.model_package_arn + ] = model_package.model_package_arn + self.model_packages[model_package.model_package_arn] = model_package + return model_package.model_package_arn + class FakeExperiment(BaseObject): def __init__( diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 6eba11d57..d04184dfc 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -796,3 +796,104 @@ class SageMakerResponse(BaseResponse): desired_weights_and_capacities=desired_weights_and_capacities, ) return 200, {}, json.dumps({"EndpointArn": endpoint_arn}) + + def list_model_packages(self) -> str: + creation_time_after = self._get_param("CreationTimeAfter") + creation_time_before = self._get_param("CreationTimeBefore") + max_results = self._get_param("MaxResults") + name_contains = self._get_param("NameContains") + model_approval_status = self._get_param("ModelApprovalStatus") + model_package_group_name = self._get_param("ModelPackageGroupName") + model_package_type = self._get_param("ModelPackageType", "Unversioned") + next_token = self._get_param("NextToken") + sort_by = self._get_param("SortBy") + sort_order = self._get_param("SortOrder") + ( + model_package_summary_list, + next_token, + ) = self.sagemaker_backend.list_model_packages( + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + max_results=max_results, + name_contains=name_contains, + model_approval_status=model_approval_status, + model_package_group_name=model_package_group_name, + model_package_type=model_package_type, + next_token=next_token, + sort_by=sort_by, + sort_order=sort_order, + ) + model_package_summary_list_response_object = [ + x.gen_response_object() for x in model_package_summary_list + ] + return json.dumps( + dict( + ModelPackageSummaryList=model_package_summary_list_response_object, + NextToken=next_token, + ) + ) + + def describe_model_package(self) -> str: + model_package_name = self._get_param("ModelPackageName") + model_package = self.sagemaker_backend.describe_model_package( + model_package_name=model_package_name, + ) + return json.dumps( + model_package.gen_response_object(), + ) + + def create_model_package(self) -> str: + model_package_name = self._get_param("ModelPackageName") + model_package_group_name = self._get_param("ModelPackageGroupName") + model_package_description = self._get_param("ModelPackageDescription") + inference_specification = self._get_param("InferenceSpecification") + validation_specification = self._get_param("ValidationSpecification") + source_algorithm_specification = self._get_param("SourceAlgorithmSpecification") + certify_for_marketplace = self._get_param("CertifyForMarketplace") + tags = self._get_param("Tags") + model_approval_status = self._get_param("ModelApprovalStatus") + metadata_properties = self._get_param("MetadataProperties") + model_metrics = self._get_param("ModelMetrics") + client_token = self._get_param("ClientToken") + customer_metadata_properties = self._get_param("CustomerMetadataProperties") + drift_check_baselines = self._get_param("DriftCheckBaselines") + domain = self._get_param("Domain") + task = self._get_param("Task") + sample_payload_url = self._get_param("SamplePayloadUrl") + additional_inference_specifications = self._get_param( + "AdditionalInferenceSpecifications" + ) + model_package_arn = self.sagemaker_backend.create_model_package( + model_package_name=model_package_name, + model_package_group_name=model_package_group_name, + model_package_description=model_package_description, + inference_specification=inference_specification, + validation_specification=validation_specification, + source_algorithm_specification=source_algorithm_specification, + certify_for_marketplace=certify_for_marketplace, + tags=tags, + model_approval_status=model_approval_status, + metadata_properties=metadata_properties, + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + drift_check_baselines=drift_check_baselines, + domain=domain, + task=task, + sample_payload_url=sample_payload_url, + additional_inference_specifications=additional_inference_specifications, + client_token=client_token, + ) + return json.dumps(dict(ModelPackageArn=model_package_arn)) + + def create_model_package_group(self) -> str: + model_package_group_name = self._get_param("ModelPackageGroupName") + model_package_group_description = self._get_param( + "ModelPackageGroupDescription" + ) + tags = self._get_param("Tags") + model_package_group_arn = self.sagemaker_backend.create_model_package_group( + model_package_group_name=model_package_group_name, + model_package_group_description=model_package_group_description, + tags=tags, + ) + return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn)) diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py new file mode 100644 index 000000000..cbb3f61fc --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -0,0 +1,236 @@ +"""Unit tests for sagemaker-supported APIs.""" +import boto3 +from freezegun import freeze_time + +from moto import mock_sagemaker, settings +from unittest import SkipTest + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_sagemaker +def test_list_model_packages(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages() + + assert ( + resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package" + ) + assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0] + assert ( + resp["ModelPackageSummaryList"][0]["ModelPackageDescription"] + == "test-model-package-description" + ) + assert ( + resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2" + ) + assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1] + assert ( + resp["ModelPackageSummaryList"][1]["ModelPackageDescription"] + == "test-model-package-description-2" + ) + + +@mock_sagemaker +def test_list_model_packages_creation_time_before(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + with freeze_time("2020-01-01 00:00:00"): + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + with freeze_time("2021-01-01 00:00:00"): + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages(CreationTimeBefore="2020-01-01T02:00:00Z") + + assert len(resp["ModelPackageSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_packages_creation_time_after(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Can't freeze time in ServerMode") + client = boto3.client("sagemaker", region_name="eu-west-1") + with freeze_time("2020-01-01 00:00:00"): + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + with freeze_time("2021-01-01 00:00:00"): + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages(CreationTimeAfter="2020-01-02T00:00:00Z") + + assert len(resp["ModelPackageSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_packages_name_contains(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + client.create_model_package( + ModelPackageName="another-model-package", + ModelPackageDescription="test-model-package-description-3", + ) + resp = client.list_model_packages(NameContains="test-model-package") + + assert len(resp["ModelPackageSummaryList"]) == 2 + + +@mock_sagemaker +def test_list_model_packages_approval_status(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ModelApprovalStatus="Approved", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ModelApprovalStatus="Rejected", + ) + resp = client.list_model_packages(ModelApprovalStatus="Approved") + + assert len(resp["ModelPackageSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_packages_model_package_group_name(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ModelPackageGroupName="test-model-package-group", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ModelPackageGroupName="test-model-package-group", + ) + resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group") + + assert len(resp["ModelPackageSummaryList"]) == 2 + + +@mock_sagemaker +def test_list_model_packages_model_package_type(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ModelPackageGroupName="test-model-package-group", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages(ModelPackageType="Versioned") + + assert len(resp["ModelPackageSummaryList"]) == 1 + + +@mock_sagemaker +def test_list_model_packages_sort_by(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages(SortBy="CreationTime") + + assert ( + resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package" + ) + assert ( + resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2" + ) + + +@mock_sagemaker +def test_list_model_packages_sort_order(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + client.create_model_package( + ModelPackageName="test-model-package-2", + ModelPackageDescription="test-model-package-description-2", + ) + resp = client.list_model_packages(SortOrder="Descending") + + assert ( + resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package-2" + ) + assert ( + resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package" + ) + + +@mock_sagemaker +def test_describe_model_package(): + client = boto3.client("sagemaker", region_name="eu-west-1") + client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + resp = client.describe_model_package(ModelPackageName="test-model-package") + assert resp["ModelPackageName"] == "test-model-package" + assert resp["ModelPackageDescription"] == "test-model-package-description" + + +@mock_sagemaker +def test_create_model_package(): + client = boto3.client("sagemaker", region_name="eu-west-1") + resp = client.create_model_package( + ModelPackageName="test-model-package", + ModelPackageDescription="test-model-package-description", + ) + assert ( + resp["ModelPackageArn"] + == "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package" + ) + + +@mock_sagemaker +def test_create_model_package_group(): + client = boto3.client("sagemaker", region_name="us-east-2") + resp = client.create_model_package_group( + ModelPackageGroupName="test-model-package-group", + ModelPackageGroupDescription="test-model-package-group-description", + Tags=[ + {"Key": "test-key", "Value": "test-value"}, + ], + ) + assert ( + resp["ModelPackageGroupArn"] + == "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group" + )