Sagemaker mock call: start_pipeline, list_pipeline_executions, describe_pipeline_execution, describe_pipeline_definition_for_execution, list_pipeline_parameters_for_execution #5 (#5836)

This commit is contained in:
stiebels 2023-01-12 20:18:30 +01:00 committed by GitHub
parent 79616e11e6
commit afeebd8993
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 436 additions and 28 deletions

View File

@ -171,8 +171,8 @@ sagemaker
- [ ] describe_notebook_instance - [ ] describe_notebook_instance
- [X] describe_notebook_instance_lifecycle_config - [X] describe_notebook_instance_lifecycle_config
- [X] describe_pipeline - [X] describe_pipeline
- [ ] describe_pipeline_definition_for_execution - [X] describe_pipeline_definition_for_execution
- [ ] describe_pipeline_execution - [X] describe_pipeline_execution
- [X] describe_processing_job - [X] describe_processing_job
- [ ] describe_project - [ ] describe_project
- [ ] describe_space - [ ] describe_space
@ -247,8 +247,8 @@ sagemaker
- [ ] list_notebook_instance_lifecycle_configs - [ ] list_notebook_instance_lifecycle_configs
- [ ] list_notebook_instances - [ ] list_notebook_instances
- [ ] list_pipeline_execution_steps - [ ] list_pipeline_execution_steps
- [ ] list_pipeline_executions - [X] list_pipeline_executions
- [ ] list_pipeline_parameters_for_execution - [X] list_pipeline_parameters_for_execution
- [X] list_pipelines - [X] list_pipelines
- [X] list_processing_jobs - [X] list_processing_jobs
- [ ] list_projects - [ ] list_projects
@ -277,7 +277,7 @@ sagemaker
- [ ] start_inference_experiment - [ ] start_inference_experiment
- [ ] start_monitoring_schedule - [ ] start_monitoring_schedule
- [X] start_notebook_instance - [X] start_notebook_instance
- [ ] start_pipeline_execution - [X] start_pipeline_execution
- [ ] stop_auto_ml_job - [ ] stop_auto_ml_job
- [ ] stop_compilation_job - [ ] stop_compilation_job
- [ ] stop_edge_deployment_stage - [ ] stop_edge_deployment_stage

View File

