Add CloudFormation support for SageMaker Endpoint Configs and Endpoints (#3863)

* Create SageMaker EndpointConfig with CloudFormation

Implement attributes for SM Endpoint Configs with CloudFormation

Delete SM Endpoint Configs with CloudFormation

Update SM Endpoint Configs with CloudFormation

* Fix typos in SM CF Model update test and refactor helper function for CF stack outputs

* Fixup weird commas in SM CF Test Configs from using black

* Create SageMaker Endpoints with CloudFormation

* Fix typos in SM CF update tests
This commit is contained in:
Zach Churchill 2021-04-17 08:49:46 -04:00 committed by GitHub
parent f6dda54a6c
commit 9b3e932822
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 450 additions and 54 deletions

View File

@ -140,7 +140,7 @@ class FakeTrainingJob(BaseObject):
) )
class FakeEndpoint(BaseObject): class FakeEndpoint(BaseObject, CloudFormationModel):
def __init__( def __init__(
self, self,
region_name, region_name,
@ -184,8 +184,70 @@ class FakeEndpoint(BaseObject):
+ endpoint_name + endpoint_name
) )
@property
def physical_resource_id(self):
return self.endpoint_arn
class FakeEndpointConfig(BaseObject): def get_cfn_attribute(self, attribute_name):
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "EndpointName":
return self.endpoint_name
raise UnformattedGetAttTemplateException()
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html
return "AWS::SageMaker::Endpoint"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
sagemaker_backend = sagemaker_backends[region_name]
# Get required properties from provided CloudFormation template
properties = cloudformation_json["Properties"]
endpoint_config_name = properties["EndpointConfigName"]
endpoint = sagemaker_backend.create_endpoint(
endpoint_name=resource_name,
endpoint_config_name=endpoint_config_name,
tags=properties.get("Tags", []),
)
return endpoint
@classmethod
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name,
):
# Changes to the Endpoint will not change resource name
cls.delete_from_cloudformation_json(
original_resource.endpoint_arn, cloudformation_json, region_name
)
new_resource = cls.create_from_cloudformation_json(
original_resource.endpoint_name, cloudformation_json, region_name
)
return new_resource
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
# Get actual name because resource_name actually provides the ARN
# since the Physical Resource ID is the ARN despite SageMaker
# using the name for most of its operations.
endpoint_name = resource_name.split("/")[-1]
sagemaker_backends[region_name].delete_endpoint(endpoint_name)
class FakeEndpointConfig(BaseObject, CloudFormationModel):
def __init__( def __init__(
self, self,
region_name, region_name,
@ -308,6 +370,70 @@ class FakeEndpointConfig(BaseObject):
+ model_name + model_name
) )
@property
def physical_resource_id(self):
return self.endpoint_config_arn
def get_cfn_attribute(self, attribute_name):
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "EndpointConfigName":
return self.endpoint_config_name
raise UnformattedGetAttTemplateException()
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html
return "AWS::SageMaker::EndpointConfig"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
sagemaker_backend = sagemaker_backends[region_name]
# Get required properties from provided CloudFormation template
properties = cloudformation_json["Properties"]
production_variants = properties["ProductionVariants"]
endpoint_config = sagemaker_backend.create_endpoint_config(
endpoint_config_name=resource_name,
production_variants=production_variants,
data_capture_config=properties.get("DataCaptureConfig", {}),
kms_key_id=properties.get("KmsKeyId"),
tags=properties.get("Tags", []),
)
return endpoint_config
@classmethod
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name,
):
# Most changes to the endpoint config will change resource name for EndpointConfigs
cls.delete_from_cloudformation_json(
original_resource.endpoint_config_arn, cloudformation_json, region_name
)
new_resource = cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
return new_resource
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
# Get actual name because resource_name actually provides the ARN
# since the Physical Resource ID is the ARN despite SageMaker
# using the name for most of its operations.
endpoint_config_name = resource_name.split("/")[-1]
sagemaker_backends[region_name].delete_endpoint_config(endpoint_config_name)
class Model(BaseObject, CloudFormationModel): class Model(BaseObject, CloudFormationModel):
def __init__( def __init__(

View File

@ -42,6 +42,14 @@ class TestConfig:
def get_cloudformation_template(self, include_outputs=True, **kwargs): def get_cloudformation_template(self, include_outputs=True, **kwargs):
pass pass
def run_setup_procedure(self, sagemaker_client):
"""Provides a method to set up resources with a SageMaker client.
Note: This procedure should be called while within a `mock_sagemaker`
context so that no actual resources are created with the sagemaker_client.
"""
pass
class NotebookInstanceTestConfig(TestConfig): class NotebookInstanceTestConfig(TestConfig):
"""Test configuration for SageMaker Notebook Instances.""" """Test configuration for SageMaker Notebook Instances."""
@ -186,6 +194,131 @@ class ModelTestConfig(TestConfig):
if include_outputs: if include_outputs:
template["Outputs"] = { template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}}, "Arn": {"Value": {"Ref": self.resource_name}},
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}}, "Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"]}},
} }
return json.dumps(template) return json.dumps(template)
class EndpointConfigTestConfig(TestConfig):
"""Test configuration for SageMaker Endpoint Configs."""
@property
def resource_name(self):
return "TestEndpointConfig"
@property
def describe_function_name(self):
return "describe_endpoint_config"
@property
def name_parameter(self):
return "EndpointConfigName"
@property
def arn_parameter(self):
return "EndpointConfigArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
num_production_variants = kwargs.get("num_production_variants", 1)
production_variants = [
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": self.resource_name,
"VariantName": "variant-name-{}".format(i),
}
for i in range(num_production_variants)
]
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::EndpointConfig",
"Properties": {"ProductionVariants": production_variants},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {
"Value": {"Fn::GetAtt": [self.resource_name, "EndpointConfigName"]}
},
}
return json.dumps(template)
def run_setup_procedure(self, sagemaker_client):
"""Adds Model that can be referenced in the CloudFormation template."""
sagemaker_client.create_model(
ModelName=self.resource_name,
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
class EndpointTestConfig(TestConfig):
"""Test configuration for SageMaker Endpoints."""
@property
def resource_name(self):
return "TestEndpoint"
@property
def describe_function_name(self):
return "describe_endpoint"
@property
def name_parameter(self):
return "EndpointName"
@property
def arn_parameter(self):
return "EndpointArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
endpoint_config_name = kwargs.get("endpoint_config_name", self.resource_name)
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::Endpoint",
"Properties": {"EndpointConfigName": endpoint_config_name},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "EndpointName"]}},
}
return json.dumps(template)
def run_setup_procedure(self, sagemaker_client):
"""Adds Model and Endpoint Config that can be referenced in the CloudFormation template."""
sagemaker_client.create_model(
ModelName=self.resource_name,
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
sagemaker_client.create_endpoint_config(
EndpointConfigName=self.resource_name,
ProductionVariants=[
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": self.resource_name,
"VariantName": "variant-name-1",
},
],
)

