Add CloudFormation support for SageMaker Models (#3861)

* Create a formal interface for SM Cloudformation test configurations

* Create SageMaker Models with CloudFormation

* Utilize six for adding metaclass to TestConfig

* Update SM backend to return Model objects instead of response objects
This commit is contained in:
Zach Churchill 2021-04-16 10:23:05 -04:00 committed by GitHub
parent 0b11b0c716
commit f6dda54a6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 388 additions and 162 deletions

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
import os
from boto3 import Session
from copy import deepcopy
from datetime import datetime
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
@ -310,7 +309,7 @@ class FakeEndpointConfig(BaseObject):
)
class Model(BaseObject):
class Model(BaseObject, CloudFormationModel):
def __init__(
self,
region_name,
@ -353,6 +352,72 @@ class Model(BaseObject):
+ model_name
)
@property
def physical_resource_id(self):
return self.model_arn
def get_cfn_attribute(self, attribute_name):
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html#aws-resource-sagemaker-model-return-values
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "ModelName":
return self.model_name
raise UnformattedGetAttTemplateException()
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html
return "AWS::SageMaker::Model"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
sagemaker_backend = sagemaker_backends[region_name]
# Get required properties from provided CloudFormation template
properties = cloudformation_json["Properties"]
execution_role_arn = properties["ExecutionRoleArn"]
primary_container = properties["PrimaryContainer"]
model = sagemaker_backend.create_model(
ModelName=resource_name,
ExecutionRoleArn=execution_role_arn,
PrimaryContainer=primary_container,
VpcConfig=properties.get("VpcConfig", {}),
Containers=properties.get("Containers", []),
Tags=properties.get("Tags", []),
)
return model
@classmethod
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name,
):
# Most changes to the model will change resource name for Models
cls.delete_from_cloudformation_json(
original_resource.model_arn, cloudformation_json, region_name
)
new_resource = cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
return new_resource
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
# Get actual name because resource_name actually provides the ARN
# since the Physical Resource ID is the ARN despite SageMaker
# using the name for most of its operations.
model_name = resource_name.split("/")[-1]
sagemaker_backends[region_name].delete_model(model_name)
class VpcConfig(BaseObject):
def __init__(self, security_group_ids, subnets):
@ -699,23 +764,19 @@ class SageMakerModelBackend(BaseBackend):
)
self._models[kwargs.get("ModelName")] = model_obj
return model_obj.response_create
return model_obj
def describe_model(self, model_name=None):
model = self._models.get(model_name)
if model:
return model.response_object
return model
message = "Could not find model '{}'.".format(
Model.arn_for_model_name(model_name, self.region_name)
)
raise ValidationError(message=message)
def list_models(self):
models = []
for model in self._models.values():
model_response = deepcopy(model.response_object)
models.append(model_response)
return {"Models": models}
return self._models.values()
def delete_model(self, model_name=None):
for model in self._models.values():

View File

@ -22,12 +22,12 @@ class SageMakerResponse(BaseResponse):
def describe_model(self):
model_name = self._get_param("ModelName")
response = self.sagemaker_backend.describe_model(model_name)
return json.dumps(response)
model = self.sagemaker_backend.describe_model(model_name)
return json.dumps(model.response_object)
def create_model(self):
response = self.sagemaker_backend.create_model(**self.request_params)
return json.dumps(response)
model = self.sagemaker_backend.create_model(**self.request_params)
return json.dumps(model.response_create)
def delete_model(self):
model_name = self._get_param("ModelName")
@ -35,8 +35,8 @@ class SageMakerResponse(BaseResponse):
return json.dumps(response)
def list_models(self):
response = self.sagemaker_backend.list_models(**self.request_params)
return json.dumps(response)
models = self.sagemaker_backend.list_models(**self.request_params)
return json.dumps({"Models": [model.response_object for model in models]})
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)

View File