@ -1,5 +1,7 @@
import json import json
import os import os
import random
import string
from datetime import datetime from datetime import datetime
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
@ -11,6 +13,11 @@ from .exceptions import (
AWSValidationException, AWSValidationException,
ResourceNotFound, ResourceNotFound,
) )
from .utils import (
get_pipeline_from_name,
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
)
from .utils import load_pipeline_definition_from_s3, arn_formatter from .utils import load_pipeline_definition_from_s3, arn_formatter
@ -72,6 +79,50 @@ class BaseObject(BaseModel):
return self.gen_response_object() return self.gen_response_object()
class FakePipelineExecution(BaseObject):
def __init__(
self,
pipeline_execution_arn,
pipeline_execution_display_name,
pipeline_parameters,
pipeline_execution_description,
parallelism_configuration,
pipeline_definition,
):
self.pipeline_execution_arn = pipeline_execution_arn
self.pipeline_execution_display_name = pipeline_execution_display_name
self.pipeline_parameters = pipeline_parameters
self.pipeline_execution_description = pipeline_execution_description
self.pipeline_execution_status = "Succeeded"
self.pipeline_execution_failure_reason = None
self.parallelism_configuration = parallelism_configuration
self.pipeline_definition_for_execution = pipeline_definition
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.start_time = now_string
fake_user_profile_name = "fake-user-profile-name"
fake_domain_id = "fake-domain-id"
fake_user_profile_arn = arn_formatter(
"user-profile",
f"{fake_domain_id}/{fake_user_profile_name}",
pipeline_execution_arn.split(":")[4],
pipeline_execution_arn.split(":")[3],
)
self.created_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
self.last_modified_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
class FakePipeline(BaseObject): class FakePipeline(BaseObject):
def __init__( def __init__(
self, self,
@ -92,6 +143,7 @@ class FakePipeline(BaseObject):
self.pipeline_display_name = pipeline_display_name or pipeline_name self.pipeline_display_name = pipeline_display_name or pipeline_name
self.pipeline_definition = pipeline_definition self.pipeline_definition = pipeline_definition
self.pipeline_description = pipeline_description self.pipeline_description = pipeline_description
self.pipeline_executions = dict()
self.role_arn = role_arn self.role_arn = role_arn
self.tags = tags or [] self.tags = tags or []
self.parallelism_configuration = parallelism_configuration self.parallelism_configuration = parallelism_configuration
@ -1088,6 +1140,7 @@ class SageMakerModelBackend(BaseBackend):
self.endpoints = {} self.endpoints = {}
self.experiments = {} self.experiments = {}
self.pipelines = {} self.pipelines = {}
self.pipeline_executions = {}
self.processing_jobs = {} self.processing_jobs = {}
self.trials = {} self.trials = {}
self.trial_components = {} self.trial_components = {}
@ -1815,27 +1868,16 @@ class SageMakerModelBackend(BaseBackend):
self, self,
pipeline_name, pipeline_name,
): ):
try: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn del self.pipelines[pipeline.pipeline_name]
except KeyError: return pipeline.pipeline_arn
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
del self.pipelines[pipeline_name]
return pipeline_arn
def update_pipeline( def update_pipeline(
self, self,
pipeline_name, pipeline_name,
**kwargs, **kwargs,
): ):
try: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
if all( if all(
[ [
kwargs.get("pipeline_definition"), kwargs.get("pipeline_definition"),
@ -1858,19 +1900,127 @@ class SageMakerModelBackend(BaseBackend):
continue continue
setattr(self.pipelines[pipeline_name], attr_key, attr_value) setattr(self.pipelines[pipeline_name], attr_key, attr_value)
return pipeline_arn return pipeline.pipeline_arn
def start_pipeline_execution(
self,
pipeline_name,
pipeline_execution_display_name,
pipeline_parameters,
pipeline_execution_description,
parallelism_configuration,
):
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
execution_id = "".join(
random.choices(string.ascii_lowercase + string.digits, k=12)
)
pipeline_execution_arn = arn_formatter(
_type="pipeline",
_id=f"{pipeline.pipeline_name}/execution/{execution_id}",
account_id=self.account_id,
region_name=self.region_name,
)
fake_pipeline_execution = FakePipelineExecution(
pipeline_execution_arn=pipeline_execution_arn,
pipeline_execution_display_name=pipeline_execution_display_name,
pipeline_parameters=pipeline_parameters,
pipeline_execution_description=pipeline_execution_description,
pipeline_definition=pipeline.pipeline_definition,
parallelism_configuration=parallelism_configuration
or pipeline.parallelism_configuration,
)
self.pipelines[pipeline_name].pipeline_executions[
pipeline_execution_arn
] = fake_pipeline_execution
self.pipelines[
pipeline_name
].last_execution_time = fake_pipeline_execution.start_time
response = {"PipelineExecutionArn": pipeline_execution_arn}
return response
def list_pipeline_executions(
self,
pipeline_name,
):
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
response = {
"PipelineExecutionSummaries": [
{
"PipelineExecutionArn": pipeline_execution_arn,
"StartTime": pipeline_execution.start_time,
"PipelineExecutionStatus": pipeline_execution.pipeline_execution_status,
"PipelineExecutionDescription": pipeline_execution.pipeline_execution_description,
"PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name,
"PipelineExecutionFailureReason": str(
pipeline_execution.pipeline_execution_failure_reason
),
}
for pipeline_execution_arn, pipeline_execution in pipeline.pipeline_executions.items()
]
}
return response
def describe_pipeline_definition_for_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
response = {
"PipelineDefinition": str(
pipeline_execution.pipeline_definition_for_execution
),
"CreationTime": pipeline_execution.creation_time,
}
return response
def list_pipeline_parameters_for_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
response = {
"PipelineParameters": pipeline_execution.pipeline_parameters,
}
return response
def describe_pipeline_execution(
self,
pipeline_execution_arn,
):
pipeline_execution = get_pipeline_execution_from_arn(
self.pipelines, pipeline_execution_arn
)
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
pipeline_execution_summaries = {
"PipelineArn": pipeline.pipeline_arn,
"PipelineExecutionArn": pipeline_execution.pipeline_execution_arn,
"PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name,
"PipelineExecutionStatus": pipeline_execution.pipeline_execution_status,
"PipelineExecutionDescription": pipeline_execution.pipeline_execution_description,
"PipelineExperimentConfig": {},
"FailureReason": "",
"CreationTime": pipeline_execution.creation_time,
"LastModifiedTime": pipeline_execution.last_modified_time,
"CreatedBy": pipeline_execution.created_by,
"LastModifiedBy": pipeline_execution.last_modified_by,
"ParallelismConfiguration": pipeline_execution.parallelism_configuration,
}
return pipeline_execution_summaries
def describe_pipeline( def describe_pipeline(
self, self,
pipeline_name, pipeline_name,
): ):
try: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
pipeline = self.pipelines[pipeline_name]
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
response = { response = {
"PipelineArn": pipeline.pipeline_arn, "PipelineArn": pipeline.pipeline_arn,
"PipelineName": pipeline.pipeline_name, "PipelineName": pipeline.pipeline_name,

View File

@ -472,6 +472,45 @@ class SageMakerResponse(BaseResponse):
) )
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id
def start_pipeline_execution(self):
response = self.sagemaker_backend.start_pipeline_execution(
self._get_param("PipelineName"),
self._get_param("PipelineExecutionDisplayName"),
self._get_param("PipelineParameters"),
self._get_param("PipelineExecutionDescription"),
self._get_param("ParallelismConfiguration"),
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_pipeline_execution(self):
response = self.sagemaker_backend.describe_pipeline_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_pipeline_definition_for_execution(self):
response = self.sagemaker_backend.describe_pipeline_definition_for_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipeline_parameters_for_execution(self):
response = self.sagemaker_backend.list_pipeline_parameters_for_execution(
self._get_param("PipelineExecutionArn")
)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipeline_executions(self):
response = self.sagemaker_backend.list_pipeline_executions(
self._get_param("PipelineName")
)
return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def create_pipeline(self): def create_pipeline(self):
pipeline = self.sagemaker_backend.create_pipeline( pipeline = self.sagemaker_backend.create_pipeline(

View File

@ -1,5 +1,32 @@
from moto.s3.models import s3_backends from moto.s3.models import s3_backends
import json import json
from .exceptions import ValidationError
def get_pipeline_from_name(pipelines, pipeline_name):
try:
pipeline = pipelines[pipeline_name]
return pipeline
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with PipelineName {pipeline_name}."
)
def get_pipeline_name_from_execution_arn(pipeline_execution_arn):
return pipeline_execution_arn.split("/")[1].split(":")[-1]
def get_pipeline_execution_from_arn(pipelines, pipeline_execution_arn):
try:
pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
pipeline = get_pipeline_from_name(pipelines, pipeline_name)
pipeline_execution = pipeline.pipeline_executions[pipeline_execution_arn]
return pipeline_execution
except KeyError:
raise ValidationError(
message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}."
)
def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id): def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id):

View File

@ -10,8 +10,16 @@ from unittest import SkipTest
from moto.s3 import mock_s3 from moto.s3 import mock_s3
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from moto.sagemaker.exceptions import ValidationError
from moto.sagemaker.utils import (
get_pipeline_from_name,
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
)
from moto.sagemaker.models import FakePipeline
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3 from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-west-1" TEST_REGION_NAME = "us-west-1"
@ -48,6 +56,190 @@ def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0):
return responses return responses
def test_utils_get_pipeline_from_name_exists():
fake_pipeline_names = ["APipelineName", "BPipelineName"]
pipelines = {
fake_pipeline_name: FakePipeline(
pipeline_name="BFakePipeline",
pipeline_display_name="BFakePipeline",
pipeline_description=" ",
tags=[],
parallelism_configuration={},
pipeline_definition=" ",
role_arn=FAKE_ROLE_ARN,
account_id=ACCOUNT_ID,
region_name=TEST_REGION_NAME,
)
for fake_pipeline_name in fake_pipeline_names
}
retrieved_pipeline = get_pipeline_from_name(
pipelines=pipelines, pipeline_name=fake_pipeline_names[0]
)
assert retrieved_pipeline == pipelines[fake_pipeline_names[0]]
def test_utils_get_pipeline_from_name_not_exists():
with pytest.raises(ValidationError):
_ = get_pipeline_from_name(pipelines={}, pipeline_name="foo")
def test_utils_get_pipeline_name_from_execution_arn():
expected_pipeline_name = "some-pipeline-name"
pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456"
observed_pipeline_name = get_pipeline_name_from_execution_arn(
pipeline_execution_arn=pipeline_execution_arn
)
assert expected_pipeline_name == observed_pipeline_name
def test_utils_get_pipeline_execution_from_arn_not_exists():
with pytest.raises(ValidationError):
_ = get_pipeline_execution_from_arn(
pipelines={},
pipeline_execution_arn="some/random/non/existent/arn",
)
def test_utils_arn_formatter():
expected_arn = (
f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/some-pipeline-name"
)
observed_arn = arn_formatter(
_type="pipeline",
_id="some-pipeline-name",
region_name=TEST_REGION_NAME,
account_id=ACCOUNT_ID,
)
assert expected_arn == observed_arn
def test_list_pipeline_executions(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0])
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0])
response = sagemaker_client.list_pipeline_executions(
PipelineName=fake_pipeline_names[0]
)
assert len(response["PipelineExecutionSummaries"]) == 2
assert (
fake_pipeline_names[0]
in response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
)
def test_describe_pipeline_definition_for_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipeline_definition = "some-pipeline-definition"
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": pipeline_definition,
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
pipeline_execution_arn = response["PipelineExecutionArn"]
response = sagemaker_client.describe_pipeline_definition_for_execution(
PipelineExecutionArn=pipeline_execution_arn
)
assert set(response.keys()) == {
"PipelineDefinition",
"CreationTime",
"ResponseMetadata",
}
assert response["PipelineDefinition"] == pipeline_definition
def test_list_pipeline_parameters_for_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0],
PipelineParameters=[
{"Name": "foo", "Value": "bar"},
],
)["PipelineExecutionArn"]
response = sagemaker_client.list_pipeline_parameters_for_execution(
PipelineExecutionArn=pipeline_execution_arn
)
assert isinstance(response["PipelineParameters"], list)
assert len(response["PipelineParameters"]) == 1
assert response["PipelineParameters"][0]["Name"] == "foo"
assert response["PipelineParameters"][0]["Value"] == "bar"
def test_start_pipeline_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
def test_describe_pipeline_execution_not_exists(sagemaker_client):
pipeline_execution_arn = arn_formatter(
# random ID (execution ID)
"pipeline-execution",
"some-pipeline-name",
ACCOUNT_ID,
TEST_REGION_NAME,
)
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=pipeline_execution_arn
)
def test_describe_pipeline_execution(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)
_ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[1])
expected_pipeline_execution_arn = response["PipelineExecutionArn"]
pipeline_execution_summary = sagemaker_client.describe_pipeline_execution(
PipelineExecutionArn=response["PipelineExecutionArn"]
)
observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"]
observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn)
def test_load_pipeline_definition_from_s3(): def test_load_pipeline_definition_from_s3():
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
raise SkipTest( raise SkipTest(