Techdebt: Replace sure with regular assertions in sagemaker (#6614)
This commit is contained in:
parent
2f8019052d
commit
56153be9d8
@ -1,8 +1,8 @@
|
||||
import boto3
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import pytest
|
||||
|
||||
from moto import mock_cloudformation, mock_sagemaker
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
@ -53,8 +53,8 @@ def test_sagemaker_cloudformation_create(test_config):
|
||||
provisioned_resource = cf.list_stack_resources(StackName=stack_name)[
|
||||
"StackResourceSummaries"
|
||||
][0]
|
||||
provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name)
|
||||
len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0)
|
||||
assert provisioned_resource["LogicalResourceId"] == test_config.resource_name
|
||||
assert len(provisioned_resource["PhysicalResourceId"]) > 0
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -87,7 +87,7 @@ def test_sagemaker_cloudformation_get_attr(test_config):
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
|
||||
assert outputs["Arn"] == resource_description[test_config.arn_parameter]
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -122,7 +122,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
|
||||
assert outputs["Arn"] == resource_description[test_config.arn_parameter]
|
||||
|
||||
# Delete the stack and verify resource has also been deleted
|
||||
cf.delete_stack(StackName=stack_name)
|
||||
@ -130,7 +130,7 @@ def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_me
|
||||
getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
ce.value.response["Error"]["Message"].should.contain(error_message)
|
||||
assert error_message in ce.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -160,19 +160,19 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_notebook_name}
|
||||
)
|
||||
initial_instance_type.should.equal(resource_description["InstanceType"])
|
||||
assert initial_instance_type == resource_description["InstanceType"]
|
||||
|
||||
# Update stack and check attributes
|
||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||
outputs = _get_stack_outputs(cf, stack_name)
|
||||
|
||||
updated_notebook_name = outputs["Name"]
|
||||
updated_notebook_name.should.equal(initial_notebook_name)
|
||||
assert updated_notebook_name == initial_notebook_name
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_notebook_name}
|
||||
)
|
||||
updated_instance_type.should.equal(resource_description["InstanceType"])
|
||||
assert updated_instance_type == resource_description["InstanceType"]
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -202,25 +202,21 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_config_name}
|
||||
)
|
||||
len(resource_description["OnCreate"]).should.equal(1)
|
||||
initial_on_create_script.should.equal(
|
||||
resource_description["OnCreate"][0]["Content"]
|
||||
)
|
||||
assert len(resource_description["OnCreate"]) == 1
|
||||
assert initial_on_create_script == resource_description["OnCreate"][0]["Content"]
|
||||
|
||||
# Update stack and check attributes
|
||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||
outputs = _get_stack_outputs(cf, stack_name)
|
||||
|
||||
updated_config_name = outputs["Name"]
|
||||
updated_config_name.should.equal(initial_config_name)
|
||||
assert updated_config_name == initial_config_name
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_config_name}
|
||||
)
|
||||
len(resource_description["OnCreate"]).should.equal(1)
|
||||
updated_on_create_script.should.equal(
|
||||
resource_description["OnCreate"][0]["Content"]
|
||||
)
|
||||
assert len(resource_description["OnCreate"]) == 1
|
||||
assert updated_on_create_script == resource_description["OnCreate"][0]["Content"]
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -251,7 +247,7 @@ def test_sagemaker_cloudformation_model_update():
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_model_name}
|
||||
)
|
||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||
assert resource_description["PrimaryContainer"]["Image"] == (
|
||||
image.format(initial_image_version)
|
||||
)
|
||||
|
||||
@ -260,12 +256,12 @@ def test_sagemaker_cloudformation_model_update():
|
||||
outputs = _get_stack_outputs(cf, stack_name)
|
||||
|
||||
updated_model_name = outputs["Name"]
|
||||
updated_model_name.should_not.equal(initial_model_name)
|
||||
assert updated_model_name != initial_model_name
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_model_name}
|
||||
)
|
||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||
assert resource_description["PrimaryContainer"]["Image"] == (
|
||||
image.format(updated_image_version)
|
||||
)
|
||||
|
||||
@ -300,7 +296,7 @@ def test_sagemaker_cloudformation_endpoint_config_update():
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_endpoint_config_name}
|
||||
)
|
||||
len(resource_description["ProductionVariants"]).should.equal(
|
||||
assert len(resource_description["ProductionVariants"]) == (
|
||||
initial_num_production_variants
|
||||
)
|
||||
|
||||
@ -309,12 +305,12 @@ def test_sagemaker_cloudformation_endpoint_config_update():
|
||||
outputs = _get_stack_outputs(cf, stack_name)
|
||||
|
||||
updated_endpoint_config_name = outputs["Name"]
|
||||
updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
|
||||
assert updated_endpoint_config_name != initial_endpoint_config_name
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_endpoint_config_name}
|
||||
)
|
||||
len(resource_description["ProductionVariants"]).should.equal(
|
||||
assert len(resource_description["ProductionVariants"]) == (
|
||||
updated_num_production_variants
|
||||
)
|
||||
|
||||
@ -365,8 +361,8 @@ def test_sagemaker_cloudformation_endpoint_update():
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_endpoint_name}
|
||||
)
|
||||
resource_description["EndpointConfigName"].should.match(
|
||||
initial_endpoint_config_name
|
||||
assert re.match(
|
||||
initial_endpoint_config_name, resource_description["EndpointConfigName"]
|
||||
)
|
||||
|
||||
# Create additional SM resources and update stack
|
||||
@ -393,11 +389,11 @@ def test_sagemaker_cloudformation_endpoint_update():
|
||||
outputs = _get_stack_outputs(cf, stack_name)
|
||||
|
||||
updated_endpoint_name = outputs["Name"]
|
||||
updated_endpoint_name.should.equal(initial_endpoint_name)
|
||||
assert updated_endpoint_name == initial_endpoint_name
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_endpoint_name}
|
||||
)
|
||||
resource_description["EndpointConfigName"].should.match(
|
||||
updated_endpoint_config_name
|
||||
assert re.match(
|
||||
updated_endpoint_config_name, resource_description["EndpointConfigName"]
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
@ -55,18 +55,20 @@ def create_endpoint_config_helper(sagemaker_client, production_variants):
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
|
||||
resp["EndpointConfigArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
|
||||
resp["EndpointConfigArn"],
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["ProductionVariants"].should.equal(production_variants)
|
||||
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
|
||||
assert resp["ProductionVariants"] == production_variants
|
||||
|
||||
|
||||
def test_create_endpoint_config(sagemaker_client):
|
||||
@ -99,15 +101,17 @@ def test_delete_endpoint_config(sagemaker_client):
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
|
||||
resp["EndpointConfigArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$",
|
||||
resp["EndpointConfigArn"],
|
||||
)
|
||||
|
||||
sagemaker_client.delete_endpoint_config(
|
||||
@ -143,7 +147,10 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert e.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: ["
|
||||
expected_message = (
|
||||
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
|
||||
"constraint: Member must satisfy enum value set: ["
|
||||
)
|
||||
assert expected_message in e.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
@ -160,7 +167,10 @@ def test_create_endpoint_invalid_memory_size(sagemaker_client):
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert e.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: ["
|
||||
expected_message = (
|
||||
f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy "
|
||||
"constraint: Member must satisfy enum value set: ["
|
||||
)
|
||||
assert expected_message in e.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
@ -185,20 +195,20 @@ def test_create_endpoint(sagemaker_client):
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
Tags=GENERIC_TAGS_PARAM,
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
|
||||
)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
|
||||
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
|
||||
assert resp["EndpointStatus"] == "InService"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
|
||||
assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME
|
||||
|
||||
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
assert resp["Tags"] == GENERIC_TAGS_PARAM
|
||||
@ -224,7 +234,10 @@ def test_add_tags_endpoint(sagemaker_client):
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
resource_arn = (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":endpoint/{TEST_ENDPOINT_NAME}"
|
||||
)
|
||||
response = sagemaker_client.add_tags(
|
||||
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
|
||||
)
|
||||
@ -239,7 +252,10 @@ def test_delete_tags_endpoint(sagemaker_client):
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
resource_arn = (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":endpoint/{TEST_ENDPOINT_NAME}"
|
||||
)
|
||||
response = sagemaker_client.add_tags(
|
||||
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
|
||||
)
|
||||
@ -262,7 +278,10 @@ def test_list_tags_endpoint(sagemaker_client):
|
||||
for _ in range(80):
|
||||
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
resource_arn = (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":endpoint/{TEST_ENDPOINT_NAME}"
|
||||
)
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
@ -295,29 +314,32 @@ def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
|
||||
},
|
||||
],
|
||||
)
|
||||
response["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$",
|
||||
response["EndpointArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
|
||||
)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
|
||||
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
|
||||
assert resp["EndpointStatus"] == "InService"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME
|
||||
assert (
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert (
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight
|
||||
assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight
|
||||
|
||||
|
||||
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
|
||||
@ -364,39 +386,44 @@ def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
DesiredWeightsAndCapacities=desired_weights_and_capacities,
|
||||
)
|
||||
response["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$",
|
||||
response["EndpointArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"]
|
||||
)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert resp["EndpointName"] == TEST_ENDPOINT_NAME
|
||||
assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME
|
||||
assert resp["EndpointStatus"] == "InService"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1")
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert resp["ProductionVariants"][0]["VariantName"] == "MyProductionVariant1"
|
||||
assert (
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert (
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight
|
||||
assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight
|
||||
|
||||
resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2")
|
||||
resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert resp["ProductionVariants"][1]["VariantName"] == "MyProductionVariant2"
|
||||
assert (
|
||||
resp["ProductionVariants"][1]["DesiredInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
assert (
|
||||
resp["ProductionVariants"][1]["CurrentInstanceCount"]
|
||||
== new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
assert resp["ProductionVariants"][1]["DesiredWeight"] == new_desired_weight
|
||||
assert resp["ProductionVariants"][1]["CurrentWeight"] == new_desired_weight
|
||||
|
||||
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
|
||||
@ -426,13 +453,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_vari
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
|
||||
assert err["Message"] == (
|
||||
f'The variant name(s) "{variant_name}" is/are not present within '
|
||||
f'endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
assert resp == old_resp
|
||||
|
||||
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
|
||||
@ -463,13 +491,14 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endp
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".'
|
||||
assert err["Message"] == (
|
||||
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:'
|
||||
f'{ACCOUNT_ID}:endpoint/{endpoint_name}".'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
assert resp == old_resp
|
||||
|
||||
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
|
||||
@ -502,13 +531,13 @@ def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonuniq
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
assert err["Message"] == (
|
||||
f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
assert resp == old_resp
|
||||
|
||||
|
||||
def _set_up_sagemaker_resources(
|
||||
@ -552,8 +581,9 @@ def _create_endpoint_config(
|
||||
resp = boto_client.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$",
|
||||
resp["EndpointConfigArn"],
|
||||
)
|
||||
|
||||
|
||||
@ -561,6 +591,6 @@ def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
|
||||
resp = boto_client.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$", resp["EndpointArn"]
|
||||
)
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Unit tests for sagemaker-supported APIs."""
|
||||
from unittest import SkipTest
|
||||
|
||||
import boto3
|
||||
from freezegun import freeze_time
|
||||
|
||||
from moto import mock_sagemaker, settings
|
||||
from unittest import SkipTest
|
||||
|
||||
# See our Development Tips on writing tests for hints on how to write good tests:
|
||||
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
|
||||
|
@ -1,10 +1,10 @@
|
||||
import re
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
|
||||
from moto.sagemaker.models import VpcConfig
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
@ -18,7 +18,7 @@ def fixture_sagemaker_client():
|
||||
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
class MySageMakerModel(object):
|
||||
class MySageMakerModel:
|
||||
def __init__(self, name=None, arn=None, container=None, vpc_config=None):
|
||||
self.name = name or TEST_MODEL_NAME
|
||||
self.arn = arn or TEST_ARN
|
||||
@ -41,13 +41,13 @@ def test_describe_model(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
|
||||
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
|
||||
assert model.get("ModelName") == TEST_MODEL_NAME
|
||||
|
||||
|
||||
def test_describe_model_not_found(sagemaker_client):
|
||||
with pytest.raises(ClientError) as err:
|
||||
sagemaker_client.describe_model(ModelName="unknown")
|
||||
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
|
||||
assert "Could not find model" in err.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
def test_create_model(sagemaker_client):
|
||||
@ -57,8 +57,8 @@ def test_create_model(sagemaker_client):
|
||||
ExecutionRoleArn=TEST_ARN,
|
||||
VpcConfig=vpc_config.response_object,
|
||||
)
|
||||
model["ModelArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", model["ModelArn"]
|
||||
)
|
||||
|
||||
|
||||
@ -66,25 +66,26 @@ def test_delete_model(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
|
||||
assert len(sagemaker_client.list_models()["Models"]).should.equal(1)
|
||||
assert len(sagemaker_client.list_models()["Models"]) == 1
|
||||
sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
|
||||
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
|
||||
assert len(sagemaker_client.list_models()["Models"]) == 0
|
||||
|
||||
|
||||
def test_delete_model_not_found(sagemaker_client):
|
||||
with pytest.raises(ClientError) as err:
|
||||
sagemaker_client.delete_model(ModelName="blah")
|
||||
assert err.value.response["Error"]["Code"].should.equal("404")
|
||||
assert err.value.response["Error"]["Code"] == "404"
|
||||
|
||||
|
||||
def test_list_models(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(1)
|
||||
assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME)
|
||||
assert models["Models"][0]["ModelArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$"
|
||||
assert len(models["Models"]) == 1
|
||||
assert models["Models"][0]["ModelName"] == TEST_MODEL_NAME
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$",
|
||||
models["Models"][0]["ModelArn"],
|
||||
)
|
||||
|
||||
|
||||
@ -99,12 +100,12 @@ def test_list_models_multiple(sagemaker_client):
|
||||
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
|
||||
test_model_2.save(sagemaker_client)
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(2)
|
||||
assert len(models["Models"]) == 2
|
||||
|
||||
|
||||
def test_list_models_none(sagemaker_client):
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(0)
|
||||
assert len(models["Models"]) == 0
|
||||
|
||||
|
||||
def test_add_tags_to_model(sagemaker_client):
|
||||
|
@ -1,11 +1,11 @@
|
||||
import datetime
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
import pytest
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
FAKE_SUBNET_ID = "subnet-012345678"
|
||||
@ -37,7 +37,10 @@ def _get_notebook_instance_arn(notebook_name):
|
||||
|
||||
|
||||
def _get_notebook_instance_lifecycle_arn(lifecycle_name):
|
||||
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
|
||||
return (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":notebook-instance-lifecycle-configuration/{lifecycle_name}"
|
||||
)
|
||||
|
||||
|
||||
def test_create_notebook_instance_minimal_params(sagemaker_client):
|
||||
@ -130,7 +133,10 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker_client.create_notebook_instance(**args)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = f"Value '{instance_type}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: ["
|
||||
expected_message = (
|
||||
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
|
||||
"constraint: Member must satisfy enum value set: ["
|
||||
)
|
||||
|
||||
assert expected_message in ex.value.response["Error"]["Message"]
|
||||
|
||||
@ -153,7 +159,10 @@ def test_notebook_instance_lifecycle(sagemaker_client):
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = f"Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({notebook_instance_arn})"
|
||||
expected_message = (
|
||||
f"Status (InService) not in ([Stopped, Failed]). Unable to "
|
||||
f"transition to (Deleting) for Notebook Instance ({notebook_instance_arn})"
|
||||
)
|
||||
assert expected_message in ex.value.response["Error"]["Message"]
|
||||
|
||||
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
|
@ -1,13 +1,14 @@
|
||||
from contextlib import contextmanager
|
||||
from moto import mock_sagemaker, settings
|
||||
from time import sleep
|
||||
from datetime import datetime
|
||||
import boto3
|
||||
import botocore
|
||||
import json
|
||||
import pytest
|
||||
from time import sleep
|
||||
from unittest import SkipTest
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker, settings
|
||||
from moto.s3 import mock_s3
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
from moto.sagemaker.exceptions import ValidationError
|
||||
@ -85,7 +86,10 @@ def test_utils_get_pipeline_from_name_not_exists():
|
||||
|
||||
def test_utils_get_pipeline_name_from_execution_arn():
|
||||
expected_pipeline_name = "some-pipeline-name"
|
||||
pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456"
|
||||
pipeline_execution_arn = (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":pipeline/{expected_pipeline_name}/execution/abc123def456"
|
||||
)
|
||||
observed_pipeline_name = get_pipeline_name_from_execution_arn(
|
||||
pipeline_execution_arn=pipeline_execution_arn
|
||||
)
|
||||
@ -266,7 +270,7 @@ def test_describe_pipeline_execution(sagemaker_client):
|
||||
PipelineExecutionArn=response["PipelineExecutionArn"]
|
||||
)
|
||||
observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"]
|
||||
observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn)
|
||||
assert observed_pipeline_execution_arn == expected_pipeline_execution_arn
|
||||
|
||||
|
||||
def test_load_pipeline_definition_from_s3():
|
||||
@ -292,7 +296,7 @@ def test_load_pipeline_definition_from_s3():
|
||||
},
|
||||
account_id=ACCOUNT_ID,
|
||||
)
|
||||
observed_pipeline_definition.should.equal(pipeline_definition)
|
||||
assert observed_pipeline_definition == pipeline_definition
|
||||
|
||||
|
||||
def test_create_pipeline(sagemaker_client):
|
||||
@ -303,7 +307,7 @@ def test_create_pipeline(sagemaker_client):
|
||||
PipelineDefinition=" ",
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
response["PipelineArn"].should.equal(
|
||||
assert response["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
@ -353,7 +357,7 @@ def test_create_pipeline_duplicate_pipeline_name(sagemaker_client):
|
||||
def test_list_pipelines_none(sagemaker_client):
|
||||
response = sagemaker_client.list_pipelines()
|
||||
assert isinstance(response, dict)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
|
||||
def test_list_pipelines_single(sagemaker_client):
|
||||
@ -368,8 +372,8 @@ def test_list_pipelines_single(sagemaker_client):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
assert len(response["PipelineSummaries"]) == 1
|
||||
assert response["PipelineSummaries"][0]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
@ -390,7 +394,7 @@ def test_list_pipelines_multiple(sagemaker_client):
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
)
|
||||
response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names))
|
||||
assert len(response["PipelineSummaries"]) == len(fake_pipeline_names)
|
||||
|
||||
|
||||
def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
||||
@ -409,13 +413,13 @@ def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][0]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][-1]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][-1]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][1]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
@ -436,13 +440,13 @@ def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
|
||||
SortBy="CreationTime",
|
||||
SortOrder="Descending",
|
||||
)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][0]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][1]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][2]["PipelineArn"].should.equal(
|
||||
assert response["PipelineSummaries"][2]["PipelineArn"] == (
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
@ -460,7 +464,7 @@ def test_list_pipelines_max_results(sagemaker_client):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(MaxResults=2)
|
||||
response["PipelineSummaries"].should.have.length_of(2)
|
||||
assert len(response["PipelineSummaries"]) == 2
|
||||
|
||||
|
||||
def test_list_pipelines_next_token(sagemaker_client):
|
||||
@ -475,7 +479,7 @@ def test_list_pipelines_next_token(sagemaker_client):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(NextToken="0")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
assert len(response["PipelineSummaries"]) == 1
|
||||
|
||||
|
||||
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
||||
@ -491,11 +495,11 @@ def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
|
||||
assert len(response["PipelineSummaries"]) == 1
|
||||
assert response["PipelineSummaries"][0]["PipelineName"] == "APipelineName"
|
||||
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline")
|
||||
response["PipelineSummaries"].should.have.length_of(3)
|
||||
assert len(response["PipelineSummaries"]) == 3
|
||||
|
||||
|
||||
def test_list_pipelines_created_after(sagemaker_client):
|
||||
@ -512,15 +516,15 @@ def test_list_pipelines_created_after(sagemaker_client):
|
||||
|
||||
created_after_str = "2099-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S")
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
created_after_timestamp = datetime.timestamp(created_after_datetime)
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
|
||||
def test_list_pipelines_created_before(sagemaker_client):
|
||||
@ -537,15 +541,15 @@ def test_list_pipelines_created_before(sagemaker_client):
|
||||
|
||||
created_before_str = "2000-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S")
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
created_before_timestamp = datetime.timestamp(created_before_datetime)
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -582,7 +586,7 @@ def test_delete_pipeline_exists(sagemaker_client):
|
||||
assert response["PipelineArn"].endswith(pipeline_name_delete)
|
||||
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix=pipeline_name_delete)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
assert not response["PipelineSummaries"]
|
||||
|
||||
response = sagemaker_client.list_pipelines()
|
||||
pipeline_names_exist = [
|
||||
@ -627,11 +631,11 @@ def test_update_pipeline_no_update(sagemaker_client):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
|
||||
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
|
||||
response["PipelineArn"].should.equal(
|
||||
assert response["PipelineArn"] == (
|
||||
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
|
||||
assert response["PipelineSummaries"][0]["PipelineName"] == pipeline_name
|
||||
|
||||
|
||||
def test_update_pipeline_add_attribute(sagemaker_client):
|
||||
@ -645,17 +649,17 @@ def test_update_pipeline_add_attribute(sagemaker_client):
|
||||
}
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name)
|
||||
assert response["PipelineSummaries"][0]["PipelineDisplayName"] == pipeline_name
|
||||
|
||||
_ = sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
PipelineDisplayName=pipeline_display_name_update,
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
|
||||
assert response["PipelineSummaries"][0]["PipelineDisplayName"] == (
|
||||
pipeline_display_name_update
|
||||
)
|
||||
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||
assert len(response["PipelineSummaries"][0]) == 6
|
||||
|
||||
|
||||
def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||
@ -673,8 +677,8 @@ def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||
RoleArn=role_arn_update,
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
|
||||
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||
assert response["PipelineSummaries"][0]["RoleArn"] == role_arn_update
|
||||
assert len(response["PipelineSummaries"][0]) == 6
|
||||
|
||||
|
||||
def test_describe_pipeline_not_exists(sagemaker_client):
|
||||
@ -707,4 +711,4 @@ def test_describe_pipeline_not_exists(sagemaker_client):
|
||||
def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"])
|
||||
response.should.have.length_of(expected_response_length)
|
||||
assert len(response) == expected_response_length
|
||||
|
@ -1,6 +1,8 @@
|
||||
import datetime
|
||||
import re
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
@ -18,7 +20,7 @@ def fixture_sagemaker_client():
|
||||
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
class MyProcessingJobModel(object):
|
||||
class MyProcessingJobModel:
|
||||
def __init__(
|
||||
self,
|
||||
processing_job_name,
|
||||
@ -129,16 +131,18 @@ def test_create_processing_job(sagemaker_client):
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save(sagemaker_client)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
|
||||
resp["ProcessingJobArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_processing_job(
|
||||
ProcessingJobName=FAKE_PROCESSING_JOB_NAME
|
||||
)
|
||||
resp["ProcessingJobName"].should.equal(FAKE_PROCESSING_JOB_NAME)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
|
||||
assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
|
||||
resp["ProcessingJobArn"],
|
||||
)
|
||||
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
@ -154,15 +158,15 @@ def test_list_processing_jobs(sagemaker_client):
|
||||
)
|
||||
test_processing_job.save(sagemaker_client)
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal(FAKE_PROCESSING_JOB_NAME)
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]) == 1
|
||||
assert (
|
||||
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"]
|
||||
== FAKE_PROCESSING_JOB_NAME
|
||||
)
|
||||
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobArn"
|
||||
].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
|
||||
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"],
|
||||
)
|
||||
assert processing_jobs.get("NextToken") is None
|
||||
|
||||
@ -182,23 +186,28 @@ def test_list_processing_jobs_multiple(sagemaker_client):
|
||||
)
|
||||
test_processing_job_2.save(sagemaker_client)
|
||||
processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1)
|
||||
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1
|
||||
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs.get("NextToken").should.be.none
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]) == 2
|
||||
assert processing_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
def test_list_processing_jobs_none(sagemaker_client):
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]) == 0
|
||||
|
||||
|
||||
def test_list_processing_jobs_should_validate_input(sagemaker_client):
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
expected_error = (
|
||||
f"1 validation errors detected: Value '{junk_status_equals}' at "
|
||||
"'statusEquals' failed to satisfy constraint: Member must satisfy "
|
||||
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
|
||||
"'Failed']"
|
||||
)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
@ -230,10 +239,10 @@ def test_list_processing_jobs_with_name_filters(sagemaker_client):
|
||||
xgboost_processing_jobs = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost"
|
||||
)
|
||||
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
|
||||
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5
|
||||
|
||||
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2")
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
|
||||
|
||||
|
||||
def test_list_processing_jobs_paginated(sagemaker_client):
|
||||
@ -247,22 +256,24 @@ def test_list_processing_jobs_paginated(sagemaker_client):
|
||||
xgboost_processing_job_1 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"]
|
||||
== "xgboost-0"
|
||||
)
|
||||
assert xgboost_processing_job_1.get("NextToken") is not None
|
||||
|
||||
xgboost_processing_job_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_processing_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"]
|
||||
== "xgboost-1"
|
||||
)
|
||||
assert xgboost_processing_job_next.get("NextToken") is not None
|
||||
|
||||
|
||||
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
|
||||
@ -283,28 +294,30 @@ def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
|
||||
vgg_processing_job_1 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=1
|
||||
)
|
||||
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert vgg_processing_job_1.get("NextToken").should_not.be.none
|
||||
assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0
|
||||
assert vgg_processing_job_1.get("NextToken") is not None
|
||||
|
||||
vgg_processing_job_6 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=6
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_processing_job_6.get("NextToken").should_not.be.none
|
||||
assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1
|
||||
assert (
|
||||
vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"]
|
||||
== "vgg-0"
|
||||
)
|
||||
assert vgg_processing_job_6.get("NextToken") is not None
|
||||
|
||||
vgg_processing_job_10 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=10
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
|
||||
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_processing_job_10.get("NextToken").should.be.none
|
||||
assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5
|
||||
assert (
|
||||
vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"]
|
||||
== "vgg-4"
|
||||
)
|
||||
assert vgg_processing_job_10.get("NextToken") is None
|
||||
|
||||
|
||||
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
|
||||
@ -325,26 +338,24 @@ def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client
|
||||
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2", MaxResults=8
|
||||
)
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs_with_2.get("NextToken").should_not.be.none
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
|
||||
assert processing_jobs_with_2.get("NextToken") is not None
|
||||
|
||||
processing_jobs_with_2_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0
|
||||
assert processing_jobs_with_2_next.get("NextToken") is not None
|
||||
|
||||
processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
|
||||
0
|
||||
)
|
||||
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0
|
||||
assert processing_jobs_with_2_next_next.get("NextToken") is None
|
||||
|
||||
|
||||
def test_add_and_delete_tags_in_training_job(sagemaker_client):
|
||||
|
@ -1,7 +1,6 @@
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
|
||||
@ -94,11 +93,11 @@ def test_search_trial_component_with_experiment_name(sagemaker_client):
|
||||
},
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Code"].should.equal("ValidationException")
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Unknown property name: ExperimentName"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"] == "Unknown property name: ExperimentName"
|
||||
)
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400
|
||||
|
||||
|
||||
def _set_up_trial_component(
|
||||
|
@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import re
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
class MyTrainingJobModel(object):
|
||||
class MyTrainingJobModel:
|
||||
def __init__(
|
||||
self,
|
||||
training_job_name,
|
||||
@ -167,14 +168,16 @@ def test_create_training_job():
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["TrainingJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
|
||||
resp["TrainingJobArn"],
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
|
||||
resp["TrainingJobName"].should.equal(training_job_name)
|
||||
resp["TrainingJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$"
|
||||
assert resp["TrainingJobName"] == training_job_name
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
|
||||
resp["TrainingJobArn"],
|
||||
)
|
||||
assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
|
||||
output_data_config["S3OutputPath"]
|
||||
@ -214,8 +217,6 @@ def test_create_training_job():
|
||||
assert "Value" in resp["FinalMetricDataList"][0]
|
||||
assert "Timestamp" in resp["FinalMetricDataList"][0]
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs():
|
||||
@ -225,13 +226,12 @@ def test_list_training_jobs():
|
||||
test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn)
|
||||
test_training_job.save()
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(1)
|
||||
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"].should.equal(
|
||||
name
|
||||
)
|
||||
assert len(training_jobs["TrainingJobSummaries"]) == 1
|
||||
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"] == name
|
||||
|
||||
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$",
|
||||
training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"],
|
||||
)
|
||||
assert training_jobs.get("NextToken") is None
|
||||
|
||||
@ -253,18 +253,18 @@ def test_list_training_jobs_multiple():
|
||||
)
|
||||
test_training_job_2.save()
|
||||
training_jobs_limit = client.list_training_jobs(MaxResults=1)
|
||||
assert len(training_jobs_limit["TrainingJobSummaries"]).should.equal(1)
|
||||
assert len(training_jobs_limit["TrainingJobSummaries"]) == 1
|
||||
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(2)
|
||||
assert training_jobs.get("NextToken").should.be.none
|
||||
assert len(training_jobs["TrainingJobSummaries"]) == 2
|
||||
assert training_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(0)
|
||||
assert len(training_jobs["TrainingJobSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -273,7 +273,12 @@ def test_list_training_jobs_should_validate_input():
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_training_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
expected_error = (
|
||||
f"1 validation errors detected: Value '{junk_status_equals}' at "
|
||||
"'statusEquals' failed to satisfy constraint: Member must satisfy "
|
||||
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
|
||||
"'Failed']"
|
||||
)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
@ -299,10 +304,10 @@ def test_list_training_jobs_with_name_filters():
|
||||
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_training_jobs["TrainingJobSummaries"]).should.equal(5)
|
||||
assert len(xgboost_training_jobs["TrainingJobSummaries"]) == 5
|
||||
|
||||
training_jobs_with_2 = client.list_training_jobs(NameContains="2")
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -315,22 +320,24 @@ def test_list_training_jobs_paginated():
|
||||
xgboost_training_job_1 = client.list_training_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_training_job_1["TrainingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_training_job_1["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_training_job_1.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_training_job_1["TrainingJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_training_job_1["TrainingJobSummaries"][0]["TrainingJobName"]
|
||||
== "xgboost-0"
|
||||
)
|
||||
assert xgboost_training_job_1.get("NextToken") is not None
|
||||
|
||||
xgboost_training_job_next = client.list_training_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_training_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_training_job_next["TrainingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_training_job_next["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_training_job_next.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_training_job_next["TrainingJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_training_job_next["TrainingJobSummaries"][0]["TrainingJobName"]
|
||||
== "xgboost-1"
|
||||
)
|
||||
assert xgboost_training_job_next.get("NextToken") is not None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -346,24 +353,20 @@ def test_list_training_jobs_paginated_with_target_in_middle():
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
|
||||
vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_training_job_1["TrainingJobSummaries"]).should.equal(0)
|
||||
assert vgg_training_job_1.get("NextToken").should_not.be.none
|
||||
assert len(vgg_training_job_1["TrainingJobSummaries"]) == 0
|
||||
assert vgg_training_job_1.get("NextToken") is not None
|
||||
|
||||
vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_training_job_6["TrainingJobSummaries"]).should.equal(1)
|
||||
assert vgg_training_job_6["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_training_job_6.get("NextToken").should_not.be.none
|
||||
assert len(vgg_training_job_6["TrainingJobSummaries"]) == 1
|
||||
assert vgg_training_job_6["TrainingJobSummaries"][0]["TrainingJobName"] == "vgg-0"
|
||||
assert vgg_training_job_6.get("NextToken") is not None
|
||||
|
||||
vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10)
|
||||
|
||||
assert len(vgg_training_job_10["TrainingJobSummaries"]).should.equal(5)
|
||||
assert vgg_training_job_10["TrainingJobSummaries"][-1][
|
||||
"TrainingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_training_job_10.get("NextToken").should.be.none
|
||||
assert len(vgg_training_job_10["TrainingJobSummaries"]) == 5
|
||||
assert vgg_training_job_10["TrainingJobSummaries"][-1]["TrainingJobName"] == "vgg-4"
|
||||
assert vgg_training_job_10.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -379,22 +382,22 @@ def test_list_training_jobs_paginated_with_fragmented_targets():
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
|
||||
training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
|
||||
assert training_jobs_with_2.get("NextToken").should_not.be.none
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
|
||||
assert training_jobs_with_2.get("NextToken") is not None
|
||||
|
||||
training_jobs_with_2_next = client.list_training_jobs(
|
||||
NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken")
|
||||
)
|
||||
assert len(training_jobs_with_2_next["TrainingJobSummaries"]).should.equal(0)
|
||||
assert training_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
assert len(training_jobs_with_2_next["TrainingJobSummaries"]) == 0
|
||||
assert training_jobs_with_2_next.get("NextToken") is not None
|
||||
|
||||
training_jobs_with_2_next_next = client.list_training_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=training_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0)
|
||||
assert training_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]) == 0
|
||||
assert training_jobs_with_2_next_next.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -447,7 +450,8 @@ def test_describe_unknown_training_job():
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.describe_training_job(TrainingJobName="unknown")
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("ValidationException")
|
||||
err["Message"].should.equal(
|
||||
f"Could not find training job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:training-job/unknown'."
|
||||
assert err["Code"] == "ValidationException"
|
||||
assert err["Message"] == (
|
||||
"Could not find training job 'arn:aws:sagemaker:us-east-1:"
|
||||
f"{ACCOUNT_ID}:training-job/unknown'."
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import re
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
@ -11,7 +12,7 @@ FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
class MyTransformJobModel(object):
|
||||
class MyTransformJobModel:
|
||||
def __init__(
|
||||
self,
|
||||
transform_job_name,
|
||||
@ -147,23 +148,24 @@ def test_create_transform_job():
|
||||
experiment_config=experiment_config,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["TransformJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$",
|
||||
resp["TransformJobArn"],
|
||||
)
|
||||
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
|
||||
resp["TransformJobName"].should.equal(transform_job_name)
|
||||
resp["TransformJobStatus"].should.equal("Completed")
|
||||
resp["ModelName"].should.equal(model_name)
|
||||
resp["MaxConcurrentTransforms"].should.equal(1)
|
||||
resp["ModelClientConfig"].should.equal(model_client_config)
|
||||
resp["MaxPayloadInMB"].should.equal(max_payload_in_mb)
|
||||
resp["BatchStrategy"].should.equal("SingleRecord")
|
||||
resp["TransformInput"].should.equal(transform_input)
|
||||
resp["TransformOutput"].should.equal(transform_output)
|
||||
resp["DataCaptureConfig"].should.equal(data_capture_config)
|
||||
resp["TransformResources"].should.equal(transform_resources)
|
||||
resp["DataProcessing"].should.equal(data_processing)
|
||||
resp["ExperimentConfig"].should.equal(experiment_config)
|
||||
assert resp["TransformJobName"] == transform_job_name
|
||||
assert resp["TransformJobStatus"] == "Completed"
|
||||
assert resp["ModelName"] == model_name
|
||||
assert resp["MaxConcurrentTransforms"] == 1
|
||||
assert resp["ModelClientConfig"] == model_client_config
|
||||
assert resp["MaxPayloadInMB"] == max_payload_in_mb
|
||||
assert resp["BatchStrategy"] == "SingleRecord"
|
||||
assert resp["TransformInput"] == transform_input
|
||||
assert resp["TransformOutput"] == transform_output
|
||||
assert resp["DataCaptureConfig"] == data_capture_config
|
||||
assert resp["TransformResources"] == transform_resources
|
||||
assert resp["DataProcessing"] == data_processing
|
||||
assert resp["ExperimentConfig"] == experiment_config
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["TransformStartTime"], datetime.datetime)
|
||||
assert isinstance(resp["TransformEndTime"], datetime.datetime)
|
||||
@ -179,13 +181,12 @@ def test_list_transform_jobs():
|
||||
)
|
||||
test_transform_job.save()
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(1)
|
||||
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal(
|
||||
name
|
||||
)
|
||||
assert len(transform_jobs["TransformJobSummaries"]) == 1
|
||||
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"] == name
|
||||
|
||||
assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$"
|
||||
assert re.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$",
|
||||
transform_jobs["TransformJobSummaries"][0]["TransformJobArn"],
|
||||
)
|
||||
assert transform_jobs.get("NextToken") is None
|
||||
|
||||
@ -207,18 +208,18 @@ def test_list_transform_jobs_multiple():
|
||||
)
|
||||
test_transform_job_2.save()
|
||||
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
|
||||
assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1)
|
||||
assert len(transform_jobs_limit["TransformJobSummaries"]) == 1
|
||||
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(2)
|
||||
assert transform_jobs.get("NextToken").should.be.none
|
||||
assert len(transform_jobs["TransformJobSummaries"]) == 2
|
||||
assert transform_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(0)
|
||||
assert len(transform_jobs["TransformJobSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -227,7 +228,12 @@ def test_list_transform_jobs_should_validate_input():
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_transform_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
expected_error = (
|
||||
f"1 validation errors detected: Value '{junk_status_equals}' at "
|
||||
"'statusEquals' failed to satisfy constraint: Member must satisfy "
|
||||
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
|
||||
"'Failed']"
|
||||
)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
@ -253,10 +259,10 @@ def test_list_transform_jobs_with_name_filters():
|
||||
model_name = f"blah_model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5)
|
||||
assert len(xgboost_transform_jobs["TransformJobSummaries"]) == 5
|
||||
|
||||
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -269,22 +275,24 @@ def test_list_transform_jobs_paginated():
|
||||
xgboost_transform_job_1 = client.list_transform_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1)
|
||||
assert xgboost_transform_job_1["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_transform_job_1.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_transform_job_1["TransformJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_transform_job_1["TransformJobSummaries"][0]["TransformJobName"]
|
||||
== "xgboost-0"
|
||||
)
|
||||
assert xgboost_transform_job_1.get("NextToken") is not None
|
||||
|
||||
xgboost_transform_job_next = client.list_transform_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_transform_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1)
|
||||
assert xgboost_transform_job_next["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_transform_job_next.get("NextToken").should_not.be.none
|
||||
assert len(xgboost_transform_job_next["TransformJobSummaries"]) == 1
|
||||
assert (
|
||||
xgboost_transform_job_next["TransformJobSummaries"][0]["TransformJobName"]
|
||||
== "xgboost-1"
|
||||
)
|
||||
assert xgboost_transform_job_next.get("NextToken") is not None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -299,24 +307,24 @@ def test_list_transform_jobs_paginated_with_target_in_middle():
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
|
||||
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0)
|
||||
assert vgg_transform_job_1.get("NextToken").should_not.be.none
|
||||
assert len(vgg_transform_job_1["TransformJobSummaries"]) == 0
|
||||
assert vgg_transform_job_1.get("NextToken") is not None
|
||||
|
||||
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1)
|
||||
assert vgg_transform_job_6["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_transform_job_6.get("NextToken").should_not.be.none
|
||||
assert len(vgg_transform_job_6["TransformJobSummaries"]) == 1
|
||||
assert (
|
||||
vgg_transform_job_6["TransformJobSummaries"][0]["TransformJobName"] == "vgg-0"
|
||||
)
|
||||
assert vgg_transform_job_6.get("NextToken") is not None
|
||||
|
||||
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
|
||||
|
||||
assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5)
|
||||
assert vgg_transform_job_10["TransformJobSummaries"][-1][
|
||||
"TransformJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_transform_job_10.get("NextToken").should.be.none
|
||||
assert len(vgg_transform_job_10["TransformJobSummaries"]) == 5
|
||||
assert (
|
||||
vgg_transform_job_10["TransformJobSummaries"][-1]["TransformJobName"] == "vgg-4"
|
||||
)
|
||||
assert vgg_transform_job_10.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -331,22 +339,22 @@ def test_list_transform_jobs_paginated_with_fragmented_targets():
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
|
||||
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||||
assert transform_jobs_with_2.get("NextToken").should_not.be.none
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
|
||||
assert transform_jobs_with_2.get("NextToken") is not None
|
||||
|
||||
transform_jobs_with_2_next = client.list_transform_jobs(
|
||||
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
|
||||
)
|
||||
assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0)
|
||||
assert transform_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
assert len(transform_jobs_with_2_next["TransformJobSummaries"]) == 0
|
||||
assert transform_jobs_with_2_next.get("NextToken") is not None
|
||||
|
||||
transform_jobs_with_2_next_next = client.list_transform_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=transform_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0)
|
||||
assert transform_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]) == 0
|
||||
assert transform_jobs_with_2_next_next.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -401,7 +409,8 @@ def test_describe_unknown_transform_job():
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.describe_transform_job(TransformJobName="unknown")
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("ValidationException")
|
||||
err["Message"].should.equal(
|
||||
f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'."
|
||||
assert err["Code"] == "ValidationException"
|
||||
assert err["Message"] == (
|
||||
"Could not find transform job 'arn:aws:sagemaker:us-east-1:"
|
||||
f"{ACCOUNT_ID}:transform-job/unknown'."
|
||||
)
|
||||
|
@ -1,9 +1,8 @@
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
@ -27,9 +26,9 @@ def test_create__trial_component():
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||
)
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
assert resp["TrialComponentSummaries"][0]["TrialComponentArn"] == (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
|
||||
|
||||
@ -173,9 +172,9 @@ def test_associate_trial_component():
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
assert resp["TrialComponentArn"] == (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
@ -199,11 +198,12 @@ def test_associate_trial_component():
|
||||
TrialComponentName="does-not-exist", TrialName="does-not-exist"
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Code"].should.equal("ResourceNotFound")
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist."
|
||||
assert ex.value.response["Error"]["Code"] == "ResourceNotFound"
|
||||
assert ex.value.response["Error"]["Message"] == (
|
||||
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
":experiment-trial/does-not-exist' does not exist."
|
||||
)
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -231,9 +231,9 @@ def test_disassociate_trial_component():
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
assert resp["TrialComponentArn"] == (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
|
||||
f":experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
@ -255,9 +255,9 @@ def test_disassociate_trial_component():
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist"
|
||||
assert resp["TrialComponentArn"] == (
|
||||
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:"
|
||||
"experiment-trial-component/does-not-exist"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
|
Loading…
Reference in New Issue
Block a user