Add CloudFormation support for SageMaker Endpoint Configs and Endpoints (#3863)
* Create SageMaker EndpointConfig with CloudFormation Implement attributes for SM Endpoint Configs with CloudFormation Delete SM Endpoint Configs with CloudFormation Update SM Endpoint Configs with CloudFormation * Fix typos in SM CF Model update test and refactor helper function for CF stack outputs * Fixup weird commas in SM CF Test Configs from using black * Create SageMaker Endpoints with CloudFormation * Fix typos in SM CF update tests
This commit is contained in:
		
							parent
							
								
									f6dda54a6c
								
							
						
					
					
						commit
						9b3e932822
					
				@ -140,7 +140,7 @@ class FakeTrainingJob(BaseObject):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FakeEndpoint(BaseObject):
 | 
					class FakeEndpoint(BaseObject, CloudFormationModel):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        region_name,
 | 
					        region_name,
 | 
				
			||||||
@ -184,8 +184,70 @@ class FakeEndpoint(BaseObject):
 | 
				
			|||||||
            + endpoint_name
 | 
					            + endpoint_name
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def physical_resource_id(self):
 | 
				
			||||||
 | 
					        return self.endpoint_arn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FakeEndpointConfig(BaseObject):
 | 
					    def get_cfn_attribute(self, attribute_name):
 | 
				
			||||||
 | 
					        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values
 | 
				
			||||||
 | 
					        from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if attribute_name == "EndpointName":
 | 
				
			||||||
 | 
					            return self.endpoint_name
 | 
				
			||||||
 | 
					        raise UnformattedGetAttTemplateException()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def cloudformation_name_type():
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def cloudformation_type():
 | 
				
			||||||
 | 
					        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html
 | 
				
			||||||
 | 
					        return "AWS::SageMaker::Endpoint"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def create_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, resource_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        sagemaker_backend = sagemaker_backends[region_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get required properties from provided CloudFormation template
 | 
				
			||||||
 | 
					        properties = cloudformation_json["Properties"]
 | 
				
			||||||
 | 
					        endpoint_config_name = properties["EndpointConfigName"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        endpoint = sagemaker_backend.create_endpoint(
 | 
				
			||||||
 | 
					            endpoint_name=resource_name,
 | 
				
			||||||
 | 
					            endpoint_config_name=endpoint_config_name,
 | 
				
			||||||
 | 
					            tags=properties.get("Tags", []),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return endpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def update_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, original_resource, new_resource_name, cloudformation_json, region_name,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Changes to the Endpoint will not change resource name
 | 
				
			||||||
 | 
					        cls.delete_from_cloudformation_json(
 | 
				
			||||||
 | 
					            original_resource.endpoint_arn, cloudformation_json, region_name
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        new_resource = cls.create_from_cloudformation_json(
 | 
				
			||||||
 | 
					            original_resource.endpoint_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return new_resource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def delete_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, resource_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Get actual name because resource_name actually provides the ARN
 | 
				
			||||||
 | 
					        # since the Physical Resource ID is the ARN despite SageMaker
 | 
				
			||||||
 | 
					        # using the name for most of its operations.
 | 
				
			||||||
 | 
					        endpoint_name = resource_name.split("/")[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sagemaker_backends[region_name].delete_endpoint(endpoint_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FakeEndpointConfig(BaseObject, CloudFormationModel):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        region_name,
 | 
					        region_name,
 | 
				
			||||||
@ -308,6 +370,70 @@ class FakeEndpointConfig(BaseObject):
 | 
				
			|||||||
            + model_name
 | 
					            + model_name
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def physical_resource_id(self):
 | 
				
			||||||
 | 
					        return self.endpoint_config_arn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_cfn_attribute(self, attribute_name):
 | 
				
			||||||
 | 
					        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values
 | 
				
			||||||
 | 
					        from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if attribute_name == "EndpointConfigName":
 | 
				
			||||||
 | 
					            return self.endpoint_config_name
 | 
				
			||||||
 | 
					        raise UnformattedGetAttTemplateException()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def cloudformation_name_type():
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def cloudformation_type():
 | 
				
			||||||
 | 
					        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html
 | 
				
			||||||
 | 
					        return "AWS::SageMaker::EndpointConfig"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def create_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, resource_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        sagemaker_backend = sagemaker_backends[region_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get required properties from provided CloudFormation template
 | 
				
			||||||
 | 
					        properties = cloudformation_json["Properties"]
 | 
				
			||||||
 | 
					        production_variants = properties["ProductionVariants"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        endpoint_config = sagemaker_backend.create_endpoint_config(
 | 
				
			||||||
 | 
					            endpoint_config_name=resource_name,
 | 
				
			||||||
 | 
					            production_variants=production_variants,
 | 
				
			||||||
 | 
					            data_capture_config=properties.get("DataCaptureConfig", {}),
 | 
				
			||||||
 | 
					            kms_key_id=properties.get("KmsKeyId"),
 | 
				
			||||||
 | 
					            tags=properties.get("Tags", []),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return endpoint_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def update_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, original_resource, new_resource_name, cloudformation_json, region_name,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Most changes to the endpoint config will change resource name for EndpointConfigs
 | 
				
			||||||
 | 
					        cls.delete_from_cloudformation_json(
 | 
				
			||||||
 | 
					            original_resource.endpoint_config_arn, cloudformation_json, region_name
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        new_resource = cls.create_from_cloudformation_json(
 | 
				
			||||||
 | 
					            new_resource_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return new_resource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def delete_from_cloudformation_json(
 | 
				
			||||||
 | 
					        cls, resource_name, cloudformation_json, region_name
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Get actual name because resource_name actually provides the ARN
 | 
				
			||||||
 | 
					        # since the Physical Resource ID is the ARN despite SageMaker
 | 
				
			||||||
 | 
					        # using the name for most of its operations.
 | 
				
			||||||
 | 
					        endpoint_config_name = resource_name.split("/")[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sagemaker_backends[region_name].delete_endpoint_config(endpoint_config_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Model(BaseObject, CloudFormationModel):
 | 
					class Model(BaseObject, CloudFormationModel):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
 | 
				
			|||||||
@ -42,6 +42,14 @@ class TestConfig:
 | 
				
			|||||||
    def get_cloudformation_template(self, include_outputs=True, **kwargs):
 | 
					    def get_cloudformation_template(self, include_outputs=True, **kwargs):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_setup_procedure(self, sagemaker_client):
 | 
				
			||||||
 | 
					        """Provides a method to set up resources with a SageMaker client.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Note: This procedure should be called while within a `mock_sagemaker`
 | 
				
			||||||
 | 
					        context so that no actual resources are created with the sagemaker_client.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NotebookInstanceTestConfig(TestConfig):
 | 
					class NotebookInstanceTestConfig(TestConfig):
 | 
				
			||||||
    """Test configuration for SageMaker Notebook Instances."""
 | 
					    """Test configuration for SageMaker Notebook Instances."""
 | 
				
			||||||
@ -186,6 +194,131 @@ class ModelTestConfig(TestConfig):
 | 
				
			|||||||
        if include_outputs:
 | 
					        if include_outputs:
 | 
				
			||||||
            template["Outputs"] = {
 | 
					            template["Outputs"] = {
 | 
				
			||||||
                "Arn": {"Value": {"Ref": self.resource_name}},
 | 
					                "Arn": {"Value": {"Ref": self.resource_name}},
 | 
				
			||||||
                "Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
 | 
					                "Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"]}},
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        return json.dumps(template)
 | 
					        return json.dumps(template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EndpointConfigTestConfig(TestConfig):
 | 
				
			||||||
 | 
					    """Test configuration for SageMaker Endpoint Configs."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def resource_name(self):
 | 
				
			||||||
 | 
					        return "TestEndpointConfig"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def describe_function_name(self):
 | 
				
			||||||
 | 
					        return "describe_endpoint_config"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def name_parameter(self):
 | 
				
			||||||
 | 
					        return "EndpointConfigName"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def arn_parameter(self):
 | 
				
			||||||
 | 
					        return "EndpointConfigArn"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_cloudformation_template(self, include_outputs=True, **kwargs):
 | 
				
			||||||
 | 
					        num_production_variants = kwargs.get("num_production_variants", 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        production_variants = [
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "InitialInstanceCount": 1,
 | 
				
			||||||
 | 
					                "InitialVariantWeight": 1,
 | 
				
			||||||
 | 
					                "InstanceType": "ml.c4.xlarge",
 | 
				
			||||||
 | 
					                "ModelName": self.resource_name,
 | 
				
			||||||
 | 
					                "VariantName": "variant-name-{}".format(i),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            for i in range(num_production_variants)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        template = {
 | 
				
			||||||
 | 
					            "AWSTemplateFormatVersion": "2010-09-09",
 | 
				
			||||||
 | 
					            "Resources": {
 | 
				
			||||||
 | 
					                self.resource_name: {
 | 
				
			||||||
 | 
					                    "Type": "AWS::SageMaker::EndpointConfig",
 | 
				
			||||||
 | 
					                    "Properties": {"ProductionVariants": production_variants},
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if include_outputs:
 | 
				
			||||||
 | 
					            template["Outputs"] = {
 | 
				
			||||||
 | 
					                "Arn": {"Value": {"Ref": self.resource_name}},
 | 
				
			||||||
 | 
					                "Name": {
 | 
				
			||||||
 | 
					                    "Value": {"Fn::GetAtt": [self.resource_name, "EndpointConfigName"]}
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        return json.dumps(template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_setup_procedure(self, sagemaker_client):
 | 
				
			||||||
 | 
					        """Adds Model that can be referenced in the CloudFormation template."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sagemaker_client.create_model(
 | 
				
			||||||
 | 
					            ModelName=self.resource_name,
 | 
				
			||||||
 | 
					            ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
 | 
				
			||||||
 | 
					            PrimaryContainer={
 | 
				
			||||||
 | 
					                "Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EndpointTestConfig(TestConfig):
 | 
				
			||||||
 | 
					    """Test configuration for SageMaker Endpoints."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def resource_name(self):
 | 
				
			||||||
 | 
					        return "TestEndpoint"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def describe_function_name(self):
 | 
				
			||||||
 | 
					        return "describe_endpoint"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def name_parameter(self):
 | 
				
			||||||
 | 
					        return "EndpointName"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def arn_parameter(self):
 | 
				
			||||||
 | 
					        return "EndpointArn"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_cloudformation_template(self, include_outputs=True, **kwargs):
 | 
				
			||||||
 | 
					        endpoint_config_name = kwargs.get("endpoint_config_name", self.resource_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        template = {
 | 
				
			||||||
 | 
					            "AWSTemplateFormatVersion": "2010-09-09",
 | 
				
			||||||
 | 
					            "Resources": {
 | 
				
			||||||
 | 
					                self.resource_name: {
 | 
				
			||||||
 | 
					                    "Type": "AWS::SageMaker::Endpoint",
 | 
				
			||||||
 | 
					                    "Properties": {"EndpointConfigName": endpoint_config_name},
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if include_outputs:
 | 
				
			||||||
 | 
					            template["Outputs"] = {
 | 
				
			||||||
 | 
					                "Arn": {"Value": {"Ref": self.resource_name}},
 | 
				
			||||||
 | 
					                "Name": {"Value": {"Fn::GetAtt": [self.resource_name, "EndpointName"]}},
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        return json.dumps(template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_setup_procedure(self, sagemaker_client):
 | 
				
			||||||
 | 
					        """Adds Model and Endpoint Config that can be referenced in the CloudFormation template."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sagemaker_client.create_model(
 | 
				
			||||||
 | 
					            ModelName=self.resource_name,
 | 
				
			||||||
 | 
					            ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
 | 
				
			||||||
 | 
					            PrimaryContainer={
 | 
				
			||||||
 | 
					                "Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        sagemaker_client.create_endpoint_config(
 | 
				
			||||||
 | 
					            EndpointConfigName=self.resource_name,
 | 
				
			||||||
 | 
					            ProductionVariants=[
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "InitialInstanceCount": 1,
 | 
				
			||||||
 | 
					                    "InitialVariantWeight": 1,
 | 
				
			||||||
 | 
					                    "InstanceType": "ml.c4.xlarge",
 | 
				
			||||||
 | 
					                    "ModelName": self.resource_name,
 | 
				
			||||||
 | 
					                    "VariantName": "variant-name-1",
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -5,25 +5,44 @@ import sure  # noqa
 | 
				
			|||||||
from botocore.exceptions import ClientError
 | 
					from botocore.exceptions import ClientError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from moto import mock_cloudformation, mock_sagemaker
 | 
					from moto import mock_cloudformation, mock_sagemaker
 | 
				
			||||||
 | 
					from moto.sts.models import ACCOUNT_ID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .cloudformation_test_configs import (
 | 
					from .cloudformation_test_configs import (
 | 
				
			||||||
    NotebookInstanceTestConfig,
 | 
					    NotebookInstanceTestConfig,
 | 
				
			||||||
    NotebookInstanceLifecycleConfigTestConfig,
 | 
					    NotebookInstanceLifecycleConfigTestConfig,
 | 
				
			||||||
    ModelTestConfig,
 | 
					    ModelTestConfig,
 | 
				
			||||||
 | 
					    EndpointConfigTestConfig,
 | 
				
			||||||
 | 
					    EndpointTestConfig,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_stack_outputs(cf_client, stack_name):
 | 
				
			||||||
 | 
					    """Returns the outputs for the first entry in describe_stacks."""
 | 
				
			||||||
 | 
					    stack_description = cf_client.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        output["OutputKey"]: output["OutputValue"]
 | 
				
			||||||
 | 
					        for output in stack_description["Outputs"]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@mock_cloudformation
 | 
					@mock_cloudformation
 | 
				
			||||||
 | 
					@mock_sagemaker
 | 
				
			||||||
@pytest.mark.parametrize(
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
    "test_config",
 | 
					    "test_config",
 | 
				
			||||||
    [
 | 
					    [
 | 
				
			||||||
        NotebookInstanceTestConfig(),
 | 
					        NotebookInstanceTestConfig(),
 | 
				
			||||||
        NotebookInstanceLifecycleConfigTestConfig(),
 | 
					        NotebookInstanceLifecycleConfigTestConfig(),
 | 
				
			||||||
        ModelTestConfig(),
 | 
					        ModelTestConfig(),
 | 
				
			||||||
 | 
					        EndpointConfigTestConfig(),
 | 
				
			||||||
 | 
					        EndpointTestConfig(),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_sagemaker_cloudformation_create(test_config):
 | 
					def test_sagemaker_cloudformation_create(test_config):
 | 
				
			||||||
    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
					    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
				
			||||||
 | 
					    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Utilize test configuration to set-up any mock SageMaker resources
 | 
				
			||||||
 | 
					    test_config.run_setup_procedure(sm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    cf.create_stack(
 | 
					    cf.create_stack(
 | 
				
			||||||
@ -46,22 +65,23 @@ def test_sagemaker_cloudformation_create(test_config):
 | 
				
			|||||||
        NotebookInstanceTestConfig(),
 | 
					        NotebookInstanceTestConfig(),
 | 
				
			||||||
        NotebookInstanceLifecycleConfigTestConfig(),
 | 
					        NotebookInstanceLifecycleConfigTestConfig(),
 | 
				
			||||||
        ModelTestConfig(),
 | 
					        ModelTestConfig(),
 | 
				
			||||||
 | 
					        EndpointConfigTestConfig(),
 | 
				
			||||||
 | 
					        EndpointTestConfig(),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_sagemaker_cloudformation_get_attr(test_config):
 | 
					def test_sagemaker_cloudformation_get_attr(test_config):
 | 
				
			||||||
    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
					    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
				
			||||||
    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
					    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Utilize test configuration to set-up any mock SageMaker resources
 | 
				
			||||||
 | 
					    test_config.run_setup_procedure(sm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create stack and get description for output values
 | 
					    # Create stack and get description for output values
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    cf.create_stack(
 | 
					    cf.create_stack(
 | 
				
			||||||
        StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
 | 
					        StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Using the describe function, ensure output ARN matches resource ARN
 | 
					    # Using the describe function, ensure output ARN matches resource ARN
 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
@ -81,22 +101,24 @@ def test_sagemaker_cloudformation_get_attr(test_config):
 | 
				
			|||||||
            "Notebook Instance Lifecycle Config does not exist",
 | 
					            "Notebook Instance Lifecycle Config does not exist",
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
        (ModelTestConfig(), "Could not find model"),
 | 
					        (ModelTestConfig(), "Could not find model"),
 | 
				
			||||||
 | 
					        (EndpointConfigTestConfig(), "Could not find endpoint configuration"),
 | 
				
			||||||
 | 
					        (EndpointTestConfig(), "Could not find endpoint"),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
 | 
					def test_sagemaker_cloudformation_notebook_instance_delete(test_config, error_message):
 | 
				
			||||||
    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
					    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
				
			||||||
    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
					    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Utilize test configuration to set-up any mock SageMaker resources
 | 
				
			||||||
 | 
					    test_config.run_setup_procedure(sm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create stack and verify existence
 | 
					    # Create stack and verify existence
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    cf.create_stack(
 | 
					    cf.create_stack(
 | 
				
			||||||
        StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
 | 
					        StackName=stack_name, TemplateBody=test_config.get_cloudformation_template()
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
        **{test_config.name_parameter: outputs["Name"]}
 | 
					        **{test_config.name_parameter: outputs["Name"]}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@ -119,7 +141,7 @@ def test_sagemaker_cloudformation_notebook_instance_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    test_config = NotebookInstanceTestConfig()
 | 
					    test_config = NotebookInstanceTestConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Set up template for stack with initial and update instance types
 | 
					    # Set up template for stack with two different instance types
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    initial_instance_type = "ml.c4.xlarge"
 | 
					    initial_instance_type = "ml.c4.xlarge"
 | 
				
			||||||
    updated_instance_type = "ml.c4.4xlarge"
 | 
					    updated_instance_type = "ml.c4.4xlarge"
 | 
				
			||||||
@ -132,24 +154,18 @@ def test_sagemaker_cloudformation_notebook_instance_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Create stack with initial template and check attributes
 | 
					    # Create stack with initial template and check attributes
 | 
				
			||||||
    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
					    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    initial_notebook_name = outputs["Name"]
 | 
					    initial_notebook_name = outputs["Name"]
 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
        **{test_config.name_parameter: initial_notebook_name}
 | 
					        **{test_config.name_parameter: initial_notebook_name}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    initial_instance_type.should.equal(resource_description["InstanceType"])
 | 
					    initial_instance_type.should.equal(resource_description["InstanceType"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Update stack with new instance type and check attributes
 | 
					    # Update stack and check attributes
 | 
				
			||||||
    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
					    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    updated_notebook_name = outputs["Name"]
 | 
					    updated_notebook_name = outputs["Name"]
 | 
				
			||||||
    updated_notebook_name.should.equal(initial_notebook_name)
 | 
					    updated_notebook_name.should.equal(initial_notebook_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -167,7 +183,7 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    test_config = NotebookInstanceLifecycleConfigTestConfig()
 | 
					    test_config = NotebookInstanceLifecycleConfigTestConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Set up template for stack with initial and update instance types
 | 
					    # Set up template for stack with two different OnCreate scripts
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    initial_on_create_script = "echo Hello World"
 | 
					    initial_on_create_script = "echo Hello World"
 | 
				
			||||||
    updated_on_create_script = "echo Goodbye World"
 | 
					    updated_on_create_script = "echo Goodbye World"
 | 
				
			||||||
@ -180,11 +196,8 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Create stack with initial template and check attributes
 | 
					    # Create stack with initial template and check attributes
 | 
				
			||||||
    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
					    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    initial_config_name = outputs["Name"]
 | 
					    initial_config_name = outputs["Name"]
 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
        **{test_config.name_parameter: initial_config_name}
 | 
					        **{test_config.name_parameter: initial_config_name}
 | 
				
			||||||
@ -194,13 +207,10 @@ def test_sagemaker_cloudformation_notebook_instance_lifecycle_config_update():
 | 
				
			|||||||
        resource_description["OnCreate"][0]["Content"]
 | 
					        resource_description["OnCreate"][0]["Content"]
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Update stack with new instance type and check attributes
 | 
					    # Update stack and check attributes
 | 
				
			||||||
    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
					    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    updated_config_name = outputs["Name"]
 | 
					    updated_config_name = outputs["Name"]
 | 
				
			||||||
    updated_config_name.should.equal(initial_config_name)
 | 
					    updated_config_name.should.equal(initial_config_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -221,7 +231,7 @@ def test_sagemaker_cloudformation_model_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    test_config = ModelTestConfig()
 | 
					    test_config = ModelTestConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Set up template for stack with initial and update instance types
 | 
					    # Set up template for stack with two different image versions
 | 
				
			||||||
    stack_name = "{}_stack".format(test_config.resource_name)
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
    image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
 | 
					    image = "404615174143.dkr.ecr.us-east-2.amazonaws.com/kmeans:{}"
 | 
				
			||||||
    initial_image_version = 1
 | 
					    initial_image_version = 1
 | 
				
			||||||
@ -235,32 +245,159 @@ def test_sagemaker_cloudformation_model_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Create stack with initial template and check attributes
 | 
					    # Create stack with initial template and check attributes
 | 
				
			||||||
    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
					    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					    initial_model_name = outputs["Name"]
 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    inital_model_name = outputs["Name"]
 | 
					 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
        **{test_config.name_parameter: inital_model_name}
 | 
					        **{test_config.name_parameter: initial_model_name}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    resource_description["PrimaryContainer"]["Image"].should.equal(
 | 
					    resource_description["PrimaryContainer"]["Image"].should.equal(
 | 
				
			||||||
        image.format(initial_image_version)
 | 
					        image.format(initial_image_version)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Update stack with new instance type and check attributes
 | 
					    # Update stack and check attributes
 | 
				
			||||||
    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
					    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
				
			||||||
    stack_description = cf.describe_stacks(StackName=stack_name)["Stacks"][0]
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
    outputs = {
 | 
					
 | 
				
			||||||
        output["OutputKey"]: output["OutputValue"]
 | 
					    updated_model_name = outputs["Name"]
 | 
				
			||||||
        for output in stack_description["Outputs"]
 | 
					    updated_model_name.should_not.equal(initial_model_name)
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    updated_notebook_name = outputs["Name"]
 | 
					 | 
				
			||||||
    updated_notebook_name.should_not.equal(inital_model_name)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
        **{test_config.name_parameter: updated_notebook_name}
 | 
					        **{test_config.name_parameter: updated_model_name}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    resource_description["PrimaryContainer"]["Image"].should.equal(
 | 
					    resource_description["PrimaryContainer"]["Image"].should.equal(
 | 
				
			||||||
        image.format(updated_image_version)
 | 
					        image.format(updated_image_version)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_cloudformation
 | 
				
			||||||
 | 
					@mock_sagemaker
 | 
				
			||||||
 | 
					def test_sagemaker_cloudformation_endpoint_config_update():
 | 
				
			||||||
 | 
					    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
				
			||||||
 | 
					    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    test_config = EndpointConfigTestConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Utilize test configuration to set-up any mock SageMaker resources
 | 
				
			||||||
 | 
					    test_config.run_setup_procedure(sm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Set up template for stack with two different production variant counts
 | 
				
			||||||
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
 | 
					    initial_num_production_variants = 1
 | 
				
			||||||
 | 
					    updated_num_production_variants = 2
 | 
				
			||||||
 | 
					    initial_template_json = test_config.get_cloudformation_template(
 | 
				
			||||||
 | 
					        num_production_variants=initial_num_production_variants
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    updated_template_json = test_config.get_cloudformation_template(
 | 
				
			||||||
 | 
					        num_production_variants=updated_num_production_variants
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create stack with initial template and check attributes
 | 
				
			||||||
 | 
					    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
				
			||||||
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    initial_endpoint_config_name = outputs["Name"]
 | 
				
			||||||
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
 | 
					        **{test_config.name_parameter: initial_endpoint_config_name}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    len(resource_description["ProductionVariants"]).should.equal(
 | 
				
			||||||
 | 
					        initial_num_production_variants
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Update stack and check attributes
 | 
				
			||||||
 | 
					    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
				
			||||||
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    updated_endpoint_config_name = outputs["Name"]
 | 
				
			||||||
 | 
					    updated_endpoint_config_name.should_not.equal(initial_endpoint_config_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
 | 
					        **{test_config.name_parameter: updated_endpoint_config_name}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    len(resource_description["ProductionVariants"]).should.equal(
 | 
				
			||||||
 | 
					        updated_num_production_variants
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_cloudformation
 | 
				
			||||||
 | 
					@mock_sagemaker
 | 
				
			||||||
 | 
					def test_sagemaker_cloudformation_endpoint_update():
 | 
				
			||||||
 | 
					    cf = boto3.client("cloudformation", region_name="us-east-1")
 | 
				
			||||||
 | 
					    sm = boto3.client("sagemaker", region_name="us-east-1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    test_config = EndpointTestConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Set up template for stack with two different endpoint config names
 | 
				
			||||||
 | 
					    stack_name = "{}_stack".format(test_config.resource_name)
 | 
				
			||||||
 | 
					    initial_endpoint_config_name = test_config.resource_name
 | 
				
			||||||
 | 
					    updated_endpoint_config_name = "updated-endpoint-config-name"
 | 
				
			||||||
 | 
					    initial_template_json = test_config.get_cloudformation_template(
 | 
				
			||||||
 | 
					        endpoint_config_name=initial_endpoint_config_name
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    updated_template_json = test_config.get_cloudformation_template(
 | 
				
			||||||
 | 
					        endpoint_config_name=updated_endpoint_config_name
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create SM resources and stack with initial template and check attributes
 | 
				
			||||||
 | 
					    sm.create_model(
 | 
				
			||||||
 | 
					        ModelName=initial_endpoint_config_name,
 | 
				
			||||||
 | 
					        ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
 | 
				
			||||||
 | 
					        PrimaryContainer={
 | 
				
			||||||
 | 
					            "Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    sm.create_endpoint_config(
 | 
				
			||||||
 | 
					        EndpointConfigName=initial_endpoint_config_name,
 | 
				
			||||||
 | 
					        ProductionVariants=[
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "InitialInstanceCount": 1,
 | 
				
			||||||
 | 
					                "InitialVariantWeight": 1,
 | 
				
			||||||
 | 
					                "InstanceType": "ml.c4.xlarge",
 | 
				
			||||||
 | 
					                "ModelName": initial_endpoint_config_name,
 | 
				
			||||||
 | 
					                "VariantName": "variant-name-1",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    cf.create_stack(StackName=stack_name, TemplateBody=initial_template_json)
 | 
				
			||||||
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    initial_endpoint_name = outputs["Name"]
 | 
				
			||||||
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
 | 
					        **{test_config.name_parameter: initial_endpoint_name}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    resource_description["EndpointConfigName"].should.match(
 | 
				
			||||||
 | 
					        initial_endpoint_config_name
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create additional SM resources and update stack
 | 
				
			||||||
 | 
					    sm.create_model(
 | 
				
			||||||
 | 
					        ModelName=updated_endpoint_config_name,
 | 
				
			||||||
 | 
					        ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
 | 
				
			||||||
 | 
					        PrimaryContainer={
 | 
				
			||||||
 | 
					            "Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    sm.create_endpoint_config(
 | 
				
			||||||
 | 
					        EndpointConfigName=updated_endpoint_config_name,
 | 
				
			||||||
 | 
					        ProductionVariants=[
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "InitialInstanceCount": 1,
 | 
				
			||||||
 | 
					                "InitialVariantWeight": 1,
 | 
				
			||||||
 | 
					                "InstanceType": "ml.c4.xlarge",
 | 
				
			||||||
 | 
					                "ModelName": updated_endpoint_config_name,
 | 
				
			||||||
 | 
					                "VariantName": "variant-name-1",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    cf.update_stack(StackName=stack_name, TemplateBody=updated_template_json)
 | 
				
			||||||
 | 
					    outputs = _get_stack_outputs(cf, stack_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    updated_endpoint_name = outputs["Name"]
 | 
				
			||||||
 | 
					    updated_endpoint_name.should.equal(initial_endpoint_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    resource_description = getattr(sm, test_config.describe_function_name)(
 | 
				
			||||||
 | 
					        **{test_config.name_parameter: updated_endpoint_name}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    resource_description["EndpointConfigName"].should.match(
 | 
				
			||||||
 | 
					        updated_endpoint_config_name
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user