SageMaker: list_notebook_instances() (#6756)

This commit is contained in:
Bert Blommers 2023-09-02 06:34:49 +00:00 committed by GitHub
parent 5af4421524
commit e211366534
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 174 additions and 83 deletions

View File

@ -29,35 +29,36 @@ PAGINATION_MODEL = {
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "experiment_arn",
"fail_on_invalid_token": True,
},
"list_trials": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "trial_arn",
"fail_on_invalid_token": True,
},
"list_trial_components": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 100,
"unique_attribute": "trial_component_arn",
"fail_on_invalid_token": True,
},
"list_tags": {
"input_token": "NextToken",
"limit_key": "MaxResults",
"limit_default": 50,
"unique_attribute": "Key",
"fail_on_invalid_token": True,
},
"list_model_packages": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "ModelPackageArn",
"fail_on_invalid_token": True,
},
"list_notebook_instances": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "arn",
},
}
@ -1132,7 +1133,7 @@ class FakeSagemakerNotebookInstance(CloudFormationModel):
self.default_code_repository = default_code_repository
self.additional_code_repositories = additional_code_repositories
self.root_access = root_access
self.status: Optional[str] = None
self.status = "Pending"
self.creation_time = self.last_modified_time = datetime.now()
self.arn = arn_formatter(
"notebook-instance", notebook_instance_name, account_id, region_name
@ -1289,6 +1290,29 @@ class FakeSagemakerNotebookInstance(CloudFormationModel):
backend.stop_notebook_instance(notebook_instance_name)
backend.delete_notebook_instance(notebook_instance_name)
def to_dict(self) -> Dict[str, Any]:
return {
"NotebookInstanceArn": self.arn,
"NotebookInstanceName": self.notebook_instance_name,
"NotebookInstanceStatus": self.status,
"Url": self.url,
"InstanceType": self.instance_type,
"SubnetId": self.subnet_id,
"SecurityGroups": self.security_group_ids,
"RoleArn": self.role_arn,
"KmsKeyId": self.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(self.last_modified_time),
"CreationTime": str(self.creation_time),
"NotebookInstanceLifecycleConfigName": self.lifecycle_config_name,
"DirectInternetAccess": self.direct_internet_access,
"VolumeSizeInGB": self.volume_size_in_gb,
"AcceleratorTypes": self.accelerator_types,
"DefaultCodeRepository": self.default_code_repository,
"AdditionalCodeRepositories": self.additional_code_repositories,
"RootAccess": self.root_access,
}
class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel):
def __init__(
@ -1942,6 +1966,38 @@ class SageMakerModelBackend(BaseBackend):
raise ValidationError(message=message)
del self.notebook_instances[notebook_instance_name]
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc]
def list_notebook_instances(
self,
sort_by: str,
sort_order: str,
name_contains: Optional[str],
status: Optional[str],
) -> Iterable[FakeSagemakerNotebookInstance]:
"""
The following parameters are not yet implemented:
CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals
"""
instances = list(self.notebook_instances.values())
if name_contains:
instances = [
i for i in instances if name_contains in i.notebook_instance_name
]
if status:
instances = [i for i in instances if i.status == status]
reverse = sort_order == "Descending"
if sort_by == "Name":
instances = sorted(
instances, key=lambda x: x.notebook_instance_name, reverse=reverse
)
if sort_by == "CreationTime":
instances = sorted(
instances, key=lambda x: x.creation_time, reverse=reverse
)
if sort_by == "Status":
instances = sorted(instances, key=lambda x: x.status, reverse=reverse)
return instances
def create_notebook_instance_lifecycle_config(
self,
notebook_instance_lifecycle_config_name: str,

View File

@ -73,33 +73,12 @@ class SageMakerResponse(BaseResponse):
return 200, {}, json.dumps({"NotebookInstanceArn": sagemaker_notebook.arn})
@amzn_request_id
def describe_notebook_instance(self) -> TYPE_RESPONSE:
def describe_notebook_instance(self) -> str:
notebook_instance_name = self._get_param("NotebookInstanceName")
notebook_instance = self.sagemaker_backend.get_notebook_instance(
notebook_instance_name
)
response = {
"NotebookInstanceArn": notebook_instance.arn,
"NotebookInstanceName": notebook_instance.notebook_instance_name,
"NotebookInstanceStatus": notebook_instance.status,
"Url": notebook_instance.url,
"InstanceType": notebook_instance.instance_type,
"SubnetId": notebook_instance.subnet_id,
"SecurityGroups": notebook_instance.security_group_ids,
"RoleArn": notebook_instance.role_arn,
"KmsKeyId": notebook_instance.kms_key_id,
# ToDo: NetworkInterfaceId
"LastModifiedTime": str(notebook_instance.last_modified_time),
"CreationTime": str(notebook_instance.creation_time),
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
"DirectInternetAccess": notebook_instance.direct_internet_access,
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
"AcceleratorTypes": notebook_instance.accelerator_types,
"DefaultCodeRepository": notebook_instance.default_code_repository,
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
"RootAccess": notebook_instance.root_access,
}
return 200, {}, json.dumps(response)
return json.dumps(notebook_instance.to_dict())
@amzn_request_id
def start_notebook_instance(self) -> TYPE_RESPONSE:
@ -119,6 +98,29 @@ class SageMakerResponse(BaseResponse):
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def list_notebook_instances(self) -> str:
sort_by = self._get_param("SortBy", "Name")
sort_order = self._get_param("SortOrder", "Ascending")
name_contains = self._get_param("NameContains")
status = self._get_param("StatusEquals")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken")
instances, next_token = self.sagemaker_backend.list_notebook_instances(
sort_by=sort_by,
sort_order=sort_order,
name_contains=name_contains,
status=status,
max_results=max_results,
next_token=next_token,
)
return json.dumps(
{
"NotebookInstances": [i.to_dict() for i in instances],
"NextToken": next_token,
}
)
@amzn_request_id
def list_tags(self) -> TYPE_RESPONSE:
arn = self._get_param("ResourceArn")

View File

@ -26,7 +26,7 @@ FAKE_NAME_PARAM = "MyNotebookInstance"
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
@pytest.fixture(name="sagemaker_client")
@pytest.fixture(name="client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@ -43,19 +43,17 @@ def _get_notebook_instance_lifecycle_arn(lifecycle_name):
)
def test_create_notebook_instance_minimal_params(sagemaker_client):
def test_create_notebook_instance_minimal_params(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker_client.create_notebook_instance(**args)
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
assert resp["NotebookInstanceStatus"] == "InService"
@ -71,7 +69,7 @@ def test_create_notebook_instance_minimal_params(sagemaker_client):
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
def test_create_notebook_instance_params(sagemaker_client):
def test_create_notebook_instance_params(client):
fake_direct_internet_access_param = "Enabled"
volume_size_in_gb_param = 7
accelerator_types_param = ["ml.eia1.medium", "ml.eia2.medium"]
@ -93,13 +91,11 @@ def test_create_notebook_instance_params(sagemaker_client):
"AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS,
"RootAccess": root_access_param,
}
resp = sagemaker_client.create_notebook_instance(**args)
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
assert resp["NotebookInstanceStatus"] == "InService"
@ -119,11 +115,11 @@ def test_create_notebook_instance_params(sagemaker_client):
assert resp["DefaultCodeRepository"] == FAKE_DEFAULT_CODE_REPO
assert resp["AdditionalCodeRepositories"] == FAKE_ADDL_CODE_REPOS
resp = sagemaker_client.list_tags(ResourceArn=resp["NotebookInstanceArn"])
resp = client.list_tags(ResourceArn=resp["NotebookInstanceArn"])
assert resp["Tags"] == GENERIC_TAGS_PARAM
def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
def test_create_notebook_instance_invalid_instance_type(client):
instance_type = "undefined_instance_type"
args = {
"NotebookInstanceName": "MyNotebookInstance",
@ -131,7 +127,7 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
"RoleArn": FAKE_ROLE_ARN,
}
with pytest.raises(ClientError) as ex:
sagemaker_client.create_notebook_instance(**args)
client.create_notebook_instance(**args)
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = (
f"Value '{instance_type}' at 'instanceType' failed to satisfy "
@ -141,23 +137,21 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
assert expected_message in ex.value.response["Error"]["Message"]
def test_notebook_instance_lifecycle(sagemaker_client):
def test_notebook_instance_lifecycle(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker_client.create_notebook_instance(**args)
resp = client.create_notebook_instance(**args)
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
assert resp["NotebookInstanceArn"] == expected_notebook_arn
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
notebook_instance_arn = resp["NotebookInstanceArn"]
with pytest.raises(ClientError) as ex:
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert ex.value.response["Error"]["Code"] == "ValidationException"
expected_message = (
f"Status (InService) not in ([Stopped, Failed]). Unable to "
@ -165,54 +159,48 @@ def test_notebook_instance_lifecycle(sagemaker_client):
)
assert expected_message in ex.value.response["Error"]["Message"]
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "Stopped"
sagemaker_client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
client.list_notebook_instances()
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "InService"
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
resp = sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert resp["NotebookInstanceStatus"] == "Stopped"
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
with pytest.raises(ClientError) as ex:
sagemaker_client.describe_notebook_instance(
NotebookInstanceName=FAKE_NAME_PARAM
)
client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
assert ex.value.response["Error"]["Message"] == "RecordNotFound"
def test_describe_nonexistent_model(sagemaker_client):
def test_describe_nonexistent_model(client):
with pytest.raises(ClientError) as e:
sagemaker_client.describe_model(ModelName="Nonexistent")
client.describe_model(ModelName="Nonexistent")
assert e.value.response["Error"]["Message"].startswith("Could not find model")
def test_notebook_instance_lifecycle_config(sagemaker_client):
def test_notebook_instance_lifecycle_config(client):
name = "MyLifeCycleConfig"
on_create = [{"Content": "Create Script Line 1"}]
on_start = [{"Content": "Start Script Line 1"}]
resp = sagemaker_client.create_notebook_instance_lifecycle_config(
resp = client.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start
)
expected_arn = _get_notebook_instance_lifecycle_arn(name)
assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn
with pytest.raises(ClientError) as e:
sagemaker_client.create_notebook_instance_lifecycle_config(
client.create_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name,
OnCreate=on_create,
OnStart=on_start,
@ -221,7 +209,7 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
"Notebook Instance Lifecycle Config already exists.)"
)
resp = sagemaker_client.describe_notebook_instance_lifecycle_config(
resp = client.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
assert resp["NotebookInstanceLifecycleConfigName"] == name
@ -231,12 +219,12 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
assert isinstance(resp["CreationTime"], datetime.datetime)
sagemaker_client.delete_notebook_instance_lifecycle_config(
client.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
with pytest.raises(ClientError) as e:
sagemaker_client.describe_notebook_instance_lifecycle_config(
client.describe_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
assert e.value.response["Error"]["Message"].endswith(
@ -244,7 +232,7 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
)
with pytest.raises(ClientError) as e:
sagemaker_client.delete_notebook_instance_lifecycle_config(
client.delete_notebook_instance_lifecycle_config(
NotebookInstanceLifecycleConfigName=name
)
assert e.value.response["Error"]["Message"].endswith(
@ -252,43 +240,88 @@ def test_notebook_instance_lifecycle_config(sagemaker_client):
)
def test_add_tags_to_notebook(sagemaker_client):
def test_add_tags_to_notebook(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker_client.create_notebook_instance(**args)
resp = client.create_notebook_instance(**args)
resource_arn = resp["NotebookInstanceArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
def test_delete_tags_from_notebook(sagemaker_client):
def test_delete_tags_from_notebook(client):
args = {
"NotebookInstanceName": FAKE_NAME_PARAM,
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
resp = sagemaker_client.create_notebook_instance(**args)
resp = client.create_notebook_instance(**args)
resource_arn = resp["NotebookInstanceArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
def test_list_notebook_instances(client):
for i in range(3):
args = {
"NotebookInstanceName": f"Name{i}",
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
"RoleArn": FAKE_ROLE_ARN,
}
client.create_notebook_instance(**args)
client.stop_notebook_instance(NotebookInstanceName="Name1")
instances = client.list_notebook_instances()["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name1", "Name2"]
instances = client.list_notebook_instances(SortBy="Status")["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2", "Name1"]
instances = client.list_notebook_instances(SortOrder="Descending")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == ["Name2", "Name1", "Name0"]
instances = client.list_notebook_instances(NameContains="1")["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name1"]
instances = client.list_notebook_instances(StatusEquals="InService")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2"]
instances = client.list_notebook_instances(StatusEquals="Pending")[
"NotebookInstances"
]
assert [i["NotebookInstanceName"] for i in instances] == []
resp = client.list_notebook_instances(MaxResults=1)
instances = resp["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name0"]
resp = client.list_notebook_instances(NextToken=resp["NextToken"])
instances = resp["NotebookInstances"]
assert [i["NotebookInstanceName"] for i in instances] == ["Name1", "Name2"]
assert "NextToken" not in resp