4b117c4884
* Added mock for sagemaker describe-pipeline call * Added NotImplementedError for PipelineDefinitionS3Location * Added support for PipelineDefinitionS3Location * Extended unit tests * Moved arn_formatter into utils * Import arn_formatter in test_sagemaker_pipeline * Adding uniqueness check for PipelineName * Removed unused import * Swapped client for s3_backend * Corrected kwarg names * From direct s3_backend to mocked boto call due to strange error * Changed to using s3_backends from mocked boto3 call * Remove unused argument * Black formatting * Delete object and bucket to avoid duplicate bucket names error * Try to fix bucket collisions * Remove unused lines * Switched to mock * SkipTest in server mode * Switched to handling inside to-be tested method * added s3 mock * mock s3 * Change mocking s3 * Removed unnecessary tests * Switch to only s3_backend * Adding skiptest to load from s3
490 lines
16 KiB
Python
490 lines
16 KiB
Python
from contextlib import contextmanager
|
|
from moto import mock_sagemaker, settings
|
|
from time import sleep
|
|
from datetime import datetime
|
|
import boto3
|
|
import botocore
|
|
import json
|
|
import pytest
|
|
from unittest import SkipTest
|
|
|
|
from moto.s3 import mock_s3
|
|
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
|
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
|
|
|
|
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
|
TEST_REGION_NAME = "us-west-1"
|
|
|
|
|
|
@contextmanager
|
|
def setup_s3_pipeline_definition(bucket_name, object_key, pipeline_definition):
|
|
client = boto3.client("s3")
|
|
client.create_bucket(
|
|
Bucket=bucket_name,
|
|
CreateBucketConfiguration={"LocationConstraint": TEST_REGION_NAME},
|
|
)
|
|
client.put_object(
|
|
Body=json.dumps(pipeline_definition),
|
|
Bucket=bucket_name,
|
|
Key=object_key,
|
|
)
|
|
yield
|
|
|
|
client.delete_object(Bucket=bucket_name, Key=object_key)
|
|
client.delete_bucket(Bucket=bucket_name)
|
|
|
|
|
|
@pytest.fixture(name="sagemaker_client")
|
|
def fixture_sagemaker_client():
|
|
with mock_sagemaker():
|
|
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0):
|
|
responses = []
|
|
for pipeline in pipelines:
|
|
responses += sagemaker_client.create_pipeline(**pipeline)
|
|
sleep(wait_seconds)
|
|
return responses
|
|
|
|
|
|
def test_load_pipeline_definition_from_s3():
|
|
if settings.TEST_SERVER_MODE:
|
|
raise SkipTest(
|
|
"Skipping test in server mode due to lack of access to s3_backend."
|
|
)
|
|
|
|
bucket_name = "some-bucket-1"
|
|
object_key = "some/object/key.json"
|
|
pipeline_definition = {"key": "value"}
|
|
|
|
with mock_s3():
|
|
with setup_s3_pipeline_definition(
|
|
bucket_name,
|
|
object_key,
|
|
pipeline_definition,
|
|
):
|
|
observed_pipeline_definition = load_pipeline_definition_from_s3(
|
|
pipeline_definition_s3_location={
|
|
"Bucket": bucket_name,
|
|
"ObjectKey": object_key,
|
|
},
|
|
account_id=ACCOUNT_ID,
|
|
)
|
|
observed_pipeline_definition.should.equal(pipeline_definition)
|
|
|
|
|
|
def test_create_pipeline(sagemaker_client):
|
|
fake_pipeline_name = "MyPipelineName"
|
|
response = sagemaker_client.create_pipeline(
|
|
PipelineName=fake_pipeline_name,
|
|
RoleArn=FAKE_ROLE_ARN,
|
|
PipelineDefinition=" ",
|
|
)
|
|
assert isinstance(response, dict)
|
|
response["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"create_pipeline_kwargs",
|
|
[
|
|
{"PipelineName": "MyPipelineName", "RoleArn": FAKE_ROLE_ARN},
|
|
{"RoleArn": FAKE_ROLE_ARN, "PipelineDefinition": " "},
|
|
{"PipelineName": "MyPipelineName", "PipelineDefinition": " "},
|
|
{
|
|
"PipelineName": "MyPipelineName",
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
"PipelineDefinitionS3Location": {"key": "value"},
|
|
},
|
|
],
|
|
)
|
|
def test_create_pipeline_invalid_required_kwargs(
|
|
sagemaker_client, create_pipeline_kwargs
|
|
):
|
|
with pytest.raises(
|
|
(
|
|
botocore.exceptions.ParamValidationError,
|
|
botocore.exceptions.ClientError,
|
|
)
|
|
):
|
|
_ = sagemaker_client.create_pipeline(
|
|
**create_pipeline_kwargs,
|
|
)
|
|
|
|
|
|
def test_create_pipeline_duplicate_pipeline_name(sagemaker_client):
|
|
with pytest.raises(botocore.exceptions.ClientError):
|
|
_ = sagemaker_client.create_pipeline(
|
|
PipelineName="APipelineName",
|
|
RoleArn=FAKE_ROLE_ARN,
|
|
PipelineDefinition=" ",
|
|
)
|
|
_ = sagemaker_client.create_pipeline(
|
|
PipelineName="APipelineName",
|
|
RoleArn=FAKE_ROLE_ARN,
|
|
PipelineDefinition=" ",
|
|
)
|
|
|
|
|
|
def test_list_pipelines_none(sagemaker_client):
|
|
response = sagemaker_client.list_pipelines()
|
|
assert isinstance(response, dict)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
|
|
def test_list_pipelines_single(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_names[0],
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
},
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines()
|
|
response["PipelineSummaries"].should.have.length_of(1)
|
|
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
|
|
|
|
def test_list_pipelines_multiple(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines(
|
|
SortBy="Name",
|
|
SortOrder="Ascending",
|
|
)
|
|
response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names))
|
|
|
|
|
|
def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines(
|
|
SortBy="Name",
|
|
SortOrder="Ascending",
|
|
)
|
|
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
response["PipelineSummaries"][-1]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
|
|
|
|
def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines, 1.0)
|
|
|
|
response = sagemaker_client.list_pipelines(
|
|
SortBy="CreationTime",
|
|
SortOrder="Descending",
|
|
)
|
|
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
response["PipelineSummaries"][2]["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
|
|
|
|
def test_list_pipelines_max_results(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines(MaxResults=2)
|
|
response["PipelineSummaries"].should.have.length_of(2)
|
|
|
|
|
|
def test_list_pipelines_next_token(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_names[0],
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
},
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines(NextToken="0")
|
|
response["PipelineSummaries"].should.have.length_of(1)
|
|
|
|
|
|
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
|
|
response["PipelineSummaries"].should.have.length_of(1)
|
|
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
|
|
|
|
response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline")
|
|
response["PipelineSummaries"].should.have.length_of(3)
|
|
|
|
|
|
def test_list_pipelines_created_after(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
created_after_str = "2099-12-31 23:59:59"
|
|
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S")
|
|
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
created_after_timestamp = datetime.timestamp(created_after_datetime)
|
|
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
|
|
def test_list_pipelines_created_before(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
|
|
created_before_str = "2000-12-31 23:59:59"
|
|
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S")
|
|
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
created_before_timestamp = datetime.timestamp(created_before_datetime)
|
|
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"list_pipelines_kwargs",
|
|
[
|
|
{"MaxResults": 200},
|
|
{"NextToken": "some-invalid-next-token"},
|
|
{"SortOrder": "some-invalid-sort-order"},
|
|
{"SortBy": "some-invalid-sort-by"},
|
|
],
|
|
)
|
|
def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs):
|
|
with pytest.raises(botocore.exceptions.ClientError):
|
|
_ = sagemaker_client.list_pipelines(**list_pipelines_kwargs)
|
|
|
|
|
|
def test_delete_pipeline_exists(sagemaker_client):
|
|
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
|
pipelines = [
|
|
{
|
|
"PipelineName": fake_pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
for fake_pipeline_name in fake_pipeline_names
|
|
]
|
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
|
pipeline_name_delete, pipeline_names_remain = (
|
|
fake_pipeline_names[0],
|
|
fake_pipeline_names[1:],
|
|
)
|
|
|
|
response = sagemaker_client.delete_pipeline(PipelineName=pipeline_name_delete)
|
|
assert response["PipelineArn"].endswith(pipeline_name_delete)
|
|
|
|
response = sagemaker_client.list_pipelines(PipelineNamePrefix=pipeline_name_delete)
|
|
assert response["PipelineSummaries"].should.be.empty
|
|
|
|
response = sagemaker_client.list_pipelines()
|
|
pipeline_names_exist = [
|
|
pipeline["PipelineName"] for pipeline in response["PipelineSummaries"]
|
|
]
|
|
assert set(pipeline_names_remain) == set(pipeline_names_exist)
|
|
|
|
|
|
def test_delete_pipeline_not_exists(sagemaker_client):
|
|
with pytest.raises(botocore.exceptions.ClientError):
|
|
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
|
|
|
|
|
|
def test_update_pipeline_not_exists(sagemaker_client):
|
|
with pytest.raises(botocore.exceptions.ClientError):
|
|
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
|
|
|
|
|
|
def test_update_pipeline_invalid_kwargs(sagemaker_client):
|
|
pipeline_name = "APipelineName"
|
|
pipeline = {
|
|
"PipelineName": pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
|
|
|
with pytest.raises(botocore.exceptions.ParamValidationError):
|
|
sagemaker_client.update_pipeline(
|
|
PipelineName=pipeline_name,
|
|
**{"InvalidKwarg": "some-value"},
|
|
)
|
|
|
|
|
|
def test_update_pipeline_no_update(sagemaker_client):
|
|
pipeline_name = "APipelineName"
|
|
pipeline = {
|
|
"PipelineName": pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
|
|
|
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
|
|
response["PipelineArn"].should.equal(
|
|
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
|
)
|
|
response = sagemaker_client.list_pipelines()
|
|
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
|
|
|
|
|
|
def test_update_pipeline_add_attribute(sagemaker_client):
|
|
pipeline_name = "APipelineName"
|
|
pipeline_display_name_update = "APipelineDisplayName"
|
|
|
|
pipeline = {
|
|
"PipelineName": pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
|
response = sagemaker_client.list_pipelines()
|
|
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name)
|
|
|
|
_ = sagemaker_client.update_pipeline(
|
|
PipelineName=pipeline_name,
|
|
PipelineDisplayName=pipeline_display_name_update,
|
|
)
|
|
response = sagemaker_client.list_pipelines()
|
|
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
|
|
pipeline_display_name_update
|
|
)
|
|
response["PipelineSummaries"][0].should.have.length_of(6)
|
|
|
|
|
|
def test_update_pipeline_update_change_attribute(sagemaker_client):
|
|
pipeline_name = "APipelineName"
|
|
role_arn_update = f"{FAKE_ROLE_ARN}Test"
|
|
pipeline = {
|
|
"PipelineName": pipeline_name,
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
}
|
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
|
|
|
_ = sagemaker_client.update_pipeline(
|
|
PipelineName=pipeline_name,
|
|
RoleArn=role_arn_update,
|
|
)
|
|
response = sagemaker_client.list_pipelines()
|
|
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
|
|
response["PipelineSummaries"][0].should.have.length_of(6)
|
|
|
|
|
|
def test_describe_pipeline_not_exists(sagemaker_client):
|
|
with pytest.raises(botocore.exceptions.ClientError):
|
|
_ = sagemaker_client.describe_pipeline(PipelineName="some-pipeline-name")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pipeline,expected_response_length",
|
|
[
|
|
(
|
|
{
|
|
"PipelineName": "APipelineName",
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
},
|
|
11,
|
|
),
|
|
(
|
|
{
|
|
"PipelineName": "BPipelineName",
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"PipelineDefinition": " ",
|
|
"PipelineDescription": "some pipeline description",
|
|
},
|
|
12,
|
|
),
|
|
],
|
|
)
|
|
def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length):
|
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
|
response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"])
|
|
response.should.have.length_of(expected_response_length)
|