Add CloudFormation support for SageMaker Endpoint Configs and Endpoints (#3863)
* Create SageMaker EndpointConfig with CloudFormation Implement attributes for SM Endpoint Configs with CloudFormation Delete SM Endpoint Configs with CloudFormation Update SM Endpoint Configs with CloudFormation * Fix typos in SM CF Model update test and refactor helper function for CF stack outputs * Fixup weird commas in SM CF Test Configs from using black * Create SageMaker Endpoints with CloudFormation * Fix typos in SM CF update tests
This commit is contained in:
parent
f6dda54a6c
commit
9b3e932822
@ -140,7 +140,7 @@ class FakeTrainingJob(BaseObject):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeEndpoint(BaseObject):
|
class FakeEndpoint(BaseObject, CloudFormationModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
region_name,
|
region_name,
|
||||||
@ -184,8 +184,70 @@ class FakeEndpoint(BaseObject):
|
|||||||
+ endpoint_name
|
+ endpoint_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def physical_resource_id(self):
|
||||||
|
return self.endpoint_arn
|
||||||
|
|
||||||
class FakeEndpointConfig(BaseObject):
|
def get_cfn_attribute(self, attribute_name):
|
||||||
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values
|
||||||
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||||
|
|
||||||
|
if attribute_name == "EndpointName":
|
||||||
|
return self.endpoint_name
|
||||||
|
raise UnformattedGetAttTemplateException()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cloudformation_name_type():
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cloudformation_type():
|
||||||
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html
|
||||||
|
return "AWS::SageMaker::Endpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_cloudformation_json(
|
||||||
|
cls, resource_name, cloudformation_json, region_name
|
||||||
|
):
|
||||||
|
sagemaker_backend = sagemaker_backends[region_name]
|
||||||
|
|
||||||
|
# Get required properties from provided CloudFormation template
|
||||||
|
properties = cloudformation_json["Properties"]
|
||||||
|
endpoint_config_name = properties["EndpointConfigName"]
|
||||||
|
|
||||||
|
endpoint = sagemaker_backend.create_endpoint(
|
||||||
|
endpoint_name=resource_name,
|
||||||
|
endpoint_config_name=endpoint_config_name,
|
||||||
|
tags=properties.get("Tags", []),
|
||||||
|
)
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_from_cloudformation_json(
|
||||||
|
cls, original_resource, new_resource_name, cloudformation_json, region_name,
|
||||||
|
):
|
||||||
|
# Changes to the Endpoint will not change resource name
|
||||||
|
cls.delete_from_cloudformation_json(
|
||||||
|
original_resource.endpoint_arn, cloudformation_json, region_name
|
||||||
|
)
|
||||||
|
new_resource = cls.create_from_cloudformation_json(
|
||||||
|
original_resource.endpoint_name, cloudformation_json, region_name
|
||||||
|
)
|
||||||
|
return new_resource
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_from_cloudformation_json(
|
||||||
|
cls, resource_name, cloudformation_json, region_name
|
||||||
|
):
|
||||||
|
# Get actual name because resource_name actually provides the ARN
|
||||||
|
# since the Physical Resource ID is the ARN despite SageMaker
|
||||||
|
# using the name for most of its operations.
|
||||||
|
endpoint_name = resource_name.split("/")[-1]
|
||||||
|
|
||||||
|
sagemaker_backends[region_name].delete_endpoint(endpoint_name)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEndpointConfig(BaseObject, CloudFormationModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
region_name,
|
region_name,
|
||||||
@ -308,6 +370,70 @@ class FakeEndpointConfig(BaseObject):
|
|||||||
+ model_name
|
+ model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def physical_resource_id(self):
|
||||||
|
return self.endpoint_config_arn
|
||||||
|
|
||||||
|
def get_cfn_attribute(self, attribute_name):
|
||||||
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values
|
||||||
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||||
|
|
||||||
|
if attribute_name == "EndpointConfigName":
|
||||||
|
return self.endpoint_config_name
|
||||||
|
raise UnformattedGetAttTemplateException()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cloudformation_name_type():
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cloudformation_type():
|
||||||
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html
|
||||||
|
return "AWS::SageMaker::EndpointConfig"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_cloudformation_json(
|
||||||
|
cls, resource_name, cloudformation_json, region_name
|
||||||
|
):
|
||||||
|
sagemaker_backend = sagemaker_backends[region_name]
|
||||||
|
|
||||||
|
# Get required properties from provided CloudFormation template
|
||||||
|
properties = cloudformation_json["Properties"]
|
||||||
|
production_variants = properties["ProductionVariants"]
|
||||||
|
|
||||||
|
endpoint_config = sagemaker_backend.create_endpoint_config(
|
||||||
|
endpoint_config_name=resource_name,
|
||||||
|
production_variants=production_variants,
|
||||||
|
data_capture_config=properties.get("DataCaptureConfig", {}),
|
||||||
|
kms_key_id=properties.get("KmsKeyId"),
|
||||||
|
tags=properties.get("Tags", []),
|
||||||
|
)
|
||||||
|
return endpoint_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_from_cloudformation_json(
|
||||||
|
cls, original_resource, new_resource_name, cloudformation_json, region_name,
|
||||||
|
):
|
||||||
|
# Most changes to the endpoint config will change resource name for EndpointConfigs
|
||||||
|
cls.delete_from_cloudformation_json(
|
||||||
|
original_resource.endpoint_config_arn, cloudformation_json, region_name
|
||||||
|
)
|
||||||
|
new_resource = cls.create_from_cloudformation_json(
|
||||||
|
new_resource_name, cloudformation_json, region_name
|
||||||
|
)
|
||||||
|
return new_resource
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_from_cloudformation_json(
|
||||||
|
cls, resource_name, cloudformation_json, region_name
|
||||||
|
):
|
||||||
|
# Get actual name because resource_name actually provides the ARN
|
||||||
|
# since the Physical Resource ID is the ARN despite SageMaker
|
||||||
|
# using the name for most of its operations.
|
||||||
|
endpoint_config_name = resource_name.split("/")[-1]
|
||||||
|
|
||||||
|
sagemaker_backends[region_name].delete_endpoint_config(endpoint_config_name)
|
||||||
|
|
||||||
|
|
||||||
class Model(BaseObject, CloudFormationModel):
|
class Model(BaseObject, CloudFormationModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -42,6 +42,14 @@ class TestConfig:
|
|||||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def run_setup_procedure(self, sagemaker_client):
|
||||||
|
"""Provides a method to set up resources with a SageMaker client.
|
||||||
|
|
||||||
|
Note: This procedure should be called while within a `mock_sagemaker`
|
||||||
|
context so that no actual resources are created with the sagemaker_client.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NotebookInstanceTestConfig(TestConfig):
|
class NotebookInstanceTestConfig(TestConfig):
|
||||||
"""Test configuration for SageMaker Notebook Instances."""
|
"""Test configuration for SageMaker Notebook Instances."""
|
||||||
@ -186,6 +194,131 @@ class ModelTestConfig(TestConfig):
|
|||||||
if include_outputs:
|
if include_outputs:
|
||||||
template["Outputs"] = {
|
template["Outputs"] = {
|
||||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||||
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
|
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"]}},
|
||||||
}
|
}
|
||||||
return json.dumps(template)
|
return json.dumps(template)
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointConfigTestConfig(TestConfig):
|
||||||
|
"""Test configuration for SageMaker Endpoint Configs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resource_name(self):
|
||||||
|
return "TestEndpointConfig"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def describe_function_name(self):
|
||||||
|
return "describe_endpoint_config"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name_parameter(self):
|
||||||
|
return "EndpointConfigName"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arn_parameter(self):
|
||||||
|
return "EndpointConfigArn"
|
||||||
|
|
||||||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||||
|
num_production_variants = kwargs.get("num_production_variants", 1)
|
||||||
|
|
||||||
|
production_variants = [
|
||||||
|
{
|
||||||
|
"InitialInstanceCount": 1,
|
||||||
|
"InitialVariantWeight": 1,
|
||||||
|
"InstanceType": "ml.c4.xlarge",
|
||||||
|
"ModelName": self.resource_name,
|
||||||
|
"VariantName": "variant-name-{}".format(i),
|
||||||
|
}
|
||||||
|
for i in range(num_production_variants)
|
||||||
|
]
|
||||||
|
|
||||||
|
template = {
|
||||||
|
"AWSTemplateFormatVersion": "2010-09-09",
|
||||||
|
"Resources": {
|
||||||
|
self.resource_name: {
|
||||||
|
"Type": "AWS::SageMaker::EndpointConfig",
|
||||||
|
"Properties": {"ProductionVariants": production_variants},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if include_outputs:
|
||||||
|
template["Outputs"] = {
|
||||||
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||||
|
"Name": {
|
||||||
|
"Value": {"Fn::GetAtt": [self.resource_name, "EndpointConfigName"]}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return json.dumps(template)
|
||||||
|
|
||||||
|
def run_setup_procedure(self, sagemaker_client):
|
||||||
|
"""Adds Model that can be referenced in the CloudFormation template."""
|
||||||
|
|
||||||
|
sagemaker_client.create_model(
|
||||||
|
ModelName=self.resource_name,
|
||||||
|
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||||
|
PrimaryContainer={
|
||||||
|
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointTestConfig(TestConfig):
|
||||||
|
"""Test configuration for SageMaker Endpoints."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resource_name(self):
|
||||||
|
return "TestEndpoint"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def describe_function_name(self):
|
||||||
|
return "describe_endpoint"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name_parameter(self):
|
||||||
|
return "EndpointName"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arn_parameter(self):
|
||||||
|
return "EndpointArn"
|
||||||
|
|
||||||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||||
|
endpoint_config_name = kwargs.get("endpoint_config_name", self.resource_name)
|
||||||
|
|
||||||
|
template = {
|
||||||
|
"AWSTemplateFormatVersion": "2010-09-09",
|
||||||
|
"Resources": {
|
||||||
|
self.resource_name: {
|
||||||
|
"Type": "AWS::SageMaker::Endpoint",
|
||||||
|
"Properties": {"EndpointConfigName": endpoint_config_name},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if include_outputs:
|
||||||
|
template["Outputs"] = {
|
||||||
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||||
|
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "EndpointName"]}},
|
||||||
|
}
|
||||||
|
return json.dumps(template)
|
||||||
|
|
||||||
|
def run_setup_procedure(self, sagemaker_client):
|
||||||
|
"""Adds Model and Endpoint Config that can be referenced in the CloudFormation template."""
|
||||||
|
|
||||||
|
sagemaker_client.create_model(
|
||||||
|
ModelName=self.resource_name,
|
||||||
|
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||||
|
PrimaryContainer={
|
||||||
|
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sagemaker_client.create_endpoint_config(
|
||||||
|
EndpointConfigName=self.resource_name,
|
||||||
|
ProductionVariants=[
|
||||||
|
{
|
||||||
|
"InitialInstanceCount": 1,
|
||||||
|
"InitialVariantWeight": 1,
|
||||||
|
"InstanceType": "ml.c4.xlarge",
|
||||||
|
"ModelName": self.resource_name,
|
||||||
|
"VariantName": "variant-name-1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -5,25 +5,44 @@ import sure # noqa
|
|||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
from moto import mock_cloudformation, mock_sagemaker
|
from moto import mock_cloudformation, mock_sagemaker
|
||||||
|
from moto.sts.models import ACCOUNT_ID
|
||||||
|
|
||||||
from .cloudformation_test_configs import (
|
from .cloudformation_test_configs import (
|
||||||
NotebookInstanceTestConfig,
|
NotebookInstanceTestConfig,
|
||||||
NotebookInstanceLifecycleConfigTestConfig,
|
NotebookInstanceLifecycleConfigTestConfig,
|
||||||
ModelTestConfig,
|
ModelTestConfig,
|
||||||
|
EndpointConfigTestConfig,
|
||||||
|
EndpointTestConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_stack_outputs(cf_client, stack_name):
|
||||||
|
"""Returns the outputs for the first entry in describe_stacks."""
|
||||||
|
stack_description = cf_client.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||||
|
return {
|
||||||
|
output["OutputKey"]: output["OutputValue"]
|
||||||
|
for output in stack_description["Outputs"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@mock_cloudformation
|
@mock_cloudformation
|
||||||
|
@mock_sagemaker
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
NotebookInstanceTestConfig(),
|
NotebookInstanceTestConfig(),
|
||||||
NotebookInstanceLifecycleConfigTestConfig(),
|
NotebookInstanceLifecycleConfigTestConfig(),
|
||||||
ModelTestConfig(),
|
ModelTestConfig(),
|
||||||
|
EndpointConfigTestConfig(),
|
||||||
|
EndpointTestConfig(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sagemaker_cloudformation_create(test_config):
|
def test_sagemaker_cloudformation_create(test_config):
|
||||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||||
|
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
|
||||||
|
# Utilize test configuration to set-up any mock SageMaker resources
|
||||||
|
test_config.run_setup_procedure(sm)
|
||||||
|
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
cf.create_stack(
|
cf.create_stack(
|
||||||
@ -46,22 +65,23 @@ def test_sagemaker_cloudformation_create(test_config):
|
|||||||
NotebookInstanceTestConfig(),
|
NotebookInstanceTestConfig(),
|
||||||
NotebookInstanceLifecycleConfigTestConfig(),
|
NotebookInstanceLifecycleConfigTestConfig(),
|
||||||
ModelTestConfig(),
|
ModelTestConfig(),
|
||||||
|
EndpointConfigTestConfig(),
|
||||||
|
EndpointTestConfig(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sagemaker_cloudformation_get_attr(test_config):
|
def test_sagemaker_cloudformation_get_attr(test_config):
|
||||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
|
||||||
|
# Utilize test configuration to set-up any mock SageMaker resources
|
||||||
|
test_config.run_setup_procedure(sm)
|
||||||
|
|
||||||
# Create stack and get description for output values
|
# Create stack and get description for output values
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
cf.create_stack(
|
cf.create_stack(
|
||||||
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
||||||
)
|
)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Using the describe function, ensure output ARN matches resource ARN
|
# Using the describe function, ensure output ARN matches resource ARN
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
@ -81,22 +101,24 @@ def test_sagemaker_cloudformation_get_attr(test_config):
|
|||||||
"Notebook Instance Lifecycle Config does not exist",
|
"Notebook Instance Lifecycle Config does not exist",
|
||||||
),
|
),
|
||||||
(ModelTestConfig(), "Could not find model"),
|
(ModelTestConfig(), "Could not find model"),
|
||||||
|
(EndpointConfigTestConfig(), "Could not find endpoint configuration"),
|
||||||
|
(EndpointTestConfig(), "Could not find endpoint"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
|
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
|
||||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
|
||||||
|
# Utilize test configuration to set-up any mock SageMaker resources
|
||||||
|
test_config.run_setup_procedure(sm)
|
||||||
|
|
||||||
# Create stack and verify existence
|
# Create stack and verify existence
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
cf.create_stack(
|
cf.create_stack(
|
||||||
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
||||||
)
|
)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
**{test_config.name_parameter: outputs["Name"]}
|
**{test_config.name_parameter: outputs["Name"]}
|
||||||
)
|
)
|
||||||
@ -119,7 +141,7 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
|||||||
|
|
||||||
test_config = NotebookInstanceTestConfig()
|
test_config = NotebookInstanceTestConfig()
|
||||||
|
|
||||||
# Set up template for stack with initial and update instance types
|
# Set up template for stack with two different instance types
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
initial_instance_type = "ml.c4.xlarge"
|
initial_instance_type = "ml.c4.xlarge"
|
||||||
updated_instance_type = "ml.c4.4xlarge"
|
updated_instance_type = "ml.c4.4xlarge"
|
||||||
@ -132,24 +154,18 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
|||||||
|
|
||||||
# Create stack with initial template and check attributes
|
# Create stack with initial template and check attributes
|
||||||
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
initial_notebook_name = outputs["Name"]
|
initial_notebook_name = outputs["Name"]
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
**{test_config.name_parameter: initial_notebook_name}
|
**{test_config.name_parameter: initial_notebook_name}
|
||||||
)
|
)
|
||||||
initial_instance_type.should.equal(resource_description["InstanceType"])
|
initial_instance_type.should.equal(resource_description["InstanceType"])
|
||||||
|
|
||||||
# Update stack with new instance type and check attributes
|
# Update stack and check attributes
|
||||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
updated_notebook_name = outputs["Name"]
|
updated_notebook_name = outputs["Name"]
|
||||||
updated_notebook_name.should.equal(initial_notebook_name)
|
updated_notebook_name.should.equal(initial_notebook_name)
|
||||||
|
|
||||||
@ -167,7 +183,7 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
|||||||
|
|
||||||
test_config = NotebookInstanceLifecycleConfigTestConfig()
|
test_config = NotebookInstanceLifecycleConfigTestConfig()
|
||||||
|
|
||||||
# Set up template for stack with initial and update instance types
|
# Set up template for stack with two different OnCreate scripts
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
initial_on_create_script = "echo Hello World"
|
initial_on_create_script = "echo Hello World"
|
||||||
updated_on_create_script = "echo Goodbye World"
|
updated_on_create_script = "echo Goodbye World"
|
||||||
@ -180,11 +196,8 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
|||||||
|
|
||||||
# Create stack with initial template and check attributes
|
# Create stack with initial template and check attributes
|
||||||
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
initial_config_name = outputs["Name"]
|
initial_config_name = outputs["Name"]
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
**{test_config.name_parameter: initial_config_name}
|
**{test_config.name_parameter: initial_config_name}
|
||||||
@ -194,13 +207,10 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
|||||||
resource_description["OnCreate"][0]["Content"]
|
resource_description["OnCreate"][0]["Content"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update stack with new instance type and check attributes
|
# Update stack and check attributes
|
||||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
updated_config_name = outputs["Name"]
|
updated_config_name = outputs["Name"]
|
||||||
updated_config_name.should.equal(initial_config_name)
|
updated_config_name.should.equal(initial_config_name)
|
||||||
|
|
||||||
@ -221,7 +231,7 @@ def test_sagemaker_cloudformation_model_update():
|
|||||||
|
|
||||||
test_config = ModelTestConfig()
|
test_config = ModelTestConfig()
|
||||||
|
|
||||||
# Set up template for stack with initial and update instance types
|
# Set up template for stack with two different image versions
|
||||||
stack_name = "{}_stack".format(test_config.resource_name)
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
|
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
|
||||||
initial_image_version = 1
|
initial_image_version = 1
|
||||||
@ -235,32 +245,159 @@ def test_sagemaker_cloudformation_model_update():
|
|||||||
|
|
||||||
# Create stack with initial template and check attributes
|
# Create stack with initial template and check attributes
|
||||||
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
initial_model_name = outputs["Name"]
|
||||||
for output in stack_description["Outputs"]
|
|
||||||
}
|
|
||||||
inital_model_name = outputs["Name"]
|
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
**{test_config.name_parameter: inital_model_name}
|
**{test_config.name_parameter: initial_model_name}
|
||||||
)
|
)
|
||||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||||
image.format(initial_image_version)
|
image.format(initial_image_version)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update stack with new instance type and check attributes
|
# Update stack and check attributes
|
||||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
outputs = {
|
|
||||||
output["OutputKey"]: output["OutputValue"]
|
updated_model_name = outputs["Name"]
|
||||||
for output in stack_description["Outputs"]
|
updated_model_name.should_not.equal(initial_model_name)
|
||||||
}
|
|
||||||
updated_notebook_name = outputs["Name"]
|
|
||||||
updated_notebook_name.should_not.equal(inital_model_name)
|
|
||||||
|
|
||||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
**{test_config.name_parameter: updated_notebook_name}
|
**{test_config.name_parameter: updated_model_name}
|
||||||
)
|
)
|
||||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||||
image.format(updated_image_version)
|
image.format(updated_image_version)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_cloudformation
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_sagemaker_cloudformation_endpoint_config_update():
|
||||||
|
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||||
|
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
|
||||||
|
test_config = EndpointConfigTestConfig()
|
||||||
|
|
||||||
|
# Utilize test configuration to set-up any mock SageMaker resources
|
||||||
|
test_config.run_setup_procedure(sm)
|
||||||
|
|
||||||
|
# Set up template for stack with two different production variant counts
|
||||||
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
|
initial_num_production_variants = 1
|
||||||
|
updated_num_production_variants = 2
|
||||||
|
initial_template_json = test_config.get_cloudformation_template(
|
||||||
|
num_production_variants=initial_num_production_variants
|
||||||
|
)
|
||||||
|
updated_template_json = test_config.get_cloudformation_template(
|
||||||
|
num_production_variants=updated_num_production_variants
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create stack with initial template and check attributes
|
||||||
|
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||||
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
|
|
||||||
|
initial_endpoint_config_name = outputs["Name"]
|
||||||
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
|
**{test_config.name_parameter: initial_endpoint_config_name}
|
||||||
|
)
|
||||||
|
len(resource_description["ProductionVariants"]).should.equal(
|
||||||
|
initial_num_production_variants
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update stack and check attributes
|
||||||
|
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||||
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
|
|
||||||
|
updated_endpoint_config_name = outputs["Name"]
|
||||||
|
updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
|
||||||
|
|
||||||
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
|
**{test_config.name_parameter: updated_endpoint_config_name}
|
||||||
|
)
|
||||||
|
len(resource_description["ProductionVariants"]).should.equal(
|
||||||
|
updated_num_production_variants
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_cloudformation
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_sagemaker_cloudformation_endpoint_update():
|
||||||
|
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||||
|
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||||
|
|
||||||
|
test_config = EndpointTestConfig()
|
||||||
|
|
||||||
|
# Set up template for stack with two different endpoint config names
|
||||||
|
stack_name = "{}_stack".format(test_config.resource_name)
|
||||||
|
initial_endpoint_config_name = test_config.resource_name
|
||||||
|
updated_endpoint_config_name = "updated-endpoint-config-name"
|
||||||
|
initial_template_json = test_config.get_cloudformation_template(
|
||||||
|
endpoint_config_name=initial_endpoint_config_name
|
||||||
|
)
|
||||||
|
updated_template_json = test_config.get_cloudformation_template(
|
||||||
|
endpoint_config_name=updated_endpoint_config_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create SM resources and stack with initial template and check attributes
|
||||||
|
sm.create_model(
|
||||||
|
ModelName=initial_endpoint_config_name,
|
||||||
|
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||||
|
PrimaryContainer={
|
||||||
|
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sm.create_endpoint_config(
|
||||||
|
EndpointConfigName=initial_endpoint_config_name,
|
||||||
|
ProductionVariants=[
|
||||||
|
{
|
||||||
|
"InitialInstanceCount": 1,
|
||||||
|
"InitialVariantWeight": 1,
|
||||||
|
"InstanceType": "ml.c4.xlarge",
|
||||||
|
"ModelName": initial_endpoint_config_name,
|
||||||
|
"VariantName": "variant-name-1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||||
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
|
|
||||||
|
initial_endpoint_name = outputs["Name"]
|
||||||
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
|
**{test_config.name_parameter: initial_endpoint_name}
|
||||||
|
)
|
||||||
|
resource_description["EndpointConfigName"].should.match(
|
||||||
|
initial_endpoint_config_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create additional SM resources and update stack
|
||||||
|
sm.create_model(
|
||||||
|
ModelName=updated_endpoint_config_name,
|
||||||
|
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||||
|
PrimaryContainer={
|
||||||
|
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sm.create_endpoint_config(
|
||||||
|
EndpointConfigName=updated_endpoint_config_name,
|
||||||
|
ProductionVariants=[
|
||||||
|
{
|
||||||
|
"InitialInstanceCount": 1,
|
||||||
|
"InitialVariantWeight": 1,
|
||||||
|
"InstanceType": "ml.c4.xlarge",
|
||||||
|
"ModelName": updated_endpoint_config_name,
|
||||||
|
"VariantName": "variant-name-1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||||
|
outputs = _get_stack_outputs(cf, stack_name)
|
||||||
|
|
||||||
|
updated_endpoint_name = outputs["Name"]
|
||||||
|
updated_endpoint_name.should.equal(initial_endpoint_name)
|
||||||
|
|
||||||
|
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||||
|
**{test_config.name_parameter: updated_endpoint_name}
|
||||||
|
)
|
||||||
|
resource_description["EndpointConfigName"].should.match(
|
||||||
|
updated_endpoint_config_name
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user