Techdebt: Replace sure with regular assertions in sagemaker (#6614)

This commit is contained in:
kbalk 2023-08-08 06:06:51 -04:00 committed by GitHub
parent 2f8019052d
commit 56153be9d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 420 additions and 356 deletions

View File

@ -1,8 +1,8 @@
import boto3
import re
import pytest
import sure # noqa # pylint: disable=unused-import
import boto3
from botocore.exceptions import ClientError
import pytest
from moto import mock_cloudformation, mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
@ -53,8 +53,8 @@ def test_sagemaker_cloudformation_create(test_config):
provisioned_resource = cf.list_stack_resources(StackName=stack_name)[
"StackResourceSummaries"
][0]
provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name)
len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0)
assert provisioned_resource["LogicalResourceId"] == test_config.resource_name
assert len(provisioned_resource["PhysicalResourceId"]) > 0
@mock_cloudformation
@ -87,7 +87,7 @@ def test_sagemaker_cloudformation_get_attr(test_config):
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
assert outputs["Arn"] == resource_description[test_config.arn_parameter]
@mock_cloudformation
@ -122,7 +122,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
assert outputs["Arn"] == resource_description[test_config.arn_parameter]
# Delete the stack and verify resource has also been deleted
cf.delete_stack(StackName=stack_name)
@ -130,7 +130,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me
getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
ce.value.response["Error"]["Message"].should.contain(error_message)
assert error_message in ce.value.response["Error"]["Message"]
@mock_cloudformation
@ -160,19 +160,19 @@ def test_sagemaker_cloudformation_notebook_instance_update():
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_notebook_name}
)
initial_instance_type.should.equal(resource_description["InstanceType"])
assert initial_instance_type == resource_description["InstanceType"]
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_notebook_name = outputs["Name"]
updated_notebook_name.should.equal(initial_notebook_name)
assert updated_notebook_name == initial_notebook_name
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_notebook_name}
)
updated_instance_type.should.equal(resource_description["InstanceType"])
assert updated_instance_type == resource_description["InstanceType"]
@mock_cloudformation
@ -202,25 +202,21 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_config_name}
)
len(resource_description["OnCreate"]).should.equal(1)
initial_on_create_script.should.equal(
resource_description["OnCreate"][0]["Content"]
)
assert len(resource_description["OnCreate"]) == 1
assert initial_on_create_script == resource_description["OnCreate"][0]["Content"]
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_config_name = outputs["Name"]
updated_config_name.should.equal(initial_config_name)
assert updated_config_name == initial_config_name
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_config_name}
)
len(resource_description["OnCreate"]).should.equal(1)
updated_on_create_script.should.equal(
resource_description["OnCreate"][0]["Content"]
)
assert len(resource_description["OnCreate"]) == 1
assert updated_on_create_script == resource_description["OnCreate"][0]["Content"]
@mock_cloudformation
@ -251,7 +247,7 @@ def test_sagemaker_cloudformation_model_update():
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_model_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
assert resource_description["PrimaryContainer"]["Image"] == (
image.format(initial_image_version)
)
@ -260,12 +256,12 @@ def test_sagemaker_cloudformation_model_update():
outputs = _get_stack_outputs(cf, stack_name)
updated_model_name = outputs["Name"]
updated_model_name.should_not.equal(initial_model_name)
assert updated_model_name != initial_model_name
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_model_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
assert resource_description["PrimaryContainer"]["Image"] == (
image.format(updated_image_version)
)
@ -300,7 +296,7 @@ def test_sagemaker_cloudformation_endpoint_config_update():
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
assert len(resource_description["ProductionVariants"]) == (
initial_num_production_variants
)
@ -309,12 +305,12 @@ def test_sagemaker_cloudformation_endpoint_config_update():
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_config_name = outputs["Name"]
updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
assert updated_endpoint_config_name != initial_endpoint_config_name
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
assert len(resource_description["ProductionVariants"]) == (
updated_num_production_variants
)
@ -365,8 +361,8 @@ def test_sagemaker_cloudformation_endpoint_update():
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
initial_endpoint_config_name
assert re.match(
initial_endpoint_config_name, resource_description["EndpointConfigName"]
)
# Create additional SM resources and update stack
@ -393,11 +389,11 @@ def test_sagemaker_cloudformation_endpoint_update():
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_name = outputs["Name"]
updated_endpoint_name.should.equal(initial_endpoint_name)
assert updated_endpoint_name == initial_endpoint_name
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
updated_endpoint_config_name
assert re.match(
updated_endpoint_config_name, resource_description["EndpointConfigName"]
)

View File

