moto/tests/test_sagemaker/test_sagemaker_cloudformation.py

404 lines
15 KiB
Python

import boto3
import pytest
import sure # noqa # pylint: disable=unused-import
from botocore.exceptions import ClientError
from moto import mock_cloudformation, mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from .cloudformation_test_configs import (
NotebookInstanceTestConfig,
NotebookInstanceLifecycleConfigTestConfig,
ModelTestConfig,
EndpointConfigTestConfig,
EndpointTestConfig,
)
def _get_stack_outputs(cf_client, stack_name):
"""Returns the outputs for the first entry in describe_stacks."""
stack_description = cf_client.describe_stacks(StackName=stack_name)["Stacks"][0]
return {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
@mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize(
"test_config",
[
NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(),
EndpointConfigTestConfig(),
EndpointTestConfig(),
],
)
def test_sagemaker_cloudformation_create(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
stack_name = f"{test_config.resource_name}_stack"
cf.create_stack(
StackName=stack_name,
TemplateBody=test_config.get_cloudformation_template(include_outputs=False),
)
provisioned_resource = cf.list_stack_resources(StackName=stack_name)[
"StackResourceSummaries"
][0]
provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name)
len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0)
@mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize(
"test_config",
[
NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(),
EndpointConfigTestConfig(),
EndpointTestConfig(),
],
)
def test_sagemaker_cloudformation_get_attr(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Create stack and get description for output values
stack_name = f"{test_config.resource_name}_stack"
cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
)
outputs = _get_stack_outputs(cf, stack_name)
# Using the describe function, ensure output ARN matches resource ARN
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
@mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize(
"test_config,error_message",
[
(NotebookInstanceTestConfig(), "RecordNotFound"),
(
NotebookInstanceLifecycleConfigTestConfig(),
"Notebook Instance Lifecycle Config does not exist",
),
(ModelTestConfig(), "Could not find model"),
(EndpointConfigTestConfig(), "Could not find endpoint configuration"),
(EndpointTestConfig(), "Could not find endpoint"),
],
)
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Create stack and verify existence
stack_name = f"{test_config.resource_name}_stack"
cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
)
outputs = _get_stack_outputs(cf, stack_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
# Delete the stack and verify resource has also been deleted
cf.delete_stack(StackName=stack_name)
with pytest.raises(ClientError) as ce:
getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
ce.value.response["Error"]["Message"].should.contain(error_message)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_notebook_instance_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = NotebookInstanceTestConfig()
# Set up template for stack with two different instance types
stack_name = f"{test_config.resource_name}_stack"
initial_instance_type = "ml.c4.xlarge"
updated_instance_type = "ml.c4.4xlarge"
initial_template_json = test_config.get_cloudformation_template(
instance_type=initial_instance_type
)
updated_template_json = test_config.get_cloudformation_template(
instance_type=updated_instance_type
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_notebook_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_notebook_name}
)
initial_instance_type.should.equal(resource_description["InstanceType"])
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_notebook_name = outputs["Name"]
updated_notebook_name.should.equal(initial_notebook_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_notebook_name}
)
updated_instance_type.should.equal(resource_description["InstanceType"])
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = NotebookInstanceLifecycleConfigTestConfig()
# Set up template for stack with two different OnCreate scripts
stack_name = f"{test_config.resource_name}_stack"
initial_on_create_script = "echo Hello World"
updated_on_create_script = "echo Goodbye World"
initial_template_json = test_config.get_cloudformation_template(
on_create=initial_on_create_script
)
updated_template_json = test_config.get_cloudformation_template(
on_create=updated_on_create_script
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_config_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_config_name}
)
len(resource_description["OnCreate"]).should.equal(1)
initial_on_create_script.should.equal(
resource_description["OnCreate"][0]["Content"]
)
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_config_name = outputs["Name"]
updated_config_name.should.equal(initial_config_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_config_name}
)
len(resource_description["OnCreate"]).should.equal(1)
updated_on_create_script.should.equal(
resource_description["OnCreate"][0]["Content"]
)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_model_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = ModelTestConfig()
# Set up template for stack with two different image versions
stack_name = f"{test_config.resource_name}_stack"
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
initial_image_version = 1
updated_image_version = 2
initial_template_json = test_config.get_cloudformation_template(
image=image.format(initial_image_version)
)
updated_template_json = test_config.get_cloudformation_template(
image=image.format(updated_image_version)
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_model_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_model_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(initial_image_version)
)
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_model_name = outputs["Name"]
updated_model_name.should_not.equal(initial_model_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_model_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(updated_image_version)
)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_endpoint_config_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = EndpointConfigTestConfig()
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Set up template for stack with two different production variant counts
stack_name = f"{test_config.resource_name}_stack"
initial_num_production_variants = 1
updated_num_production_variants = 2
initial_template_json = test_config.get_cloudformation_template(
num_production_variants=initial_num_production_variants
)
updated_template_json = test_config.get_cloudformation_template(
num_production_variants=updated_num_production_variants
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_endpoint_config_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
initial_num_production_variants
)
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_config_name = outputs["Name"]
updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
updated_num_production_variants
)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_endpoint_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = EndpointTestConfig()
# Set up template for stack with two different endpoint config names
stack_name = f"{test_config.resource_name}_stack"
initial_endpoint_config_name = test_config.resource_name
updated_endpoint_config_name = "updated-endpoint-config-name"
initial_template_json = test_config.get_cloudformation_template(
endpoint_config_name=initial_endpoint_config_name
)
updated_template_json = test_config.get_cloudformation_template(
endpoint_config_name=updated_endpoint_config_name
)
# Create SM resources and stack with initial template and check attributes
sm.create_model(
ModelName=initial_endpoint_config_name,
ExecutionRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole",
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
sm.create_endpoint_config(
EndpointConfigName=initial_endpoint_config_name,
ProductionVariants=[
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": initial_endpoint_config_name,
"VariantName": "variant-name-1",
},
],
)
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_endpoint_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
initial_endpoint_config_name
)
# Create additional SM resources and update stack
sm.create_model(
ModelName=updated_endpoint_config_name,
ExecutionRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole",
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
sm.create_endpoint_config(
EndpointConfigName=updated_endpoint_config_name,
ProductionVariants=[
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": updated_endpoint_config_name,
"VariantName": "variant-name-1",
},
],
)
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_name = outputs["Name"]
updated_endpoint_name.should.equal(initial_endpoint_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
updated_endpoint_config_name
)