@ -0,0 +1,191 @@
import json
from abc import ABCMeta, abstractmethod
import six
from moto.sts.models import ACCOUNT_ID
@six.add_metaclass(ABCMeta)
class TestConfig:
"""Provides the interface to use for creating test configurations.
This class will provide the interface for what information will be
needed for the SageMaker CloudFormation tests. Ultimately, this will
improve the readability of the tests in `test_sagemaker_cloudformation.py`
because it will reduce the amount of information we pass through the
`pytest.mark.parametrize` decorator.
"""
@property
@abstractmethod
def resource_name(self):
pass
@property
@abstractmethod
def describe_function_name(self):
pass
@property
@abstractmethod
def name_parameter(self):
pass
@property
@abstractmethod
def arn_parameter(self):
pass
@abstractmethod
def get_cloudformation_template(self, include_outputs=True, **kwargs):
pass
class NotebookInstanceTestConfig(TestConfig):
"""Test configuration for SageMaker Notebook Instances."""
@property
def resource_name(self):
return "TestNotebook"
@property
def describe_function_name(self):
return "describe_notebook_instance"
@property
def name_parameter(self):
return "NotebookInstanceName"
@property
def arn_parameter(self):
return "NotebookInstanceArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
instance_type = kwargs.get("instance_type", "ml.c4.xlarge")
role_arn = kwargs.get(
"role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
)
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::NotebookInstance",
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {
"Value": {
"Fn::GetAtt": [self.resource_name, "NotebookInstanceName"]
}
},
}
return json.dumps(template)
class NotebookInstanceLifecycleConfigTestConfig(TestConfig):
"""Test configuration for SageMaker Notebook Instance Lifecycle Configs."""
@property
def resource_name(self):
return "TestNotebookLifecycleConfig"
@property
def describe_function_name(self):
return "describe_notebook_instance_lifecycle_config"
@property
def name_parameter(self):
return "NotebookInstanceLifecycleConfigName"
@property
def arn_parameter(self):
return "NotebookInstanceLifecycleConfigArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
on_create = kwargs.get("on_create")
on_start = kwargs.get("on_start")
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
"Properties": {},
},
},
}
if on_create is not None:
template["Resources"][self.resource_name]["Properties"]["OnCreate"] = [
{"Content": on_create}
]
if on_start is not None:
template["Resources"][self.resource_name]["Properties"]["OnStart"] = [
{"Content": on_start}
]
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {
"Value": {
"Fn::GetAtt": [
self.resource_name,
"NotebookInstanceLifecycleConfigName",
]
}
},
}
return json.dumps(template)
class ModelTestConfig(TestConfig):
"""Test configuration for SageMaker Models."""
@property
def resource_name(self):
return "TestModel"
@property
def describe_function_name(self):
return "describe_model"
@property
def name_parameter(self):
return "ModelName"
@property
def arn_parameter(self):
return "ModelArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
execution_role_arn = kwargs.get(
"execution_role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
)
image = kwargs.get(
"image", "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1"
)
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::Model",
"Properties": {
"ExecutionRoleArn": execution_role_arn,
"PrimaryContainer": {"Image": image,},
},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
}
return json.dumps(template)

View File