View File

@ -5,25 +5,44 @@ import sure # noqa
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_cloudformation, mock_sagemaker from moto import mock_cloudformation, mock_sagemaker
from moto.sts.models import ACCOUNT_ID
from .cloudformation_test_configs import ( from .cloudformation_test_configs import (
NotebookInstanceTestConfig, NotebookInstanceTestConfig,
NotebookInstanceLifecycleConfigTestConfig, NotebookInstanceLifecycleConfigTestConfig,
ModelTestConfig, ModelTestConfig,
EndpointConfigTestConfig,
EndpointTestConfig,
) )
def _get_stack_outputs(cf_client, stack_name):
"""Returns the outputs for the first entry in describe_stacks."""
stack_description = cf_client.describe_stacks(StackName=stack_name)["Stacks"][0]
return {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
@mock_cloudformation @mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_config", "test_config",
[ [
NotebookInstanceTestConfig(), NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(), NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(), ModelTestConfig(),
EndpointConfigTestConfig(),
EndpointTestConfig(),
], ],
) )
def test_sagemaker_cloudformation_create(test_config): def test_sagemaker_cloudformation_create(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1") cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack( cf.create_stack(
@ -46,22 +65,23 @@ def test_sagemaker_cloudformation_create(test_config):
NotebookInstanceTestConfig(), NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(), NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(), ModelTestConfig(),
EndpointConfigTestConfig(),
EndpointTestConfig(),
], ],
) )
def test_sagemaker_cloudformation_get_attr(test_config): def test_sagemaker_cloudformation_get_attr(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1") cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Create stack and get description for output values # Create stack and get description for output values
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack( cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template() StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
) )
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
# Using the describe function, ensure output ARN matches resource ARN # Using the describe function, ensure output ARN matches resource ARN
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
@ -81,22 +101,24 @@ def test_sagemaker_cloudformation_get_attr(test_config):
"Notebook Instance Lifecycle Config does not exist", "Notebook Instance Lifecycle Config does not exist",
), ),
(ModelTestConfig(), "Could not find model"), (ModelTestConfig(), "Could not find model"),
(EndpointConfigTestConfig(), "Could not find endpoint configuration"),
(EndpointTestConfig(), "Could not find endpoint"),
], ],
) )
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message): def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
cf = boto3.client("cloudformation", region_name="us-east-1") cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1")
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Create stack and verify existence # Create stack and verify existence
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack( cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template() StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
) )
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]} **{test_config.name_parameter: outputs["Name"]}
) )
@ -119,7 +141,7 @@ def test_sagemaker_cloudformation_notebook_instance_update():
test_config = NotebookInstanceTestConfig() test_config = NotebookInstanceTestConfig()
# Set up template for stack with initial and update instance types # Set up template for stack with two different instance types
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
initial_instance_type = "ml.c4.xlarge" initial_instance_type = "ml.c4.xlarge"
updated_instance_type = "ml.c4.4xlarge" updated_instance_type = "ml.c4.4xlarge"
@ -132,24 +154,18 @@ def test_sagemaker_cloudformation_notebook_instance_update():
# Create stack with initial template and check attributes # Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
initial_notebook_name = outputs["Name"] initial_notebook_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_notebook_name} **{test_config.name_parameter: initial_notebook_name}
) )
initial_instance_type.should.equal(resource_description["InstanceType"]) initial_instance_type.should.equal(resource_description["InstanceType"])
# Update stack with new instance type and check attributes # Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
updated_notebook_name = outputs["Name"] updated_notebook_name = outputs["Name"]
updated_notebook_name.should.equal(initial_notebook_name) updated_notebook_name.should.equal(initial_notebook_name)
@ -167,7 +183,7 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
test_config = NotebookInstanceLifecycleConfigTestConfig() test_config = NotebookInstanceLifecycleConfigTestConfig()
# Set up template for stack with initial and update instance types # Set up template for stack with two different OnCreate scripts
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
initial_on_create_script = "echo Hello World" initial_on_create_script = "echo Hello World"
updated_on_create_script = "echo Goodbye World" updated_on_create_script = "echo Goodbye World"
@ -180,11 +196,8 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
# Create stack with initial template and check attributes # Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
initial_config_name = outputs["Name"] initial_config_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_config_name} **{test_config.name_parameter: initial_config_name}
@ -194,13 +207,10 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
resource_description["OnCreate"][0]["Content"] resource_description["OnCreate"][0]["Content"]
) )
# Update stack with new instance type and check attributes # Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
updated_config_name = outputs["Name"] updated_config_name = outputs["Name"]
updated_config_name.should.equal(initial_config_name) updated_config_name.should.equal(initial_config_name)
@ -221,7 +231,7 @@ def test_sagemaker_cloudformation_model_update():
test_config = ModelTestConfig() test_config = ModelTestConfig()
# Set up template for stack with initial and update instance types # Set up template for stack with two different image versions
stack_name = "{}_stack".format(test_config.resource_name) stack_name = "{}_stack".format(test_config.resource_name)
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}" image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
initial_image_version = 1 initial_image_version = 1
@ -235,32 +245,159 @@ def test_sagemaker_cloudformation_model_update():
# Create stack with initial template and check attributes # Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"] initial_model_name = outputs["Name"]
for output in stack_description["Outputs"]
}
inital_model_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: inital_model_name} **{test_config.name_parameter: initial_model_name}
) )
resource_description["PrimaryContainer"]["Image"].should.equal( resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(initial_image_version) image.format(initial_image_version)
) )
# Update stack with new instance type and check attributes # Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = _get_stack_outputs(cf, stack_name)
outputs = {
output["OutputKey"]: output["OutputValue"] updated_model_name = outputs["Name"]
for output in stack_description["Outputs"] updated_model_name.should_not.equal(initial_model_name)
}
updated_notebook_name = outputs["Name"]
updated_notebook_name.should_not.equal(inital_model_name)
resource_description = getattr(sm, test_config.describe_function_name)( resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_notebook_name} **{test_config.name_parameter: updated_model_name}
) )
resource_description["PrimaryContainer"]["Image"].should.equal( resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(updated_image_version) image.format(updated_image_version)
) )
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_endpoint_config_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = EndpointConfigTestConfig()
# Utilize test configuration to set-up any mock SageMaker resources
test_config.run_setup_procedure(sm)
# Set up template for stack with two different production variant counts
stack_name = "{}_stack".format(test_config.resource_name)
initial_num_production_variants = 1
updated_num_production_variants = 2
initial_template_json = test_config.get_cloudformation_template(
num_production_variants=initial_num_production_variants
)
updated_template_json = test_config.get_cloudformation_template(
num_production_variants=updated_num_production_variants
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_endpoint_config_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
initial_num_production_variants
)
# Update stack and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_config_name = outputs["Name"]
updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_config_name}
)
len(resource_description["ProductionVariants"]).should.equal(
updated_num_production_variants
)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_endpoint_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = EndpointTestConfig()
# Set up template for stack with two different endpoint config names
stack_name = "{}_stack".format(test_config.resource_name)
initial_endpoint_config_name = test_config.resource_name
updated_endpoint_config_name = "updated-endpoint-config-name"
initial_template_json = test_config.get_cloudformation_template(
endpoint_config_name=initial_endpoint_config_name
)
updated_template_json = test_config.get_cloudformation_template(
endpoint_config_name=updated_endpoint_config_name
)
# Create SM resources and stack with initial template and check attributes
sm.create_model(
ModelName=initial_endpoint_config_name,
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
sm.create_endpoint_config(
EndpointConfigName=initial_endpoint_config_name,
ProductionVariants=[
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": initial_endpoint_config_name,
"VariantName": "variant-name-1",
},
],
)
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
outputs = _get_stack_outputs(cf, stack_name)
initial_endpoint_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
initial_endpoint_config_name
)
# Create additional SM resources and update stack
sm.create_model(
ModelName=updated_endpoint_config_name,
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
PrimaryContainer={
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
},
)
sm.create_endpoint_config(
EndpointConfigName=updated_endpoint_config_name,
ProductionVariants=[
{
"InitialInstanceCount": 1,
"InitialVariantWeight": 1,
"InstanceType": "ml.c4.xlarge",
"ModelName": updated_endpoint_config_name,
"VariantName": "variant-name-1",
},
],
)
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
outputs = _get_stack_outputs(cf, stack_name)
updated_endpoint_name = outputs["Name"]
updated_endpoint_name.should.equal(initial_endpoint_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_endpoint_name}
)
resource_description["EndpointConfigName"].should.match(
updated_endpoint_config_name
)