moto/tests/test_sagemaker/test_sagemaker_model_packages.py

259 lines
9.3 KiB
Python
Raw Normal View History

"""Unit tests for sagemaker-supported APIs."""
from unittest import SkipTest
import boto3
from freezegun import freeze_time
from moto import mock_sagemaker, settings
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@mock_sagemaker
def test_list_model_packages():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-v1",
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-v2",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-v1-2",
)
resp = client.list_model_packages()
assert (
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package"
)
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][0]
assert (
resp["ModelPackageSummaryList"][0]["ModelPackageDescription"]
== "test-model-package-description-v2"
)
assert (
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
)
assert "ModelPackageDescription" in resp["ModelPackageSummaryList"][1]
assert (
resp["ModelPackageSummaryList"][1]["ModelPackageDescription"]
== "test-model-package-description-v1-2"
)
@mock_sagemaker
def test_list_model_packages_creation_time_before():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
with freeze_time("2020-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
with freeze_time("2021-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
resp = client.list_model_packages(CreationTimeBefore="2020-01-01T02:00:00Z")
assert len(resp["ModelPackageSummaryList"]) == 1
@mock_sagemaker
def test_list_model_packages_creation_time_after():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't freeze time in ServerMode")
client = boto3.client("sagemaker", region_name="eu-west-1")
with freeze_time("2020-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
with freeze_time("2021-01-01 00:00:00"):
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
resp = client.list_model_packages(CreationTimeAfter="2020-01-02T00:00:00Z")
assert len(resp["ModelPackageSummaryList"]) == 1
@mock_sagemaker
def test_list_model_packages_name_contains():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
client.create_model_package(
ModelPackageName="another-model-package",
ModelPackageDescription="test-model-package-description-3",
)
resp = client.list_model_packages(NameContains="test-model-package")
assert len(resp["ModelPackageSummaryList"]) == 2
@mock_sagemaker
def test_list_model_packages_approval_status():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelApprovalStatus="Approved",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
ModelApprovalStatus="Rejected",
)
resp = client.list_model_packages(ModelApprovalStatus="Approved")
assert len(resp["ModelPackageSummaryList"]) == 1
@mock_sagemaker
def test_list_model_packages_model_package_group_name():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group",
)
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description-2",
ModelPackageGroupName="test-model-package-group",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-3",
ModelPackageGroupName="test-model-package-group",
)
client.create_model_package(
ModelPackageName="test-model-package-without-group",
ModelPackageDescription="test-model-package-description-without-group",
)
resp = client.list_model_packages(ModelPackageGroupName="test-model-package-group")
assert len(resp["ModelPackageSummaryList"]) == 3
@mock_sagemaker
def test_list_model_packages_model_package_type():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
ModelPackageGroupName="test-model-package-group",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
resp = client.list_model_packages(ModelPackageType="Versioned")
assert len(resp["ModelPackageSummaryList"]) == 1
@mock_sagemaker
def test_list_model_packages_sort_by():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
resp = client.list_model_packages(SortBy="CreationTime")
assert (
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package"
)
assert (
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package-2"
)
@mock_sagemaker
def test_list_model_packages_sort_order():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
client.create_model_package(
ModelPackageName="test-model-package-2",
ModelPackageDescription="test-model-package-description-2",
)
resp = client.list_model_packages(SortOrder="Descending")
assert (
resp["ModelPackageSummaryList"][0]["ModelPackageName"] == "test-model-package-2"
)
assert (
resp["ModelPackageSummaryList"][1]["ModelPackageName"] == "test-model-package"
)
@mock_sagemaker
def test_describe_model_package():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
resp = client.describe_model_package(ModelPackageName="test-model-package")
assert resp["ModelPackageName"] == "test-model-package"
assert resp["ModelPackageDescription"] == "test-model-package-description"
@mock_sagemaker
def test_create_model_package():
client = boto3.client("sagemaker", region_name="eu-west-1")
resp = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageDescription="test-model-package-description",
)
assert (
resp["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package"
)
@mock_sagemaker
def test_create_model_package_in_model_package_group():
client = boto3.client("sagemaker", region_name="eu-west-1")
client.create_model_package_group(ModelPackageGroupName="test-model-package-group")
resp_version_1 = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
resp_version_2 = client.create_model_package(
ModelPackageName="test-model-package",
ModelPackageGroupName="test-model-package-group",
ModelPackageDescription="test-model-package-description",
)
assert (
resp_version_1["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/1"
)
assert (
resp_version_2["ModelPackageArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:model-package/test-model-package/2"
)