From 0b11b0c71614c446e03e0eacdba47c65d31f23fa Mon Sep 17 00:00:00 2001 From: Zach Churchill Date: Tue, 13 Apr 2021 07:03:25 -0400 Subject: [PATCH] Add CloudFormation support for SageMaker Notebook Instance Lifecycle Configs (#3855) * Create SageMaker Notebook Instance Lifecycle Configs with CloudFormation Implement attributes for SM Notebook Instance Lifecycle Config in CloudFormation Delete SM Notebook Instance Lifecycle Configs with CloudFormation Update SM Notebook Instance Lifecycle Configs with CloudFormation Also fixed error in create_from method where the properties where not being referenced when setting OnCreate and OnStart. Factor out template for SM Notebook Lifecycle Config CF tests * Refactor SM CloudFormation create tests to use pytest.mark.parametrize * Refactor SM CloudFormation get_attr tests to use pytest.mark.parametrize Also update the NotebookInstance template function to use Name and Arn for the output IDs so that the parametrization is easier. * Refactor SM CloudFormation delete tests to use pytest.mark.parametrize --- moto/sagemaker/models.py | 67 +++++- .../test_sagemaker_cloudformation.py | 215 ++++++++++++++---- 2 files changed, 242 insertions(+), 40 deletions(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index cf2cdec7e..10f993e47 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -567,7 +567,7 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): backend.delete_notebook_instance(notebook_instance_name) -class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): +class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start ): @@ -606,6 +606,71 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): def response_create(self): return {"TrainingJobArn": self.training_job_arn} + @property + def physical_resource_id(self): + return self.notebook_instance_lifecycle_config_arn + + def get_cfn_attribute(self, attribute_name): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html#aws-resource-sagemaker-notebookinstancelifecycleconfig-return-values + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "NotebookInstanceLifecycleConfigName": + return self.notebook_instance_lifecycle_config_name + raise UnformattedGetAttTemplateException() + + @staticmethod + def cloudformation_name_type(): + return None + + @staticmethod + def cloudformation_type(): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html + return "AWS::SageMaker::NotebookInstanceLifecycleConfig" + + @classmethod + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + + config = sagemaker_backends[ + region_name + ].create_notebook_instance_lifecycle_config( + notebook_instance_lifecycle_config_name=resource_name, + on_create=properties.get("OnCreate"), + on_start=properties.get("OnStart"), + ) + return config + + @classmethod + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name, + ): + # Operations keep same resource name so delete old and create new to mimic update + cls.delete_from_cloudformation_json( + original_resource.notebook_instance_lifecycle_config_arn, + cloudformation_json, + region_name, + ) + new_resource = cls.create_from_cloudformation_json( + original_resource.notebook_instance_lifecycle_config_name, + cloudformation_json, + region_name, + ) + return new_resource + + @classmethod + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + # Get actual name because resource_name actually provides the ARN + # since the Physical Resource ID is the ARN despite SageMaker + # using the name for most of its operations. + config_name = resource_name.split("/")[-1] + + backend = sagemaker_backends[region_name] + backend.delete_notebook_instance_lifecycle_config(config_name) + class SageMakerModelBackend(BaseBackend): def __init__(self, region_name=None): diff --git a/tests/test_sagemaker/test_sagemaker_cloudformation.py b/tests/test_sagemaker/test_sagemaker_cloudformation.py index e90a6f151..60c463f58 100644 --- a/tests/test_sagemaker/test_sagemaker_cloudformation.py +++ b/tests/test_sagemaker/test_sagemaker_cloudformation.py @@ -10,6 +10,7 @@ from moto.sts.models import ACCOUNT_ID def _get_notebook_instance_template_string( + resource_name="TestNotebook", instance_type="ml.c4.xlarge", role_arn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID), include_outputs=True, @@ -17,7 +18,7 @@ def _get_notebook_instance_template_string( template = { "AWSTemplateFormatVersion": "2010-09-09", "Resources": { - "TestNotebookInstance": { + resource_name: { "Type": "AWS::SageMaker::NotebookInstance", "Properties": {"InstanceType": instance_type, "RoleArn": role_arn}, }, @@ -25,88 +26,172 @@ def _get_notebook_instance_template_string( } if include_outputs: template["Outputs"] = { - "NotebookInstanceArn": {"Value": {"Ref": "TestNotebookInstance"}}, - "NotebookInstanceName": { + "Arn": {"Value": {"Ref": resource_name}}, + "Name": {"Value": {"Fn::GetAtt": [resource_name, "NotebookInstanceName"]}}, + } + return json.dumps(template) + + +def _get_notebook_instance_lifecycle_config_template_string( + resource_name="TestConfig", on_create=None, on_start=None, include_outputs=True, +): + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + resource_name: { + "Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig", + "Properties": {}, + }, + }, + } + if on_create is not None: + template["Resources"][resource_name]["Properties"]["OnCreate"] = [ + {"Content": on_create} + ] + if on_start is not None: + template["Resources"][resource_name]["Properties"]["OnStart"] = [ + {"Content": on_start} + ] + if include_outputs: + template["Outputs"] = { + "Arn": {"Value": {"Ref": resource_name}}, + "Name": { "Value": { - "Fn::GetAtt": ["TestNotebookInstance", "NotebookInstanceName"] - }, + "Fn::GetAtt": [ + resource_name, + "NotebookInstanceLifecycleConfigName", + ] + } }, } return json.dumps(template) @mock_cloudformation -def test_sagemaker_cloudformation_create_notebook_instance(): +@pytest.mark.parametrize( + "stack_name,resource_name,template", + [ + ( + "test_sagemaker_notebook_instance", + "TestNotebookInstance", + _get_notebook_instance_template_string( + resource_name="TestNotebookInstance", include_outputs=False + ), + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + "TestNotebookInstanceLifecycleConfig", + _get_notebook_instance_lifecycle_config_template_string( + resource_name="TestNotebookInstanceLifecycleConfig", + include_outputs=False, + ), + ), + ], +) +def test_sagemaker_cloudformation_create(stack_name, resource_name, template): cf = boto3.client("cloudformation", region_name="us-east-1") - - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string(include_outputs=False) cf.create_stack(StackName=stack_name, TemplateBody=template) provisioned_resource = cf.list_stack_resources(StackName=stack_name)[ "StackResourceSummaries" ][0] - provisioned_resource["LogicalResourceId"].should.equal("TestNotebookInstance") + provisioned_resource["LogicalResourceId"].should.equal(resource_name) len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0) @mock_cloudformation @mock_sagemaker -def test_sagemaker_cloudformation_notebook_instance_get_attr(): +@pytest.mark.parametrize( + "stack_name,template,describe_function_name,name_parameter,arn_parameter", + [ + ( + "test_sagemaker_notebook_instance", + _get_notebook_instance_template_string(), + "describe_notebook_instance", + "NotebookInstanceName", + "NotebookInstanceArn", + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + _get_notebook_instance_lifecycle_config_template_string(), + "describe_notebook_instance_lifecycle_config", + "NotebookInstanceLifecycleConfigName", + "NotebookInstanceLifecycleConfigArn", + ), + ], +) +def test_sagemaker_cloudformation_get_attr( + stack_name, template, describe_function_name, name_parameter, arn_parameter +): cf = boto3.client("cloudformation", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1") - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string() + # Create stack and get description for output values cf.create_stack(StackName=stack_name, TemplateBody=template) - stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = { output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - notebook_instance_name = outputs["NotebookInstanceName"] - notebook_instance_arn = outputs["NotebookInstanceArn"] - notebook_instance_description = sm.describe_notebook_instance( - NotebookInstanceName=notebook_instance_name, - ) - notebook_instance_arn.should.equal( - notebook_instance_description["NotebookInstanceArn"] + # Using the describe function, ensure output ARN matches resource ARN + resource_description = getattr(sm, describe_function_name)( + **{name_parameter: outputs["Name"]} ) + outputs["Arn"].should.equal(resource_description[arn_parameter]) @mock_cloudformation @mock_sagemaker -def test_sagemaker_cloudformation_notebook_instance_delete(): +@pytest.mark.parametrize( + "stack_name,template,describe_function_name,name_parameter,arn_parameter,error_message", + [ + ( + "test_sagemaker_notebook_instance", + _get_notebook_instance_template_string(), + "describe_notebook_instance", + "NotebookInstanceName", + "NotebookInstanceArn", + "RecordNotFound", + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + _get_notebook_instance_lifecycle_config_template_string(), + "describe_notebook_instance_lifecycle_config", + "NotebookInstanceLifecycleConfigName", + "NotebookInstanceLifecycleConfigArn", + "Notebook Instance Lifecycle Config does not exist", + ), + ], +) +def test_sagemaker_cloudformation_notebook_instance_delete( + stack_name, + template, + describe_function_name, + name_parameter, + arn_parameter, + error_message, +): cf = boto3.client("cloudformation", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1") - # Create stack with notebook instance and verify existence - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string() + # Create stack and verify existence cf.create_stack(StackName=stack_name, TemplateBody=template) - stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = { output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - notebook_instance = sm.describe_notebook_instance( - NotebookInstanceName=outputs["NotebookInstanceName"], - ) - outputs["NotebookInstanceArn"].should.equal( - notebook_instance["NotebookInstanceArn"] + resource_description = getattr(sm, describe_function_name)( + **{name_parameter: outputs["Name"]} ) + outputs["Arn"].should.equal(resource_description[arn_parameter]) - # Delete the stack and verify notebook instance has also been deleted - # TODO replace exception check with `list_notebook_instances` method when implemented + # Delete the stack and verify resource has also been deleted cf.delete_stack(StackName=stack_name) with pytest.raises(ClientError) as ce: - sm.describe_notebook_instance( - NotebookInstanceName=outputs["NotebookInstanceName"] - ) - ce.value.response["Error"]["Message"].should.contain("RecordNotFound") + getattr(sm, describe_function_name)(**{name_parameter: outputs["Name"]}) + ce.value.response["Error"]["Message"].should.contain(error_message) @mock_cloudformation @@ -133,7 +218,7 @@ def test_sagemaker_cloudformation_notebook_instance_update(): output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - initial_notebook_name = outputs["NotebookInstanceName"] + initial_notebook_name = outputs["Name"] notebook_instance_description = sm.describe_notebook_instance( NotebookInstanceName=initial_notebook_name, ) @@ -146,10 +231,62 @@ def test_sagemaker_cloudformation_notebook_instance_update(): output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - updated_notebook_name = outputs["NotebookInstanceName"] + updated_notebook_name = outputs["Name"] updated_notebook_name.should.equal(initial_notebook_name) notebook_instance_description = sm.describe_notebook_instance( NotebookInstanceName=updated_notebook_name, ) updated_instance_type.should.equal(notebook_instance_description["InstanceType"]) + + +@mock_cloudformation +@mock_sagemaker +def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update(): + cf = boto3.client("cloudformation", region_name="us-east-1") + sm = boto3.client("sagemaker", region_name="us-east-1") + + # Set up template for stack with initial and update instance types + stack_name = "test_sagemaker_notebook_instance_lifecycle_config" + initial_on_create_script = "echo Hello World" + updated_on_create_script = "echo Goodbye World" + initial_template_json = _get_notebook_instance_lifecycle_config_template_string( + on_create=initial_on_create_script + ) + updated_template_json = _get_notebook_instance_lifecycle_config_template_string( + on_create=updated_on_create_script + ) + + # Create stack with initial template and check attributes + cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + initial_config_name = outputs["Name"] + notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=initial_config_name, + ) + len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1) + initial_on_create_script.should.equal( + notebook_lifecycle_config_description["OnCreate"][0]["Content"] + ) + + # Update stack with new instance type and check attributes + cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + updated_config_name = outputs["Name"] + updated_config_name.should.equal(initial_config_name) + + notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=updated_config_name, + ) + len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1) + updated_on_create_script.should.equal( + notebook_lifecycle_config_description["OnCreate"][0]["Content"] + )