@ -1,9 +1,9 @@
import datetime
import re
import uuid
import boto3
from botocore.exceptions import ClientError
import sure # noqa # pylint: disable=unused-import
from moto import mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
@ -55,18 +55,20 @@ def create_endpoint_config_helper(sagemaker_client, production_variants):
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=production_variants,
)
resp["EndpointConfigArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
resp["EndpointConfigArn"],
)
resp = sagemaker_client.describe_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
resp["EndpointConfigArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
resp["EndpointConfigArn"],
)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["ProductionVariants"].should.equal(production_variants)
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
assert resp["ProductionVariants"] == production_variants
def test_create_endpoint_config(sagemaker_client):
@ -99,15 +101,17 @@ def test_delete_endpoint_config(sagemaker_client):
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
resp["EndpointConfigArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
resp["EndpointConfigArn"],
)
resp = sagemaker_client.describe_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
resp["EndpointConfigArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
resp["EndpointConfigArn"],
)
sagemaker_client.delete_endpoint_config(
@ -143,7 +147,10 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
ProductionVariants=production_variants,
)
assert e.value.response["Error"]["Code"] == "ValidationException"
expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: ["
expected_message = (
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
"constraint: Member must satisfy enum value set: ["
)
assert expected_message in e.value.response["Error"]["Message"]
@ -160,7 +167,10 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client):
ProductionVariants=production_variants,
)
assert e.value.response["Error"]["Code"] == "ValidationException"
expected_message = f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: ["
expected_message = (
f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy "
"constraint: Member must satisfy enum value set: ["
)
assert expected_message in e.value.response["Error"]["Message"]
@ -185,20 +195,20 @@ def test_create_endpoint(sagemaker_client):
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
Tags=GENERIC_TAGS_PARAM,
)
resp["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
assert resp["EndpointStatus"] == "InService"
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
assert resp["Tags"] == GENERIC_TAGS_PARAM
@ -224,7 +234,10 @@ def test_add_tags_endpoint(sagemaker_client):
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
resource_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":endpoint/{TEST_ENDPOINT_NAME}"
)
response = sagemaker_client.add_tags(
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
)
@ -239,7 +252,10 @@ def test_delete_tags_endpoint(sagemaker_client):
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
resource_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":endpoint/{TEST_ENDPOINT_NAME}"
)
response = sagemaker_client.add_tags(
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
)
@ -262,7 +278,10 @@ def test_list_tags_endpoint(sagemaker_client):
for _ in range(80):
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
resource_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":endpoint/{TEST_ENDPOINT_NAME}"
)
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
@ -295,29 +314,32 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
},
],
)
response["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$",
response["EndpointArn"],
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
assert resp["EndpointStatus"] == "InService"
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME
assert (
resp["ProductionVariants"][0]["DesiredInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
assert (
resp["ProductionVariants"][0]["CurrentInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight
assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
@ -364,39 +386,44 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
EndpointName=TEST_ENDPOINT_NAME,
DesiredWeightsAndCapacities=desired_weights_and_capacities,
)
response["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$",
response["EndpointArn"],
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
assert resp["EndpointStatus"] == "InService"
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1")
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
assert resp["ProductionVariants"][0]["VariantName"] == "MyProductionVariant1"
assert (
resp["ProductionVariants"][0]["DesiredInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
assert (
resp["ProductionVariants"][0]["CurrentInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight
assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight
resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2")
resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
assert resp["ProductionVariants"][1]["VariantName"] == "MyProductionVariant2"
assert (
resp["ProductionVariants"][1]["DesiredInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
assert (
resp["ProductionVariants"][1]["CurrentInstanceCount"]
== new_desired_instance_count
)
resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
assert resp["ProductionVariants"][1]["DesiredWeight"] == new_desired_weight
assert resp["ProductionVariants"][1]["CurrentWeight"] == new_desired_weight
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
@ -426,13 +453,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari
)
err = exc.value.response["Error"]
err["Message"].should.equal(
f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
assert err["Message"] == (
f'The variant name(s) "{variant_name}" is/are not present within '
f'endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
assert resp == old_resp
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
@ -463,13 +491,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp
)
err = exc.value.response["Error"]
err["Message"].should.equal(
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".'
assert err["Message"] == (
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:'
f'{ACCOUNT_ID}:endpoint/{endpoint_name}".'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
assert resp == old_resp
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
@ -502,13 +531,13 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonuniq
)
err = exc.value.response["Error"]
err["Message"].should.equal(
assert err["Message"] == (
f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
assert resp == old_resp
def _set_up_sagemaker_resources(
@ -552,8 +581,9 @@ def _create_endpoint_config(
resp = boto_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)
resp["EndpointConfigArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$",
resp["EndpointConfigArn"],
)
@ -561,6 +591,6 @@ def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
resp = boto_client.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
resp["EndpointArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$", resp["EndpointArn"]
)

View File

@ -1,9 +1,10 @@
"""Unit tests for sagemaker-supported APIs."""
from unittest import SkipTest
import boto3
from freezegun import freeze_time
from moto import mock_sagemaker, settings
from unittest import SkipTest
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html

View File

@ -1,10 +1,10 @@
import re
import boto3
from botocore.exceptions import ClientError
import pytest
from moto import mock_sagemaker
import sure # noqa # pylint: disable=unused-import
from moto.sagemaker.models import VpcConfig
TEST_REGION_NAME = "us-east-1"
@ -18,7 +18,7 @@ def fixture_sagemaker_client():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel(object):
class MySageMakerModel:
def __init__(self, name=None, arn=None, container=None, vpc_config=None):
self.name = name or TEST_MODEL_NAME
self.arn = arn or TEST_ARN
@ -41,13 +41,13 @@ def test_describe_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
assert model.get("ModelName") == TEST_MODEL_NAME
def test_describe_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown")
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
assert "Could not find model" in err.value.response["Error"]["Message"]
def test_create_model(sagemaker_client):
@ -57,8 +57,8 @@ def test_create_model(sagemaker_client):
ExecutionRoleArn=TEST_ARN,
VpcConfig=vpc_config.response_object,
)
model["ModelArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", model["ModelArn"]
)
@ -66,25 +66,26 @@ def test_delete_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
assert len(sagemaker_client.list_models()["Models"]).should.equal(1)
assert len(sagemaker_client.list_models()["Models"]) == 1
sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
assert len(sagemaker_client.list_models()["Models"]) == 0
def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah")
assert err.value.response["Error"]["Code"].should.equal("404")
assert err.value.response["Error"]["Code"] == "404"
def test_list_models(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(1)
assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME)
assert models["Models"][0]["ModelArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$"
assert len(models["Models"]) == 1
assert models["Models"][0]["ModelName"] == TEST_MODEL_NAME
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$",
models["Models"][0]["ModelArn"],
)
@ -99,12 +100,12 @@ def test_list_models_multiple(sagemaker_client):
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
test_model_2.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(2)
assert len(models["Models"]) == 2
def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(0)
assert len(models["Models"]) == 0
def test_add_tags_to_model(sagemaker_client):

View File

@ -1,11 +1,11 @@
import datetime
import boto3
from botocore.exceptions import ClientError
import sure # noqa # pylint: disable=unused-import
import pytest
from moto import mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
import pytest
TEST_REGION_NAME = "us-east-1"
FAKE_SUBNET_ID = "subnet-012345678"
@ -37,7 +37,10 @@ def _get_notebook_instance_arn(notebook_name):
def _get_notebook_instance_lifecycle_arn(lifecycle_name):
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
return (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":notebook-instance-lifecycle-configuration/{lifecycle_name}"
)
def test_create_notebook_instance_minimal_params(sagemaker_client):
@ -130,7 +133,10 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
with pytest.raises(ClientError) as ex:
sagemaker_client.create_notebook_instance(**args)
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: ["
expected_message = (
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
"constraint: Member must satisfy enum value set: ["
)
assert expected_message in ex.value.response["Error"]["Message"]
@ -153,7 +159,10 @@ def test_notebook_instance_lifecycle(sagemaker_client):
with pytest.raises(ClientError) as ex:
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = f"Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({notebook_instance_arn})"
expected_message = (
f"Status (InService) not in ([Stopped, Failed]). Unable to "
f"transition to (Deleting) for Notebook Instance ({notebook_instance_arn})"
)
assert expected_message in ex.value.response["Error"]["Message"]
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)

View File

@ -1,13 +1,14 @@
from contextlib import contextmanager
from moto import mock_sagemaker, settings
from time import sleep
from datetime import datetime
import boto3
import botocore
import json
import pytest
from time import sleep
from unittest import SkipTest
import boto3
import botocore
import pytest
from moto import mock_sagemaker, settings
from moto.s3 import mock_s3
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from moto.sagemaker.exceptions import ValidationError
@ -85,7 +86,10 @@ def test_utils_get_pipeline_from_name_not_exists():
def test_utils_get_pipeline_name_from_execution_arn():
expected_pipeline_name = "some-pipeline-name"
pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456"
pipeline_execution_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":pipeline/{expected_pipeline_name}/execution/abc123def456"
)
observed_pipeline_name = get_pipeline_name_from_execution_arn(
pipeline_execution_arn=pipeline_execution_arn
)
@ -266,7 +270,7 @@ def test_describe_pipeline_execution(sagemaker_client):
PipelineExecutionArn=response["PipelineExecutionArn"]
)
observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"]
observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn)
assert observed_pipeline_execution_arn == expected_pipeline_execution_arn
def test_load_pipeline_definition_from_s3():
@ -292,7 +296,7 @@ def test_load_pipeline_definition_from_s3():
},
account_id=ACCOUNT_ID,
)
observed_pipeline_definition.should.equal(pipeline_definition)
assert observed_pipeline_definition == pipeline_definition
def test_create_pipeline(sagemaker_client):
@ -303,7 +307,7 @@ def test_create_pipeline(sagemaker_client):
PipelineDefinition=" ",
)
assert isinstance(response, dict)
response["PipelineArn"].should.equal(
assert response["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
)
@ -353,7 +357,7 @@ def test_create_pipeline_duplicate_pipeline_name(sagemaker_client):
def test_list_pipelines_none(sagemaker_client):
response = sagemaker_client.list_pipelines()
assert isinstance(response, dict)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
def test_list_pipelines_single(sagemaker_client):
@ -368,8 +372,8 @@ def test_list_pipelines_single(sagemaker_client):
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
assert len(response["PipelineSummaries"]) == 1
assert response["PipelineSummaries"][0]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
@ -390,7 +394,7 @@ def test_list_pipelines_multiple(sagemaker_client):
SortBy="Name",
SortOrder="Ascending",
)
response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names))
assert len(response["PipelineSummaries"]) == len(fake_pipeline_names)
def test_list_pipelines_sort_name_ascending(sagemaker_client):
@ -409,13 +413,13 @@ def test_list_pipelines_sort_name_ascending(sagemaker_client):
SortBy="Name",
SortOrder="Ascending",
)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][0]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][-1]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][-1]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][1]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
)
@ -436,13 +440,13 @@ def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
SortBy="CreationTime",
SortOrder="Descending",
)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][0]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][1]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][2]["PipelineArn"].should.equal(
assert response["PipelineSummaries"][2]["PipelineArn"] == (
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
@ -460,7 +464,7 @@ def test_list_pipelines_max_results(sagemaker_client):
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(MaxResults=2)
response["PipelineSummaries"].should.have.length_of(2)
assert len(response["PipelineSummaries"]) == 2
def test_list_pipelines_next_token(sagemaker_client):
@ -475,7 +479,7 @@ def test_list_pipelines_next_token(sagemaker_client):
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(NextToken="0")
response["PipelineSummaries"].should.have.length_of(1)
assert len(response["PipelineSummaries"]) == 1
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
@ -491,11 +495,11 @@ def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
assert len(response["PipelineSummaries"]) == 1
assert response["PipelineSummaries"][0]["PipelineName"] == "APipelineName"
response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline")
response["PipelineSummaries"].should.have.length_of(3)
assert len(response["PipelineSummaries"]) == 3
def test_list_pipelines_created_after(sagemaker_client):
@ -512,15 +516,15 @@ def test_list_pipelines_created_after(sagemaker_client):
created_after_str = "2099-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S")
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
created_after_timestamp = datetime.timestamp(created_after_datetime)
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
def test_list_pipelines_created_before(sagemaker_client):
@ -537,15 +541,15 @@ def test_list_pipelines_created_before(sagemaker_client):
created_before_str = "2000-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S")
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
created_before_timestamp = datetime.timestamp(created_before_datetime)
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
@pytest.mark.parametrize(
@ -582,7 +586,7 @@ def test_delete_pipeline_exists(sagemaker_client):
assert response["PipelineArn"].endswith(pipeline_name_delete)
response = sagemaker_client.list_pipelines(PipelineNamePrefix=pipeline_name_delete)
assert response["PipelineSummaries"].should.be.empty
assert not response["PipelineSummaries"]
response = sagemaker_client.list_pipelines()
pipeline_names_exist = [
@ -627,11 +631,11 @@ def test_update_pipeline_no_update(sagemaker_client):
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
response["PipelineArn"].should.equal(
assert response["PipelineArn"] == (
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
assert response["PipelineSummaries"][0]["PipelineName"] == pipeline_name
def test_update_pipeline_add_attribute(sagemaker_client):
@ -645,17 +649,17 @@ def test_update_pipeline_add_attribute(sagemaker_client):
}
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name)
assert response["PipelineSummaries"][0]["PipelineDisplayName"] == pipeline_name
_ = sagemaker_client.update_pipeline(
PipelineName=pipeline_name,
PipelineDisplayName=pipeline_display_name_update,
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
assert response["PipelineSummaries"][0]["PipelineDisplayName"] == (
pipeline_display_name_update
)
response["PipelineSummaries"][0].should.have.length_of(6)
assert len(response["PipelineSummaries"][0]) == 6
def test_update_pipeline_update_change_attribute(sagemaker_client):
@ -673,8 +677,8 @@ def test_update_pipeline_update_change_attribute(sagemaker_client):
RoleArn=role_arn_update,
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
response["PipelineSummaries"][0].should.have.length_of(6)
assert response["PipelineSummaries"][0]["RoleArn"] == role_arn_update
assert len(response["PipelineSummaries"][0]) == 6
def test_describe_pipeline_not_exists(sagemaker_client):
@ -707,4 +711,4 @@ def test_describe_pipeline_not_exists(sagemaker_client):
def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length):
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"])
response.should.have.length_of(expected_response_length)
assert len(response) == expected_response_length

View File

@ -1,6 +1,8 @@
import datetime
import re
import boto3
from botocore.exceptions import ClientError
import datetime
import pytest
from moto import mock_sagemaker
@ -18,7 +20,7 @@ def fixture_sagemaker_client():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MyProcessingJobModel(object):
class MyProcessingJobModel:
def __init__(
self,
processing_job_name,
@ -129,16 +131,18 @@ def test_create_processing_job(sagemaker_client):
stopping_condition=stopping_condition,
)
resp = job.save(sagemaker_client)
resp["ProcessingJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
resp["ProcessingJobArn"],
)
resp = sagemaker_client.describe_processing_job(
ProcessingJobName=FAKE_PROCESSING_JOB_NAME
)
resp["ProcessingJobName"].should.equal(FAKE_PROCESSING_JOB_NAME)
resp["ProcessingJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
resp["ProcessingJobArn"],
)
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
@ -154,15 +158,15 @@ def test_list_processing_jobs(sagemaker_client):
)
test_processing_job.save(sagemaker_client)
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
assert processing_jobs["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal(FAKE_PROCESSING_JOB_NAME)
assert len(processing_jobs["ProcessingJobSummaries"]) == 1
assert (
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"]
== FAKE_PROCESSING_JOB_NAME
)
assert processing_jobs["ProcessingJobSummaries"][0][
"ProcessingJobArn"
].should.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"],
)
assert processing_jobs.get("NextToken") is None
@ -182,23 +186,28 @@ def test_list_processing_jobs_multiple(sagemaker_client):
)
test_processing_job_2.save(sagemaker_client)
processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1)
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
assert processing_jobs.get("NextToken").should.be.none
assert len(processing_jobs["ProcessingJobSummaries"]) == 2
assert processing_jobs.get("NextToken") is None
def test_list_processing_jobs_none(sagemaker_client):
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
assert len(processing_jobs["ProcessingJobSummaries"]) == 0
def test_list_processing_jobs_should_validate_input(sagemaker_client):
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals)
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
expected_error = (
f"1 validation errors detected: Value '{junk_status_equals}' at "
"'statusEquals' failed to satisfy constraint: Member must satisfy "
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
"'Failed']"
)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
@ -230,10 +239,10 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client):
xgboost_processing_jobs = sagemaker_client.list_processing_jobs(
NameContains="xgboost"
)
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2")
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
def test_list_processing_jobs_paginated(sagemaker_client):
@ -247,22 +256,24 @@ def test_list_processing_jobs_paginated(sagemaker_client):
xgboost_processing_job_1 = sagemaker_client.list_processing_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("xgboost-0")
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1
assert (
xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "xgboost-0"
)
assert xgboost_processing_job_1.get("NextToken") is not None
xgboost_processing_job_next = sagemaker_client.list_processing_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_processing_job_1.get("NextToken"),
)
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("xgboost-1")
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1
assert (
xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "xgboost-1"
)
assert xgboost_processing_job_next.get("NextToken") is not None
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
@ -283,28 +294,30 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
vgg_processing_job_1 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=1
)
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
assert vgg_processing_job_1.get("NextToken").should_not.be.none
assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0
assert vgg_processing_job_1.get("NextToken") is not None
vgg_processing_job_6 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=6
)
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("vgg-0")
assert vgg_processing_job_6.get("NextToken").should_not.be.none
assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1
assert (
vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "vgg-0"
)
assert vgg_processing_job_6.get("NextToken") is not None
vgg_processing_job_10 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=10
)
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
"ProcessingJobName"
].should.equal("vgg-4")
assert vgg_processing_job_10.get("NextToken").should.be.none
assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5
assert (
vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"]
== "vgg-4"
)
assert vgg_processing_job_10.get("NextToken") is None
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
@ -325,26 +338,24 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(
NameContains="2", MaxResults=8
)
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
assert processing_jobs_with_2.get("NextToken").should_not.be.none
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
assert processing_jobs_with_2.get("NextToken") is not None
processing_jobs_with_2_next = sagemaker_client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2.get("NextToken"),
)
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0
assert processing_jobs_with_2_next.get("NextToken") is not None
processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2_next.get("NextToken"),
)
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
0
)
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0
assert processing_jobs_with_2_next_next.get("NextToken") is None
def test_add_and_delete_tags_in_training_job(sagemaker_client):

