moto/tests/test_sagemaker/test_sagemaker_notebooks.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

328 lines
13 KiB
Python
Raw Permalink Normal View History

import datetime
import boto3
import pytest
from botocore.exceptions import ClientError
2024-01-07 12:03:33 +00:00
from moto import mock_aws
2022-08-13 09:49:43 +00:00
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
TEST_REGION_NAME = "us-east-1"
FAKE_SUBNET_ID = "subnet-012345678"
FAKE_SECURITY_GROUP_IDS = ["sg-0123456789abcdef0", "sg-0123456789abcdef1"]
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
FAKE_KMS_KEY_ID = "62d4509a-9f96-446c-a9ba-6b1c353c8c58"
GENERIC_TAGS_PARAM = [
{"Key": "newkey1", "Value": "newval1"},
{"Key": "newkey2", "Value": "newval2"},
]
FAKE_LIFECYCLE_CONFIG_NAME = "FakeLifecycleConfigName"
FAKE_DEFAULT_CODE_REPO = "https://github.com/user/repo1"
FAKE_ADDL_CODE_REPOS = [
"https://github.com/user/repo2",
"https://github.com/user/repo2",
]
FAKE_NAME_PARAM = "MyNotebookInstance"
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
@pytest.fixture(name="client")
def fixture_sagemaker_client():
2024-01-07 12:03:33 +00:00
with mock_aws():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def _get_notebook_instance_arn(notebook_name):
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance/{notebook_name}"
def _get_notebook_instance_lifecycle_arn(lifecycle_name):
return (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}"
f":notebook-instance-lifecycle-configuration/{lifecycle_name}"
)
def test_create_notebook_instance_minimal_params(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
assert resp["NotebookInstanceStatus"] == "InService"
assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws"
assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM
assert resp["RoleArn"] == FAKE_ROLE_ARN
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
assert isinstance(resp["CreationTime"], datetime.datetime)
assert resp["DirectInternetAccess"] == "Enabled"
assert resp["VolumeSizeInGB"] == 5
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
def test_create_notebook_instance_params(client):
fake_direct_internet_access_param = "Enabled"
volume_size_in_gb_param = 7
accelerator_types_param = ["ml.eia1.medium", "ml.eia2.medium"]
root_access_param = "Disabled"
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"SubnetId": FAKE_SUBNET_ID,
"SecurityGroupIds": FAKE_SECURITY_GROUP_IDS,
"RoleArn": FAKE_ROLE_ARN,
"KmsKeyId": FAKE_KMS_KEY_ID,
"Tags": GENERIC_TAGS_PARAM,
"LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME,
"DirectInternetAccess": fake_direct_internet_access_param,
"VolumeSizeInGB": volume_size_in_gb_param,
"AcceleratorTypes": accelerator_types_param,
"DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO,
"AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS,
"RootAccess": root_access_param,
}
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
assert resp["NotebookInstanceStatus"] == "InService"
assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws"
assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM
assert resp["RoleArn"] == FAKE_ROLE_ARN
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
assert isinstance(resp["CreationTime"], datetime.datetime)
assert resp["DirectInternetAccess"] == "Enabled"
assert resp["VolumeSizeInGB"] == volume_size_in_gb_param
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
assert resp["SubnetId"] == FAKE_SUBNET_ID
assert resp["SecurityGroups"] == FAKE_SECURITY_GROUP_IDS
assert resp["KmsKeyId"] == FAKE_KMS_KEY_ID
2020-10-06 06:46:05 +00:00
assert resp["NotebookInstanceLifecycleConfigName"] == FAKE_LIFECYCLE_CONFIG_NAME
assert resp["AcceleratorTypes"] == accelerator_types_param
assert resp["DefaultCodeRepository"] == FAKE_DEFAULT_CODE_REPO
assert resp["AdditionalCodeRepositories"] == FAKE_ADDL_CODE_REPOS
resp = client.list_tags(ResourceArn=resp["NotebookInstanceArn"])
assert resp["Tags"] == GENERIC_TAGS_PARAM
def test_create_notebook_instance_invalid_instance_type(client):
instance_type = "undefined_instance_type"
args = {
"NotebookInstanceName": "MyNotebookInstance",
"InstanceType": instance_type,
"RoleArn": FAKE_ROLE_ARN,
}
with pytest.raises(ClientError) as ex:
client.create_notebook_instance(**args)
2020-10-06 06:04:09 +00:00
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = (
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
"constraint: Member must satisfy enum value set: ["
)
2020-10-06 06:04:09 +00:00
assert expected_message in ex.value.response["Error"]["Message"]
def test_notebook_instance_lifecycle(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
notebook_instance_arn = resp["NotebookInstanceArn"]
with pytest.raises(ClientError) as ex:
client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
2020-10-06 06:04:09 +00:00
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = (
f"Status (InService) not in ([Stopped, Failed]). Unable to "
f"transition to (Deleting) for Notebook Instance ({notebook_instance_arn})"
)
2020-10-06 06:04:09 +00:00
assert expected_message in ex.value.response["Error"]["Message"]
client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "Stopped"
client.list_notebook_instances()
client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "InService"
client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "Stopped"
client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
with pytest.raises(ClientError) as ex:
client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
2020-10-06 06:04:09 +00:00
assert ex.value.response["Error"]["Message"] == "RecordNotFound"
def test_describe_nonexistent_model(client):
with pytest.raises(ClientError) as e:
client.describe_model(ModelName="Nonexistent")
2020-10-06 06:04:09 +00:00
assert e.value.response["Error"]["Message"].startswith("Could not find model")
def test_notebook_instance_lifecycle_config(client):
name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}]
on_start = [{"Content": "Start Script Line 1"}]
resp = client.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start
)
expected_arn = _get_notebook_instance_lifecycle_arn(name)
assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn
with pytest.raises(ClientError) as e:
client.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
OnCreate=on_create,
OnStart=on_start,
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config already exists.)"
)
resp = client.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
assert resp["NotebookInstanceLifecycleConfigName"] == name
assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn
assert resp["OnStart"] == on_start
assert resp["OnCreate"] == on_create
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
assert isinstance(resp["CreationTime"], datetime.datetime)
client.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
with pytest.raises(ClientError) as e:
client.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config does not exist.)"
)
with pytest.raises(ClientError) as e:
client.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].endswith(
"Notebook Instance Lifecycle Config does not exist.)"
)
def test_add_tags_to_notebook(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = client.create_notebook_instance(**args)
resource_arn = resp["NotebookInstanceArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
def test_delete_tags_from_notebook(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = client.create_notebook_instance(**args)
resource_arn = resp["NotebookInstanceArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
def test_list_notebook_instances(client):
for i in range(3):
args = {
"NotebookInstanceName": f"Name{i}",
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
client.create_notebook_instance(**args)
client.stop_notebook_instance(NotebookInstanceName="Name1")
instances = client.list_notebook_instances()["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name1", "Name2"]
instances = client.list_notebook_instances(SortBy="Status")["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2", "Name1"]
instances = client.list_notebook_instances(SortOrder="Descending")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == ["Name2", "Name1", "Name0"]
instances = client.list_notebook_instances(NameContains="1")["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name1"]
instances = client.list_notebook_instances(StatusEquals="InService")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2"]
instances = client.list_notebook_instances(StatusEquals="Pending")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == []
resp = client.list_notebook_instances(MaxResults=1)
instances = resp["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0"]
resp = client.list_notebook_instances(NextToken=resp["NextToken"])
instances = resp["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name1", "Name2"]
assert "NextToken" not in resp