192 lines
5.5 KiB
Python
192 lines
5.5 KiB
Python
|
import json
|
||
|
from abc import ABCMeta, abstractmethod
|
||
|
|
||
|
import six
|
||
|
|
||
|
from moto.sts.models import ACCOUNT_ID
|
||
|
|
||
|
|
||
|
@six.add_metaclass(ABCMeta)
|
||
|
class TestConfig:
|
||
|
"""Provides the interface to use for creating test configurations.
|
||
|
|
||
|
This class will provide the interface for what information will be
|
||
|
needed for the SageMaker CloudFormation tests. Ultimately, this will
|
||
|
improve the readability of the tests in `test_sagemaker_cloudformation.py`
|
||
|
because it will reduce the amount of information we pass through the
|
||
|
`pytest.mark.parametrize` decorator.
|
||
|
|
||
|
"""
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def resource_name(self):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def describe_function_name(self):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def name_parameter(self):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def arn_parameter(self):
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class NotebookInstanceTestConfig(TestConfig):
|
||
|
"""Test configuration for SageMaker Notebook Instances."""
|
||
|
|
||
|
@property
|
||
|
def resource_name(self):
|
||
|
return "TestNotebook"
|
||
|
|
||
|
@property
|
||
|
def describe_function_name(self):
|
||
|
return "describe_notebook_instance"
|
||
|
|
||
|
@property
|
||
|
def name_parameter(self):
|
||
|
return "NotebookInstanceName"
|
||
|
|
||
|
@property
|
||
|
def arn_parameter(self):
|
||
|
return "NotebookInstanceArn"
|
||
|
|
||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||
|
instance_type = kwargs.get("instance_type", "ml.c4.xlarge")
|
||
|
role_arn = kwargs.get(
|
||
|
"role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||
|
)
|
||
|
|
||
|
template = {
|
||
|
"AWSTemplateFormatVersion": "2010-09-09",
|
||
|
"Resources": {
|
||
|
self.resource_name: {
|
||
|
"Type": "AWS::SageMaker::NotebookInstance",
|
||
|
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
if include_outputs:
|
||
|
template["Outputs"] = {
|
||
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||
|
"Name": {
|
||
|
"Value": {
|
||
|
"Fn::GetAtt": [self.resource_name, "NotebookInstanceName"]
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
return json.dumps(template)
|
||
|
|
||
|
|
||
|
class NotebookInstanceLifecycleConfigTestConfig(TestConfig):
|
||
|
"""Test configuration for SageMaker Notebook Instance Lifecycle Configs."""
|
||
|
|
||
|
@property
|
||
|
def resource_name(self):
|
||
|
return "TestNotebookLifecycleConfig"
|
||
|
|
||
|
@property
|
||
|
def describe_function_name(self):
|
||
|
return "describe_notebook_instance_lifecycle_config"
|
||
|
|
||
|
@property
|
||
|
def name_parameter(self):
|
||
|
return "NotebookInstanceLifecycleConfigName"
|
||
|
|
||
|
@property
|
||
|
def arn_parameter(self):
|
||
|
return "NotebookInstanceLifecycleConfigArn"
|
||
|
|
||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||
|
on_create = kwargs.get("on_create")
|
||
|
on_start = kwargs.get("on_start")
|
||
|
|
||
|
template = {
|
||
|
"AWSTemplateFormatVersion": "2010-09-09",
|
||
|
"Resources": {
|
||
|
self.resource_name: {
|
||
|
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
|
||
|
"Properties": {},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
if on_create is not None:
|
||
|
template["Resources"][self.resource_name]["Properties"]["OnCreate"] = [
|
||
|
{"Content": on_create}
|
||
|
]
|
||
|
if on_start is not None:
|
||
|
template["Resources"][self.resource_name]["Properties"]["OnStart"] = [
|
||
|
{"Content": on_start}
|
||
|
]
|
||
|
if include_outputs:
|
||
|
template["Outputs"] = {
|
||
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||
|
"Name": {
|
||
|
"Value": {
|
||
|
"Fn::GetAtt": [
|
||
|
self.resource_name,
|
||
|
"NotebookInstanceLifecycleConfigName",
|
||
|
]
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
return json.dumps(template)
|
||
|
|
||
|
|
||
|
class ModelTestConfig(TestConfig):
|
||
|
"""Test configuration for SageMaker Models."""
|
||
|
|
||
|
@property
|
||
|
def resource_name(self):
|
||
|
return "TestModel"
|
||
|
|
||
|
@property
|
||
|
def describe_function_name(self):
|
||
|
return "describe_model"
|
||
|
|
||
|
@property
|
||
|
def name_parameter(self):
|
||
|
return "ModelName"
|
||
|
|
||
|
@property
|
||
|
def arn_parameter(self):
|
||
|
return "ModelArn"
|
||
|
|
||
|
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||
|
execution_role_arn = kwargs.get(
|
||
|
"execution_role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||
|
)
|
||
|
image = kwargs.get(
|
||
|
"image", "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1"
|
||
|
)
|
||
|
|
||
|
template = {
|
||
|
"AWSTemplateFormatVersion": "2010-09-09",
|
||
|
"Resources": {
|
||
|
self.resource_name: {
|
||
|
"Type": "AWS::SageMaker::Model",
|
||
|
"Properties": {
|
||
|
"ExecutionRoleArn": execution_role_arn,
|
||
|
"PrimaryContainer": {"Image": image,},
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
if include_outputs:
|
||
|
template["Outputs"] = {
|
||
|
"Arn": {"Value": {"Ref": self.resource_name}},
|
||
|
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
|
||
|
}
|
||
|
return json.dumps(template)
|