@ -1,4 +1,3 @@
import json
import boto3
import pytest
@ -6,128 +5,58 @@ import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_cloudformation, mock_sagemaker
from moto.sts.models import ACCOUNT_ID
def _get_notebook_instance_template_string(
resource_name="TestNotebook",
instance_type="ml.c4.xlarge",
role_arn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
include_outputs=True,
):
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
resource_name: {
"Type": "AWS::SageMaker::NotebookInstance",
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": resource_name}},
"Name": {"Value": {"Fn::GetAtt": [resource_name, "NotebookInstanceName"]}},
}
return json.dumps(template)
def _get_notebook_instance_lifecycle_config_template_string(
resource_name="TestConfig", on_create=None, on_start=None, include_outputs=True,
):
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
resource_name: {
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
"Properties": {},
},
},
}
if on_create is not None:
template["Resources"][resource_name]["Properties"]["OnCreate"] = [
{"Content": on_create}
]
if on_start is not None:
template["Resources"][resource_name]["Properties"]["OnStart"] = [
{"Content": on_start}
]
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": resource_name}},
"Name": {
"Value": {
"Fn::GetAtt": [
resource_name,
"NotebookInstanceLifecycleConfigName",
]
}
},
}
return json.dumps(template)
from .cloudformation_test_configs import (
NotebookInstanceTestConfig,
NotebookInstanceLifecycleConfigTestConfig,
ModelTestConfig,
)
@mock_cloudformation
@pytest.mark.parametrize(
"stack_name,resource_name,template",
"test_config",
[
(
"test_sagemaker_notebook_instance",
"TestNotebookInstance",
_get_notebook_instance_template_string(
resource_name="TestNotebookInstance", include_outputs=False
),
),
(
"test_sagemaker_notebook_instance_lifecycle_config",
"TestNotebookInstanceLifecycleConfig",
_get_notebook_instance_lifecycle_config_template_string(
resource_name="TestNotebookInstanceLifecycleConfig",
include_outputs=False,
),
),
NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(),
],
)
def test_sagemaker_cloudformation_create(stack_name, resource_name, template):
def test_sagemaker_cloudformation_create(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1")
cf.create_stack(StackName=stack_name, TemplateBody=template)
stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack(
StackName=stack_name,
TemplateBody=test_config.get_cloudformation_template(include_outputs=False),
)
provisioned_resource = cf.list_stack_resources(StackName=stack_name)[
"StackResourceSummaries"
][0]
provisioned_resource["LogicalResourceId"].should.equal(resource_name)
provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name)
len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0)
@mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize(
"stack_name,template,describe_function_name,name_parameter,arn_parameter",
"test_config",
[
(
"test_sagemaker_notebook_instance",
_get_notebook_instance_template_string(),
"describe_notebook_instance",
"NotebookInstanceName",
"NotebookInstanceArn",
),
(
"test_sagemaker_notebook_instance_lifecycle_config",
_get_notebook_instance_lifecycle_config_template_string(),
"describe_notebook_instance_lifecycle_config",
"NotebookInstanceLifecycleConfigName",
"NotebookInstanceLifecycleConfigArn",
),
NotebookInstanceTestConfig(),
NotebookInstanceLifecycleConfigTestConfig(),
ModelTestConfig(),
],
)
def test_sagemaker_cloudformation_get_attr(
stack_name, template, describe_function_name, name_parameter, arn_parameter
):
def test_sagemaker_cloudformation_get_attr(test_config):
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Create stack and get description for output values
cf.create_stack(StackName=stack_name, TemplateBody=template)
stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
outputs = {
output["OutputKey"]: output["OutputValue"]
@ -135,62 +64,50 @@ def test_sagemaker_cloudformation_get_attr(
}
# Using the describe function, ensure output ARN matches resource ARN
resource_description = getattr(sm, describe_function_name)(
**{name_parameter: outputs["Name"]}
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[arn_parameter])
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
@mock_cloudformation
@mock_sagemaker
@pytest.mark.parametrize(
"stack_name,template,describe_function_name,name_parameter,arn_parameter,error_message",
"test_config,error_message",
[
(NotebookInstanceTestConfig(), "RecordNotFound"),
(
"test_sagemaker_notebook_instance",
_get_notebook_instance_template_string(),
"describe_notebook_instance",
"NotebookInstanceName",
"NotebookInstanceArn",
"RecordNotFound",
),
(
"test_sagemaker_notebook_instance_lifecycle_config",
_get_notebook_instance_lifecycle_config_template_string(),
"describe_notebook_instance_lifecycle_config",
"NotebookInstanceLifecycleConfigName",
"NotebookInstanceLifecycleConfigArn",
NotebookInstanceLifecycleConfigTestConfig(),
"Notebook Instance Lifecycle Config does not exist",
),
(ModelTestConfig(), "Could not find model"),
],
)
def test_sagemaker_cloudformation_notebook_instance_delete(
stack_name,
template,
describe_function_name,
name_parameter,
arn_parameter,
error_message,
):
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
# Create stack and verify existence
cf.create_stack(StackName=stack_name, TemplateBody=template)
stack_name = "{}_stack".format(test_config.resource_name)
cf.create_stack(
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
resource_description = getattr(sm, describe_function_name)(
**{name_parameter: outputs["Name"]}
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
outputs["Arn"].should.equal(resource_description[arn_parameter])
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
# Delete the stack and verify resource has also been deleted
cf.delete_stack(StackName=stack_name)
with pytest.raises(ClientError) as ce:
getattr(sm, describe_function_name)(**{name_parameter: outputs["Name"]})
getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: outputs["Name"]}
)
ce.value.response["Error"]["Message"].should.contain(error_message)
@ -200,14 +117,16 @@ def test_sagemaker_cloudformation_notebook_instance_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = NotebookInstanceTestConfig()
# Set up template for stack with initial and update instance types
stack_name = "test_sagemaker_notebook_instance"
stack_name = "{}_stack".format(test_config.resource_name)
initial_instance_type = "ml.c4.xlarge"
updated_instance_type = "ml.c4.4xlarge"
initial_template_json = _get_notebook_instance_template_string(
initial_template_json = test_config.get_cloudformation_template(
instance_type=initial_instance_type
)
updated_template_json = _get_notebook_instance_template_string(
updated_template_json = test_config.get_cloudformation_template(
instance_type=updated_instance_type
)
@ -219,10 +138,10 @@ def test_sagemaker_cloudformation_notebook_instance_update():
for output in stack_description["Outputs"]
}
initial_notebook_name = outputs["Name"]
notebook_instance_description = sm.describe_notebook_instance(
NotebookInstanceName=initial_notebook_name,
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_notebook_name}
)
initial_instance_type.should.equal(notebook_instance_description["InstanceType"])
initial_instance_type.should.equal(resource_description["InstanceType"])
# Update stack with new instance type and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
@ -234,10 +153,10 @@ def test_sagemaker_cloudformation_notebook_instance_update():
updated_notebook_name = outputs["Name"]
updated_notebook_name.should.equal(initial_notebook_name)
notebook_instance_description = sm.describe_notebook_instance(
NotebookInstanceName=updated_notebook_name,
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_notebook_name}
)
updated_instance_type.should.equal(notebook_instance_description["InstanceType"])
updated_instance_type.should.equal(resource_description["InstanceType"])
@mock_cloudformation
@ -246,14 +165,16 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = NotebookInstanceLifecycleConfigTestConfig()
# Set up template for stack with initial and update instance types
stack_name = "test_sagemaker_notebook_instance_lifecycle_config"
stack_name = "{}_stack".format(test_config.resource_name)
initial_on_create_script = "echo Hello World"
updated_on_create_script = "echo Goodbye World"
initial_template_json = _get_notebook_instance_lifecycle_config_template_string(
initial_template_json = test_config.get_cloudformation_template(
on_create=initial_on_create_script
)
updated_template_json = _get_notebook_instance_lifecycle_config_template_string(
updated_template_json = test_config.get_cloudformation_template(
on_create=updated_on_create_script
)
@ -265,12 +186,12 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
for output in stack_description["Outputs"]
}
initial_config_name = outputs["Name"]
notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=initial_config_name,
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: initial_config_name}
)
len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1)
len(resource_description["OnCreate"]).should.equal(1)
initial_on_create_script.should.equal(
notebook_lifecycle_config_description["OnCreate"][0]["Content"]
resource_description["OnCreate"][0]["Content"]
)
# Update stack with new instance type and check attributes
@ -283,10 +204,63 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
updated_config_name = outputs["Name"]
updated_config_name.should.equal(initial_config_name)
notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=updated_config_name,
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_config_name}
)
len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1)
len(resource_description["OnCreate"]).should.equal(1)
updated_on_create_script.should.equal(
notebook_lifecycle_config_description["OnCreate"][0]["Content"]
resource_description["OnCreate"][0]["Content"]
)
@mock_cloudformation
@mock_sagemaker
def test_sagemaker_cloudformation_model_update():
cf = boto3.client("cloudformation", region_name="us-east-1")
sm = boto3.client("sagemaker", region_name="us-east-1")
test_config = ModelTestConfig()
# Set up template for stack with initial and update instance types
stack_name = "{}_stack".format(test_config.resource_name)
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
initial_image_version = 1
updated_image_version = 2
initial_template_json = test_config.get_cloudformation_template(
image=image.format(initial_image_version)
)
updated_template_json = test_config.get_cloudformation_template(
image=image.format(updated_image_version)
)
# Create stack with initial template and check attributes
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
inital_model_name = outputs["Name"]
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: inital_model_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(initial_image_version)
)
# Update stack with new instance type and check attributes
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
outputs = {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}
updated_notebook_name = outputs["Name"]
updated_notebook_name.should_not.equal(inital_model_name)
resource_description = getattr(sm, test_config.describe_function_name)(
**{test_config.name_parameter: updated_notebook_name}
)
resource_description["PrimaryContainer"]["Image"].should.equal(
image.format(updated_image_version)
)