diff --git a/docs/docs/services/sagemaker.rst b/docs/docs/services/sagemaker.rst index a9e90a411..69b07f598 100644 --- a/docs/docs/services/sagemaker.rst +++ b/docs/docs/services/sagemaker.rst @@ -171,8 +171,8 @@ sagemaker - [ ] describe_notebook_instance - [X] describe_notebook_instance_lifecycle_config - [X] describe_pipeline -- [ ] describe_pipeline_definition_for_execution -- [ ] describe_pipeline_execution +- [X] describe_pipeline_definition_for_execution +- [X] describe_pipeline_execution - [X] describe_processing_job - [ ] describe_project - [ ] describe_space @@ -247,8 +247,8 @@ sagemaker - [ ] list_notebook_instance_lifecycle_configs - [ ] list_notebook_instances - [ ] list_pipeline_execution_steps -- [ ] list_pipeline_executions -- [ ] list_pipeline_parameters_for_execution +- [X] list_pipeline_executions +- [X] list_pipeline_parameters_for_execution - [X] list_pipelines - [X] list_processing_jobs - [ ] list_projects @@ -277,7 +277,7 @@ sagemaker - [ ] start_inference_experiment - [ ] start_monitoring_schedule - [X] start_notebook_instance -- [ ] start_pipeline_execution +- [X] start_pipeline_execution - [ ] stop_auto_ml_job - [ ] stop_compilation_job - [ ] stop_edge_deployment_stage diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 515091d46..2d5a2103b 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1,5 +1,7 @@ import json import os +import random +import string from datetime import datetime from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel @@ -11,6 +13,11 @@ from .exceptions import ( AWSValidationException, ResourceNotFound, ) +from .utils import ( + get_pipeline_from_name, + get_pipeline_execution_from_arn, + get_pipeline_name_from_execution_arn, +) from .utils import load_pipeline_definition_from_s3, arn_formatter @@ -72,6 +79,50 @@ class BaseObject(BaseModel): return self.gen_response_object() +class FakePipelineExecution(BaseObject): + def __init__( + self, + pipeline_execution_arn, + pipeline_execution_display_name, + pipeline_parameters, + pipeline_execution_description, + parallelism_configuration, + pipeline_definition, + ): + self.pipeline_execution_arn = pipeline_execution_arn + self.pipeline_execution_display_name = pipeline_execution_display_name + self.pipeline_parameters = pipeline_parameters + self.pipeline_execution_description = pipeline_execution_description + self.pipeline_execution_status = "Succeeded" + self.pipeline_execution_failure_reason = None + self.parallelism_configuration = parallelism_configuration + self.pipeline_definition_for_execution = pipeline_definition + + now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.creation_time = now_string + self.last_modified_time = now_string + self.start_time = now_string + + fake_user_profile_name = "fake-user-profile-name" + fake_domain_id = "fake-domain-id" + fake_user_profile_arn = arn_formatter( + "user-profile", + f"{fake_domain_id}/{fake_user_profile_name}", + pipeline_execution_arn.split(":")[4], + pipeline_execution_arn.split(":")[3], + ) + self.created_by = { + "UserProfileArn": fake_user_profile_arn, + "UserProfileName": fake_user_profile_name, + "DomainId": fake_domain_id, + } + self.last_modified_by = { + "UserProfileArn": fake_user_profile_arn, + "UserProfileName": fake_user_profile_name, + "DomainId": fake_domain_id, + } + + class FakePipeline(BaseObject): def __init__( self, @@ -92,6 +143,7 @@ class FakePipeline(BaseObject): self.pipeline_display_name = pipeline_display_name or pipeline_name self.pipeline_definition = pipeline_definition self.pipeline_description = pipeline_description + self.pipeline_executions = dict() self.role_arn = role_arn self.tags = tags or [] self.parallelism_configuration = parallelism_configuration @@ -1088,6 +1140,7 @@ class SageMakerModelBackend(BaseBackend): self.endpoints = {} self.experiments = {} self.pipelines = {} + self.pipeline_executions = {} self.processing_jobs = {} self.trials = {} self.trial_components = {} @@ -1815,27 +1868,16 @@ class SageMakerModelBackend(BaseBackend): self, pipeline_name, ): - try: - pipeline_arn = self.pipelines[pipeline_name].pipeline_arn - except KeyError: - raise ValidationError( - message=f"Could not find pipeline with name {pipeline_name}." - ) - del self.pipelines[pipeline_name] - return pipeline_arn + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) + del self.pipelines[pipeline.pipeline_name] + return pipeline.pipeline_arn def update_pipeline( self, pipeline_name, **kwargs, ): - try: - pipeline_arn = self.pipelines[pipeline_name].pipeline_arn - except KeyError: - raise ValidationError( - message=f"Could not find pipeline with name {pipeline_name}." - ) - + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) if all( [ kwargs.get("pipeline_definition"), @@ -1858,19 +1900,127 @@ class SageMakerModelBackend(BaseBackend): continue setattr(self.pipelines[pipeline_name], attr_key, attr_value) - return pipeline_arn + return pipeline.pipeline_arn + + def start_pipeline_execution( + self, + pipeline_name, + pipeline_execution_display_name, + pipeline_parameters, + pipeline_execution_description, + parallelism_configuration, + ): + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) + execution_id = "".join( + random.choices(string.ascii_lowercase + string.digits, k=12) + ) + pipeline_execution_arn = arn_formatter( + _type="pipeline", + _id=f"{pipeline.pipeline_name}/execution/{execution_id}", + account_id=self.account_id, + region_name=self.region_name, + ) + + fake_pipeline_execution = FakePipelineExecution( + pipeline_execution_arn=pipeline_execution_arn, + pipeline_execution_display_name=pipeline_execution_display_name, + pipeline_parameters=pipeline_parameters, + pipeline_execution_description=pipeline_execution_description, + pipeline_definition=pipeline.pipeline_definition, + parallelism_configuration=parallelism_configuration + or pipeline.parallelism_configuration, + ) + + self.pipelines[pipeline_name].pipeline_executions[ + pipeline_execution_arn + ] = fake_pipeline_execution + self.pipelines[ + pipeline_name + ].last_execution_time = fake_pipeline_execution.start_time + + response = {"PipelineExecutionArn": pipeline_execution_arn} + return response + + def list_pipeline_executions( + self, + pipeline_name, + ): + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) + response = { + "PipelineExecutionSummaries": [ + { + "PipelineExecutionArn": pipeline_execution_arn, + "StartTime": pipeline_execution.start_time, + "PipelineExecutionStatus": pipeline_execution.pipeline_execution_status, + "PipelineExecutionDescription": pipeline_execution.pipeline_execution_description, + "PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name, + "PipelineExecutionFailureReason": str( + pipeline_execution.pipeline_execution_failure_reason + ), + } + for pipeline_execution_arn, pipeline_execution in pipeline.pipeline_executions.items() + ] + } + return response + + def describe_pipeline_definition_for_execution( + self, + pipeline_execution_arn, + ): + pipeline_execution = get_pipeline_execution_from_arn( + self.pipelines, pipeline_execution_arn + ) + response = { + "PipelineDefinition": str( + pipeline_execution.pipeline_definition_for_execution + ), + "CreationTime": pipeline_execution.creation_time, + } + return response + + def list_pipeline_parameters_for_execution( + self, + pipeline_execution_arn, + ): + pipeline_execution = get_pipeline_execution_from_arn( + self.pipelines, pipeline_execution_arn + ) + response = { + "PipelineParameters": pipeline_execution.pipeline_parameters, + } + return response + + def describe_pipeline_execution( + self, + pipeline_execution_arn, + ): + pipeline_execution = get_pipeline_execution_from_arn( + self.pipelines, pipeline_execution_arn + ) + pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn) + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) + + pipeline_execution_summaries = { + "PipelineArn": pipeline.pipeline_arn, + "PipelineExecutionArn": pipeline_execution.pipeline_execution_arn, + "PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name, + "PipelineExecutionStatus": pipeline_execution.pipeline_execution_status, + "PipelineExecutionDescription": pipeline_execution.pipeline_execution_description, + "PipelineExperimentConfig": {}, + "FailureReason": "", + "CreationTime": pipeline_execution.creation_time, + "LastModifiedTime": pipeline_execution.last_modified_time, + "CreatedBy": pipeline_execution.created_by, + "LastModifiedBy": pipeline_execution.last_modified_by, + "ParallelismConfiguration": pipeline_execution.parallelism_configuration, + } + return pipeline_execution_summaries def describe_pipeline( self, pipeline_name, ): - try: - pipeline = self.pipelines[pipeline_name] - except KeyError: - raise ValidationError( - message=f"Could not find pipeline with name {pipeline_name}." - ) - + pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) response = { "PipelineArn": pipeline.pipeline_arn, "PipelineName": pipeline.pipeline_name, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 3c33ce99a..2a6cc4bfc 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -472,6 +472,45 @@ class SageMakerResponse(BaseResponse): ) return 200, {}, json.dumps(response) + @amzn_request_id + def start_pipeline_execution(self): + response = self.sagemaker_backend.start_pipeline_execution( + self._get_param("PipelineName"), + self._get_param("PipelineExecutionDisplayName"), + self._get_param("PipelineParameters"), + self._get_param("PipelineExecutionDescription"), + self._get_param("ParallelismConfiguration"), + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_pipeline_execution(self): + response = self.sagemaker_backend.describe_pipeline_execution( + self._get_param("PipelineExecutionArn") + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_pipeline_definition_for_execution(self): + response = self.sagemaker_backend.describe_pipeline_definition_for_execution( + self._get_param("PipelineExecutionArn") + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_pipeline_parameters_for_execution(self): + response = self.sagemaker_backend.list_pipeline_parameters_for_execution( + self._get_param("PipelineExecutionArn") + ) + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_pipeline_executions(self): + response = self.sagemaker_backend.list_pipeline_executions( + self._get_param("PipelineName") + ) + return 200, {}, json.dumps(response) + @amzn_request_id def create_pipeline(self): pipeline = self.sagemaker_backend.create_pipeline( diff --git a/moto/sagemaker/utils.py b/moto/sagemaker/utils.py index 1a83177ff..130ce41d2 100644 --- a/moto/sagemaker/utils.py +++ b/moto/sagemaker/utils.py @@ -1,5 +1,32 @@ from moto.s3.models import s3_backends import json +from .exceptions import ValidationError + + +def get_pipeline_from_name(pipelines, pipeline_name): + try: + pipeline = pipelines[pipeline_name] + return pipeline + except KeyError: + raise ValidationError( + message=f"Could not find pipeline with PipelineName {pipeline_name}." + ) + + +def get_pipeline_name_from_execution_arn(pipeline_execution_arn): + return pipeline_execution_arn.split("/")[1].split(":")[-1] + + +def get_pipeline_execution_from_arn(pipelines, pipeline_execution_arn): + try: + pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn) + pipeline = get_pipeline_from_name(pipelines, pipeline_name) + pipeline_execution = pipeline.pipeline_executions[pipeline_execution_arn] + return pipeline_execution + except KeyError: + raise ValidationError( + message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}." + ) def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id): diff --git a/tests/test_sagemaker/test_sagemaker_pipeline.py b/tests/test_sagemaker/test_sagemaker_pipeline.py index db76a42e5..610c85b8c 100644 --- a/tests/test_sagemaker/test_sagemaker_pipeline.py +++ b/tests/test_sagemaker/test_sagemaker_pipeline.py @@ -10,8 +10,16 @@ from unittest import SkipTest from moto.s3 import mock_s3 from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID +from moto.sagemaker.exceptions import ValidationError +from moto.sagemaker.utils import ( + get_pipeline_from_name, + get_pipeline_execution_from_arn, + get_pipeline_name_from_execution_arn, +) +from moto.sagemaker.models import FakePipeline from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3 + FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" TEST_REGION_NAME = "us-west-1" @@ -48,6 +56,190 @@ def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0): return responses +def test_utils_get_pipeline_from_name_exists(): + fake_pipeline_names = ["APipelineName", "BPipelineName"] + pipelines = { + fake_pipeline_name: FakePipeline( + pipeline_name="BFakePipeline", + pipeline_display_name="BFakePipeline", + pipeline_description=" ", + tags=[], + parallelism_configuration={}, + pipeline_definition=" ", + role_arn=FAKE_ROLE_ARN, + account_id=ACCOUNT_ID, + region_name=TEST_REGION_NAME, + ) + for fake_pipeline_name in fake_pipeline_names + } + retrieved_pipeline = get_pipeline_from_name( + pipelines=pipelines, pipeline_name=fake_pipeline_names[0] + ) + assert retrieved_pipeline == pipelines[fake_pipeline_names[0]] + + +def test_utils_get_pipeline_from_name_not_exists(): + with pytest.raises(ValidationError): + _ = get_pipeline_from_name(pipelines={}, pipeline_name="foo") + + +def test_utils_get_pipeline_name_from_execution_arn(): + expected_pipeline_name = "some-pipeline-name" + pipeline_execution_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/{expected_pipeline_name}/execution/abc123def456" + observed_pipeline_name = get_pipeline_name_from_execution_arn( + pipeline_execution_arn=pipeline_execution_arn + ) + assert expected_pipeline_name == observed_pipeline_name + + +def test_utils_get_pipeline_execution_from_arn_not_exists(): + with pytest.raises(ValidationError): + _ = get_pipeline_execution_from_arn( + pipelines={}, + pipeline_execution_arn="some/random/non/existent/arn", + ) + + +def test_utils_arn_formatter(): + expected_arn = ( + f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:pipeline/some-pipeline-name" + ) + observed_arn = arn_formatter( + _type="pipeline", + _id="some-pipeline-name", + region_name=TEST_REGION_NAME, + account_id=ACCOUNT_ID, + ) + assert expected_arn == observed_arn + + +def test_list_pipeline_executions(sagemaker_client): + fake_pipeline_names = ["APipelineName"] + pipelines = [ + { + "PipelineName": fake_pipeline_names[0], + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": " ", + }, + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + _ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0]) + _ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[0]) + response = sagemaker_client.list_pipeline_executions( + PipelineName=fake_pipeline_names[0] + ) + assert len(response["PipelineExecutionSummaries"]) == 2 + assert ( + fake_pipeline_names[0] + in response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"] + ) + + +def test_describe_pipeline_definition_for_execution(sagemaker_client): + fake_pipeline_names = ["APipelineName"] + pipeline_definition = "some-pipeline-definition" + pipelines = [ + { + "PipelineName": fake_pipeline_names[0], + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": pipeline_definition, + }, + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + response = sagemaker_client.start_pipeline_execution( + PipelineName=fake_pipeline_names[0] + ) + pipeline_execution_arn = response["PipelineExecutionArn"] + response = sagemaker_client.describe_pipeline_definition_for_execution( + PipelineExecutionArn=pipeline_execution_arn + ) + assert set(response.keys()) == { + "PipelineDefinition", + "CreationTime", + "ResponseMetadata", + } + assert response["PipelineDefinition"] == pipeline_definition + + +def test_list_pipeline_parameters_for_execution(sagemaker_client): + fake_pipeline_names = ["APipelineName"] + pipelines = [ + { + "PipelineName": fake_pipeline_names[0], + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": " ", + }, + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + pipeline_execution_arn = sagemaker_client.start_pipeline_execution( + PipelineName=fake_pipeline_names[0], + PipelineParameters=[ + {"Name": "foo", "Value": "bar"}, + ], + )["PipelineExecutionArn"] + + response = sagemaker_client.list_pipeline_parameters_for_execution( + PipelineExecutionArn=pipeline_execution_arn + ) + assert isinstance(response["PipelineParameters"], list) + assert len(response["PipelineParameters"]) == 1 + assert response["PipelineParameters"][0]["Name"] == "foo" + assert response["PipelineParameters"][0]["Value"] == "bar" + + +def test_start_pipeline_execution(sagemaker_client): + fake_pipeline_names = ["APipelineName"] + pipelines = [ + { + "PipelineName": fake_pipeline_names[0], + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": " ", + }, + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + pipeline_execution_arn = sagemaker_client.start_pipeline_execution( + PipelineName=fake_pipeline_names[0] + ) + assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"] + + +def test_describe_pipeline_execution_not_exists(sagemaker_client): + pipeline_execution_arn = arn_formatter( + # random ID (execution ID) + "pipeline-execution", + "some-pipeline-name", + ACCOUNT_ID, + TEST_REGION_NAME, + ) + with pytest.raises(botocore.exceptions.ClientError): + _ = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + ) + + +def test_describe_pipeline_execution(sagemaker_client): + fake_pipeline_names = ["APipelineName", "BPipelineName"] + pipelines = [ + { + "PipelineName": fake_pipeline_name, + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": " ", + } + for fake_pipeline_name in fake_pipeline_names + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + response = sagemaker_client.start_pipeline_execution( + PipelineName=fake_pipeline_names[0] + ) + _ = sagemaker_client.start_pipeline_execution(PipelineName=fake_pipeline_names[1]) + expected_pipeline_execution_arn = response["PipelineExecutionArn"] + pipeline_execution_summary = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=response["PipelineExecutionArn"] + ) + observed_pipeline_execution_arn = pipeline_execution_summary["PipelineExecutionArn"] + observed_pipeline_execution_arn.should.be.equal(expected_pipeline_execution_arn) + + def test_load_pipeline_definition_from_s3(): if settings.TEST_SERVER_MODE: raise SkipTest(