View File

@ -1,7 +1,6 @@
import boto3
import pytest
from botocore.exceptions import ClientError
import pytest
from moto import mock_sagemaker
@ -94,11 +93,11 @@ def test_search_trial_component_with_experiment_name(sagemaker_client):
},
)
ex.value.response["Error"]["Code"].should.equal("ValidationException")
ex.value.response["Error"]["Message"].should.equal(
"Unknown property name: ExperimentName"
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert (
ex.value.response["Error"]["Message"] == "Unknown property name: ExperimentName"
)
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400
def _set_up_trial_component(

View File

@ -1,7 +1,8 @@
import datetime
import re
import boto3
from botocore.exceptions import ClientError
import datetime
import sure # noqa # pylint: disable=unused-import
import pytest
from moto import mock_sagemaker
@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1"
class MyTrainingJobModel(object):
class MyTrainingJobModel:
def __init__(
self,
training_job_name,
@ -167,14 +168,16 @@ def test_create_training_job():
stopping_condition=stopping_condition,
)
resp = job.save()
resp["TrainingJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
resp["TrainingJobArn"],
)
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
resp["TrainingJobName"].should.equal(training_job_name)
resp["TrainingJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$"
assert resp["TrainingJobName"] == training_job_name
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
resp["TrainingJobArn"],
)
assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
output_data_config["S3OutputPath"]
@ -214,8 +217,6 @@ def test_create_training_job():
assert "Value" in resp["FinalMetricDataList"][0]
assert "Timestamp" in resp["FinalMetricDataList"][0]
pass
@mock_sagemaker
def test_list_training_jobs():
@ -225,13 +226,12 @@ def test_list_training_jobs():
test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn)
test_training_job.save()
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]).should.equal(1)
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"].should.equal(
name
)
assert len(training_jobs["TrainingJobSummaries"]) == 1
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"] == name
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$",
training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"],
)
assert training_jobs.get("NextToken") is None
@ -253,18 +253,18 @@ def test_list_training_jobs_multiple():
)
test_training_job_2.save()
training_jobs_limit = client.list_training_jobs(MaxResults=1)
assert len(training_jobs_limit["TrainingJobSummaries"]).should.equal(1)
assert len(training_jobs_limit["TrainingJobSummaries"]) == 1
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]).should.equal(2)
assert training_jobs.get("NextToken").should.be.none
assert len(training_jobs["TrainingJobSummaries"]) == 2
assert training_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_training_jobs_none():
client = boto3.client("sagemaker", region_name="us-east-1")
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]).should.equal(0)
assert len(training_jobs["TrainingJobSummaries"]) == 0
@mock_sagemaker
@ -273,7 +273,12 @@ def test_list_training_jobs_should_validate_input():
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
client.list_training_jobs(StatusEquals=junk_status_equals)
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
expected_error = (
f"1 validation errors detected: Value '{junk_status_equals}' at "
"'statusEquals' failed to satisfy constraint: Member must satisfy "
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
"'Failed']"
)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
@ -299,10 +304,10 @@ def test_list_training_jobs_with_name_filters():
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost")
assert len(xgboost_training_jobs["TrainingJobSummaries"]).should.equal(5)
assert len(xgboost_training_jobs["TrainingJobSummaries"]) == 5
training_jobs_with_2 = client.list_training_jobs(NameContains="2")
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
@mock_sagemaker
@ -315,22 +320,24 @@ def test_list_training_jobs_paginated():
xgboost_training_job_1 = client.list_training_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_training_job_1["TrainingJobSummaries"]).should.equal(1)
assert xgboost_training_job_1["TrainingJobSummaries"][0][
"TrainingJobName"
].should.equal("xgboost-0")
assert xgboost_training_job_1.get("NextToken").should_not.be.none
assert len(xgboost_training_job_1["TrainingJobSummaries"]) == 1
assert (
xgboost_training_job_1["TrainingJobSummaries"][0]["TrainingJobName"]
== "xgboost-0"
)
assert xgboost_training_job_1.get("NextToken") is not None
xgboost_training_job_next = client.list_training_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_training_job_1.get("NextToken"),
)
assert len(xgboost_training_job_next["TrainingJobSummaries"]).should.equal(1)
assert xgboost_training_job_next["TrainingJobSummaries"][0][
"TrainingJobName"
].should.equal("xgboost-1")
assert xgboost_training_job_next.get("NextToken").should_not.be.none
assert len(xgboost_training_job_next["TrainingJobSummaries"]) == 1
assert (
xgboost_training_job_next["TrainingJobSummaries"][0]["TrainingJobName"]
== "xgboost-1"
)
assert xgboost_training_job_next.get("NextToken") is not None
@mock_sagemaker
@ -346,24 +353,20 @@ def test_list_training_jobs_paginated_with_target_in_middle():
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1)
assert len(vgg_training_job_1["TrainingJobSummaries"]).should.equal(0)
assert vgg_training_job_1.get("NextToken").should_not.be.none
assert len(vgg_training_job_1["TrainingJobSummaries"]) == 0
assert vgg_training_job_1.get("NextToken") is not None
vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6)
assert len(vgg_training_job_6["TrainingJobSummaries"]).should.equal(1)
assert vgg_training_job_6["TrainingJobSummaries"][0][
"TrainingJobName"
].should.equal("vgg-0")
assert vgg_training_job_6.get("NextToken").should_not.be.none
assert len(vgg_training_job_6["TrainingJobSummaries"]) == 1
assert vgg_training_job_6["TrainingJobSummaries"][0]["TrainingJobName"] == "vgg-0"
assert vgg_training_job_6.get("NextToken") is not None
vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10)
assert len(vgg_training_job_10["TrainingJobSummaries"]).should.equal(5)
assert vgg_training_job_10["TrainingJobSummaries"][-1][
"TrainingJobName"
].should.equal("vgg-4")
assert vgg_training_job_10.get("NextToken").should.be.none
assert len(vgg_training_job_10["TrainingJobSummaries"]) == 5
assert vgg_training_job_10["TrainingJobSummaries"][-1]["TrainingJobName"] == "vgg-4"
assert vgg_training_job_10.get("NextToken") is None
@mock_sagemaker
@ -379,22 +382,22 @@ def test_list_training_jobs_paginated_with_fragmented_targets():
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8)
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
assert training_jobs_with_2.get("NextToken").should_not.be.none
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
assert training_jobs_with_2.get("NextToken") is not None
training_jobs_with_2_next = client.list_training_jobs(
NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken")
)
assert len(training_jobs_with_2_next["TrainingJobSummaries"]).should.equal(0)
assert training_jobs_with_2_next.get("NextToken").should_not.be.none
assert len(training_jobs_with_2_next["TrainingJobSummaries"]) == 0
assert training_jobs_with_2_next.get("NextToken") is not None
training_jobs_with_2_next_next = client.list_training_jobs(
NameContains="2",
MaxResults=1,
NextToken=training_jobs_with_2_next.get("NextToken"),
)
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0)
assert training_jobs_with_2_next_next.get("NextToken").should.be.none
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]) == 0
assert training_jobs_with_2_next_next.get("NextToken") is None
@mock_sagemaker
@ -447,7 +450,8 @@ def test_describe_unknown_training_job():
with pytest.raises(ClientError) as exc:
client.describe_training_job(TrainingJobName="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
f"Could not find training job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:training-job/unknown'."
assert err["Code"] == "ValidationException"
assert err["Message"] == (
"Could not find training job 'arn:aws:sagemaker:us-east-1:"
f"{ACCOUNT_ID}:training-job/unknown'."
)

View File

@ -1,7 +1,8 @@
import datetime
import re
import boto3
from botocore.exceptions import ClientError
import datetime
import sure # noqa # pylint: disable=unused-import
import pytest
from moto import mock_sagemaker
@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1"
class MyTransformJobModel(object):
class MyTransformJobModel:
def __init__(
self,
transform_job_name,
@ -147,23 +148,24 @@ def test_create_transform_job():
experiment_config=experiment_config,
)
resp = job.save()
resp["TransformJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$",
resp["TransformJobArn"],
)
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
resp["TransformJobName"].should.equal(transform_job_name)
resp["TransformJobStatus"].should.equal("Completed")
resp["ModelName"].should.equal(model_name)
resp["MaxConcurrentTransforms"].should.equal(1)
resp["ModelClientConfig"].should.equal(model_client_config)
resp["MaxPayloadInMB"].should.equal(max_payload_in_mb)
resp["BatchStrategy"].should.equal("SingleRecord")
resp["TransformInput"].should.equal(transform_input)
resp["TransformOutput"].should.equal(transform_output)
resp["DataCaptureConfig"].should.equal(data_capture_config)
resp["TransformResources"].should.equal(transform_resources)
resp["DataProcessing"].should.equal(data_processing)
resp["ExperimentConfig"].should.equal(experiment_config)
assert resp["TransformJobName"] == transform_job_name
assert resp["TransformJobStatus"] == "Completed"
assert resp["ModelName"] == model_name
assert resp["MaxConcurrentTransforms"] == 1
assert resp["ModelClientConfig"] == model_client_config
assert resp["MaxPayloadInMB"] == max_payload_in_mb
assert resp["BatchStrategy"] == "SingleRecord"
assert resp["TransformInput"] == transform_input
assert resp["TransformOutput"] == transform_output
assert resp["DataCaptureConfig"] == data_capture_config
assert resp["TransformResources"] == transform_resources
assert resp["DataProcessing"] == data_processing
assert resp["ExperimentConfig"] == experiment_config
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["TransformStartTime"], datetime.datetime)
assert isinstance(resp["TransformEndTime"], datetime.datetime)
@ -179,13 +181,12 @@ def test_list_transform_jobs():
)
test_transform_job.save()
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(1)
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal(
name
)
assert len(transform_jobs["TransformJobSummaries"]) == 1
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"] == name
assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$"
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$",
transform_jobs["TransformJobSummaries"][0]["TransformJobArn"],
)
assert transform_jobs.get("NextToken") is None
@ -207,18 +208,18 @@ def test_list_transform_jobs_multiple():
)
test_transform_job_2.save()
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1)
assert len(transform_jobs_limit["TransformJobSummaries"]) == 1
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(2)
assert transform_jobs.get("NextToken").should.be.none
assert len(transform_jobs["TransformJobSummaries"]) == 2
assert transform_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_transform_jobs_none():
client = boto3.client("sagemaker", region_name="us-east-1")
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(0)
assert len(transform_jobs["TransformJobSummaries"]) == 0
@mock_sagemaker
@ -227,7 +228,12 @@ def test_list_transform_jobs_should_validate_input():
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
client.list_transform_jobs(StatusEquals=junk_status_equals)
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
expected_error = (
f"1 validation errors detected: Value '{junk_status_equals}' at "
"'statusEquals' failed to satisfy constraint: Member must satisfy "
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
"'Failed']"
)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
@ -253,10 +259,10 @@ def test_list_transform_jobs_with_name_filters():
model_name = f"blah_model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5)
assert len(xgboost_transform_jobs["TransformJobSummaries"]) == 5
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
@mock_sagemaker
@ -269,22 +275,24 @@ def test_list_transform_jobs_paginated():
xgboost_transform_job_1 = client.list_transform_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1)
assert xgboost_transform_job_1["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("xgboost-0")
assert xgboost_transform_job_1.get("NextToken").should_not.be.none
assert len(xgboost_transform_job_1["TransformJobSummaries"]) == 1
assert (
xgboost_transform_job_1["TransformJobSummaries"][0]["TransformJobName"]
== "xgboost-0"
)
assert xgboost_transform_job_1.get("NextToken") is not None
xgboost_transform_job_next = client.list_transform_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_transform_job_1.get("NextToken"),
)
assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1)
assert xgboost_transform_job_next["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("xgboost-1")
assert xgboost_transform_job_next.get("NextToken").should_not.be.none
assert len(xgboost_transform_job_next["TransformJobSummaries"]) == 1
assert (
xgboost_transform_job_next["TransformJobSummaries"][0]["TransformJobName"]
== "xgboost-1"
)
assert xgboost_transform_job_next.get("NextToken") is not None
@mock_sagemaker
@ -299,24 +307,24 @@ def test_list_transform_jobs_paginated_with_target_in_middle():
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0)
assert vgg_transform_job_1.get("NextToken").should_not.be.none
assert len(vgg_transform_job_1["TransformJobSummaries"]) == 0
assert vgg_transform_job_1.get("NextToken") is not None
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1)
assert vgg_transform_job_6["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("vgg-0")
assert vgg_transform_job_6.get("NextToken").should_not.be.none
assert len(vgg_transform_job_6["TransformJobSummaries"]) == 1
assert (
vgg_transform_job_6["TransformJobSummaries"][0]["TransformJobName"] == "vgg-0"
)
assert vgg_transform_job_6.get("NextToken") is not None
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5)
assert vgg_transform_job_10["TransformJobSummaries"][-1][
"TransformJobName"
].should.equal("vgg-4")
assert vgg_transform_job_10.get("NextToken").should.be.none
assert len(vgg_transform_job_10["TransformJobSummaries"]) == 5
assert (
vgg_transform_job_10["TransformJobSummaries"][-1]["TransformJobName"] == "vgg-4"
)
assert vgg_transform_job_10.get("NextToken") is None
@mock_sagemaker
@ -331,22 +339,22 @@ def test_list_transform_jobs_paginated_with_fragmented_targets():
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
assert transform_jobs_with_2.get("NextToken").should_not.be.none
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
assert transform_jobs_with_2.get("NextToken") is not None
transform_jobs_with_2_next = client.list_transform_jobs(
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
)
assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0)
assert transform_jobs_with_2_next.get("NextToken").should_not.be.none
assert len(transform_jobs_with_2_next["TransformJobSummaries"]) == 0
assert transform_jobs_with_2_next.get("NextToken") is not None
transform_jobs_with_2_next_next = client.list_transform_jobs(
NameContains="2",
MaxResults=1,
NextToken=transform_jobs_with_2_next.get("NextToken"),
)
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0)
assert transform_jobs_with_2_next_next.get("NextToken").should.be.none
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]) == 0
assert transform_jobs_with_2_next_next.get("NextToken") is None
@mock_sagemaker
@ -401,7 +409,8 @@ def test_describe_unknown_transform_job():
with pytest.raises(ClientError) as exc:
client.describe_transform_job(TransformJobName="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'."
assert err["Code"] == "ValidationException"
assert err["Message"] == (
"Could not find transform job 'arn:aws:sagemaker:us-east-1:"
f"{ACCOUNT_ID}:transform-job/unknown'."
)

View File

@ -1,9 +1,8 @@
import uuid
import boto3
import pytest
from botocore.exceptions import ClientError
import pytest
from moto import mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
@ -27,9 +26,9 @@ def test_create__trial_component():
assert (
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
)
assert (
resp["TrialComponentSummaries"][0]["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
assert resp["TrialComponentSummaries"][0]["TrialComponentArn"] == (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":experiment-trial-component/{trial_component_name}"
)
@ -173,9 +172,9 @@ def test_associate_trial_component():
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
assert resp["TrialComponentArn"] == (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":experiment-trial-component/{trial_component_name}"
)
assert (
resp["TrialArn"]
@ -199,11 +198,12 @@ def test_associate_trial_component():
TrialComponentName="does-not-exist", TrialName="does-not-exist"
)
ex.value.response["Error"]["Code"].should.equal("ResourceNotFound")
ex.value.response["Error"]["Message"].should.equal(
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist."
assert ex.value.response["Error"]["Code"] == "ResourceNotFound"
assert ex.value.response["Error"]["Message"] == (
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
":experiment-trial/does-not-exist' does not exist."
)
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400
@mock_sagemaker
@ -231,9 +231,9 @@ def test_disassociate_trial_component():
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
assert resp["TrialComponentArn"] == (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":experiment-trial-component/{trial_component_name}"
)
assert (
resp["TrialArn"]
@ -255,9 +255,9 @@ def test_disassociate_trial_component():
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist"
assert resp["TrialComponentArn"] == (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:"
"experiment-trial-component/does-not-exist"
)
assert (
resp["TrialArn"]