Sagemaker: Add create_model_package_group (#6454)
This commit is contained in:
parent
9222b81825
commit
8cc5155cc6
@ -51,6 +51,13 @@ PAGINATION_MODEL = {
|
|||||||
"unique_attribute": "Key",
|
"unique_attribute": "Key",
|
||||||
"fail_on_invalid_token": True,
|
"fail_on_invalid_token": True,
|
||||||
},
|
},
|
||||||
|
"list_model_packages": {
|
||||||
|
"input_token": "next_token",
|
||||||
|
"limit_key": "max_results",
|
||||||
|
"limit_default": 100,
|
||||||
|
"unique_attribute": "ModelPackageArn",
|
||||||
|
"fail_on_invalid_token": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -911,6 +918,152 @@ class Model(BaseObject, CloudFormationModel):
|
|||||||
sagemaker_backends[account_id][region_name].delete_model(model_name)
|
sagemaker_backends[account_id][region_name].delete_model(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPackageGroup(BaseObject):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_package_group_name: str,
|
||||||
|
model_package_group_description: str,
|
||||||
|
account_id: str,
|
||||||
|
region_name: str,
|
||||||
|
tags: Optional[List[Dict[str, str]]] = None,
|
||||||
|
) -> None:
|
||||||
|
model_package_group_arn = arn_formatter(
|
||||||
|
region_name=region_name,
|
||||||
|
account_id=account_id,
|
||||||
|
_type="model-package-group",
|
||||||
|
_id=model_package_group_name,
|
||||||
|
)
|
||||||
|
fake_user_profile_name = "fake-user-profile-name"
|
||||||
|
fake_domain_id = "fake-domain-id"
|
||||||
|
fake_user_profile_arn = arn_formatter(
|
||||||
|
_type="user-profile",
|
||||||
|
_id=f"{fake_domain_id}/{fake_user_profile_name}",
|
||||||
|
account_id=account_id,
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
self.model_package_group_name = model_package_group_name
|
||||||
|
self.model_package_group_arn = model_package_group_arn
|
||||||
|
self.model_package_group_description = model_package_group_description
|
||||||
|
self.creation_time = datetime.now()
|
||||||
|
self.created_by = {
|
||||||
|
"UserProfileArn": fake_user_profile_arn,
|
||||||
|
"UserProfileName": fake_user_profile_name,
|
||||||
|
"DomainId": fake_domain_id,
|
||||||
|
}
|
||||||
|
self.model_package_group_status = "Completed"
|
||||||
|
self.tags = tags
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPackage(BaseObject):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_package_name: str,
|
||||||
|
model_package_group_name: Optional[str],
|
||||||
|
model_package_version: Optional[int],
|
||||||
|
model_package_description: Optional[str],
|
||||||
|
inference_specification: Any,
|
||||||
|
source_algorithm_specification: Any,
|
||||||
|
validation_specification: Any,
|
||||||
|
certify_for_marketplace: bool,
|
||||||
|
model_approval_status: str,
|
||||||
|
metadata_properties: Any,
|
||||||
|
model_metrics: Any,
|
||||||
|
approval_description: str,
|
||||||
|
customer_metadata_properties: Any,
|
||||||
|
drift_check_baselines: Any,
|
||||||
|
domain: str,
|
||||||
|
task: str,
|
||||||
|
sample_payload_url: str,
|
||||||
|
additional_inference_specifications: List[Any],
|
||||||
|
client_token: str,
|
||||||
|
region_name: str,
|
||||||
|
account_id: str,
|
||||||
|
tags: Optional[List[Dict[str, str]]] = None,
|
||||||
|
) -> None:
|
||||||
|
fake_user_profile_name = "fake-user-profile-name"
|
||||||
|
fake_domain_id = "fake-domain-id"
|
||||||
|
fake_user_profile_arn = arn_formatter(
|
||||||
|
_type="user-profile",
|
||||||
|
_id=f"{fake_domain_id}/{fake_user_profile_name}",
|
||||||
|
account_id=account_id,
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
model_package_arn = arn_formatter(
|
||||||
|
region_name=region_name,
|
||||||
|
account_id=account_id,
|
||||||
|
_type="model-package",
|
||||||
|
_id=model_package_name,
|
||||||
|
)
|
||||||
|
datetime_now = datetime.utcnow()
|
||||||
|
self.model_package_name = model_package_name
|
||||||
|
self.model_package_group_name = model_package_group_name
|
||||||
|
self.model_package_version = model_package_version
|
||||||
|
self.model_package_arn = model_package_arn
|
||||||
|
self.model_package_description = model_package_description
|
||||||
|
self.creation_time = datetime_now
|
||||||
|
self.inference_specification = inference_specification
|
||||||
|
self.source_algorithm_specification = source_algorithm_specification
|
||||||
|
self.validation_specification = validation_specification
|
||||||
|
self.model_package_status_details = (
|
||||||
|
{
|
||||||
|
"ValidationStatuses": [
|
||||||
|
{
|
||||||
|
"Name": model_package_arn,
|
||||||
|
"Status": "Completed",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"ImageScanStatuses": [
|
||||||
|
{
|
||||||
|
"Name": model_package_arn,
|
||||||
|
"Status": "Completed",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.certify_for_marketplace = certify_for_marketplace
|
||||||
|
self.model_approval_status = model_approval_status
|
||||||
|
self.created_by = {
|
||||||
|
"UserProfileArn": fake_user_profile_arn,
|
||||||
|
"UserProfileName": fake_user_profile_name,
|
||||||
|
"DomainId": fake_domain_id,
|
||||||
|
}
|
||||||
|
self.metadata_properties = metadata_properties
|
||||||
|
self.model_metrics = model_metrics
|
||||||
|
self.last_modified_time = datetime_now
|
||||||
|
self.approval_description = approval_description
|
||||||
|
self.customer_metadata_properties = customer_metadata_properties
|
||||||
|
self.drift_check_baselines = drift_check_baselines
|
||||||
|
self.domain = domain
|
||||||
|
self.task = task
|
||||||
|
self.sample_payload_url = sample_payload_url
|
||||||
|
self.additional_inference_specifications = additional_inference_specifications
|
||||||
|
self.tags = tags
|
||||||
|
self.model_package_status = "Completed"
|
||||||
|
self.last_modified_by = {
|
||||||
|
"UserProfileArn": fake_user_profile_arn,
|
||||||
|
"UserProfileName": fake_user_profile_name,
|
||||||
|
"DomainId": fake_domain_id,
|
||||||
|
}
|
||||||
|
self.client_token = client_token
|
||||||
|
|
||||||
|
def gen_response_object(self) -> Dict[str, Any]:
|
||||||
|
response_object = super().gen_response_object()
|
||||||
|
for k, v in response_object.items():
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
response_object[k] = v.isoformat()
|
||||||
|
response_values = [
|
||||||
|
"ModelPackageName",
|
||||||
|
"ModelPackageGroupName",
|
||||||
|
"ModelPackageVersion",
|
||||||
|
"ModelPackageArn",
|
||||||
|
"ModelPackageDescription",
|
||||||
|
"CreationTime",
|
||||||
|
"ModelPackageStatus",
|
||||||
|
"ModelApprovalStatus",
|
||||||
|
]
|
||||||
|
return {k: v for k, v in response_object.items() if k in response_values}
|
||||||
|
|
||||||
|
|
||||||
class VpcConfig(BaseObject):
|
class VpcConfig(BaseObject):
|
||||||
def __init__(self, security_group_ids: List[str], subnets: List[str]):
|
def __init__(self, security_group_ids: List[str], subnets: List[str]):
|
||||||
self.security_group_ids = security_group_ids
|
self.security_group_ids = security_group_ids
|
||||||
@ -1277,6 +1430,9 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
self.notebook_instance_lifecycle_configurations: Dict[
|
self.notebook_instance_lifecycle_configurations: Dict[
|
||||||
str, FakeSageMakerNotebookInstanceLifecycleConfig
|
str, FakeSageMakerNotebookInstanceLifecycleConfig
|
||||||
] = {}
|
] = {}
|
||||||
|
self.model_package_groups: Dict[str, ModelPackageGroup] = {}
|
||||||
|
self.model_packages: Dict[str, ModelPackage] = {}
|
||||||
|
self.model_package_name_mapping: Dict[str, str] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(
|
def default_vpc_endpoint_service(
|
||||||
@ -2671,6 +2827,164 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
endpoint.endpoint_status = "InService"
|
endpoint.endpoint_status = "InService"
|
||||||
return endpoint.endpoint_arn
|
return endpoint.endpoint_arn
|
||||||
|
|
||||||
|
def create_model_package_group(
|
||||||
|
self,
|
||||||
|
model_package_group_name: str,
|
||||||
|
model_package_group_description: str,
|
||||||
|
tags: Optional[List[Dict[str, str]]] = None,
|
||||||
|
) -> str:
|
||||||
|
self.model_package_groups[model_package_group_name] = ModelPackageGroup(
|
||||||
|
model_package_group_name=model_package_group_name,
|
||||||
|
model_package_group_description=model_package_group_description,
|
||||||
|
account_id=self.account_id,
|
||||||
|
region_name=self.region_name,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
return self.model_package_groups[
|
||||||
|
model_package_group_name
|
||||||
|
].model_package_group_arn
|
||||||
|
|
||||||
|
def _get_versioned_or_not(
|
||||||
|
self, model_package_type: Optional[str], model_package_version: Optional[int]
|
||||||
|
) -> bool:
|
||||||
|
if model_package_type == "Versioned":
|
||||||
|
return model_package_version is not None
|
||||||
|
elif model_package_type == "Unversioned" or model_package_type is None:
|
||||||
|
return model_package_version is None
|
||||||
|
elif model_package_type == "Both":
|
||||||
|
return True
|
||||||
|
raise ValueError(f"Invalid model package type: {model_package_type}")
|
||||||
|
|
||||||
|
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
|
||||||
|
def list_model_packages( # type: ignore[misc]
|
||||||
|
self,
|
||||||
|
creation_time_after: Optional[int],
|
||||||
|
creation_time_before: Optional[int],
|
||||||
|
name_contains: Optional[str],
|
||||||
|
model_approval_status: Optional[str],
|
||||||
|
model_package_group_name: Optional[str],
|
||||||
|
model_package_type: Optional[str],
|
||||||
|
sort_by: Optional[str],
|
||||||
|
sort_order: Optional[str],
|
||||||
|
) -> List[ModelPackage]:
|
||||||
|
if isinstance(creation_time_before, int):
|
||||||
|
creation_time_before_datetime = datetime.fromtimestamp(creation_time_before)
|
||||||
|
if isinstance(creation_time_after, int):
|
||||||
|
creation_time_after_datetime = datetime.fromtimestamp(creation_time_after)
|
||||||
|
if model_package_group_name is not None:
|
||||||
|
model_package_type = "Versioned"
|
||||||
|
model_package_summary_list = list(
|
||||||
|
filter(
|
||||||
|
lambda x: (
|
||||||
|
creation_time_after is None
|
||||||
|
or x.creation_time > creation_time_after_datetime
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
creation_time_before is None
|
||||||
|
or x.creation_time < creation_time_before_datetime
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
name_contains is None
|
||||||
|
or x.model_package_name.find(name_contains) != -1
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
model_approval_status is None
|
||||||
|
or x.model_approval_status == model_approval_status
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
model_package_group_name is None
|
||||||
|
or x.model_package_group_name == model_package_group_name
|
||||||
|
)
|
||||||
|
and self._get_versioned_or_not(
|
||||||
|
model_package_type, x.model_package_version
|
||||||
|
),
|
||||||
|
self.model_packages.values(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_package_summary_list = list(
|
||||||
|
sorted(
|
||||||
|
model_package_summary_list,
|
||||||
|
key={
|
||||||
|
"Name": lambda x: x.model_package_name,
|
||||||
|
"CreationTime": lambda x: x.creation_time,
|
||||||
|
None: lambda x: x.creation_time,
|
||||||
|
}[sort_by],
|
||||||
|
reverse=sort_order == "Descending",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model_package_summary_list
|
||||||
|
|
||||||
|
def describe_model_package(self, model_package_name: str) -> ModelPackage:
|
||||||
|
model_package_name_mapped = self.model_package_name_mapping.get(
|
||||||
|
model_package_name, model_package_name
|
||||||
|
)
|
||||||
|
model_package = self.model_packages.get(model_package_name_mapped)
|
||||||
|
if model_package is None:
|
||||||
|
raise ValidationError(f"Model package {model_package_name} not found")
|
||||||
|
return model_package
|
||||||
|
|
||||||
|
def create_model_package(
|
||||||
|
self,
|
||||||
|
model_package_name: str,
|
||||||
|
model_package_group_name: Optional[str],
|
||||||
|
model_package_description: Optional[str],
|
||||||
|
inference_specification: Any,
|
||||||
|
validation_specification: Any,
|
||||||
|
source_algorithm_specification: Any,
|
||||||
|
certify_for_marketplace: Any,
|
||||||
|
tags: Any,
|
||||||
|
model_approval_status: str,
|
||||||
|
metadata_properties: Any,
|
||||||
|
model_metrics: Any,
|
||||||
|
client_token: Any,
|
||||||
|
customer_metadata_properties: Any,
|
||||||
|
drift_check_baselines: Any,
|
||||||
|
domain: Any,
|
||||||
|
task: Any,
|
||||||
|
sample_payload_url: Any,
|
||||||
|
additional_inference_specifications: Any,
|
||||||
|
) -> str:
|
||||||
|
model_package_version = None
|
||||||
|
if model_package_group_name is not None:
|
||||||
|
model_packages_for_group = [
|
||||||
|
x
|
||||||
|
for x in self.model_packages.values()
|
||||||
|
if x.model_package_group_name == model_package_group_name
|
||||||
|
]
|
||||||
|
model_package_version = len(model_packages_for_group) + 1
|
||||||
|
model_package = ModelPackage(
|
||||||
|
model_package_name=model_package_name,
|
||||||
|
model_package_group_name=model_package_group_name,
|
||||||
|
model_package_description=model_package_description,
|
||||||
|
inference_specification=inference_specification,
|
||||||
|
validation_specification=validation_specification,
|
||||||
|
source_algorithm_specification=source_algorithm_specification,
|
||||||
|
certify_for_marketplace=certify_for_marketplace,
|
||||||
|
tags=tags,
|
||||||
|
model_approval_status=model_approval_status,
|
||||||
|
metadata_properties=metadata_properties,
|
||||||
|
model_metrics=model_metrics,
|
||||||
|
customer_metadata_properties=customer_metadata_properties,
|
||||||
|
drift_check_baselines=drift_check_baselines,
|
||||||
|
domain=domain,
|
||||||
|
task=task,
|
||||||
|
sample_payload_url=sample_payload_url,
|
||||||
|
additional_inference_specifications=additional_inference_specifications,
|
||||||
|
model_package_version=model_package_version,
|
||||||
|
approval_description=model_approval_status,
|
||||||
|
region_name=self.region_name,
|
||||||
|
account_id=self.account_id,
|
||||||
|
client_token=client_token,
|
||||||
|
)
|
||||||
|
self.model_package_name_mapping[
|
||||||
|
model_package.model_package_name
|
||||||
|
] = model_package.model_package_arn
|
||||||
|
self.model_package_name_mapping[
|
||||||
|
model_package.model_package_arn
|
||||||
|
] = model_package.model_package_arn
|
||||||
|
self.model_packages[model_package.model_package_arn] = model_package
|
||||||
|
return model_package.model_package_arn
|
||||||
|
|
||||||
|
|
||||||
class FakeExperiment(BaseObject):
|
class FakeExperiment(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -796,3 +796,104 @@ class SageMakerResponse(BaseResponse):
|
|||||||
desired_weights_and_capacities=desired_weights_and_capacities,
|
desired_weights_and_capacities=desired_weights_and_capacities,
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
|
return 200, {}, json.dumps({"EndpointArn": endpoint_arn})
|
||||||
|
|
||||||
|
def list_model_packages(self) -> str:
|
||||||
|
creation_time_after = self._get_param("CreationTimeAfter")
|
||||||
|
creation_time_before = self._get_param("CreationTimeBefore")
|
||||||
|
max_results = self._get_param("MaxResults")
|
||||||
|
name_contains = self._get_param("NameContains")
|
||||||
|
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||||
|
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||||
|
model_package_type = self._get_param("ModelPackageType", "Unversioned")
|
||||||
|
next_token = self._get_param("NextToken")
|
||||||
|
sort_by = self._get_param("SortBy")
|
||||||
|
sort_order = self._get_param("SortOrder")
|
||||||
|
(
|
||||||
|
model_package_summary_list,
|
||||||
|
next_token,
|
||||||
|
) = self.sagemaker_backend.list_model_packages(
|
||||||
|
creation_time_after=creation_time_after,
|
||||||
|
creation_time_before=creation_time_before,
|
||||||
|
max_results=max_results,
|
||||||
|
name_contains=name_contains,
|
||||||
|
model_approval_status=model_approval_status,
|
||||||
|
model_package_group_name=model_package_group_name,
|
||||||
|
model_package_type=model_package_type,
|
||||||
|
next_token=next_token,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
model_package_summary_list_response_object = [
|
||||||
|
x.gen_response_object() for x in model_package_summary_list
|
||||||
|
]
|
||||||
|
return json.dumps(
|
||||||
|
dict(
|
||||||
|
ModelPackageSummaryList=model_package_summary_list_response_object,
|
||||||
|
NextToken=next_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def describe_model_package(self) -> str:
|
||||||
|
model_package_name = self._get_param("ModelPackageName")
|
||||||
|
model_package = self.sagemaker_backend.describe_model_package(
|
||||||
|
model_package_name=model_package_name,
|
||||||
|
)
|
||||||
|
return json.dumps(
|
||||||
|
model_package.gen_response_object(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_model_package(self) -> str:
|
||||||
|
model_package_name = self._get_param("ModelPackageName")
|
||||||
|
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||||
|
model_package_description = self._get_param("ModelPackageDescription")
|
||||||
|
inference_specification = self._get_param("InferenceSpecification")
|
||||||
|
validation_specification = self._get_param("ValidationSpecification")
|
||||||
|
source_algorithm_specification = self._get_param("SourceAlgorithmSpecification")
|
||||||
|
certify_for_marketplace = self._get_param("CertifyForMarketplace")
|
||||||
|
tags = self._get_param("Tags")
|
||||||
|
model_approval_status = self._get_param("ModelApprovalStatus")
|
||||||
|
metadata_properties = self._get_param("MetadataProperties")
|
||||||
|
model_metrics = self._get_param("ModelMetrics")
|
||||||
|
client_token = self._get_param("ClientToken")
|
||||||
|
customer_metadata_properties = self._get_param("CustomerMetadataProperties")
|
||||||
|
drift_check_baselines = self._get_param("DriftCheckBaselines")
|
||||||
|
domain = self._get_param("Domain")
|
||||||
|
task = self._get_param("Task")
|
||||||
|
sample_payload_url = self._get_param("SamplePayloadUrl")
|
||||||
|
additional_inference_specifications = self._get_param(
|
||||||
|
"AdditionalInferenceSpecifications"
|
||||||
|
)
|
||||||
|
model_package_arn = self.sagemaker_backend.create_model_package(
|
||||||
|
model_package_name=model_package_name,
|
||||||
|
model_package_group_name=model_package_group_name,
|
||||||
|
model_package_description=model_package_description,
|
||||||
|
inference_specification=inference_specification,
|
||||||
|
validation_specification=validation_specification,
|
||||||
|
source_algorithm_specification=source_algorithm_specification,
|
||||||
|
certify_for_marketplace=certify_for_marketplace,
|
||||||
|
tags=tags,
|
||||||
|
model_approval_status=model_approval_status,
|
||||||
|
metadata_properties=metadata_properties,
|
||||||
|
model_metrics=model_metrics,
|
||||||
|
customer_metadata_properties=customer_metadata_properties,
|
||||||
|
drift_check_baselines=drift_check_baselines,
|
||||||
|
domain=domain,
|
||||||
|
task=task,
|
||||||
|
sample_payload_url=sample_payload_url,
|
||||||
|
additional_inference_specifications=additional_inference_specifications,
|
||||||
|
client_token=client_token,
|
||||||
|
)
|
||||||
|
return json.dumps(dict(ModelPackageArn=model_package_arn))
|
||||||
|
|
||||||
|
def create_model_package_group(self) -> str:
|
||||||
|
model_package_group_name = self._get_param("ModelPackageGroupName")
|
||||||
|
model_package_group_description = self._get_param(
|
||||||
|
"ModelPackageGroupDescription"
|
||||||
|
)
|
||||||
|
tags = self._get_param("Tags")
|
||||||
|
model_package_group_arn = self.sagemaker_backend.create_model_package_group(
|
||||||
|
model_package_group_name=model_package_group_name,
|
||||||
|
model_package_group_description=model_package_group_description,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
return json.dumps(dict(ModelPackageGroupArn=model_package_group_arn))
|
||||||
|
236
tests/test_sagemaker/test_sagemaker_model_packages.py
Normal file
236
tests/test_sagemaker/test_sagemaker_model_packages.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
"""Unit tests for sagemaker-supported APIs."""
|
||||||
|
import boto3
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from moto import mock_sagemaker, settings
|
||||||
|
from unittest import SkipTest
|
||||||
|
|
||||||
|
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||||
|
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package"
|
||||||
|
)
|
||||||
|
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0]
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][0]["ModelPackageDescription"]
|
||||||
|
== "test-model-package-description"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
|
||||||
|
)
|
||||||
|
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1]
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][1]["ModelPackageDescription"]
|
||||||
|
== "test-model-package-description-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_creation_time_before():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
with freeze_time("2020-01-01 00:00:00"):
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
with freeze_time("2021-01-01 00:00:00"):
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(CreationTimeBefore="2020-01-01T02:00:00Z")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_creation_time_after():
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest("Can't freeze time in ServerMode")
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
with freeze_time("2020-01-01 00:00:00"):
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
with freeze_time("2021-01-01 00:00:00"):
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(CreationTimeAfter="2020-01-02T00:00:00Z")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_name_contains():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="another-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description-3",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(NameContains="test-model-package")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_approval_status():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
ModelApprovalStatus="Approved",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
ModelApprovalStatus="Rejected",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(ModelApprovalStatus="Approved")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_model_package_group_name():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_model_package_type():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(ModelPackageType="Versioned")
|
||||||
|
|
||||||
|
assert len(resp["ModelPackageSummaryList"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_sort_by():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(SortBy="CreationTime")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_list_model_packages_sort_order():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package-2",
|
||||||
|
ModelPackageDescription="test-model-package-description-2",
|
||||||
|
)
|
||||||
|
resp = client.list_model_packages(SortOrder="Descending")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package-2"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_describe_model_package():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
resp = client.describe_model_package(ModelPackageName="test-model-package")
|
||||||
|
assert resp["ModelPackageName"] == "test-model-package"
|
||||||
|
assert resp["ModelPackageDescription"] == "test-model-package-description"
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_model_package():
|
||||||
|
client = boto3.client("sagemaker", region_name="eu-west-1")
|
||||||
|
resp = client.create_model_package(
|
||||||
|
ModelPackageName="test-model-package",
|
||||||
|
ModelPackageDescription="test-model-package-description",
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageArn"]
|
||||||
|
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_model_package_group():
|
||||||
|
client = boto3.client("sagemaker", region_name="us-east-2")
|
||||||
|
resp = client.create_model_package_group(
|
||||||
|
ModelPackageGroupName="test-model-package-group",
|
||||||
|
ModelPackageGroupDescription="test-model-package-group-description",
|
||||||
|
Tags=[
|
||||||
|
{"Key": "test-key", "Value": "test-value"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
resp["ModelPackageGroupArn"]
|
||||||
|
== "arn:aws:sagemaker:us-east-2:123456789012:model-package-group/test-model-package-group"
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user