Add CloudFormation support for SageMaker Models (#3861)
* Create a formal interface for SM Cloudformation test configurations * Create SageMaker Models with CloudFormation * Utilize six for adding metaclass to TestConfig * Update SM backend to return Model objects instead of response objects
This commit is contained in:
parent
0b11b0c716
commit
f6dda54a6c
@ -2,7 +2,6 @@ from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from boto3 import Session
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
|
||||
@ -310,7 +309,7 @@ class FakeEndpointConfig(BaseObject):
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseObject):
|
||||
class Model(BaseObject, CloudFormationModel):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
@ -353,6 +352,72 @@ class Model(BaseObject):
|
||||
+ model_name
|
||||
)
|
||||
|
||||
@property
|
||||
def physical_resource_id(self):
|
||||
return self.model_arn
|
||||
|
||||
def get_cfn_attribute(self, attribute_name):
|
||||
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html#aws-resource-sagemaker-model-return-values
|
||||
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||
|
||||
if attribute_name == "ModelName":
|
||||
return self.model_name
|
||||
raise UnformattedGetAttTemplateException()
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_name_type():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_type():
|
||||
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html
|
||||
return "AWS::SageMaker::Model"
|
||||
|
||||
@classmethod
|
||||
def create_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, region_name
|
||||
):
|
||||
sagemaker_backend = sagemaker_backends[region_name]
|
||||
|
||||
# Get required properties from provided CloudFormation template
|
||||
properties = cloudformation_json["Properties"]
|
||||
execution_role_arn = properties["ExecutionRoleArn"]
|
||||
primary_container = properties["PrimaryContainer"]
|
||||
|
||||
model = sagemaker_backend.create_model(
|
||||
ModelName=resource_name,
|
||||
ExecutionRoleArn=execution_role_arn,
|
||||
PrimaryContainer=primary_container,
|
||||
VpcConfig=properties.get("VpcConfig", {}),
|
||||
Containers=properties.get("Containers", []),
|
||||
Tags=properties.get("Tags", []),
|
||||
)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def update_from_cloudformation_json(
|
||||
cls, original_resource, new_resource_name, cloudformation_json, region_name,
|
||||
):
|
||||
# Most changes to the model will change resource name for Models
|
||||
cls.delete_from_cloudformation_json(
|
||||
original_resource.model_arn, cloudformation_json, region_name
|
||||
)
|
||||
new_resource = cls.create_from_cloudformation_json(
|
||||
new_resource_name, cloudformation_json, region_name
|
||||
)
|
||||
return new_resource
|
||||
|
||||
@classmethod
|
||||
def delete_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, region_name
|
||||
):
|
||||
# Get actual name because resource_name actually provides the ARN
|
||||
# since the Physical Resource ID is the ARN despite SageMaker
|
||||
# using the name for most of its operations.
|
||||
model_name = resource_name.split("/")[-1]
|
||||
|
||||
sagemaker_backends[region_name].delete_model(model_name)
|
||||
|
||||
|
||||
class VpcConfig(BaseObject):
|
||||
def __init__(self, security_group_ids, subnets):
|
||||
@ -699,23 +764,19 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
|
||||
self._models[kwargs.get("ModelName")] = model_obj
|
||||
return model_obj.response_create
|
||||
return model_obj
|
||||
|
||||
def describe_model(self, model_name=None):
|
||||
model = self._models.get(model_name)
|
||||
if model:
|
||||
return model.response_object
|
||||
return model
|
||||
message = "Could not find model '{}'.".format(
|
||||
Model.arn_for_model_name(model_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def list_models(self):
|
||||
models = []
|
||||
for model in self._models.values():
|
||||
model_response = deepcopy(model.response_object)
|
||||
models.append(model_response)
|
||||
return {"Models": models}
|
||||
return self._models.values()
|
||||
|
||||
def delete_model(self, model_name=None):
|
||||
for model in self._models.values():
|
||||
|
@ -22,12 +22,12 @@ class SageMakerResponse(BaseResponse):
|
||||
|
||||
def describe_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
response = self.sagemaker_backend.describe_model(model_name)
|
||||
return json.dumps(response)
|
||||
model = self.sagemaker_backend.describe_model(model_name)
|
||||
return json.dumps(model.response_object)
|
||||
|
||||
def create_model(self):
|
||||
response = self.sagemaker_backend.create_model(**self.request_params)
|
||||
return json.dumps(response)
|
||||
model = self.sagemaker_backend.create_model(**self.request_params)
|
||||
return json.dumps(model.response_create)
|
||||
|
||||
def delete_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
@ -35,8 +35,8 @@ class SageMakerResponse(BaseResponse):
|
||||
return json.dumps(response)
|
||||
|
||||
def list_models(self):
|
||||
response = self.sagemaker_backend.list_models(**self.request_params)
|
||||
return json.dumps(response)
|
||||
models = self.sagemaker_backend.list_models(**self.request_params)
|
||||
return json.dumps({"Models": [model.response_object for model in models]})
|
||||
|
||||
def _get_param(self, param, if_none=None):
|
||||
return self.request_params.get(param, if_none)
|
||||
|
191
tests/test_sagemaker/cloudformation_test_configs.py
Normal file
191
tests/test_sagemaker/cloudformation_test_configs.py
Normal file
@ -0,0 +1,191 @@
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import six
|
||||
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class TestConfig:
|
||||
"""Provides the interface to use for creating test configurations.
|
||||
|
||||
This class will provide the interface for what information will be
|
||||
needed for the SageMaker CloudFormation tests. Ultimately, this will
|
||||
improve the readability of the tests in `test_sagemaker_cloudformation.py`
|
||||
because it will reduce the amount of information we pass through the
|
||||
`pytest.mark.parametrize` decorator.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resource_name(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def describe_function_name(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name_parameter(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def arn_parameter(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class NotebookInstanceTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Notebook Instances."""
|
||||
|
||||
@property
|
||||
def resource_name(self):
|
||||
return "TestNotebook"
|
||||
|
||||
@property
|
||||
def describe_function_name(self):
|
||||
return "describe_notebook_instance"
|
||||
|
||||
@property
|
||||
def name_parameter(self):
|
||||
return "NotebookInstanceName"
|
||||
|
||||
@property
|
||||
def arn_parameter(self):
|
||||
return "NotebookInstanceArn"
|
||||
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
instance_type = kwargs.get("instance_type", "ml.c4.xlarge")
|
||||
role_arn = kwargs.get(
|
||||
"role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
)
|
||||
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
self.resource_name: {
|
||||
"Type": "AWS::SageMaker::NotebookInstance",
|
||||
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
|
||||
},
|
||||
},
|
||||
}
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {
|
||||
"Value": {
|
||||
"Fn::GetAtt": [self.resource_name, "NotebookInstanceName"]
|
||||
}
|
||||
},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
|
||||
class NotebookInstanceLifecycleConfigTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Notebook Instance Lifecycle Configs."""
|
||||
|
||||
@property
|
||||
def resource_name(self):
|
||||
return "TestNotebookLifecycleConfig"
|
||||
|
||||
@property
|
||||
def describe_function_name(self):
|
||||
return "describe_notebook_instance_lifecycle_config"
|
||||
|
||||
@property
|
||||
def name_parameter(self):
|
||||
return "NotebookInstanceLifecycleConfigName"
|
||||
|
||||
@property
|
||||
def arn_parameter(self):
|
||||
return "NotebookInstanceLifecycleConfigArn"
|
||||
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
on_create = kwargs.get("on_create")
|
||||
on_start = kwargs.get("on_start")
|
||||
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
self.resource_name: {
|
||||
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
|
||||
"Properties": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
if on_create is not None:
|
||||
template["Resources"][self.resource_name]["Properties"]["OnCreate"] = [
|
||||
{"Content": on_create}
|
||||
]
|
||||
if on_start is not None:
|
||||
template["Resources"][self.resource_name]["Properties"]["OnStart"] = [
|
||||
{"Content": on_start}
|
||||
]
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {
|
||||
"Value": {
|
||||
"Fn::GetAtt": [
|
||||
self.resource_name,
|
||||
"NotebookInstanceLifecycleConfigName",
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
|
||||
class ModelTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Models."""
|
||||
|
||||
@property
|
||||
def resource_name(self):
|
||||
return "TestModel"
|
||||
|
||||
@property
|
||||
def describe_function_name(self):
|
||||
return "describe_model"
|
||||
|
||||
@property
|
||||
def name_parameter(self):
|
||||
return "ModelName"
|
||||
|
||||
@property
|
||||
def arn_parameter(self):
|
||||
return "ModelArn"
|
||||
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
execution_role_arn = kwargs.get(
|
||||
"execution_role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
)
|
||||
image = kwargs.get(
|
||||
"image", "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1"
|
||||
)
|
||||
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
self.resource_name: {
|
||||
"Type": "AWS::SageMaker::Model",
|
||||
"Properties": {
|
||||
"ExecutionRoleArn": execution_role_arn,
|
||||
"PrimaryContainer": {"Image": image,},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
|
||||
}
|
||||
return json.dumps(template)
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import boto3
|
||||
|
||||
import pytest
|
||||
@ -6,128 +5,58 @@ import sure # noqa
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from moto import mock_cloudformation, mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
|
||||
def _get_notebook_instance_template_string(
|
||||
resource_name="TestNotebook",
|
||||
instance_type="ml.c4.xlarge",
|
||||
role_arn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||
include_outputs=True,
|
||||
):
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
resource_name: {
|
||||
"Type": "AWS::SageMaker::NotebookInstance",
|
||||
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
|
||||
},
|
||||
},
|
||||
}
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": resource_name}},
|
||||
"Name": {"Value": {"Fn::GetAtt": [resource_name, "NotebookInstanceName"]}},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
|
||||
def _get_notebook_instance_lifecycle_config_template_string(
|
||||
resource_name="TestConfig", on_create=None, on_start=None, include_outputs=True,
|
||||
):
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
resource_name: {
|
||||
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
|
||||
"Properties": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
if on_create is not None:
|
||||
template["Resources"][resource_name]["Properties"]["OnCreate"] = [
|
||||
{"Content": on_create}
|
||||
]
|
||||
if on_start is not None:
|
||||
template["Resources"][resource_name]["Properties"]["OnStart"] = [
|
||||
{"Content": on_start}
|
||||
]
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": resource_name}},
|
||||
"Name": {
|
||||
"Value": {
|
||||
"Fn::GetAtt": [
|
||||
resource_name,
|
||||
"NotebookInstanceLifecycleConfigName",
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
return json.dumps(template)
|
||||
from .cloudformation_test_configs import (
|
||||
NotebookInstanceTestConfig,
|
||||
NotebookInstanceLifecycleConfigTestConfig,
|
||||
ModelTestConfig,
|
||||
)
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@pytest.mark.parametrize(
|
||||
"stack_name,resource_name,template",
|
||||
"test_config",
|
||||
[
|
||||
(
|
||||
"test_sagemaker_notebook_instance",
|
||||
"TestNotebookInstance",
|
||||
_get_notebook_instance_template_string(
|
||||
resource_name="TestNotebookInstance", include_outputs=False
|
||||
),
|
||||
),
|
||||
(
|
||||
"test_sagemaker_notebook_instance_lifecycle_config",
|
||||
"TestNotebookInstanceLifecycleConfig",
|
||||
_get_notebook_instance_lifecycle_config_template_string(
|
||||
resource_name="TestNotebookInstanceLifecycleConfig",
|
||||
include_outputs=False,
|
||||
),
|
||||
),
|
||||
NotebookInstanceTestConfig(),
|
||||
NotebookInstanceLifecycleConfigTestConfig(),
|
||||
ModelTestConfig(),
|
||||
],
|
||||
)
|
||||
def test_sagemaker_cloudformation_create(stack_name, resource_name, template):
|
||||
def test_sagemaker_cloudformation_create(test_config):
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
cf.create_stack(StackName=stack_name, TemplateBody=template)
|
||||
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
cf.create_stack(
|
||||
StackName=stack_name,
|
||||
TemplateBody=test_config.get_cloudformation_template(include_outputs=False),
|
||||
)
|
||||
|
||||
provisioned_resource = cf.list_stack_resources(StackName=stack_name)[
|
||||
"StackResourceSummaries"
|
||||
][0]
|
||||
provisioned_resource["LogicalResourceId"].should.equal(resource_name)
|
||||
provisioned_resource["LogicalResourceId"].should.equal(test_config.resource_name)
|
||||
len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0)
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@mock_sagemaker
|
||||
@pytest.mark.parametrize(
|
||||
"stack_name,template,describe_function_name,name_parameter,arn_parameter",
|
||||
"test_config",
|
||||
[
|
||||
(
|
||||
"test_sagemaker_notebook_instance",
|
||||
_get_notebook_instance_template_string(),
|
||||
"describe_notebook_instance",
|
||||
"NotebookInstanceName",
|
||||
"NotebookInstanceArn",
|
||||
),
|
||||
(
|
||||
"test_sagemaker_notebook_instance_lifecycle_config",
|
||||
_get_notebook_instance_lifecycle_config_template_string(),
|
||||
"describe_notebook_instance_lifecycle_config",
|
||||
"NotebookInstanceLifecycleConfigName",
|
||||
"NotebookInstanceLifecycleConfigArn",
|
||||
),
|
||||
NotebookInstanceTestConfig(),
|
||||
NotebookInstanceLifecycleConfigTestConfig(),
|
||||
ModelTestConfig(),
|
||||
],
|
||||
)
|
||||
def test_sagemaker_cloudformation_get_attr(
|
||||
stack_name, template, describe_function_name, name_parameter, arn_parameter
|
||||
):
|
||||
def test_sagemaker_cloudformation_get_attr(test_config):
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
# Create stack and get description for output values
|
||||
cf.create_stack(StackName=stack_name, TemplateBody=template)
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
cf.create_stack(
|
||||
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
||||
)
|
||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
outputs = {
|
||||
output["OutputKey"]: output["OutputValue"]
|
||||
@ -135,62 +64,50 @@ def test_sagemaker_cloudformation_get_attr(
|
||||
}
|
||||
|
||||
# Using the describe function, ensure output ARN matches resource ARN
|
||||
resource_description = getattr(sm, describe_function_name)(
|
||||
**{name_parameter: outputs["Name"]}
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
outputs["Arn"].should.equal(resource_description[arn_parameter])
|
||||
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@mock_sagemaker
|
||||
@pytest.mark.parametrize(
|
||||
"stack_name,template,describe_function_name,name_parameter,arn_parameter,error_message",
|
||||
"test_config,error_message",
|
||||
[
|
||||
(NotebookInstanceTestConfig(), "RecordNotFound"),
|
||||
(
|
||||
"test_sagemaker_notebook_instance",
|
||||
_get_notebook_instance_template_string(),
|
||||
"describe_notebook_instance",
|
||||
"NotebookInstanceName",
|
||||
"NotebookInstanceArn",
|
||||
"RecordNotFound",
|
||||
),
|
||||
(
|
||||
"test_sagemaker_notebook_instance_lifecycle_config",
|
||||
_get_notebook_instance_lifecycle_config_template_string(),
|
||||
"describe_notebook_instance_lifecycle_config",
|
||||
"NotebookInstanceLifecycleConfigName",
|
||||
"NotebookInstanceLifecycleConfigArn",
|
||||
NotebookInstanceLifecycleConfigTestConfig(),
|
||||
"Notebook Instance Lifecycle Config does not exist",
|
||||
),
|
||||
(ModelTestConfig(), "Could not find model"),
|
||||
],
|
||||
)
|
||||
def test_sagemaker_cloudformation_notebook_instance_delete(
|
||||
stack_name,
|
||||
template,
|
||||
describe_function_name,
|
||||
name_parameter,
|
||||
arn_parameter,
|
||||
error_message,
|
||||
):
|
||||
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
# Create stack and verify existence
|
||||
cf.create_stack(StackName=stack_name, TemplateBody=template)
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
cf.create_stack(
|
||||
StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
|
||||
)
|
||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
outputs = {
|
||||
output["OutputKey"]: output["OutputValue"]
|
||||
for output in stack_description["Outputs"]
|
||||
}
|
||||
resource_description = getattr(sm, describe_function_name)(
|
||||
**{name_parameter: outputs["Name"]}
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
outputs["Arn"].should.equal(resource_description[arn_parameter])
|
||||
outputs["Arn"].should.equal(resource_description[test_config.arn_parameter])
|
||||
|
||||
# Delete the stack and verify resource has also been deleted
|
||||
cf.delete_stack(StackName=stack_name)
|
||||
with pytest.raises(ClientError) as ce:
|
||||
getattr(sm, describe_function_name)(**{name_parameter: outputs["Name"]})
|
||||
getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: outputs["Name"]}
|
||||
)
|
||||
ce.value.response["Error"]["Message"].should.contain(error_message)
|
||||
|
||||
|
||||
@ -200,14 +117,16 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
test_config = NotebookInstanceTestConfig()
|
||||
|
||||
# Set up template for stack with initial and update instance types
|
||||
stack_name = "test_sagemaker_notebook_instance"
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
initial_instance_type = "ml.c4.xlarge"
|
||||
updated_instance_type = "ml.c4.4xlarge"
|
||||
initial_template_json = _get_notebook_instance_template_string(
|
||||
initial_template_json = test_config.get_cloudformation_template(
|
||||
instance_type=initial_instance_type
|
||||
)
|
||||
updated_template_json = _get_notebook_instance_template_string(
|
||||
updated_template_json = test_config.get_cloudformation_template(
|
||||
instance_type=updated_instance_type
|
||||
)
|
||||
|
||||
@ -219,10 +138,10 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
||||
for output in stack_description["Outputs"]
|
||||
}
|
||||
initial_notebook_name = outputs["Name"]
|
||||
notebook_instance_description = sm.describe_notebook_instance(
|
||||
NotebookInstanceName=initial_notebook_name,
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_notebook_name}
|
||||
)
|
||||
initial_instance_type.should.equal(notebook_instance_description["InstanceType"])
|
||||
initial_instance_type.should.equal(resource_description["InstanceType"])
|
||||
|
||||
# Update stack with new instance type and check attributes
|
||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||
@ -234,10 +153,10 @@ def test_sagemaker_cloudformation_notebook_instance_update():
|
||||
updated_notebook_name = outputs["Name"]
|
||||
updated_notebook_name.should.equal(initial_notebook_name)
|
||||
|
||||
notebook_instance_description = sm.describe_notebook_instance(
|
||||
NotebookInstanceName=updated_notebook_name,
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_notebook_name}
|
||||
)
|
||||
updated_instance_type.should.equal(notebook_instance_description["InstanceType"])
|
||||
updated_instance_type.should.equal(resource_description["InstanceType"])
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@ -246,14 +165,16 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
test_config = NotebookInstanceLifecycleConfigTestConfig()
|
||||
|
||||
# Set up template for stack with initial and update instance types
|
||||
stack_name = "test_sagemaker_notebook_instance_lifecycle_config"
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
initial_on_create_script = "echo Hello World"
|
||||
updated_on_create_script = "echo Goodbye World"
|
||||
initial_template_json = _get_notebook_instance_lifecycle_config_template_string(
|
||||
initial_template_json = test_config.get_cloudformation_template(
|
||||
on_create=initial_on_create_script
|
||||
)
|
||||
updated_template_json = _get_notebook_instance_lifecycle_config_template_string(
|
||||
updated_template_json = test_config.get_cloudformation_template(
|
||||
on_create=updated_on_create_script
|
||||
)
|
||||
|
||||
@ -265,12 +186,12 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
||||
for output in stack_description["Outputs"]
|
||||
}
|
||||
initial_config_name = outputs["Name"]
|
||||
notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=initial_config_name,
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: initial_config_name}
|
||||
)
|
||||
len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1)
|
||||
len(resource_description["OnCreate"]).should.equal(1)
|
||||
initial_on_create_script.should.equal(
|
||||
notebook_lifecycle_config_description["OnCreate"][0]["Content"]
|
||||
resource_description["OnCreate"][0]["Content"]
|
||||
)
|
||||
|
||||
# Update stack with new instance type and check attributes
|
||||
@ -283,10 +204,63 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
|
||||
updated_config_name = outputs["Name"]
|
||||
updated_config_name.should.equal(initial_config_name)
|
||||
|
||||
notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=updated_config_name,
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_config_name}
|
||||
)
|
||||
len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1)
|
||||
len(resource_description["OnCreate"]).should.equal(1)
|
||||
updated_on_create_script.should.equal(
|
||||
notebook_lifecycle_config_description["OnCreate"][0]["Content"]
|
||||
resource_description["OnCreate"][0]["Content"]
|
||||
)
|
||||
|
||||
|
||||
@mock_cloudformation
|
||||
@mock_sagemaker
|
||||
def test_sagemaker_cloudformation_model_update():
|
||||
cf = boto3.client("cloudformation", region_name="us-east-1")
|
||||
sm = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
test_config = ModelTestConfig()
|
||||
|
||||
# Set up template for stack with initial and update instance types
|
||||
stack_name = "{}_stack".format(test_config.resource_name)
|
||||
image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
|
||||
initial_image_version = 1
|
||||
updated_image_version = 2
|
||||
initial_template_json = test_config.get_cloudformation_template(
|
||||
image=image.format(initial_image_version)
|
||||
)
|
||||
updated_template_json = test_config.get_cloudformation_template(
|
||||
image=image.format(updated_image_version)
|
||||
)
|
||||
|
||||
# Create stack with initial template and check attributes
|
||||
cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
|
||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
outputs = {
|
||||
output["OutputKey"]: output["OutputValue"]
|
||||
for output in stack_description["Outputs"]
|
||||
}
|
||||
inital_model_name = outputs["Name"]
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: inital_model_name}
|
||||
)
|
||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||
image.format(initial_image_version)
|
||||
)
|
||||
|
||||
# Update stack with new instance type and check attributes
|
||||
cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
|
||||
stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
outputs = {
|
||||
output["OutputKey"]: output["OutputValue"]
|
||||
for output in stack_description["Outputs"]
|
||||
}
|
||||
updated_notebook_name = outputs["Name"]
|
||||
updated_notebook_name.should_not.equal(inital_model_name)
|
||||
|
||||
resource_description = getattr(sm, test_config.describe_function_name)(
|
||||
**{test_config.name_parameter: updated_notebook_name}
|
||||
)
|
||||
resource_description["PrimaryContainer"]["Image"].should.equal(
|
||||
image.format(updated_image_version)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user