from moto import mock_sagemaker from time import sleep from datetime import datetime import boto3 import botocore import pytest from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" TEST_REGION_NAME = "us-east-1" def arn_formatter(_type, _id, account_id, region_name): return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}" @pytest.fixture(name="sagemaker_client") def fixture_sagemaker_client(): with mock_sagemaker(): yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) def create_sagemaker_pipelines(sagemaker_client, pipeline_names, wait_seconds=0.0): responses = [] for pipeline_name in pipeline_names: responses += sagemaker_client.create_pipeline( PipelineName=pipeline_name, RoleArn=FAKE_ROLE_ARN, ) sleep(wait_seconds) return responses def test_create_pipeline(sagemaker_client): fake_pipeline_name = "MyPipelineName" response = sagemaker_client.create_pipeline( PipelineName=fake_pipeline_name, RoleArn=FAKE_ROLE_ARN, ) assert isinstance(response, dict) response["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) ) def test_list_pipelines_none(sagemaker_client): response = sagemaker_client.list_pipelines() assert isinstance(response, dict) assert response["PipelineSummaries"].should.be.empty def test_list_pipelines_single(sagemaker_client): fake_pipeline_names = ["APipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) response = sagemaker_client.list_pipelines() response["PipelineSummaries"].should.have.length_of(1) response["PipelineSummaries"][0]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) def test_list_pipelines_multiple(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) response = sagemaker_client.list_pipelines( SortBy="Name", SortOrder="Ascending", ) response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names)) def test_list_pipelines_sort_name_ascending(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) response = sagemaker_client.list_pipelines( SortBy="Name", SortOrder="Ascending", ) response["PipelineSummaries"][0]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) response["PipelineSummaries"][-1]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME) ) response["PipelineSummaries"][1]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME) ) def test_list_pipelines_sort_creation_time_descending(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 1) response = sagemaker_client.list_pipelines( SortBy="CreationTime", SortOrder="Descending", ) response["PipelineSummaries"][0]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME) ) response["PipelineSummaries"][1]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME) ) response["PipelineSummaries"][2]["PipelineArn"].should.equal( arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME) ) def test_list_pipelines_max_results(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) response = sagemaker_client.list_pipelines(MaxResults=2) response["PipelineSummaries"].should.have.length_of(2) def test_list_pipelines_next_token(sagemaker_client): fake_pipeline_names = ["APipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) response = sagemaker_client.list_pipelines(NextToken="0") response["PipelineSummaries"].should.have.length_of(1) def test_list_pipelines_pipeline_name_prefix(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe") response["PipelineSummaries"].should.have.length_of(1) response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName") response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline") response["PipelineSummaries"].should.have.length_of(3) def test_list_pipelines_created_after(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) created_after_str = "2099-12-31 23:59:59" response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str) assert response["PipelineSummaries"].should.be.empty created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S") response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime) assert response["PipelineSummaries"].should.be.empty created_after_timestamp = datetime.timestamp(created_after_datetime) response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp) assert response["PipelineSummaries"].should.be.empty def test_list_pipelines_created_before(sagemaker_client): fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] _ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) created_before_str = "2000-12-31 23:59:59" response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str) assert response["PipelineSummaries"].should.be.empty created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S") response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime) assert response["PipelineSummaries"].should.be.empty created_before_timestamp = datetime.timestamp(created_before_datetime) response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp) assert response["PipelineSummaries"].should.be.empty @pytest.mark.parametrize( "list_pipelines_kwargs", [ {"MaxResults": 200}, {"NextToken": "some-invalid-next-token"}, {"SortOrder": "some-invalid-sort-order"}, {"SortBy": "some-invalid-sort-by"}, ], ) def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs): with pytest.raises(botocore.exceptions.ClientError): _ = sagemaker_client.list_pipelines(**list_pipelines_kwargs)