Sagemaker mock call: start_pipeline, list_pipeline_executions, describe_pipeline_execution, describe_pipeline_definition_for_execution, list_pipeline_parameters_for_execution #5 (#5836)

This commit is contained in:
stiebels 2023-01-12 20:18:30 +01:00 committed by GitHub
parent 79616e11e6
commit afeebd8993
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 436 additions and 28 deletions

View File

@ -171,8 +171,8 @@ sagemaker
- [ ] describe_notebook_instance
- [X] describe_notebook_instance_lifecycle_config
- [X] describe_pipeline
- [ ] describe_pipeline_definition_for_execution
- [ ] describe_pipeline_execution
- [X] describe_pipeline_definition_for_execution
- [X] describe_pipeline_execution
- [X] describe_processing_job
- [ ] describe_project
- [ ] describe_space
@ -247,8 +247,8 @@ sagemaker
- [ ] list_notebook_instance_lifecycle_configs
- [ ] list_notebook_instances
- [ ] list_pipeline_execution_steps
- [ ] list_pipeline_executions
- [ ] list_pipeline_parameters_for_execution
- [X] list_pipeline_executions
- [X] list_pipeline_parameters_for_execution
- [X] list_pipelines
- [X] list_processing_jobs
- [ ] list_projects
@ -277,7 +277,7 @@ sagemaker
- [ ] start_inference_experiment
- [ ] start_monitoring_schedule
- [X] start_notebook_instance
- [ ] start_pipeline_execution
- [X] start_pipeline_execution
- [ ] stop_auto_ml_job
- [ ] stop_compilation_job
- [ ] stop_edge_deployment_stage

View File

