SageMaker: create_pipeline, list_pipelines (#5771)
This commit is contained in:
parent
2cf770f697
commit
e5d40f63f8
@ -64,7 +64,7 @@ sagemaker
|
||||
- [ ] create_monitoring_schedule
|
||||
- [X] create_notebook_instance
|
||||
- [X] create_notebook_instance_lifecycle_config
|
||||
- [ ] create_pipeline
|
||||
- [X] create_pipeline
|
||||
- [ ] create_presigned_domain_url
|
||||
- [ ] create_presigned_notebook_instance_url
|
||||
- [X] create_processing_job
|
||||
@ -222,7 +222,7 @@ sagemaker
|
||||
- [ ] list_pipeline_execution_steps
|
||||
- [ ] list_pipeline_executions
|
||||
- [ ] list_pipeline_parameters_for_execution
|
||||
- [ ] list_pipelines
|
||||
- [X] list_pipelines
|
||||
- [X] list_processing_jobs
|
||||
- [ ] list_projects
|
||||
- [ ] list_stage_devices
|
||||
|
@ -74,6 +74,38 @@ class BaseObject(BaseModel):
|
||||
return self.gen_response_object()
|
||||
|
||||
|
||||
class FakePipeline(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_name,
|
||||
pipeline_display_name,
|
||||
pipeline_definition,
|
||||
pipeline_description,
|
||||
role_arn,
|
||||
tags,
|
||||
account_id,
|
||||
region_name,
|
||||
parallelism_configuration,
|
||||
pipeline_definition_s3_location,
|
||||
):
|
||||
self.pipeline_name = pipeline_name
|
||||
self.pipeline_arn = arn_formatter(
|
||||
"pipeline", pipeline_name, account_id, region_name
|
||||
)
|
||||
self.pipeline_display_name = pipeline_display_name
|
||||
self.pipeline_definition = pipeline_definition
|
||||
self.pipeline_description = pipeline_description
|
||||
self.role_arn = role_arn
|
||||
self.tags = tags or []
|
||||
self.parallelism_configuration = parallelism_configuration
|
||||
self.pipeline_definition_s3_location = pipeline_definition_s3_location
|
||||
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.last_execution_time = now_string
|
||||
|
||||
|
||||
class FakeProcessingJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1039,6 +1071,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.experiments = {}
|
||||
self.pipelines = {}
|
||||
self.processing_jobs = {}
|
||||
self.trials = {}
|
||||
self.trial_components = {}
|
||||
@ -1160,6 +1193,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"experiment-trial": self.trials,
|
||||
"experiment-trial-component": self.trial_components,
|
||||
"processing-job": self.processing_jobs,
|
||||
"pipeline": self.pipelines,
|
||||
}
|
||||
target_resource, target_name = arn.split(":")[-1].split("/")
|
||||
try:
|
||||
@ -1713,6 +1747,116 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=f"Could not find processing job '{arn}'.")
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
pipeline_name,
|
||||
pipeline_display_name,
|
||||
pipeline_definition,
|
||||
pipeline_definition_s3_location,
|
||||
pipeline_description,
|
||||
role_arn,
|
||||
tags,
|
||||
parallelism_configuration,
|
||||
):
|
||||
pipeline = FakePipeline(
|
||||
pipeline_name,
|
||||
pipeline_display_name,
|
||||
pipeline_definition,
|
||||
pipeline_description,
|
||||
role_arn,
|
||||
tags,
|
||||
self.account_id,
|
||||
self.region_name,
|
||||
pipeline_definition_s3_location,
|
||||
parallelism_configuration,
|
||||
)
|
||||
|
||||
self.pipelines[pipeline_name] = pipeline
|
||||
return pipeline
|
||||
|
||||
def list_pipelines(
|
||||
self,
|
||||
pipeline_name_prefix,
|
||||
created_after,
|
||||
created_before,
|
||||
next_token,
|
||||
max_results,
|
||||
sort_by,
|
||||
sort_order,
|
||||
):
|
||||
if next_token:
|
||||
try:
|
||||
starting_index = int(next_token)
|
||||
if starting_index > len(self.pipelines):
|
||||
raise ValueError # invalid next_token
|
||||
except ValueError:
|
||||
raise AWSValidationException('Invalid pagination token because "{0}".')
|
||||
else:
|
||||
starting_index = 0
|
||||
|
||||
if max_results:
|
||||
end_index = max_results + starting_index
|
||||
pipelines_fetched = list(self.pipelines.values())[starting_index:end_index]
|
||||
if end_index >= len(self.pipelines):
|
||||
next_index = None
|
||||
else:
|
||||
next_index = end_index
|
||||
else:
|
||||
pipelines_fetched = list(self.pipelines.values())
|
||||
next_index = None
|
||||
|
||||
if pipeline_name_prefix is not None:
|
||||
pipelines_fetched = filter(
|
||||
lambda x: pipeline_name_prefix in x.pipeline_name,
|
||||
pipelines_fetched,
|
||||
)
|
||||
|
||||
def format_time(x):
|
||||
return (
|
||||
x
|
||||
if isinstance(x, str)
|
||||
else datetime.fromtimestamp(x).strftime("%Y-%m-%d " "%H:%M:%S")
|
||||
)
|
||||
|
||||
if created_after is not None:
|
||||
pipelines_fetched = filter(
|
||||
lambda x: x.creation_time > format_time(created_after),
|
||||
pipelines_fetched,
|
||||
)
|
||||
|
||||
if created_before is not None:
|
||||
pipelines_fetched = filter(
|
||||
lambda x: x.creation_time < format_time(created_before),
|
||||
pipelines_fetched,
|
||||
)
|
||||
|
||||
sort_key = "pipeline_name" if sort_by == "Name" else "creation_time"
|
||||
sort_order = False if sort_order == "Ascending" else True
|
||||
pipelines_fetched = sorted(
|
||||
pipelines_fetched,
|
||||
key=lambda pipeline_fetched: getattr(pipeline_fetched, sort_key),
|
||||
reverse=sort_order,
|
||||
)
|
||||
|
||||
pipeline_summaries = [
|
||||
{
|
||||
"PipelineArn": pipeline_data.pipeline_arn,
|
||||
"PipelineName": pipeline_data.pipeline_name,
|
||||
"PipelineDisplayName": pipeline_data.pipeline_display_name,
|
||||
"PipelineDescription": pipeline_data.pipeline_description,
|
||||
"RoleArn": pipeline_data.role_arn,
|
||||
"CreationTime": pipeline_data.creation_time,
|
||||
"LastModifiedTime": pipeline_data.last_modified_time,
|
||||
"LastExecutionTime": pipeline_data.last_execution_time,
|
||||
}
|
||||
for pipeline_data in pipelines_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"PipelineSummaries": pipeline_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def list_processing_jobs(
|
||||
self,
|
||||
next_token,
|
||||
|
@ -465,6 +465,69 @@ class SageMakerResponse(BaseResponse):
|
||||
response = self.sagemaker_backend.list_associations(self.request_params)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_pipeline(self):
|
||||
pipeline = self.sagemaker_backend.create_pipeline(
|
||||
pipeline_name=self._get_param("PipelineName"),
|
||||
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
||||
pipeline_definition=self._get_param("PipelineDefinition"),
|
||||
pipeline_definition_s3_location=self._get_param(
|
||||
"PipelineDefinitionS3Location"
|
||||
),
|
||||
pipeline_description=self._get_param("PipelineDescription"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
tags=self._get_param("Tags"),
|
||||
parallelism_configuration=self._get_param("ParallelismConfiguration"),
|
||||
)
|
||||
response = {
|
||||
"PipelineArn": pipeline.pipeline_arn,
|
||||
}
|
||||
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_pipelines(self):
|
||||
max_results_range = range(1, 101)
|
||||
allowed_sort_by = ("Name", "CreationTime")
|
||||
allowed_sort_order = ("Ascending", "Descending")
|
||||
|
||||
pipeline_name_prefix = self._get_param("PipelineNamePrefix")
|
||||
created_after = self._get_param("CreatedAfter")
|
||||
created_before = self._get_param("CreatedBefore")
|
||||
sort_by = self._get_param("SortBy", "CreationTime")
|
||||
sort_order = self._get_param("SortOrder", "Descending")
|
||||
next_token = self._get_param("NextToken")
|
||||
max_results = self._get_param("MaxResults")
|
||||
|
||||
errors = []
|
||||
if max_results and max_results not in max_results_range:
|
||||
errors.append(
|
||||
f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}"
|
||||
)
|
||||
|
||||
if sort_by not in allowed_sort_by:
|
||||
errors.append(format_enum_error(sort_by, "SortBy", allowed_sort_by))
|
||||
if sort_order not in allowed_sort_order:
|
||||
errors.append(
|
||||
format_enum_error(sort_order, "SortOrder", allowed_sort_order)
|
||||
)
|
||||
if errors:
|
||||
raise AWSValidationException(
|
||||
f"{len(errors)} validation errors detected: {';'.join(errors)}"
|
||||
)
|
||||
|
||||
response = self.sagemaker_backend.list_pipelines(
|
||||
pipeline_name_prefix=pipeline_name_prefix,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
next_token=next_token,
|
||||
max_results=max_results,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_processing_jobs(self):
|
||||
max_results_range = range(1, 101)
|
||||
|
180
tests/test_sagemaker/test_sagemaker_pipeline.py
Normal file
180
tests/test_sagemaker/test_sagemaker_pipeline.py
Normal file
@ -0,0 +1,180 @@
|
||||
from moto import mock_sagemaker
|
||||
from time import sleep
|
||||
from datetime import datetime
|
||||
import boto3
|
||||
import botocore
|
||||
import pytest
|
||||
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
|
||||
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
def arn_formatter(_type, _id, account_id, region_name):
|
||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||
|
||||
|
||||
@pytest.fixture(name="sagemaker_client")
|
||||
def fixture_sagemaker_client():
|
||||
with mock_sagemaker():
|
||||
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
def create_sagemaker_pipelines(sagemaker_client, pipeline_names, wait_seconds=0.0):
|
||||
responses = []
|
||||
for pipeline_name in pipeline_names:
|
||||
responses += sagemaker_client.create_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
)
|
||||
sleep(wait_seconds)
|
||||
return responses
|
||||
|
||||
|
||||
def test_create_pipeline(sagemaker_client):
|
||||
fake_pipeline_name = "MyPipelineName"
|
||||
response = sagemaker_client.create_pipeline(
|
||||
PipelineName=fake_pipeline_name,
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
response["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
|
||||
def test_list_pipelines_none(sagemaker_client):
|
||||
response = sagemaker_client.list_pipelines()
|
||||
assert isinstance(response, dict)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
|
||||
def test_list_pipelines_single(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
|
||||
def test_list_pipelines_multiple(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
)
|
||||
response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names))
|
||||
|
||||
|
||||
def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][-1]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
|
||||
def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 1)
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="CreationTime",
|
||||
SortOrder="Descending",
|
||||
)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response["PipelineSummaries"][2]["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
|
||||
|
||||
def test_list_pipelines_max_results(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
response = sagemaker_client.list_pipelines(MaxResults=2)
|
||||
response["PipelineSummaries"].should.have.length_of(2)
|
||||
|
||||
|
||||
def test_list_pipelines_next_token(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
|
||||
response = sagemaker_client.list_pipelines(NextToken="0")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
|
||||
|
||||
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
|
||||
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline")
|
||||
response["PipelineSummaries"].should.have.length_of(3)
|
||||
|
||||
|
||||
def test_list_pipelines_created_after(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
|
||||
created_after_str = "2099-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S")
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
created_after_timestamp = datetime.timestamp(created_after_datetime)
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
|
||||
def test_list_pipelines_created_before(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
|
||||
created_before_str = "2000-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S")
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
created_before_timestamp = datetime.timestamp(created_before_datetime)
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp)
|
||||
assert response["PipelineSummaries"].should.be.empty
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"list_pipelines_kwargs",
|
||||
[
|
||||
{"MaxResults": 200},
|
||||
{"NextToken": "some-invalid-next-token"},
|
||||
{"SortOrder": "some-invalid-sort-order"},
|
||||
{"SortBy": "some-invalid-sort-by"},
|
||||
],
|
||||
)
|
||||
def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.list_pipelines(**list_pipelines_kwargs)
|
Loading…
Reference in New Issue
Block a user