From d45233fa0037b7c2fc5cf0284abfa90de7398eb6 Mon Sep 17 00:00:00 2001 From: Zach Churchill <47751011+zachurchill-root@users.noreply.github.com> Date: Fri, 9 Apr 2021 13:54:00 -0400 Subject: [PATCH] Add CloudFormation support for SageMaker Notebook Instances (#3845) * Create SageMaker Notebook Instances with CloudFormation * Implement attributes for SageMaker notebook instance in Cloudformation * Delete SageMaker Notebook Instances with CloudFormation * Update SageMaker Notebook Instances with CloudFormation * Factor out template into function for SageMaker notebook instance tests --- moto/cloudformation/parsing.py | 1 + moto/sagemaker/models.py | 67 +++++++- .../test_sagemaker_cloudformation.py | 155 ++++++++++++++++++ 3 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 tests/test_sagemaker/test_sagemaker_cloudformation.py diff --git a/moto/cloudformation/parsing.py b/moto/cloudformation/parsing.py index 33d6fa6d1..f92a402db 100644 --- a/moto/cloudformation/parsing.py +++ b/moto/cloudformation/parsing.py @@ -37,6 +37,7 @@ from moto.redshift import models as redshift_models # noqa from moto.route53 import models as route53_models # noqa from moto.s3 import models as s3_models, s3_backend # noqa from moto.s3.utils import bucket_and_name_from_url +from moto.sagemaker import models as sagemaker_models # noqa from moto.sns import models as sns_models # noqa from moto.sqs import models as sqs_models # noqa from moto.stepfunctions import models as stepfunctions_models # noqa diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 8fef306b8..cf2cdec7e 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -5,7 +5,7 @@ from boto3 import Session from copy import deepcopy from datetime import datetime -from moto.core import ACCOUNT_ID, BaseBackend, BaseModel +from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import RESTError from moto.sagemaker import validators from .exceptions import MissingModel, ValidationError @@ -383,7 +383,7 @@ class Container(BaseObject): } -class FakeSagemakerNotebookInstance: +class FakeSagemakerNotebookInstance(CloudFormationModel): def __init__( self, region_name, @@ -503,6 +503,69 @@ class FakeSagemakerNotebookInstance: def stop(self): self.status = "Stopped" + @property + def physical_resource_id(self): + return self.arn + + def get_cfn_attribute(self, attribute_name): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html#aws-resource-sagemaker-notebookinstance-return-values + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "NotebookInstanceName": + return self.notebook_instance_name + raise UnformattedGetAttTemplateException() + + @staticmethod + def cloudformation_name_type(): + return None + + @staticmethod + def cloudformation_type(): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html + return "AWS::SageMaker::NotebookInstance" + + @classmethod + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + # Get required properties from provided CloudFormation template + properties = cloudformation_json["Properties"] + instance_type = properties["InstanceType"] + role_arn = properties["RoleArn"] + + notebook = sagemaker_backends[region_name].create_notebook_instance( + notebook_instance_name=resource_name, + instance_type=instance_type, + role_arn=role_arn, + ) + return notebook + + @classmethod + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name, + ): + # Operations keep same resource name so delete old and create new to mimic update + cls.delete_from_cloudformation_json( + original_resource.arn, cloudformation_json, region_name + ) + new_resource = cls.create_from_cloudformation_json( + original_resource.notebook_instance_name, cloudformation_json, region_name + ) + return new_resource + + @classmethod + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + # Get actual name because resource_name actually provides the ARN + # since the Physical Resource ID is the ARN despite SageMaker + # using the name for most of its operations. + notebook_instance_name = resource_name.split("/")[-1] + + backend = sagemaker_backends[region_name] + backend.stop_notebook_instance(notebook_instance_name) + backend.delete_notebook_instance(notebook_instance_name) + class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): def __init__( diff --git a/tests/test_sagemaker/test_sagemaker_cloudformation.py b/tests/test_sagemaker/test_sagemaker_cloudformation.py new file mode 100644 index 000000000..e90a6f151 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_cloudformation.py @@ -0,0 +1,155 @@ +import json +import boto3 + +import pytest +import sure # noqa +from botocore.exceptions import ClientError + +from moto import mock_cloudformation, mock_sagemaker +from moto.sts.models import ACCOUNT_ID + + +def _get_notebook_instance_template_string( + instance_type="ml.c4.xlarge", + role_arn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID), + include_outputs=True, +): + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + "TestNotebookInstance": { + "Type": "AWS::SageMaker::NotebookInstance", + "Properties": {"InstanceType": instance_type, "RoleArn": role_arn}, + }, + }, + } + if include_outputs: + template["Outputs"] = { + "NotebookInstanceArn": {"Value": {"Ref": "TestNotebookInstance"}}, + "NotebookInstanceName": { + "Value": { + "Fn::GetAtt": ["TestNotebookInstance", "NotebookInstanceName"] + }, + }, + } + return json.dumps(template) + + +@mock_cloudformation +def test_sagemaker_cloudformation_create_notebook_instance(): + cf = boto3.client("cloudformation", region_name="us-east-1") + + stack_name = "test_sagemaker_notebook_instance" + template = _get_notebook_instance_template_string(include_outputs=False) + cf.create_stack(StackName=stack_name, TemplateBody=template) + + provisioned_resource = cf.list_stack_resources(StackName=stack_name)[ + "StackResourceSummaries" + ][0] + provisioned_resource["LogicalResourceId"].should.equal("TestNotebookInstance") + len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0) + + +@mock_cloudformation +@mock_sagemaker +def test_sagemaker_cloudformation_notebook_instance_get_attr(): + cf = boto3.client("cloudformation", region_name="us-east-1") + sm = boto3.client("sagemaker", region_name="us-east-1") + + stack_name = "test_sagemaker_notebook_instance" + template = _get_notebook_instance_template_string() + cf.create_stack(StackName=stack_name, TemplateBody=template) + + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + notebook_instance_name = outputs["NotebookInstanceName"] + notebook_instance_arn = outputs["NotebookInstanceArn"] + + notebook_instance_description = sm.describe_notebook_instance( + NotebookInstanceName=notebook_instance_name, + ) + notebook_instance_arn.should.equal( + notebook_instance_description["NotebookInstanceArn"] + ) + + +@mock_cloudformation +@mock_sagemaker +def test_sagemaker_cloudformation_notebook_instance_delete(): + cf = boto3.client("cloudformation", region_name="us-east-1") + sm = boto3.client("sagemaker", region_name="us-east-1") + + # Create stack with notebook instance and verify existence + stack_name = "test_sagemaker_notebook_instance" + template = _get_notebook_instance_template_string() + cf.create_stack(StackName=stack_name, TemplateBody=template) + + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + notebook_instance = sm.describe_notebook_instance( + NotebookInstanceName=outputs["NotebookInstanceName"], + ) + outputs["NotebookInstanceArn"].should.equal( + notebook_instance["NotebookInstanceArn"] + ) + + # Delete the stack and verify notebook instance has also been deleted + # TODO replace exception check with `list_notebook_instances` method when implemented + cf.delete_stack(StackName=stack_name) + with pytest.raises(ClientError) as ce: + sm.describe_notebook_instance( + NotebookInstanceName=outputs["NotebookInstanceName"] + ) + ce.value.response["Error"]["Message"].should.contain("RecordNotFound") + + +@mock_cloudformation +@mock_sagemaker +def test_sagemaker_cloudformation_notebook_instance_update(): + cf = boto3.client("cloudformation", region_name="us-east-1") + sm = boto3.client("sagemaker", region_name="us-east-1") + + # Set up template for stack with initial and update instance types + stack_name = "test_sagemaker_notebook_instance" + initial_instance_type = "ml.c4.xlarge" + updated_instance_type = "ml.c4.4xlarge" + initial_template_json = _get_notebook_instance_template_string( + instance_type=initial_instance_type + ) + updated_template_json = _get_notebook_instance_template_string( + instance_type=updated_instance_type + ) + + # Create stack with initial template and check attributes + cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + initial_notebook_name = outputs["NotebookInstanceName"] + notebook_instance_description = sm.describe_notebook_instance( + NotebookInstanceName=initial_notebook_name, + ) + initial_instance_type.should.equal(notebook_instance_description["InstanceType"]) + + # Update stack with new instance type and check attributes + cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + updated_notebook_name = outputs["NotebookInstanceName"] + updated_notebook_name.should.equal(initial_notebook_name) + + notebook_instance_description = sm.describe_notebook_instance( + NotebookInstanceName=updated_notebook_name, + ) + updated_instance_type.should.equal(notebook_instance_description["InstanceType"])