@ -1,5 +1,7 @@
import json
import os
import random
import string
from datetime import datetime
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
@ -11,6 +13,11 @@ from .exceptions import (
AWSValidationException,
ResourceNotFound,
)
from .utils import (
get_pipeline_from_name,
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
)
from .utils import load_pipeline_definition_from_s3, arn_formatter
@ -72,6 +79,50 @@ class BaseObject(BaseModel):
return self.gen_response_object()
class FakePipelineExecution(BaseObject):
def __init__(
self,
pipeline_execution_arn,
pipeline_execution_display_name,
pipeline_parameters,
pipeline_execution_description,
parallelism_configuration,
pipeline_definition,
):
self.pipeline_execution_arn = pipeline_execution_arn
self.pipeline_execution_display_name = pipeline_execution_display_name
self.pipeline_parameters = pipeline_parameters
self.pipeline_execution_description = pipeline_execution_description
self.pipeline_execution_status = "Succeeded"
self.pipeline_execution_failure_reason = None
self.parallelism_configuration = parallelism_configuration
self.pipeline_definition_for_execution = pipeline_definition
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.start_time = now_string
fake_user_profile_name = "fake-user-profile-name"
fake_domain_id = "fake-domain-id"
fake_user_profile_arn = arn_formatter(
"user-profile",
f"{fake_domain_id}/{fake_user_profile_name}",
pipeline_execution_arn.split(":")[4],
pipeline_execution_arn.split(":")[3],
)
self.created_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
self.last_modified_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
class FakePipeline(BaseObject):
def __init__(
self,
@ -92,6 +143,7 @@ class FakePipeline(BaseObject):
self.pipeline_display_name = pipeline_display_name or pipeline_name
self.pipeline_definition = pipeline_definition
self.pipeline_description = pipeline_description
self.pipeline_executions = dict()
self.role_arn = role_arn
self.tags = tags or []
self.parallelism_configuration = parallelism_configuration
@ -1088,6 +1140,7 @@ class SageMakerModelBackend(BaseBackend):
self.endpoints = {}
self.experiments = {}
self.pipelines = {}
self.pipeline_executions = {}
self.processing_jobs = {}
self.trials = {}
self.trial_components = {}
@ -1815,27 +1868,16 @@ class SageMakerModelBackend(BaseBackend):
self,
pipeline_name,
):
try:
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
del self.pipelines[pipeline_name]
return pipeline_arn
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
del self.pipelines[pipeline.pipeline_name]
return pipeline.pipeline_arn
def update_pipeline(
self,
pipeline_name,
**kwargs,
):
try:
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
if all(
[
kwargs.get("pipeline_definition"),
@ -1858,19 +1900,127 @@ class SageMakerModelBackend(BaseBackend):
continue
setattr(self.pipelines[pipeline_name], attr_key, attr_value)
return pipeline_arn
return pipeline.pipeline_arn
def start_pipeline_execution(
self,
pipeline_name,
pipeline_execution_display_name,
pipeline_parameters,
pipeline_execution_description,
parallelism_configuration,
):
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
execution_id = "".join(
random.choices(string.ascii_lowercase + string.digits, k=12)
)
pipeline_execution_arn = arn_formatter(
_type="pipeline",
_id=f"{pipeline.pipeline_name}/execution/{execution_id}",
account_id=self.account_id,
region_name=self.region_name,
)
fake_pipeline_execution = FakePipelineExecution(
pipeline_execution_arn=pipeline_execution_arn,
pipeline_execution_display_name=pipeline_execution_display_name,
pipeline_parameters=pipeline_parameters,
pipeline_execution_description=pipeline_execution_description,
pipeline_definition=pipeline.pipeline_definition,
parallelism_configuration=parallelism_configuration
or pipeline.parallelism_configuration,
)
self.pipelines[pipeline_name].pipeline_executions[
pipeline_execution_arn
] = fake_pipeline_execution
self.pipelines[
pipeline_name
].last_execution_time = fake_pipeline_execution.start_time
response = {"PipelineExecutionArn": pipeline_execution_arn}
return response
def list_pipeline_executions(
self,
pipeline_name,
):
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
response = {
"PipelineExecutionSummaries": [
{
"PipelineExecutionArn": pipeline_execution_arn,
"StartTime": pipeline_execution.start_time,
"PipelineExecutionStatus": pipeline_execution.pipeline_execution_status,
"PipelineExecutionDescription": pipeline_execution.pipeline_execution_description,
"PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name,
"PipelineExecutionFailureReason": str(
pipeline_execution.pipeline_execution_failure_reason
),
}
for pipeline_execution_arn, pipeline_execution in pipeline.pipeline_executions.items()
]
}
return response
def describe_pipeline_definition_for_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
response = {
"PipelineDefinition": str(
pipeline_execution.pipeline_definition_for_execution
),
"CreationTime": pipeline_execution.creation_time,
}
return response
def list_pipeline_parameters_for_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
response = {
"PipelineParameters": pipeline_execution.pipeline_parameters,
}
return response
def describe_pipeline_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
pipeline_execution_summaries = {
"PipelineArn": pipeline.pipeline_arn,
"PipelineExecutionArn": pipeline_execution.pipeline_execution_arn,
"PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name,
"PipelineExecutionStatus": pipeline_execution.pipeline_execution_status,
"PipelineExecutionDescription": pipeline_execution.pipeline_execution_description,
"PipelineExperimentConfig": {},
"FailureReason": "",
"CreationTime": pipeline_execution.creation_time,
"LastModifiedTime": pipeline_execution.last_modified_time,
"CreatedBy": pipeline_execution.created_by,
"LastModifiedBy": pipeline_execution.last_modified_by,
"ParallelismConfiguration": pipeline_execution.parallelism_configuration,
}
return pipeline_execution_summaries
def describe_pipeline(
self,
pipeline_name,
):
try:
pipeline = self.pipelines[pipeline_name]
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
response = {
"PipelineArn": pipeline.pipeline_arn,
"PipelineName": pipeline.pipeline_name,

View File

@ -472,6 +472,45 @@ class SageMakerResponse(BaseResponse):
)
return 200, {}, json.dumps(response)
@amzn_request_id
def start_pipeline_execution(self):
response = self.sagemaker_backend.start_pipeline_execution(
self._get_param("PipelineName"),
self._get_param("PipelineExecutionDisplayName"),
self._get_param("PipelineParameters"),
self._get_param("PipelineExecutionDescription"),
self._get_param("ParallelismConfiguration"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_pipeline_execution(self):
response = self.sagemaker_backend.describe_pipeline_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_pipeline_definition_for_execution(self):
response = self.sagemaker_backend.describe_pipeline_definition_for_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipeline_parameters_for_execution(self):
response = self.sagemaker_backend.list_pipeline_parameters_for_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipeline_executions(self):
response = self.sagemaker_backend.list_pipeline_executions(
self._get_param("PipelineName")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def create_pipeline(self):
pipeline = self.sagemaker_backend.create_pipeline(

View File

@ -1,5 +1,32 @@
from moto.s3.models import s3_backends
import json
from .exceptions import ValidationError
def get_pipeline_from_name(pipelines, pipeline_name):
try:
pipeline = pipelines[pipeline_name]
return pipeline
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with PipelineName {pipeline_name}."
)
def get_pipeline_name_from_execution_arn(pipeline_execution_arn):
return pipeline_execution_arn.split("/")[1].split(":")[-1]
def get_pipeline_execution_from_arn(pipelines, pipeline_execution_arn):
try:
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
pipeline = get_pipeline_from_name(pipelines, pipeline_name)
pipeline_execution = pipeline.pipeline_executions[pipeline_execution_arn]
return pipeline_execution
except KeyError:
raise ValidationError(
message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}."
)
def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id):

View File

@ -10,8 +10,16 @@ from unittest import SkipTest
from moto.s3 import mock_s3
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from moto.sagemaker.exceptions import ValidationError
from moto.sagemaker.utils import (
get_pipeline_from_name,
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
)
from moto.sagemaker.models import FakePipeline
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-west-1"
@ -48,6 +56,190 @@ def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0):
return responses
def test_utils_get_pipeline_from_name_exists():
fake_pipeline_names = ["APipelineName", "BPipelineName"]
pipelines = {
fake_pipeline_name: FakePipeline(
pipeline_name="BFakePipeline",
pipeline_display_name="BFakePipeline",
pipeline_description=" ",
tags=[],
parallelism_configuration={},
pipeline_definition=" ",
role_arn=FAKE_ROLE_ARN,
account_id=ACCOUNT_ID,
region_name=TEST_REGION_NAME,
)
for fake_pipeline_name in fake_pipeline_names
}
retrieved_pipeline = get_pipeline_from_name(
pipelines=pipelines, pipeline_name=fake_pipeline_names[0]
)
assert retrieved_pipeline == pipelines[fake_pipeline_names[0]]
def test_utils_get_pipeline_from_name_not_exists():
with pytest.raises(ValidationError):
_ = get_pipeline_from_name(pipelines={}, pipeline_name="foo")
def test_utils_get_pipeline_name_from_execution_arn():
expected_pipeline_name = "some-pipeline-name"
pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456"
observed_pipeline_name = get_pipeline_name_from_execution_arn(
pipeline_execution_arn=pipeline_execution_arn
)
assert expected_pipeline_name == observed_pipeline_name
def test_utils_get_pipeline_execution_from_arn_not_exists():
with pytest.raises(ValidationError):
_ = get_pipeline_execution_from_arn(
pipelines={},
pipeline_execution_arn="some/random/non/existent/arn",
)
def test_utils_arn_formatter():
expected_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/some-pipeline-name"
)
observed_arn = arn_formatter(
_type="pipeline",
_id="some-pipeline-name",
region_name=TEST_REGION_NAME,
account_id=ACCOUNT_ID,
)
assert expected_arn == observed_arn
def test_list_pipeline_executions(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0])
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0])
response = sagemaker_client.list_pipeline_executions(
PipelineName=fake_pipeline_names[0]
)
assert len(response["PipelineExecutionSummaries"]) == 2
assert (
fake_pipeline_names[0]
in response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
)
def test_describe_pipeline_definition_for_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipeline_definition = "some-pipeline-definition"
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": pipeline_definition,
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
pipeline_execution_arn = response["PipelineExecutionArn"]
response = sagemaker_client.describe_pipeline_definition_for_execution(
PipelineExecutionArn=pipeline_execution_arn
)
assert set(response.keys()) == {
"PipelineDefinition",
"CreationTime",
"ResponseMetadata",
}
assert response["PipelineDefinition"] == pipeline_definition
def test_list_pipeline_parameters_for_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0],
PipelineParameters=[
{"Name": "foo", "Value": "bar"},
],
)["PipelineExecutionArn"]
response = sagemaker_client.list_pipeline_parameters_for_execution(
PipelineExecutionArn=pipeline_execution_arn
)
assert isinstance(response["PipelineParameters"], list)
assert len(response["PipelineParameters"]) == 1
assert response["PipelineParameters"][0]["Name"] == "foo"
assert response["PipelineParameters"][0]["Value"] == "bar"
def test_start_pipeline_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
def test_describe_pipeline_execution_not_exists(sagemaker_client):
pipeline_execution_arn = arn_formatter(
# random ID (execution ID)
"pipeline-execution",
"some-pipeline-name",
ACCOUNT_ID,
TEST_REGION_NAME,
)
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=pipeline_execution_arn
)
def test_describe_pipeline_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[1])
expected_pipeline_execution_arn = response["PipelineExecutionArn"]
pipeline_execution_summary = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=response["PipelineExecutionArn"]
)
observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"]
observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn)
def test_load_pipeline_definition_from_s3():
if settings.TEST_SERVER_MODE:
raise SkipTest(