diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index cf2cdec7e..10f993e47 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -567,7 +567,7 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): backend.delete_notebook_instance(notebook_instance_name) -class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): +class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( self, region_name, notebook_instance_lifecycle_config_name, on_create, on_start ): @@ -606,6 +606,71 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject): def response_create(self): return {"TrainingJobArn": self.training_job_arn} + @property + def physical_resource_id(self): + return self.notebook_instance_lifecycle_config_arn + + def get_cfn_attribute(self, attribute_name): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html#aws-resource-sagemaker-notebookinstancelifecycleconfig-return-values + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "NotebookInstanceLifecycleConfigName": + return self.notebook_instance_lifecycle_config_name + raise UnformattedGetAttTemplateException() + + @staticmethod + def cloudformation_name_type(): + return None + + @staticmethod + def cloudformation_type(): + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html + return "AWS::SageMaker::NotebookInstanceLifecycleConfig" + + @classmethod + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + + config = sagemaker_backends[ + region_name + ].create_notebook_instance_lifecycle_config( + notebook_instance_lifecycle_config_name=resource_name, + on_create=properties.get("OnCreate"), + on_start=properties.get("OnStart"), + ) + return config + + @classmethod + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name, + ): + # Operations keep same resource name so delete old and create new to mimic update + cls.delete_from_cloudformation_json( + original_resource.notebook_instance_lifecycle_config_arn, + cloudformation_json, + region_name, + ) + new_resource = cls.create_from_cloudformation_json( + original_resource.notebook_instance_lifecycle_config_name, + cloudformation_json, + region_name, + ) + return new_resource + + @classmethod + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + # Get actual name because resource_name actually provides the ARN + # since the Physical Resource ID is the ARN despite SageMaker + # using the name for most of its operations. + config_name = resource_name.split("/")[-1] + + backend = sagemaker_backends[region_name] + backend.delete_notebook_instance_lifecycle_config(config_name) + class SageMakerModelBackend(BaseBackend): def __init__(self, region_name=None): diff --git a/tests/test_sagemaker/test_sagemaker_cloudformation.py b/tests/test_sagemaker/test_sagemaker_cloudformation.py index e90a6f151..60c463f58 100644 --- a/tests/test_sagemaker/test_sagemaker_cloudformation.py +++ b/tests/test_sagemaker/test_sagemaker_cloudformation.py @@ -10,6 +10,7 @@ from moto.sts.models import ACCOUNT_ID def _get_notebook_instance_template_string( + resource_name="TestNotebook", instance_type="ml.c4.xlarge", role_arn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID), include_outputs=True, @@ -17,7 +18,7 @@ def _get_notebook_instance_template_string( template = { "AWSTemplateFormatVersion": "2010-09-09", "Resources": { - "TestNotebookInstance": { + resource_name: { "Type": "AWS::SageMaker::NotebookInstance", "Properties": {"InstanceType": instance_type, "RoleArn": role_arn}, }, @@ -25,88 +26,172 @@ def _get_notebook_instance_template_string( } if include_outputs: template["Outputs"] = { - "NotebookInstanceArn": {"Value": {"Ref": "TestNotebookInstance"}}, - "NotebookInstanceName": { + "Arn": {"Value": {"Ref": resource_name}}, + "Name": {"Value": {"Fn::GetAtt": [resource_name, "NotebookInstanceName"]}}, + } + return json.dumps(template) + + +def _get_notebook_instance_lifecycle_config_template_string( + resource_name="TestConfig", on_create=None, on_start=None, include_outputs=True, +): + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + resource_name: { + "Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig", + "Properties": {}, + }, + }, + } + if on_create is not None: + template["Resources"][resource_name]["Properties"]["OnCreate"] = [ + {"Content": on_create} + ] + if on_start is not None: + template["Resources"][resource_name]["Properties"]["OnStart"] = [ + {"Content": on_start} + ] + if include_outputs: + template["Outputs"] = { + "Arn": {"Value": {"Ref": resource_name}}, + "Name": { "Value": { - "Fn::GetAtt": ["TestNotebookInstance", "NotebookInstanceName"] - }, + "Fn::GetAtt": [ + resource_name, + "NotebookInstanceLifecycleConfigName", + ] + } }, } return json.dumps(template) @mock_cloudformation -def test_sagemaker_cloudformation_create_notebook_instance(): +@pytest.mark.parametrize( + "stack_name,resource_name,template", + [ + ( + "test_sagemaker_notebook_instance", + "TestNotebookInstance", + _get_notebook_instance_template_string( + resource_name="TestNotebookInstance", include_outputs=False + ), + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + "TestNotebookInstanceLifecycleConfig", + _get_notebook_instance_lifecycle_config_template_string( + resource_name="TestNotebookInstanceLifecycleConfig", + include_outputs=False, + ), + ), + ], +) +def test_sagemaker_cloudformation_create(stack_name, resource_name, template): cf = boto3.client("cloudformation", region_name="us-east-1") - - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string(include_outputs=False) cf.create_stack(StackName=stack_name, TemplateBody=template) provisioned_resource = cf.list_stack_resources(StackName=stack_name)[ "StackResourceSummaries" ][0] - provisioned_resource["LogicalResourceId"].should.equal("TestNotebookInstance") + provisioned_resource["LogicalResourceId"].should.equal(resource_name) len(provisioned_resource["PhysicalResourceId"]).should.be.greater_than(0) @mock_cloudformation @mock_sagemaker -def test_sagemaker_cloudformation_notebook_instance_get_attr(): +@pytest.mark.parametrize( + "stack_name,template,describe_function_name,name_parameter,arn_parameter", + [ + ( + "test_sagemaker_notebook_instance", + _get_notebook_instance_template_string(), + "describe_notebook_instance", + "NotebookInstanceName", + "NotebookInstanceArn", + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + _get_notebook_instance_lifecycle_config_template_string(), + "describe_notebook_instance_lifecycle_config", + "NotebookInstanceLifecycleConfigName", + "NotebookInstanceLifecycleConfigArn", + ), + ], +) +def test_sagemaker_cloudformation_get_attr( + stack_name, template, describe_function_name, name_parameter, arn_parameter +): cf = boto3.client("cloudformation", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1") - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string() + # Create stack and get description for output values cf.create_stack(StackName=stack_name, TemplateBody=template) - stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = { output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - notebook_instance_name = outputs["NotebookInstanceName"] - notebook_instance_arn = outputs["NotebookInstanceArn"] - notebook_instance_description = sm.describe_notebook_instance( - NotebookInstanceName=notebook_instance_name, - ) - notebook_instance_arn.should.equal( - notebook_instance_description["NotebookInstanceArn"] + # Using the describe function, ensure output ARN matches resource ARN + resource_description = getattr(sm, describe_function_name)( + **{name_parameter: outputs["Name"]} ) + outputs["Arn"].should.equal(resource_description[arn_parameter]) @mock_cloudformation @mock_sagemaker -def test_sagemaker_cloudformation_notebook_instance_delete(): +@pytest.mark.parametrize( + "stack_name,template,describe_function_name,name_parameter,arn_parameter,error_message", + [ + ( + "test_sagemaker_notebook_instance", + _get_notebook_instance_template_string(), + "describe_notebook_instance", + "NotebookInstanceName", + "NotebookInstanceArn", + "RecordNotFound", + ), + ( + "test_sagemaker_notebook_instance_lifecycle_config", + _get_notebook_instance_lifecycle_config_template_string(), + "describe_notebook_instance_lifecycle_config", + "NotebookInstanceLifecycleConfigName", + "NotebookInstanceLifecycleConfigArn", + "Notebook Instance Lifecycle Config does not exist", + ), + ], +) +def test_sagemaker_cloudformation_notebook_instance_delete( + stack_name, + template, + describe_function_name, + name_parameter, + arn_parameter, + error_message, +): cf = boto3.client("cloudformation", region_name="us-east-1") sm = boto3.client("sagemaker", region_name="us-east-1") - # Create stack with notebook instance and verify existence - stack_name = "test_sagemaker_notebook_instance" - template = _get_notebook_instance_template_string() + # Create stack and verify existence cf.create_stack(StackName=stack_name, TemplateBody=template) - stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] outputs = { output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - notebook_instance = sm.describe_notebook_instance( - NotebookInstanceName=outputs["NotebookInstanceName"], - ) - outputs["NotebookInstanceArn"].should.equal( - notebook_instance["NotebookInstanceArn"] + resource_description = getattr(sm, describe_function_name)( + **{name_parameter: outputs["Name"]} ) + outputs["Arn"].should.equal(resource_description[arn_parameter]) - # Delete the stack and verify notebook instance has also been deleted - # TODO replace exception check with `list_notebook_instances` method when implemented + # Delete the stack and verify resource has also been deleted cf.delete_stack(StackName=stack_name) with pytest.raises(ClientError) as ce: - sm.describe_notebook_instance( - NotebookInstanceName=outputs["NotebookInstanceName"] - ) - ce.value.response["Error"]["Message"].should.contain("RecordNotFound") + getattr(sm, describe_function_name)(**{name_parameter: outputs["Name"]}) + ce.value.response["Error"]["Message"].should.contain(error_message) @mock_cloudformation @@ -133,7 +218,7 @@ def test_sagemaker_cloudformation_notebook_instance_update(): output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - initial_notebook_name = outputs["NotebookInstanceName"] + initial_notebook_name = outputs["Name"] notebook_instance_description = sm.describe_notebook_instance( NotebookInstanceName=initial_notebook_name, ) @@ -146,10 +231,62 @@ def test_sagemaker_cloudformation_notebook_instance_update(): output["OutputKey"]: output["OutputValue"] for output in stack_description["Outputs"] } - updated_notebook_name = outputs["NotebookInstanceName"] + updated_notebook_name = outputs["Name"] updated_notebook_name.should.equal(initial_notebook_name) notebook_instance_description = sm.describe_notebook_instance( NotebookInstanceName=updated_notebook_name, ) updated_instance_type.should.equal(notebook_instance_description["InstanceType"]) + + +@mock_cloudformation +@mock_sagemaker +def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update(): + cf = boto3.client("cloudformation", region_name="us-east-1") + sm = boto3.client("sagemaker", region_name="us-east-1") + + # Set up template for stack with initial and update instance types + stack_name = "test_sagemaker_notebook_instance_lifecycle_config" + initial_on_create_script = "echo Hello World" + updated_on_create_script = "echo Goodbye World" + initial_template_json = _get_notebook_instance_lifecycle_config_template_string( + on_create=initial_on_create_script + ) + updated_template_json = _get_notebook_instance_lifecycle_config_template_string( + on_create=updated_on_create_script + ) + + # Create stack with initial template and check attributes + cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + initial_config_name = outputs["Name"] + notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=initial_config_name, + ) + len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1) + initial_on_create_script.should.equal( + notebook_lifecycle_config_description["OnCreate"][0]["Content"] + ) + + # Update stack with new instance type and check attributes + cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json) + stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0] + outputs = { + output["OutputKey"]: output["OutputValue"] + for output in stack_description["Outputs"] + } + updated_config_name = outputs["Name"] + updated_config_name.should.equal(initial_config_name) + + notebook_lifecycle_config_description = sm.describe_notebook_instance_lifecycle_config( + NotebookInstanceLifecycleConfigName=updated_config_name, + ) + len(notebook_lifecycle_config_description["OnCreate"]).should.equal(1) + updated_on_create_script.should.equal( + notebook_lifecycle_config_description["OnCreate"][0]["Content"] + )