diff --git a/tests/test_sagemaker/test_sagemaker_cloudformation.py b/tests/test_sagemaker/test_sagemaker_cloudformation.py index 2119eb749..7f0f0129f 100644 --- a/tests/test_sagemaker/test_sagemaker_cloudformation.py +++ b/tests/test_sagemaker/test_sagemaker_cloudformation.py @@ -1,8 +1,8 @@ -import boto3 +import re -import pytest -import sure # noqa # pylint: disable=unused-import +import boto3 from botocore.exceptions import ClientError +import pytest from moto import mock_cloudformation, mock_sagemaker from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @@ -53,8 +53,8 @@ def test_sagemaker_cloudformation_create(test_config): provisioned_resource = cf.list_stack_resources(StackName=stack_name)[ "StackResourceSummaries" ][0] - provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name) - len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0) + assert provisioned_resource["LogicalResourceId"] == test_config.resource_name + assert len(provisioned_resource["PhysicalResourceId"]) > 0 @mock_cloudformation @@ -87,7 +87,7 @@ def test_sagemaker_cloudformation_get_attr(test_config): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: outputs["Name"]} ) - outputs["Arn"].should.equal(resource_description[test_config.arn_parameter]) + assert outputs["Arn"] == resource_description[test_config.arn_parameter] @mock_cloudformation @@ -122,7 +122,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: outputs["Name"]} ) - outputs["Arn"].should.equal(resource_description[test_config.arn_parameter]) + assert outputs["Arn"] == resource_description[test_config.arn_parameter] # Delete the stack and verify resource has also been deleted cf.delete_stack(StackName=stack_name) @@ -130,7 +130,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: outputs["Name"]} ) - ce.value.response["Error"]["Message"].should.contain(error_message) + assert error_message in ce.value.response["Error"]["Message"] @mock_cloudformation @@ -160,19 +160,19 @@ def test_sagemaker_cloudformation_notebook_instance_update(): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: initial_notebook_name} ) - initial_instance_type.should.equal(resource_description["InstanceType"]) + assert initial_instance_type == resource_description["InstanceType"] # Update stack and check attributes cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) outputs = _get_stack_outputs(cf, stack_name) updated_notebook_name = outputs["Name"] - updated_notebook_name.should.equal(initial_notebook_name) + assert updated_notebook_name == initial_notebook_name resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: updated_notebook_name} ) - updated_instance_type.should.equal(resource_description["InstanceType"]) + assert updated_instance_type == resource_description["InstanceType"] @mock_cloudformation @@ -202,25 +202,21 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update(): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: initial_config_name} ) - len(resource_description["OnCreate"]).should.equal(1) - initial_on_create_script.should.equal( - resource_description["OnCreate"][0]["Content"] - ) + assert len(resource_description["OnCreate"]) == 1 + assert initial_on_create_script == resource_description["OnCreate"][0]["Content"] # Update stack and check attributes cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) outputs = _get_stack_outputs(cf, stack_name) updated_config_name = outputs["Name"] - updated_config_name.should.equal(initial_config_name) + assert updated_config_name == initial_config_name resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: updated_config_name} ) - len(resource_description["OnCreate"]).should.equal(1) - updated_on_create_script.should.equal( - resource_description["OnCreate"][0]["Content"] - ) + assert len(resource_description["OnCreate"]) == 1 + assert updated_on_create_script == resource_description["OnCreate"][0]["Content"] @mock_cloudformation @@ -251,7 +247,7 @@ def test_sagemaker_cloudformation_model_update(): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: initial_model_name} ) - resource_description["PrimaryContainer"]["Image"].should.equal( + assert resource_description["PrimaryContainer"]["Image"] == ( image.format(initial_image_version) ) @@ -260,12 +256,12 @@ def test_sagemaker_cloudformation_model_update(): outputs = _get_stack_outputs(cf, stack_name) updated_model_name = outputs["Name"] - updated_model_name.should_not.equal(initial_model_name) + assert updated_model_name != initial_model_name resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: updated_model_name} ) - resource_description["PrimaryContainer"]["Image"].should.equal( + assert resource_description["PrimaryContainer"]["Image"] == ( image.format(updated_image_version) ) @@ -300,7 +296,7 @@ def test_sagemaker_cloudformation_endpoint_config_update(): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: initial_endpoint_config_name} ) - len(resource_description["ProductionVariants"]).should.equal( + assert len(resource_description["ProductionVariants"]) == ( initial_num_production_variants ) @@ -309,12 +305,12 @@ def test_sagemaker_cloudformation_endpoint_config_update(): outputs = _get_stack_outputs(cf, stack_name) updated_endpoint_config_name = outputs["Name"] - updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name) + assert updated_endpoint_config_name != initial_endpoint_config_name resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: updated_endpoint_config_name} ) - len(resource_description["ProductionVariants"]).should.equal( + assert len(resource_description["ProductionVariants"]) == ( updated_num_production_variants ) @@ -365,8 +361,8 @@ def test_sagemaker_cloudformation_endpoint_update(): resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: initial_endpoint_name} ) - resource_description["EndpointConfigName"].should.match( - initial_endpoint_config_name + assert re.match( + initial_endpoint_config_name, resource_description["EndpointConfigName"] ) # Create additional SM resources and update stack @@ -393,11 +389,11 @@ def test_sagemaker_cloudformation_endpoint_update(): outputs = _get_stack_outputs(cf, stack_name) updated_endpoint_name = outputs["Name"] - updated_endpoint_name.should.equal(initial_endpoint_name) + assert updated_endpoint_name == initial_endpoint_name resource_description = getattr(sm, test_config.describe_function_name)( **{test_config.name_parameter: updated_endpoint_name} ) - resource_description["EndpointConfigName"].should.match( - updated_endpoint_config_name + assert re.match( + updated_endpoint_config_name, resource_description["EndpointConfigName"] ) diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py index af2ab0e8b..053a3fdbe 100644 --- a/tests/test_sagemaker/test_sagemaker_endpoint.py +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -1,9 +1,9 @@ import datetime +import re import uuid import boto3 from botocore.exceptions import ClientError -import sure # noqa # pylint: disable=unused-import from moto import mock_sagemaker from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @@ -55,18 +55,20 @@ def create_endpoint_config_helper(sagemaker_client, production_variants): EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, ProductionVariants=production_variants, ) - resp["EndpointConfigArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", + resp["EndpointConfigArn"], ) resp = sagemaker_client.describe_endpoint_config( EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME ) - resp["EndpointConfigArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", + resp["EndpointConfigArn"], ) - resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) - resp["ProductionVariants"].should.equal(production_variants) + assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME + assert resp["ProductionVariants"] == production_variants def test_create_endpoint_config(sagemaker_client): @@ -99,15 +101,17 @@ def test_delete_endpoint_config(sagemaker_client): EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, ProductionVariants=TEST_PRODUCTION_VARIANTS, ) - resp["EndpointConfigArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", + resp["EndpointConfigArn"], ) resp = sagemaker_client.describe_endpoint_config( EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME ) - resp["EndpointConfigArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", + resp["EndpointConfigArn"], ) sagemaker_client.delete_endpoint_config( @@ -143,7 +147,10 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client): ProductionVariants=production_variants, ) assert e.value.response["Error"]["Code"] == "ValidationException" - expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [" + expected_message = ( + f"Value '{instance_type}' at 'instanceType' failed to satisfy " + "constraint: Member must satisfy enum value set: [" + ) assert expected_message in e.value.response["Error"]["Message"] @@ -160,7 +167,10 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client): ProductionVariants=production_variants, ) assert e.value.response["Error"]["Code"] == "ValidationException" - expected_message = f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: [" + expected_message = ( + f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy " + "constraint: Member must satisfy enum value set: [" + ) assert expected_message in e.value.response["Error"]["Message"] @@ -185,20 +195,20 @@ def test_create_endpoint(sagemaker_client): EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, Tags=GENERIC_TAGS_PARAM, ) - resp["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) - resp["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] ) - resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) - resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) - resp["EndpointStatus"].should.equal("InService") + assert resp["EndpointName"] == TEST_ENDPOINT_NAME + assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME + assert resp["EndpointStatus"] == "InService" assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) - resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME) + assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM @@ -224,7 +234,10 @@ def test_add_tags_endpoint(sagemaker_client): sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME ) - resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + resource_arn = ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":endpoint/{TEST_ENDPOINT_NAME}" + ) response = sagemaker_client.add_tags( ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM ) @@ -239,7 +252,10 @@ def test_delete_tags_endpoint(sagemaker_client): sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME ) - resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + resource_arn = ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":endpoint/{TEST_ENDPOINT_NAME}" + ) response = sagemaker_client.add_tags( ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM ) @@ -262,7 +278,10 @@ def test_list_tags_endpoint(sagemaker_client): for _ in range(80): tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) - resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + resource_arn = ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":endpoint/{TEST_ENDPOINT_NAME}" + ) response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 @@ -295,29 +314,32 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client): }, ], ) - response["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", + response["EndpointArn"], ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) - resp["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] ) - resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) - resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) - resp["EndpointStatus"].should.equal("InService") + assert resp["EndpointName"] == TEST_ENDPOINT_NAME + assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME + assert resp["EndpointStatus"] == "InService" assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) - resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME) - resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal( - new_desired_instance_count + assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME + assert ( + resp["ProductionVariants"][0]["DesiredInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal( - new_desired_instance_count + assert ( + resp["ProductionVariants"][0]["CurrentInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight) - resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight) + assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight + assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client): @@ -364,39 +386,44 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client): EndpointName=TEST_ENDPOINT_NAME, DesiredWeightsAndCapacities=desired_weights_and_capacities, ) - response["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", + response["EndpointArn"], ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) - resp["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] ) - resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) - resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) - resp["EndpointStatus"].should.equal("InService") + assert resp["EndpointName"] == TEST_ENDPOINT_NAME + assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME + assert resp["EndpointStatus"] == "InService" assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) - resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1") - resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal( - new_desired_instance_count + assert resp["ProductionVariants"][0]["VariantName"] == "MyProductionVariant1" + assert ( + resp["ProductionVariants"][0]["DesiredInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal( - new_desired_instance_count + assert ( + resp["ProductionVariants"][0]["CurrentInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight) - resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight) + assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight + assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight - resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2") - resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal( - new_desired_instance_count + assert resp["ProductionVariants"][1]["VariantName"] == "MyProductionVariant2" + assert ( + resp["ProductionVariants"][1]["DesiredInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal( - new_desired_instance_count + assert ( + resp["ProductionVariants"][1]["CurrentInstanceCount"] + == new_desired_instance_count ) - resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight) - resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight) + assert resp["ProductionVariants"][1]["DesiredWeight"] == new_desired_weight + assert resp["ProductionVariants"][1]["CurrentWeight"] == new_desired_weight def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant( @@ -426,13 +453,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari ) err = exc.value.response["Error"] - err["Message"].should.equal( - f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".' + assert err["Message"] == ( + f'The variant name(s) "{variant_name}" is/are not present within ' + f'endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".' ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) del resp["ResponseMetadata"] - resp.should.equal(old_resp) + assert resp == old_resp def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint( @@ -463,13 +491,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp ) err = exc.value.response["Error"] - err["Message"].should.equal( - f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".' + assert err["Message"] == ( + f'Could not find endpoint "arn:aws:sagemaker:us-east-1:' + f'{ACCOUNT_ID}:endpoint/{endpoint_name}".' ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) del resp["ResponseMetadata"] - resp.should.equal(old_resp) + assert resp == old_resp def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant( @@ -502,13 +531,13 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonuniq ) err = exc.value.response["Error"] - err["Message"].should.equal( + assert err["Message"] == ( f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.' ) resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) del resp["ResponseMetadata"] - resp.should.equal(old_resp) + assert resp == old_resp def _set_up_sagemaker_resources( @@ -552,8 +581,9 @@ def _create_endpoint_config( resp = boto_client.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants ) - resp["EndpointConfigArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$", + resp["EndpointConfigArn"], ) @@ -561,6 +591,6 @@ def _create_endpoint(boto_client, endpoint_name, endpoint_config_name): resp = boto_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name ) - resp["EndpointArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$", resp["EndpointArn"] ) diff --git a/tests/test_sagemaker/test_sagemaker_model_packages.py b/tests/test_sagemaker/test_sagemaker_model_packages.py index cbb3f61fc..1b9a8d752 100644 --- a/tests/test_sagemaker/test_sagemaker_model_packages.py +++ b/tests/test_sagemaker/test_sagemaker_model_packages.py @@ -1,9 +1,10 @@ """Unit tests for sagemaker-supported APIs.""" +from unittest import SkipTest + import boto3 from freezegun import freeze_time from moto import mock_sagemaker, settings -from unittest import SkipTest # See our Development Tips on writing tests for hints on how to write good tests: # http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html diff --git a/tests/test_sagemaker/test_sagemaker_models.py b/tests/test_sagemaker/test_sagemaker_models.py index 32f9212c7..d7c3c829c 100644 --- a/tests/test_sagemaker/test_sagemaker_models.py +++ b/tests/test_sagemaker/test_sagemaker_models.py @@ -1,10 +1,10 @@ +import re + import boto3 from botocore.exceptions import ClientError import pytest + from moto import mock_sagemaker - -import sure # noqa # pylint: disable=unused-import - from moto.sagemaker.models import VpcConfig TEST_REGION_NAME = "us-east-1" @@ -18,7 +18,7 @@ def fixture_sagemaker_client(): yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) -class MySageMakerModel(object): +class MySageMakerModel: def __init__(self, name=None, arn=None, container=None, vpc_config=None): self.name = name or TEST_MODEL_NAME self.arn = arn or TEST_ARN @@ -41,13 +41,13 @@ def test_describe_model(sagemaker_client): test_model = MySageMakerModel() test_model.save(sagemaker_client) model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME) - assert model.get("ModelName").should.equal(TEST_MODEL_NAME) + assert model.get("ModelName") == TEST_MODEL_NAME def test_describe_model_not_found(sagemaker_client): with pytest.raises(ClientError) as err: sagemaker_client.describe_model(ModelName="unknown") - assert err.value.response["Error"]["Message"].should.contain("Could not find model") + assert "Could not find model" in err.value.response["Error"]["Message"] def test_create_model(sagemaker_client): @@ -57,8 +57,8 @@ def test_create_model(sagemaker_client): ExecutionRoleArn=TEST_ARN, VpcConfig=vpc_config.response_object, ) - model["ModelArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", model["ModelArn"] ) @@ -66,25 +66,26 @@ def test_delete_model(sagemaker_client): test_model = MySageMakerModel() test_model.save(sagemaker_client) - assert len(sagemaker_client.list_models()["Models"]).should.equal(1) + assert len(sagemaker_client.list_models()["Models"]) == 1 sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME) - assert len(sagemaker_client.list_models()["Models"]).should.equal(0) + assert len(sagemaker_client.list_models()["Models"]) == 0 def test_delete_model_not_found(sagemaker_client): with pytest.raises(ClientError) as err: sagemaker_client.delete_model(ModelName="blah") - assert err.value.response["Error"]["Code"].should.equal("404") + assert err.value.response["Error"]["Code"] == "404" def test_list_models(sagemaker_client): test_model = MySageMakerModel() test_model.save(sagemaker_client) models = sagemaker_client.list_models() - assert len(models["Models"]).should.equal(1) - assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME) - assert models["Models"][0]["ModelArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$" + assert len(models["Models"]) == 1 + assert models["Models"][0]["ModelName"] == TEST_MODEL_NAME + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", + models["Models"][0]["ModelArn"], ) @@ -99,12 +100,12 @@ def test_list_models_multiple(sagemaker_client): test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2) test_model_2.save(sagemaker_client) models = sagemaker_client.list_models() - assert len(models["Models"]).should.equal(2) + assert len(models["Models"]) == 2 def test_list_models_none(sagemaker_client): models = sagemaker_client.list_models() - assert len(models["Models"]).should.equal(0) + assert len(models["Models"]) == 0 def test_add_tags_to_model(sagemaker_client): diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py index 11c16311d..01556ef42 100644 --- a/tests/test_sagemaker/test_sagemaker_notebooks.py +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -1,11 +1,11 @@ import datetime + import boto3 from botocore.exceptions import ClientError -import sure # noqa # pylint: disable=unused-import +import pytest from moto import mock_sagemaker from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID -import pytest TEST_REGION_NAME = "us-east-1" FAKE_SUBNET_ID = "subnet-012345678" @@ -37,7 +37,10 @@ def _get_notebook_instance_arn(notebook_name): def _get_notebook_instance_lifecycle_arn(lifecycle_name): - return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}" + return ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":notebook-instance-lifecycle-configuration/{lifecycle_name}" + ) def test_create_notebook_instance_minimal_params(sagemaker_client): @@ -130,7 +133,10 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client): with pytest.raises(ClientError) as ex: sagemaker_client.create_notebook_instance(**args) assert ex.value.response["Error"]["Code"] == "ValidationException" - expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [" + expected_message = ( + f"Value '{instance_type}' at 'instanceType' failed to satisfy " + "constraint: Member must satisfy enum value set: [" + ) assert expected_message in ex.value.response["Error"]["Message"] @@ -153,7 +159,10 @@ def test_notebook_instance_lifecycle(sagemaker_client): with pytest.raises(ClientError) as ex: sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert ex.value.response["Error"]["Code"] == "ValidationException" - expected_message = f"Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({notebook_instance_arn})" + expected_message = ( + f"Status (InService) not in ([Stopped, Failed]). Unable to " + f"transition to (Deleting) for Notebook Instance ({notebook_instance_arn})" + ) assert expected_message in ex.value.response["Error"]["Message"] sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) diff --git a/tests/test_sagemaker/test_sagemaker_pipeline.py b/tests/test_sagemaker/test_sagemaker_pipeline.py index c17af6d6b..3183b1c6c 100644 --- a/tests/test_sagemaker/test_sagemaker_pipeline.py +++ b/tests/test_sagemaker/test_sagemaker_pipeline.py @@ -1,13 +1,14 @@ from contextlib import contextmanager -from moto import mock_sagemaker, settings -from time import sleep from datetime import datetime -import boto3 -import botocore import json -import pytest +from time import sleep from unittest import SkipTest +import boto3 +import botocore +import pytest + +from moto import mock_sagemaker, settings from moto.s3 import mock_s3 from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.sagemaker.exceptions import ValidationError @@ -85,7 +86,10 @@ def test_utils_get_pipeline_from_name_not_exists(): def test_utils_get_pipeline_name_from_execution_arn(): expected_pipeline_name = "some-pipeline-name" - pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456" + pipeline_execution_arn = ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":pipeline/{expected_pipeline_name}/execution/abc123def456" + ) observed_pipeline_name = get_pipeline_name_from_execution_arn( pipeline_execution_arn=pipeline_execution_arn ) @@ -266,7 +270,7 @@ def test_describe_pipeline_execution(sagemaker_client): PipelineExecutionArn=response["PipelineExecutionArn"] ) observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"] - observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn) + assert observed_pipeline_execution_arn == expected_pipeline_execution_arn def test_load_pipeline_definition_from_s3(): @@ -292,7 +296,7 @@ def test_load_pipeline_definition_from_s3(): }, account_id=ACCOUNT_ID, ) - observed_pipeline_definition.should.equal(pipeline_definition) + assert observed_pipeline_definition == pipeline_definition def test_create_pipeline(sagemaker_client): @@ -303,7 +307,7 @@ def test_create_pipeline(sagemaker_client): PipelineDefinition=" ", ) assert isinstance(response, dict) - response["PipelineArn"].should.equal( + assert response["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) ) @@ -353,7 +357,7 @@ def test_create_pipeline_duplicate_pipeline_name(sagemaker_client): def test_list_pipelines_none(sagemaker_client): response = sagemaker_client.list_pipelines() assert isinstance(response, dict) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] def test_list_pipelines_single(sagemaker_client): @@ -368,8 +372,8 @@ def test_list_pipelines_single(sagemaker_client): _ = create_sagemaker_pipelines(sagemaker_client, pipelines) response = sagemaker_client.list_pipelines() - response["PipelineSummaries"].should.have.length_of(1) - response["PipelineSummaries"][0]["PipelineArn"].should.equal( + assert len(response["PipelineSummaries"]) == 1 + assert response["PipelineSummaries"][0]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) @@ -390,7 +394,7 @@ def test_list_pipelines_multiple(sagemaker_client): SortBy="Name", SortOrder="Ascending", ) - response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names)) + assert len(response["PipelineSummaries"]) == len(fake_pipeline_names) def test_list_pipelines_sort_name_ascending(sagemaker_client): @@ -409,13 +413,13 @@ def test_list_pipelines_sort_name_ascending(sagemaker_client): SortBy="Name", SortOrder="Ascending", ) - response["PipelineSummaries"][0]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][0]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) - response["PipelineSummaries"][-1]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][-1]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME) ) - response["PipelineSummaries"][1]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][1]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME) ) @@ -436,13 +440,13 @@ def test_list_pipelines_sort_creation_time_descending(sagemaker_client): SortBy="CreationTime", SortOrder="Descending", ) - response["PipelineSummaries"][0]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][0]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME) ) - response["PipelineSummaries"][1]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][1]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME) ) - response["PipelineSummaries"][2]["PipelineArn"].should.equal( + assert response["PipelineSummaries"][2]["PipelineArn"] == ( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) @@ -460,7 +464,7 @@ def test_list_pipelines_max_results(sagemaker_client): _ = create_sagemaker_pipelines(sagemaker_client, pipelines) response = sagemaker_client.list_pipelines(MaxResults=2) - response["PipelineSummaries"].should.have.length_of(2) + assert len(response["PipelineSummaries"]) == 2 def test_list_pipelines_next_token(sagemaker_client): @@ -475,7 +479,7 @@ def test_list_pipelines_next_token(sagemaker_client): _ = create_sagemaker_pipelines(sagemaker_client, pipelines) response = sagemaker_client.list_pipelines(NextToken="0") - response["PipelineSummaries"].should.have.length_of(1) + assert len(response["PipelineSummaries"]) == 1 def test_list_pipelines_pipeline_name_prefix(sagemaker_client): @@ -491,11 +495,11 @@ def test_list_pipelines_pipeline_name_prefix(sagemaker_client): _ = create_sagemaker_pipelines(sagemaker_client, pipelines) response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe") - response["PipelineSummaries"].should.have.length_of(1) - response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName") + assert len(response["PipelineSummaries"]) == 1 + assert response["PipelineSummaries"][0]["PipelineName"] == "APipelineName" response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline") - response["PipelineSummaries"].should.have.length_of(3) + assert len(response["PipelineSummaries"]) == 3 def test_list_pipelines_created_after(sagemaker_client): @@ -512,15 +516,15 @@ def test_list_pipelines_created_after(sagemaker_client): created_after_str = "2099-12-31 23:59:59" response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S") response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] created_after_timestamp = datetime.timestamp(created_after_datetime) response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] def test_list_pipelines_created_before(sagemaker_client): @@ -537,15 +541,15 @@ def test_list_pipelines_created_before(sagemaker_client): created_before_str = "2000-12-31 23:59:59" response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S") response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] created_before_timestamp = datetime.timestamp(created_before_datetime) response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] @pytest.mark.parametrize( @@ -582,7 +586,7 @@ def test_delete_pipeline_exists(sagemaker_client): assert response["PipelineArn"].endswith(pipeline_name_delete) response = sagemaker_client.list_pipelines(PipelineNamePrefix=pipeline_name_delete) - assert response["PipelineSummaries"].should.be.empty + assert not response["PipelineSummaries"] response = sagemaker_client.list_pipelines() pipeline_names_exist = [ @@ -627,11 +631,11 @@ def test_update_pipeline_no_update(sagemaker_client): _ = create_sagemaker_pipelines(sagemaker_client, [pipeline]) response = sagemaker_client.update_pipeline(PipelineName=pipeline_name) - response["PipelineArn"].should.equal( + assert response["PipelineArn"] == ( arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) ) response = sagemaker_client.list_pipelines() - response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name) + assert response["PipelineSummaries"][0]["PipelineName"] == pipeline_name def test_update_pipeline_add_attribute(sagemaker_client): @@ -645,17 +649,17 @@ def test_update_pipeline_add_attribute(sagemaker_client): } _ = create_sagemaker_pipelines(sagemaker_client, [pipeline]) response = sagemaker_client.list_pipelines() - response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name) + assert response["PipelineSummaries"][0]["PipelineDisplayName"] == pipeline_name _ = sagemaker_client.update_pipeline( PipelineName=pipeline_name, PipelineDisplayName=pipeline_display_name_update, ) response = sagemaker_client.list_pipelines() - response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal( + assert response["PipelineSummaries"][0]["PipelineDisplayName"] == ( pipeline_display_name_update ) - response["PipelineSummaries"][0].should.have.length_of(6) + assert len(response["PipelineSummaries"][0]) == 6 def test_update_pipeline_update_change_attribute(sagemaker_client): @@ -673,8 +677,8 @@ def test_update_pipeline_update_change_attribute(sagemaker_client): RoleArn=role_arn_update, ) response = sagemaker_client.list_pipelines() - response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update) - response["PipelineSummaries"][0].should.have.length_of(6) + assert response["PipelineSummaries"][0]["RoleArn"] == role_arn_update + assert len(response["PipelineSummaries"][0]) == 6 def test_describe_pipeline_not_exists(sagemaker_client): @@ -707,4 +711,4 @@ def test_describe_pipeline_not_exists(sagemaker_client): def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length): _ = create_sagemaker_pipelines(sagemaker_client, [pipeline]) response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"]) - response.should.have.length_of(expected_response_length) + assert len(response) == expected_response_length diff --git a/tests/test_sagemaker/test_sagemaker_processing.py b/tests/test_sagemaker/test_sagemaker_processing.py index f9f500073..fefac10f4 100644 --- a/tests/test_sagemaker/test_sagemaker_processing.py +++ b/tests/test_sagemaker/test_sagemaker_processing.py @@ -1,6 +1,8 @@ +import datetime +import re + import boto3 from botocore.exceptions import ClientError -import datetime import pytest from moto import mock_sagemaker @@ -18,7 +20,7 @@ def fixture_sagemaker_client(): yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) -class MyProcessingJobModel(object): +class MyProcessingJobModel: def __init__( self, processing_job_name, @@ -129,16 +131,18 @@ def test_create_processing_job(sagemaker_client): stopping_condition=stopping_condition, ) resp = job.save(sagemaker_client) - resp["ProcessingJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", + resp["ProcessingJobArn"], ) resp = sagemaker_client.describe_processing_job( ProcessingJobName=FAKE_PROCESSING_JOB_NAME ) - resp["ProcessingJobName"].should.equal(FAKE_PROCESSING_JOB_NAME) - resp["ProcessingJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$" + assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", + resp["ProcessingJobArn"], ) assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"] assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"] @@ -154,15 +158,15 @@ def test_list_processing_jobs(sagemaker_client): ) test_processing_job.save(sagemaker_client) processing_jobs = sagemaker_client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1) - assert processing_jobs["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal(FAKE_PROCESSING_JOB_NAME) + assert len(processing_jobs["ProcessingJobSummaries"]) == 1 + assert ( + processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"] + == FAKE_PROCESSING_JOB_NAME + ) - assert processing_jobs["ProcessingJobSummaries"][0][ - "ProcessingJobArn" - ].should.match( - rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", + processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"], ) assert processing_jobs.get("NextToken") is None @@ -182,23 +186,28 @@ def test_list_processing_jobs_multiple(sagemaker_client): ) test_processing_job_2.save(sagemaker_client) processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1) - assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1) + assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1 processing_jobs = sagemaker_client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2) - assert processing_jobs.get("NextToken").should.be.none + assert len(processing_jobs["ProcessingJobSummaries"]) == 2 + assert processing_jobs.get("NextToken") is None def test_list_processing_jobs_none(sagemaker_client): processing_jobs = sagemaker_client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0) + assert len(processing_jobs["ProcessingJobSummaries"]) == 0 def test_list_processing_jobs_should_validate_input(sagemaker_client): junk_status_equals = "blah" with pytest.raises(ClientError) as ex: sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals) - expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + expected_error = ( + f"1 validation errors detected: Value '{junk_status_equals}' at " + "'statusEquals' failed to satisfy constraint: Member must satisfy " + "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " + "'Failed']" + ) assert ex.value.response["Error"]["Code"] == "ValidationException" assert ex.value.response["Error"]["Message"] == expected_error @@ -230,10 +239,10 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client): xgboost_processing_jobs = sagemaker_client.list_processing_jobs( NameContains="xgboost" ) - assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5) + assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5 processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2") - assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) + assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2 def test_list_processing_jobs_paginated(sagemaker_client): @@ -247,22 +256,24 @@ def test_list_processing_jobs_paginated(sagemaker_client): xgboost_processing_job_1 = sagemaker_client.list_processing_jobs( NameContains="xgboost", MaxResults=1 ) - assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1) - assert xgboost_processing_job_1["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("xgboost-0") - assert xgboost_processing_job_1.get("NextToken").should_not.be.none + assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1 + assert ( + xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"] + == "xgboost-0" + ) + assert xgboost_processing_job_1.get("NextToken") is not None xgboost_processing_job_next = sagemaker_client.list_processing_jobs( NameContains="xgboost", MaxResults=1, NextToken=xgboost_processing_job_1.get("NextToken"), ) - assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1) - assert xgboost_processing_job_next["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("xgboost-1") - assert xgboost_processing_job_next.get("NextToken").should_not.be.none + assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1 + assert ( + xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"] + == "xgboost-1" + ) + assert xgboost_processing_job_next.get("NextToken") is not None def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client): @@ -283,28 +294,30 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client): vgg_processing_job_1 = sagemaker_client.list_processing_jobs( NameContains="vgg", MaxResults=1 ) - assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0) - assert vgg_processing_job_1.get("NextToken").should_not.be.none + assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0 + assert vgg_processing_job_1.get("NextToken") is not None vgg_processing_job_6 = sagemaker_client.list_processing_jobs( NameContains="vgg", MaxResults=6 ) - assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1) - assert vgg_processing_job_6["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("vgg-0") - assert vgg_processing_job_6.get("NextToken").should_not.be.none + assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1 + assert ( + vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"] + == "vgg-0" + ) + assert vgg_processing_job_6.get("NextToken") is not None vgg_processing_job_10 = sagemaker_client.list_processing_jobs( NameContains="vgg", MaxResults=10 ) - assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5) - assert vgg_processing_job_10["ProcessingJobSummaries"][-1][ - "ProcessingJobName" - ].should.equal("vgg-4") - assert vgg_processing_job_10.get("NextToken").should.be.none + assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5 + assert ( + vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"] + == "vgg-4" + ) + assert vgg_processing_job_10.get("NextToken") is None def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client): @@ -325,26 +338,24 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client processing_jobs_with_2 = sagemaker_client.list_processing_jobs( NameContains="2", MaxResults=8 ) - assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) - assert processing_jobs_with_2.get("NextToken").should_not.be.none + assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2 + assert processing_jobs_with_2.get("NextToken") is not None processing_jobs_with_2_next = sagemaker_client.list_processing_jobs( NameContains="2", MaxResults=1, NextToken=processing_jobs_with_2.get("NextToken"), ) - assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0) - assert processing_jobs_with_2_next.get("NextToken").should_not.be.none + assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0 + assert processing_jobs_with_2_next.get("NextToken") is not None processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs( NameContains="2", MaxResults=1, NextToken=processing_jobs_with_2_next.get("NextToken"), ) - assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal( - 0 - ) - assert processing_jobs_with_2_next_next.get("NextToken").should.be.none + assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0 + assert processing_jobs_with_2_next_next.get("NextToken") is None def test_add_and_delete_tags_in_training_job(sagemaker_client): diff --git a/tests/test_sagemaker/test_sagemaker_search.py b/tests/test_sagemaker/test_sagemaker_search.py index 4fb7d2457..20d2bcc18 100644 --- a/tests/test_sagemaker/test_sagemaker_search.py +++ b/tests/test_sagemaker/test_sagemaker_search.py @@ -1,7 +1,6 @@ import boto3 -import pytest - from botocore.exceptions import ClientError +import pytest from moto import mock_sagemaker @@ -94,11 +93,11 @@ def test_search_trial_component_with_experiment_name(sagemaker_client): }, ) - ex.value.response["Error"]["Code"].should.equal("ValidationException") - ex.value.response["Error"]["Message"].should.equal( - "Unknown property name: ExperimentName" + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ( + ex.value.response["Error"]["Message"] == "Unknown property name: ExperimentName" ) - ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400 def _set_up_trial_component( diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py index 5a130e30f..b3be6df8d 100644 --- a/tests/test_sagemaker/test_sagemaker_training.py +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -1,7 +1,8 @@ +import datetime +import re + import boto3 from botocore.exceptions import ClientError -import datetime -import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_sagemaker @@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" TEST_REGION_NAME = "us-east-1" -class MyTrainingJobModel(object): +class MyTrainingJobModel: def __init__( self, training_job_name, @@ -167,14 +168,16 @@ def test_create_training_job(): stopping_condition=stopping_condition, ) resp = job.save() - resp["TrainingJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$", + resp["TrainingJobArn"], ) resp = sagemaker.describe_training_job(TrainingJobName=training_job_name) - resp["TrainingJobName"].should.equal(training_job_name) - resp["TrainingJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$" + assert resp["TrainingJobName"] == training_job_name + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$", + resp["TrainingJobArn"], ) assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith( output_data_config["S3OutputPath"] @@ -214,8 +217,6 @@ def test_create_training_job(): assert "Value" in resp["FinalMetricDataList"][0] assert "Timestamp" in resp["FinalMetricDataList"][0] - pass - @mock_sagemaker def test_list_training_jobs(): @@ -225,13 +226,12 @@ def test_list_training_jobs(): test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn) test_training_job.save() training_jobs = client.list_training_jobs() - assert len(training_jobs["TrainingJobSummaries"]).should.equal(1) - assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"].should.equal( - name - ) + assert len(training_jobs["TrainingJobSummaries"]) == 1 + assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"] == name - assert training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$", + training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"], ) assert training_jobs.get("NextToken") is None @@ -253,18 +253,18 @@ def test_list_training_jobs_multiple(): ) test_training_job_2.save() training_jobs_limit = client.list_training_jobs(MaxResults=1) - assert len(training_jobs_limit["TrainingJobSummaries"]).should.equal(1) + assert len(training_jobs_limit["TrainingJobSummaries"]) == 1 training_jobs = client.list_training_jobs() - assert len(training_jobs["TrainingJobSummaries"]).should.equal(2) - assert training_jobs.get("NextToken").should.be.none + assert len(training_jobs["TrainingJobSummaries"]) == 2 + assert training_jobs.get("NextToken") is None @mock_sagemaker def test_list_training_jobs_none(): client = boto3.client("sagemaker", region_name="us-east-1") training_jobs = client.list_training_jobs() - assert len(training_jobs["TrainingJobSummaries"]).should.equal(0) + assert len(training_jobs["TrainingJobSummaries"]) == 0 @mock_sagemaker @@ -273,7 +273,12 @@ def test_list_training_jobs_should_validate_input(): junk_status_equals = "blah" with pytest.raises(ClientError) as ex: client.list_training_jobs(StatusEquals=junk_status_equals) - expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + expected_error = ( + f"1 validation errors detected: Value '{junk_status_equals}' at " + "'statusEquals' failed to satisfy constraint: Member must satisfy " + "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " + "'Failed']" + ) assert ex.value.response["Error"]["Code"] == "ValidationException" assert ex.value.response["Error"]["Message"] == expected_error @@ -299,10 +304,10 @@ def test_list_training_jobs_with_name_filters(): arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" MyTrainingJobModel(training_job_name=name, role_arn=arn).save() xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost") - assert len(xgboost_training_jobs["TrainingJobSummaries"]).should.equal(5) + assert len(xgboost_training_jobs["TrainingJobSummaries"]) == 5 training_jobs_with_2 = client.list_training_jobs(NameContains="2") - assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2) + assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2 @mock_sagemaker @@ -315,22 +320,24 @@ def test_list_training_jobs_paginated(): xgboost_training_job_1 = client.list_training_jobs( NameContains="xgboost", MaxResults=1 ) - assert len(xgboost_training_job_1["TrainingJobSummaries"]).should.equal(1) - assert xgboost_training_job_1["TrainingJobSummaries"][0][ - "TrainingJobName" - ].should.equal("xgboost-0") - assert xgboost_training_job_1.get("NextToken").should_not.be.none + assert len(xgboost_training_job_1["TrainingJobSummaries"]) == 1 + assert ( + xgboost_training_job_1["TrainingJobSummaries"][0]["TrainingJobName"] + == "xgboost-0" + ) + assert xgboost_training_job_1.get("NextToken") is not None xgboost_training_job_next = client.list_training_jobs( NameContains="xgboost", MaxResults=1, NextToken=xgboost_training_job_1.get("NextToken"), ) - assert len(xgboost_training_job_next["TrainingJobSummaries"]).should.equal(1) - assert xgboost_training_job_next["TrainingJobSummaries"][0][ - "TrainingJobName" - ].should.equal("xgboost-1") - assert xgboost_training_job_next.get("NextToken").should_not.be.none + assert len(xgboost_training_job_next["TrainingJobSummaries"]) == 1 + assert ( + xgboost_training_job_next["TrainingJobSummaries"][0]["TrainingJobName"] + == "xgboost-1" + ) + assert xgboost_training_job_next.get("NextToken") is not None @mock_sagemaker @@ -346,24 +353,20 @@ def test_list_training_jobs_paginated_with_target_in_middle(): MyTrainingJobModel(training_job_name=name, role_arn=arn).save() vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1) - assert len(vgg_training_job_1["TrainingJobSummaries"]).should.equal(0) - assert vgg_training_job_1.get("NextToken").should_not.be.none + assert len(vgg_training_job_1["TrainingJobSummaries"]) == 0 + assert vgg_training_job_1.get("NextToken") is not None vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6) - assert len(vgg_training_job_6["TrainingJobSummaries"]).should.equal(1) - assert vgg_training_job_6["TrainingJobSummaries"][0][ - "TrainingJobName" - ].should.equal("vgg-0") - assert vgg_training_job_6.get("NextToken").should_not.be.none + assert len(vgg_training_job_6["TrainingJobSummaries"]) == 1 + assert vgg_training_job_6["TrainingJobSummaries"][0]["TrainingJobName"] == "vgg-0" + assert vgg_training_job_6.get("NextToken") is not None vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10) - assert len(vgg_training_job_10["TrainingJobSummaries"]).should.equal(5) - assert vgg_training_job_10["TrainingJobSummaries"][-1][ - "TrainingJobName" - ].should.equal("vgg-4") - assert vgg_training_job_10.get("NextToken").should.be.none + assert len(vgg_training_job_10["TrainingJobSummaries"]) == 5 + assert vgg_training_job_10["TrainingJobSummaries"][-1]["TrainingJobName"] == "vgg-4" + assert vgg_training_job_10.get("NextToken") is None @mock_sagemaker @@ -379,22 +382,22 @@ def test_list_training_jobs_paginated_with_fragmented_targets(): MyTrainingJobModel(training_job_name=name, role_arn=arn).save() training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8) - assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2) - assert training_jobs_with_2.get("NextToken").should_not.be.none + assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2 + assert training_jobs_with_2.get("NextToken") is not None training_jobs_with_2_next = client.list_training_jobs( NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken") ) - assert len(training_jobs_with_2_next["TrainingJobSummaries"]).should.equal(0) - assert training_jobs_with_2_next.get("NextToken").should_not.be.none + assert len(training_jobs_with_2_next["TrainingJobSummaries"]) == 0 + assert training_jobs_with_2_next.get("NextToken") is not None training_jobs_with_2_next_next = client.list_training_jobs( NameContains="2", MaxResults=1, NextToken=training_jobs_with_2_next.get("NextToken"), ) - assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0) - assert training_jobs_with_2_next_next.get("NextToken").should.be.none + assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]) == 0 + assert training_jobs_with_2_next_next.get("NextToken") is None @mock_sagemaker @@ -447,7 +450,8 @@ def test_describe_unknown_training_job(): with pytest.raises(ClientError) as exc: client.describe_training_job(TrainingJobName="unknown") err = exc.value.response["Error"] - err["Code"].should.equal("ValidationException") - err["Message"].should.equal( - f"Could not find training job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:training-job/unknown'." + assert err["Code"] == "ValidationException" + assert err["Message"] == ( + "Could not find training job 'arn:aws:sagemaker:us-east-1:" + f"{ACCOUNT_ID}:training-job/unknown'." ) diff --git a/tests/test_sagemaker/test_sagemaker_transform.py b/tests/test_sagemaker/test_sagemaker_transform.py index 125681e9d..a4459d46b 100644 --- a/tests/test_sagemaker/test_sagemaker_transform.py +++ b/tests/test_sagemaker/test_sagemaker_transform.py @@ -1,7 +1,8 @@ +import datetime +import re + import boto3 from botocore.exceptions import ClientError -import datetime -import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_sagemaker @@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" TEST_REGION_NAME = "us-east-1" -class MyTransformJobModel(object): +class MyTransformJobModel: def __init__( self, transform_job_name, @@ -147,23 +148,24 @@ def test_create_transform_job(): experiment_config=experiment_config, ) resp = job.save() - resp["TransformJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$", + resp["TransformJobArn"], ) resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name) - resp["TransformJobName"].should.equal(transform_job_name) - resp["TransformJobStatus"].should.equal("Completed") - resp["ModelName"].should.equal(model_name) - resp["MaxConcurrentTransforms"].should.equal(1) - resp["ModelClientConfig"].should.equal(model_client_config) - resp["MaxPayloadInMB"].should.equal(max_payload_in_mb) - resp["BatchStrategy"].should.equal("SingleRecord") - resp["TransformInput"].should.equal(transform_input) - resp["TransformOutput"].should.equal(transform_output) - resp["DataCaptureConfig"].should.equal(data_capture_config) - resp["TransformResources"].should.equal(transform_resources) - resp["DataProcessing"].should.equal(data_processing) - resp["ExperimentConfig"].should.equal(experiment_config) + assert resp["TransformJobName"] == transform_job_name + assert resp["TransformJobStatus"] == "Completed" + assert resp["ModelName"] == model_name + assert resp["MaxConcurrentTransforms"] == 1 + assert resp["ModelClientConfig"] == model_client_config + assert resp["MaxPayloadInMB"] == max_payload_in_mb + assert resp["BatchStrategy"] == "SingleRecord" + assert resp["TransformInput"] == transform_input + assert resp["TransformOutput"] == transform_output + assert resp["DataCaptureConfig"] == data_capture_config + assert resp["TransformResources"] == transform_resources + assert resp["DataProcessing"] == data_processing + assert resp["ExperimentConfig"] == experiment_config assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["TransformStartTime"], datetime.datetime) assert isinstance(resp["TransformEndTime"], datetime.datetime) @@ -179,13 +181,12 @@ def test_list_transform_jobs(): ) test_transform_job.save() transform_jobs = client.list_transform_jobs() - assert len(transform_jobs["TransformJobSummaries"]).should.equal(1) - assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal( - name - ) + assert len(transform_jobs["TransformJobSummaries"]) == 1 + assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"] == name - assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match( - rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$" + assert re.match( + rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$", + transform_jobs["TransformJobSummaries"][0]["TransformJobArn"], ) assert transform_jobs.get("NextToken") is None @@ -207,18 +208,18 @@ def test_list_transform_jobs_multiple(): ) test_transform_job_2.save() transform_jobs_limit = client.list_transform_jobs(MaxResults=1) - assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1) + assert len(transform_jobs_limit["TransformJobSummaries"]) == 1 transform_jobs = client.list_transform_jobs() - assert len(transform_jobs["TransformJobSummaries"]).should.equal(2) - assert transform_jobs.get("NextToken").should.be.none + assert len(transform_jobs["TransformJobSummaries"]) == 2 + assert transform_jobs.get("NextToken") is None @mock_sagemaker def test_list_transform_jobs_none(): client = boto3.client("sagemaker", region_name="us-east-1") transform_jobs = client.list_transform_jobs() - assert len(transform_jobs["TransformJobSummaries"]).should.equal(0) + assert len(transform_jobs["TransformJobSummaries"]) == 0 @mock_sagemaker @@ -227,7 +228,12 @@ def test_list_transform_jobs_should_validate_input(): junk_status_equals = "blah" with pytest.raises(ClientError) as ex: client.list_transform_jobs(StatusEquals=junk_status_equals) - expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + expected_error = ( + f"1 validation errors detected: Value '{junk_status_equals}' at " + "'statusEquals' failed to satisfy constraint: Member must satisfy " + "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " + "'Failed']" + ) assert ex.value.response["Error"]["Code"] == "ValidationException" assert ex.value.response["Error"]["Message"] == expected_error @@ -253,10 +259,10 @@ def test_list_transform_jobs_with_name_filters(): model_name = f"blah_model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost") - assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5) + assert len(xgboost_transform_jobs["TransformJobSummaries"]) == 5 transform_jobs_with_2 = client.list_transform_jobs(NameContains="2") - assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2) + assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2 @mock_sagemaker @@ -269,22 +275,24 @@ def test_list_transform_jobs_paginated(): xgboost_transform_job_1 = client.list_transform_jobs( NameContains="xgboost", MaxResults=1 ) - assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1) - assert xgboost_transform_job_1["TransformJobSummaries"][0][ - "TransformJobName" - ].should.equal("xgboost-0") - assert xgboost_transform_job_1.get("NextToken").should_not.be.none + assert len(xgboost_transform_job_1["TransformJobSummaries"]) == 1 + assert ( + xgboost_transform_job_1["TransformJobSummaries"][0]["TransformJobName"] + == "xgboost-0" + ) + assert xgboost_transform_job_1.get("NextToken") is not None xgboost_transform_job_next = client.list_transform_jobs( NameContains="xgboost", MaxResults=1, NextToken=xgboost_transform_job_1.get("NextToken"), ) - assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1) - assert xgboost_transform_job_next["TransformJobSummaries"][0][ - "TransformJobName" - ].should.equal("xgboost-1") - assert xgboost_transform_job_next.get("NextToken").should_not.be.none + assert len(xgboost_transform_job_next["TransformJobSummaries"]) == 1 + assert ( + xgboost_transform_job_next["TransformJobSummaries"][0]["TransformJobName"] + == "xgboost-1" + ) + assert xgboost_transform_job_next.get("NextToken") is not None @mock_sagemaker @@ -299,24 +307,24 @@ def test_list_transform_jobs_paginated_with_target_in_middle(): MyTransformJobModel(transform_job_name=name, model_name=model_name).save() vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1) - assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0) - assert vgg_transform_job_1.get("NextToken").should_not.be.none + assert len(vgg_transform_job_1["TransformJobSummaries"]) == 0 + assert vgg_transform_job_1.get("NextToken") is not None vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6) - assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1) - assert vgg_transform_job_6["TransformJobSummaries"][0][ - "TransformJobName" - ].should.equal("vgg-0") - assert vgg_transform_job_6.get("NextToken").should_not.be.none + assert len(vgg_transform_job_6["TransformJobSummaries"]) == 1 + assert ( + vgg_transform_job_6["TransformJobSummaries"][0]["TransformJobName"] == "vgg-0" + ) + assert vgg_transform_job_6.get("NextToken") is not None vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10) - assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5) - assert vgg_transform_job_10["TransformJobSummaries"][-1][ - "TransformJobName" - ].should.equal("vgg-4") - assert vgg_transform_job_10.get("NextToken").should.be.none + assert len(vgg_transform_job_10["TransformJobSummaries"]) == 5 + assert ( + vgg_transform_job_10["TransformJobSummaries"][-1]["TransformJobName"] == "vgg-4" + ) + assert vgg_transform_job_10.get("NextToken") is None @mock_sagemaker @@ -331,22 +339,22 @@ def test_list_transform_jobs_paginated_with_fragmented_targets(): MyTransformJobModel(transform_job_name=name, model_name=model_name).save() transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8) - assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2) - assert transform_jobs_with_2.get("NextToken").should_not.be.none + assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2 + assert transform_jobs_with_2.get("NextToken") is not None transform_jobs_with_2_next = client.list_transform_jobs( NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken") ) - assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0) - assert transform_jobs_with_2_next.get("NextToken").should_not.be.none + assert len(transform_jobs_with_2_next["TransformJobSummaries"]) == 0 + assert transform_jobs_with_2_next.get("NextToken") is not None transform_jobs_with_2_next_next = client.list_transform_jobs( NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2_next.get("NextToken"), ) - assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0) - assert transform_jobs_with_2_next_next.get("NextToken").should.be.none + assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]) == 0 + assert transform_jobs_with_2_next_next.get("NextToken") is None @mock_sagemaker @@ -401,7 +409,8 @@ def test_describe_unknown_transform_job(): with pytest.raises(ClientError) as exc: client.describe_transform_job(TransformJobName="unknown") err = exc.value.response["Error"] - err["Code"].should.equal("ValidationException") - err["Message"].should.equal( - f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'." + assert err["Code"] == "ValidationException" + assert err["Message"] == ( + "Could not find transform job 'arn:aws:sagemaker:us-east-1:" + f"{ACCOUNT_ID}:transform-job/unknown'." ) diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py index 0affddc3c..f460ee54f 100644 --- a/tests/test_sagemaker/test_sagemaker_trial_component.py +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -1,9 +1,8 @@ import uuid import boto3 -import pytest - from botocore.exceptions import ClientError +import pytest from moto import mock_sagemaker from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID @@ -27,9 +26,9 @@ def test_create__trial_component(): assert ( resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name ) - assert ( - resp["TrialComponentSummaries"][0]["TrialComponentArn"] - == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + assert resp["TrialComponentSummaries"][0]["TrialComponentArn"] == ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":experiment-trial-component/{trial_component_name}" ) @@ -173,9 +172,9 @@ def test_associate_trial_component(): ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - assert ( - resp["TrialComponentArn"] - == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + assert resp["TrialComponentArn"] == ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":experiment-trial-component/{trial_component_name}" ) assert ( resp["TrialArn"] @@ -199,11 +198,12 @@ def test_associate_trial_component(): TrialComponentName="does-not-exist", TrialName="does-not-exist" ) - ex.value.response["Error"]["Code"].should.equal("ResourceNotFound") - ex.value.response["Error"]["Message"].should.equal( - f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist." + assert ex.value.response["Error"]["Code"] == "ResourceNotFound" + assert ex.value.response["Error"]["Message"] == ( + f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + ":experiment-trial/does-not-exist' does not exist." ) - ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400 @mock_sagemaker @@ -231,9 +231,9 @@ def test_disassociate_trial_component(): ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - assert ( - resp["TrialComponentArn"] - == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}" + assert resp["TrialComponentArn"] == ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" + f":experiment-trial-component/{trial_component_name}" ) assert ( resp["TrialArn"] @@ -255,9 +255,9 @@ def test_disassociate_trial_component(): ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - assert ( - resp["TrialComponentArn"] - == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist" + assert resp["TrialComponentArn"] == ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:" + "experiment-trial-component/does-not-exist" ) assert ( resp["TrialArn"]