SageMaker: create_pipeline, list_pipelines (#5771)

This commit is contained in:
sist 2022-12-16 19:24:14 +01:00 committed by GitHub
parent 2cf770f697
commit e5d40f63f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 389 additions and 2 deletions

View File

@ -64,7 +64,7 @@ sagemaker
- [ ] create_monitoring_schedule
- [X] create_notebook_instance
- [X] create_notebook_instance_lifecycle_config
- [ ] create_pipeline
- [X] create_pipeline
- [ ] create_presigned_domain_url
- [ ] create_presigned_notebook_instance_url
- [X] create_processing_job
@ -222,7 +222,7 @@ sagemaker
- [ ] list_pipeline_execution_steps
- [ ] list_pipeline_executions
- [ ] list_pipeline_parameters_for_execution
- [ ] list_pipelines
- [X] list_pipelines
- [X] list_processing_jobs
- [ ] list_projects
- [ ] list_stage_devices

View File

@ -74,6 +74,38 @@ class BaseObject(BaseModel):
return self.gen_response_object()
class FakePipeline(BaseObject):
def __init__(
self,
pipeline_name,
pipeline_display_name,
pipeline_definition,
pipeline_description,
role_arn,
tags,
account_id,
region_name,
parallelism_configuration,
pipeline_definition_s3_location,
):
self.pipeline_name = pipeline_name
self.pipeline_arn = arn_formatter(
"pipeline", pipeline_name, account_id, region_name
)
self.pipeline_display_name = pipeline_display_name
self.pipeline_definition = pipeline_definition
self.pipeline_description = pipeline_description
self.role_arn = role_arn
self.tags = tags or []
self.parallelism_configuration = parallelism_configuration
self.pipeline_definition_s3_location = pipeline_definition_s3_location
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.last_execution_time = now_string
class FakeProcessingJob(BaseObject):
def __init__(
self,
@ -1039,6 +1071,7 @@ class SageMakerModelBackend(BaseBackend):
self.endpoint_configs = {}
self.endpoints = {}
self.experiments = {}
self.pipelines = {}
self.processing_jobs = {}
self.trials = {}
self.trial_components = {}
@ -1160,6 +1193,7 @@ class SageMakerModelBackend(BaseBackend):
"experiment-trial": self.trials,
"experiment-trial-component": self.trial_components,
"processing-job": self.processing_jobs,
"pipeline": self.pipelines,
}
target_resource, target_name = arn.split(":")[-1].split("/")
try:
@ -1713,6 +1747,116 @@ class SageMakerModelBackend(BaseBackend):
)
raise ValidationError(message=f"Could not find processing job '{arn}'.")
def create_pipeline(
self,
pipeline_name,
pipeline_display_name,
pipeline_definition,
pipeline_definition_s3_location,
pipeline_description,
role_arn,
tags,
parallelism_configuration,
):
pipeline = FakePipeline(
pipeline_name,
pipeline_display_name,
pipeline_definition,
pipeline_description,
role_arn,
tags,
self.account_id,
self.region_name,
pipeline_definition_s3_location,
parallelism_configuration,
)
self.pipelines[pipeline_name] = pipeline
return pipeline
def list_pipelines(
self,
pipeline_name_prefix,
created_after,
created_before,
next_token,
max_results,
sort_by,
sort_order,
):
if next_token:
try:
starting_index = int(next_token)
if starting_index > len(self.pipelines):
raise ValueError # invalid next_token
except ValueError:
raise AWSValidationException('Invalid pagination token because "{0}".')
else:
starting_index = 0
if max_results:
end_index = max_results + starting_index
pipelines_fetched = list(self.pipelines.values())[starting_index:end_index]
if end_index >= len(self.pipelines):
next_index = None
else:
next_index = end_index
else:
pipelines_fetched = list(self.pipelines.values())
next_index = None
if pipeline_name_prefix is not None:
pipelines_fetched = filter(
lambda x: pipeline_name_prefix in x.pipeline_name,
pipelines_fetched,
)
def format_time(x):
return (
x
if isinstance(x, str)
else datetime.fromtimestamp(x).strftime("%Y-%m-%d " "%H:%M:%S")
)
if created_after is not None:
pipelines_fetched = filter(
lambda x: x.creation_time > format_time(created_after),
pipelines_fetched,
)
if created_before is not None:
pipelines_fetched = filter(
lambda x: x.creation_time < format_time(created_before),
pipelines_fetched,
)
sort_key = "pipeline_name" if sort_by == "Name" else "creation_time"
sort_order = False if sort_order == "Ascending" else True
pipelines_fetched = sorted(
pipelines_fetched,
key=lambda pipeline_fetched: getattr(pipeline_fetched, sort_key),
reverse=sort_order,
)
pipeline_summaries = [
{
"PipelineArn": pipeline_data.pipeline_arn,
"PipelineName": pipeline_data.pipeline_name,
"PipelineDisplayName": pipeline_data.pipeline_display_name,
"PipelineDescription": pipeline_data.pipeline_description,
"RoleArn": pipeline_data.role_arn,
"CreationTime": pipeline_data.creation_time,
"LastModifiedTime": pipeline_data.last_modified_time,
"LastExecutionTime": pipeline_data.last_execution_time,
}
for pipeline_data in pipelines_fetched
]
return {
"PipelineSummaries": pipeline_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def list_processing_jobs(
self,
next_token,

View File

@ -465,6 +465,69 @@ class SageMakerResponse(BaseResponse):
response = self.sagemaker_backend.list_associations(self.request_params)
return 200, {}, json.dumps(response)
@amzn_request_id
def create_pipeline(self):
pipeline = self.sagemaker_backend.create_pipeline(
pipeline_name=self._get_param("PipelineName"),
pipeline_display_name=self._get_param("PipelineDisplayName"),
pipeline_definition=self._get_param("PipelineDefinition"),
pipeline_definition_s3_location=self._get_param(
"PipelineDefinitionS3Location"
),
pipeline_description=self._get_param("PipelineDescription"),
role_arn=self._get_param("RoleArn"),
tags=self._get_param("Tags"),
parallelism_configuration=self._get_param("ParallelismConfiguration"),
)
response = {
"PipelineArn": pipeline.pipeline_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipelines(self):
max_results_range = range(1, 101)
allowed_sort_by = ("Name", "CreationTime")
allowed_sort_order = ("Ascending", "Descending")
pipeline_name_prefix = self._get_param("PipelineNamePrefix")
created_after = self._get_param("CreatedAfter")
created_before = self._get_param("CreatedBefore")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Descending")
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}"
)
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "SortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "SortOrder", allowed_sort_order)
)
if errors:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_pipelines(
pipeline_name_prefix=pipeline_name_prefix,
created_after=created_after,
created_before=created_before,
next_token=next_token,
max_results=max_results,
sort_by=sort_by,
sort_order=sort_order,
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_processing_jobs(self):
max_results_range = range(1, 101)

View File

@ -0,0 +1,180 @@
from moto import mock_sagemaker
from time import sleep
from datetime import datetime
import boto3
import botocore
import pytest
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1"
def arn_formatter(_type, _id, account_id, region_name):
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_sagemaker():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def create_sagemaker_pipelines(sagemaker_client, pipeline_names, wait_seconds=0.0):
responses = []
for pipeline_name in pipeline_names:
responses += sagemaker_client.create_pipeline(
PipelineName=pipeline_name,
RoleArn=FAKE_ROLE_ARN,
)
sleep(wait_seconds)
return responses
def test_create_pipeline(sagemaker_client):
fake_pipeline_name = "MyPipelineName"
response = sagemaker_client.create_pipeline(
PipelineName=fake_pipeline_name,
RoleArn=FAKE_ROLE_ARN,
)
assert isinstance(response, dict)
response["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
)
def test_list_pipelines_none(sagemaker_client):
response = sagemaker_client.list_pipelines()
assert isinstance(response, dict)
assert response["PipelineSummaries"].should.be.empty
def test_list_pipelines_single(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
def test_list_pipelines_multiple(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
response = sagemaker_client.list_pipelines(
SortBy="Name",
SortOrder="Ascending",
)
response["PipelineSummaries"].should.have.length_of(len(fake_pipeline_names))
def test_list_pipelines_sort_name_ascending(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
response = sagemaker_client.list_pipelines(
SortBy="Name",
SortOrder="Ascending",
)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][-1]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
)
def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 1)
response = sagemaker_client.list_pipelines(
SortBy="CreationTime",
SortOrder="Descending",
)
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[-1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][1]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[1], ACCOUNT_ID, TEST_REGION_NAME)
)
response["PipelineSummaries"][2]["PipelineArn"].should.equal(
arn_formatter("pipeline", fake_pipeline_names[0], ACCOUNT_ID, TEST_REGION_NAME)
)
def test_list_pipelines_max_results(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
response = sagemaker_client.list_pipelines(MaxResults=2)
response["PipelineSummaries"].should.have.length_of(2)
def test_list_pipelines_next_token(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
response = sagemaker_client.list_pipelines(NextToken="0")
response["PipelineSummaries"].should.have.length_of(1)
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
response = sagemaker_client.list_pipelines(PipelineNamePrefix="Pipeline")
response["PipelineSummaries"].should.have.length_of(3)
def test_list_pipelines_created_after(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
created_after_str = "2099-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
assert response["PipelineSummaries"].should.be.empty
created_after_datetime = datetime.strptime(created_after_str, "%Y-%m-%d %H:%M:%S")
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_datetime)
assert response["PipelineSummaries"].should.be.empty
created_after_timestamp = datetime.timestamp(created_after_datetime)
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_timestamp)
assert response["PipelineSummaries"].should.be.empty
def test_list_pipelines_created_before(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
created_before_str = "2000-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
assert response["PipelineSummaries"].should.be.empty
created_before_datetime = datetime.strptime(created_before_str, "%Y-%m-%d %H:%M:%S")
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_datetime)
assert response["PipelineSummaries"].should.be.empty
created_before_timestamp = datetime.timestamp(created_before_datetime)
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_timestamp)
assert response["PipelineSummaries"].should.be.empty
@pytest.mark.parametrize(
"list_pipelines_kwargs",
[
{"MaxResults": 200},
{"NextToken": "some-invalid-next-token"},
{"SortOrder": "some-invalid-sort-order"},
{"SortBy": "some-invalid-sort-by"},
],
)
def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs):
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.list_pipelines(**list_pipelines_kwargs)