From e211366534b239a999c148f50ff9c83a24b9dbc7 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 2 Sep 2023 06:34:49 +0000 Subject: [PATCH] SageMaker: list_notebook_instances() (#6756) --- moto/sagemaker/models.py | 68 ++++++++- moto/sagemaker/responses.py | 48 +++--- .../test_sagemaker_notebooks.py | 141 +++++++++++------- 3 files changed, 174 insertions(+), 83 deletions(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 726eab895..7d1419587 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -29,35 +29,36 @@ PAGINATION_MODEL = { "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "experiment_arn", - "fail_on_invalid_token": True, }, "list_trials": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "trial_arn", - "fail_on_invalid_token": True, }, "list_trial_components": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 100, "unique_attribute": "trial_component_arn", - "fail_on_invalid_token": True, }, "list_tags": { "input_token": "NextToken", "limit_key": "MaxResults", "limit_default": 50, "unique_attribute": "Key", - "fail_on_invalid_token": True, }, "list_model_packages": { "input_token": "next_token", "limit_key": "max_results", "limit_default": 100, "unique_attribute": "ModelPackageArn", - "fail_on_invalid_token": True, + }, + "list_notebook_instances": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "arn", }, } @@ -1132,7 +1133,7 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): self.default_code_repository = default_code_repository self.additional_code_repositories = additional_code_repositories self.root_access = root_access - self.status: Optional[str] = None + self.status = "Pending" self.creation_time = self.last_modified_time = datetime.now() self.arn = arn_formatter( "notebook-instance", notebook_instance_name, account_id, region_name @@ -1289,6 +1290,29 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): backend.stop_notebook_instance(notebook_instance_name) backend.delete_notebook_instance(notebook_instance_name) + def to_dict(self) -> Dict[str, Any]: + return { + "NotebookInstanceArn": self.arn, + "NotebookInstanceName": self.notebook_instance_name, + "NotebookInstanceStatus": self.status, + "Url": self.url, + "InstanceType": self.instance_type, + "SubnetId": self.subnet_id, + "SecurityGroups": self.security_group_ids, + "RoleArn": self.role_arn, + "KmsKeyId": self.kms_key_id, + # ToDo: NetworkInterfaceId + "LastModifiedTime": str(self.last_modified_time), + "CreationTime": str(self.creation_time), + "NotebookInstanceLifecycleConfigName": self.lifecycle_config_name, + "DirectInternetAccess": self.direct_internet_access, + "VolumeSizeInGB": self.volume_size_in_gb, + "AcceleratorTypes": self.accelerator_types, + "DefaultCodeRepository": self.default_code_repository, + "AdditionalCodeRepositories": self.additional_code_repositories, + "RootAccess": self.root_access, + } + class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( @@ -1942,6 +1966,38 @@ class SageMakerModelBackend(BaseBackend): raise ValidationError(message=message) del self.notebook_instances[notebook_instance_name] + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] + def list_notebook_instances( + self, + sort_by: str, + sort_order: str, + name_contains: Optional[str], + status: Optional[str], + ) -> Iterable[FakeSagemakerNotebookInstance]: + """ + The following parameters are not yet implemented: + CreationTimeBefore, CreationTimeAfter, LastModifiedTimeBefore, LastModifiedTimeAfter, NotebookInstanceLifecycleConfigNameContains, DefaultCodeRepositoryContains, AdditionalCodeRepositoryEquals + """ + instances = list(self.notebook_instances.values()) + if name_contains: + instances = [ + i for i in instances if name_contains in i.notebook_instance_name + ] + if status: + instances = [i for i in instances if i.status == status] + reverse = sort_order == "Descending" + if sort_by == "Name": + instances = sorted( + instances, key=lambda x: x.notebook_instance_name, reverse=reverse + ) + if sort_by == "CreationTime": + instances = sorted( + instances, key=lambda x: x.creation_time, reverse=reverse + ) + if sort_by == "Status": + instances = sorted(instances, key=lambda x: x.status, reverse=reverse) + return instances + def create_notebook_instance_lifecycle_config( self, notebook_instance_lifecycle_config_name: str, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index d04184dfc..b40f8b537 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -73,33 +73,12 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps({"NotebookInstanceArn": sagemaker_notebook.arn}) @amzn_request_id - def describe_notebook_instance(self) -> TYPE_RESPONSE: + def describe_notebook_instance(self) -> str: notebook_instance_name = self._get_param("NotebookInstanceName") notebook_instance = self.sagemaker_backend.get_notebook_instance( notebook_instance_name ) - response = { - "NotebookInstanceArn": notebook_instance.arn, - "NotebookInstanceName": notebook_instance.notebook_instance_name, - "NotebookInstanceStatus": notebook_instance.status, - "Url": notebook_instance.url, - "InstanceType": notebook_instance.instance_type, - "SubnetId": notebook_instance.subnet_id, - "SecurityGroups": notebook_instance.security_group_ids, - "RoleArn": notebook_instance.role_arn, - "KmsKeyId": notebook_instance.kms_key_id, - # ToDo: NetworkInterfaceId - "LastModifiedTime": str(notebook_instance.last_modified_time), - "CreationTime": str(notebook_instance.creation_time), - "NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name, - "DirectInternetAccess": notebook_instance.direct_internet_access, - "VolumeSizeInGB": notebook_instance.volume_size_in_gb, - "AcceleratorTypes": notebook_instance.accelerator_types, - "DefaultCodeRepository": notebook_instance.default_code_repository, - "AdditionalCodeRepositories": notebook_instance.additional_code_repositories, - "RootAccess": notebook_instance.root_access, - } - return 200, {}, json.dumps(response) + return json.dumps(notebook_instance.to_dict()) @amzn_request_id def start_notebook_instance(self) -> TYPE_RESPONSE: @@ -119,6 +98,29 @@ class SageMakerResponse(BaseResponse): self.sagemaker_backend.delete_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") + @amzn_request_id + def list_notebook_instances(self) -> str: + sort_by = self._get_param("SortBy", "Name") + sort_order = self._get_param("SortOrder", "Ascending") + name_contains = self._get_param("NameContains") + status = self._get_param("StatusEquals") + max_results = self._get_param("MaxResults") + next_token = self._get_param("NextToken") + instances, next_token = self.sagemaker_backend.list_notebook_instances( + sort_by=sort_by, + sort_order=sort_order, + name_contains=name_contains, + status=status, + max_results=max_results, + next_token=next_token, + ) + return json.dumps( + { + "NotebookInstances": [i.to_dict() for i in instances], + "NextToken": next_token, + } + ) + @amzn_request_id def list_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py index 01556ef42..f1d7e015a 100644 --- a/tests/test_sagemaker/test_sagemaker_notebooks.py +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -26,7 +26,7 @@ FAKE_NAME_PARAM = "MyNotebookInstance" FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium" -@pytest.fixture(name="sagemaker_client") +@pytest.fixture(name="client") def fixture_sagemaker_client(): with mock_sagemaker(): yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) @@ -43,19 +43,17 @@ def _get_notebook_instance_lifecycle_arn(lifecycle_name): ) -def test_create_notebook_instance_minimal_params(sagemaker_client): +def test_create_notebook_instance_minimal_params(client): args = { "NotebookInstanceName": FAKE_NAME_PARAM, "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker_client.create_notebook_instance(**args) + resp = client.create_notebook_instance(**args) expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert resp["NotebookInstanceArn"] == expected_notebook_arn assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM assert resp["NotebookInstanceStatus"] == "InService" @@ -71,7 +69,7 @@ def test_create_notebook_instance_minimal_params(sagemaker_client): # assert resp["RootAccess"] == True # ToDo: Not sure if this defaults... -def test_create_notebook_instance_params(sagemaker_client): +def test_create_notebook_instance_params(client): fake_direct_internet_access_param = "Enabled" volume_size_in_gb_param = 7 accelerator_types_param = ["ml.eia1.medium", "ml.eia2.medium"] @@ -93,13 +91,11 @@ def test_create_notebook_instance_params(sagemaker_client): "AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS, "RootAccess": root_access_param, } - resp = sagemaker_client.create_notebook_instance(**args) + resp = client.create_notebook_instance(**args) expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert resp["NotebookInstanceArn"] == expected_notebook_arn assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM assert resp["NotebookInstanceStatus"] == "InService" @@ -119,11 +115,11 @@ def test_create_notebook_instance_params(sagemaker_client): assert resp["DefaultCodeRepository"] == FAKE_DEFAULT_CODE_REPO assert resp["AdditionalCodeRepositories"] == FAKE_ADDL_CODE_REPOS - resp = sagemaker_client.list_tags(ResourceArn=resp["NotebookInstanceArn"]) + resp = client.list_tags(ResourceArn=resp["NotebookInstanceArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM -def test_create_notebook_instance_invalid_instance_type(sagemaker_client): +def test_create_notebook_instance_invalid_instance_type(client): instance_type = "undefined_instance_type" args = { "NotebookInstanceName": "MyNotebookInstance", @@ -131,7 +127,7 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client): "RoleArn": FAKE_ROLE_ARN, } with pytest.raises(ClientError) as ex: - sagemaker_client.create_notebook_instance(**args) + client.create_notebook_instance(**args) assert ex.value.response["Error"]["Code"] == "ValidationException" expected_message = ( f"Value '{instance_type}' at 'instanceType' failed to satisfy " @@ -141,23 +137,21 @@ def test_create_notebook_instance_invalid_instance_type(sagemaker_client): assert expected_message in ex.value.response["Error"]["Message"] -def test_notebook_instance_lifecycle(sagemaker_client): +def test_notebook_instance_lifecycle(client): args = { "NotebookInstanceName": FAKE_NAME_PARAM, "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker_client.create_notebook_instance(**args) + resp = client.create_notebook_instance(**args) expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) notebook_instance_arn = resp["NotebookInstanceArn"] with pytest.raises(ClientError) as ex: - sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert ex.value.response["Error"]["Code"] == "ValidationException" expected_message = ( f"Status (InService) not in ([Stopped, Failed]). Unable to " @@ -165,54 +159,48 @@ def test_notebook_instance_lifecycle(sagemaker_client): ) assert expected_message in ex.value.response["Error"]["Message"] - sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert resp["NotebookInstanceStatus"] == "Stopped" - sagemaker_client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + client.list_notebook_instances() - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert resp["NotebookInstanceStatus"] == "InService" - sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) - resp = sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + resp = client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert resp["NotebookInstanceStatus"] == "Stopped" - sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) + client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) with pytest.raises(ClientError) as ex: - sagemaker_client.describe_notebook_instance( - NotebookInstanceName=FAKE_NAME_PARAM - ) + client.describe_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert ex.value.response["Error"]["Message"] == "RecordNotFound" -def test_describe_nonexistent_model(sagemaker_client): +def test_describe_nonexistent_model(client): with pytest.raises(ClientError) as e: - sagemaker_client.describe_model(ModelName="Nonexistent") + client.describe_model(ModelName="Nonexistent") assert e.value.response["Error"]["Message"].startswith("Could not find model") -def test_notebook_instance_lifecycle_config(sagemaker_client): +def test_notebook_instance_lifecycle_config(client): name = "MyLifeCycleConfig" on_create = [{"Content": "Create Script Line 1"}] on_start = [{"Content": "Start Script Line 1"}] - resp = sagemaker_client.create_notebook_instance_lifecycle_config( + resp = client.create_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start ) expected_arn = _get_notebook_instance_lifecycle_arn(name) assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn with pytest.raises(ClientError) as e: - sagemaker_client.create_notebook_instance_lifecycle_config( + client.create_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start, @@ -221,7 +209,7 @@ def test_notebook_instance_lifecycle_config(sagemaker_client): "Notebook Instance Lifecycle Config already exists.)" ) - resp = sagemaker_client.describe_notebook_instance_lifecycle_config( + resp = client.describe_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert resp["NotebookInstanceLifecycleConfigName"] == name @@ -231,12 +219,12 @@ def test_notebook_instance_lifecycle_config(sagemaker_client): assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert isinstance(resp["CreationTime"], datetime.datetime) - sagemaker_client.delete_notebook_instance_lifecycle_config( + client.delete_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) with pytest.raises(ClientError) as e: - sagemaker_client.describe_notebook_instance_lifecycle_config( + client.describe_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert e.value.response["Error"]["Message"].endswith( @@ -244,7 +232,7 @@ def test_notebook_instance_lifecycle_config(sagemaker_client): ) with pytest.raises(ClientError) as e: - sagemaker_client.delete_notebook_instance_lifecycle_config( + client.delete_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert e.value.response["Error"]["Message"].endswith( @@ -252,43 +240,88 @@ def test_notebook_instance_lifecycle_config(sagemaker_client): ) -def test_add_tags_to_notebook(sagemaker_client): +def test_add_tags_to_notebook(client): args = { "NotebookInstanceName": FAKE_NAME_PARAM, "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker_client.create_notebook_instance(**args) + resp = client.create_notebook_instance(**args) resource_arn = resp["NotebookInstanceArn"] tags = [ {"Key": "myKey", "Value": "myValue"}, ] - response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 - response = sagemaker_client.list_tags(ResourceArn=resource_arn) + response = client.list_tags(ResourceArn=resource_arn) assert response["Tags"] == tags -def test_delete_tags_from_notebook(sagemaker_client): +def test_delete_tags_from_notebook(client): args = { "NotebookInstanceName": FAKE_NAME_PARAM, "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker_client.create_notebook_instance(**args) + resp = client.create_notebook_instance(**args) resource_arn = resp["NotebookInstanceArn"] tags = [ {"Key": "myKey", "Value": "myValue"}, ] - response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 tag_keys = [tag["Key"] for tag in tags] - response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 - response = sagemaker_client.list_tags(ResourceArn=resource_arn) + response = client.list_tags(ResourceArn=resource_arn) assert response["Tags"] == [] + + +def test_list_notebook_instances(client): + for i in range(3): + args = { + "NotebookInstanceName": f"Name{i}", + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, + "RoleArn": FAKE_ROLE_ARN, + } + client.create_notebook_instance(**args) + + client.stop_notebook_instance(NotebookInstanceName="Name1") + + instances = client.list_notebook_instances()["NotebookInstances"] + assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name1", "Name2"] + + instances = client.list_notebook_instances(SortBy="Status")["NotebookInstances"] + assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2", "Name1"] + + instances = client.list_notebook_instances(SortOrder="Descending")[ + "NotebookInstances" + ] + assert [i["NotebookInstanceName"] for i in instances] == ["Name2", "Name1", "Name0"] + + instances = client.list_notebook_instances(NameContains="1")["NotebookInstances"] + assert [i["NotebookInstanceName"] for i in instances] == ["Name1"] + + instances = client.list_notebook_instances(StatusEquals="InService")[ + "NotebookInstances" + ] + assert [i["NotebookInstanceName"] for i in instances] == ["Name0", "Name2"] + + instances = client.list_notebook_instances(StatusEquals="Pending")[ + "NotebookInstances" + ] + assert [i["NotebookInstanceName"] for i in instances] == [] + + resp = client.list_notebook_instances(MaxResults=1) + instances = resp["NotebookInstances"] + assert [i["NotebookInstanceName"] for i in instances] == ["Name0"] + + resp = client.list_notebook_instances(NextToken=resp["NextToken"]) + instances = resp["NotebookInstances"] + assert [i["NotebookInstanceName"] for i in instances] == ["Name1", "Name2"] + assert "NextToken" not in resp