From 490e6312456941e159cb7c0e441d98b309e24da7 Mon Sep 17 00:00:00 2001 From: Bogdan Girman Date: Wed, 1 Mar 2023 12:19:05 +0100 Subject: [PATCH] Sagemaker: Store client_request_token in FakePipelineExecution object (#6000) --- moto/sagemaker/models.py | 4 +++ moto/sagemaker/responses.py | 1 + .../test_sagemaker/test_sagemaker_pipeline.py | 31 ++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 2d5a2103b..d1ea299a4 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -88,6 +88,7 @@ class FakePipelineExecution(BaseObject): pipeline_execution_description, parallelism_configuration, pipeline_definition, + client_request_token, ): self.pipeline_execution_arn = pipeline_execution_arn self.pipeline_execution_display_name = pipeline_execution_display_name @@ -97,6 +98,7 @@ class FakePipelineExecution(BaseObject): self.pipeline_execution_failure_reason = None self.parallelism_configuration = parallelism_configuration self.pipeline_definition_for_execution = pipeline_definition + self.client_request_token = client_request_token now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string @@ -1909,6 +1911,7 @@ class SageMakerModelBackend(BaseBackend): pipeline_parameters, pipeline_execution_description, parallelism_configuration, + client_request_token, ): pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) execution_id = "".join( @@ -1929,6 +1932,7 @@ class SageMakerModelBackend(BaseBackend): pipeline_definition=pipeline.pipeline_definition, parallelism_configuration=parallelism_configuration or pipeline.parallelism_configuration, + client_request_token=client_request_token, ) self.pipelines[pipeline_name].pipeline_executions[ diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 2a6cc4bfc..c679cb5dd 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -480,6 +480,7 @@ class SageMakerResponse(BaseResponse): self._get_param("PipelineParameters"), self._get_param("PipelineExecutionDescription"), self._get_param("ParallelismConfiguration"), + self._get_param("ClientRequestToken"), ) return 200, {}, json.dumps(response) diff --git a/tests/test_sagemaker/test_sagemaker_pipeline.py b/tests/test_sagemaker/test_sagemaker_pipeline.py index 610c85b8c..c17af6d6b 100644 --- a/tests/test_sagemaker/test_sagemaker_pipeline.py +++ b/tests/test_sagemaker/test_sagemaker_pipeline.py @@ -16,7 +16,7 @@ from moto.sagemaker.utils import ( get_pipeline_execution_from_arn, get_pipeline_name_from_execution_arn, ) -from moto.sagemaker.models import FakePipeline +from moto.sagemaker.models import FakePipeline, sagemaker_backends from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3 @@ -203,6 +203,35 @@ def test_start_pipeline_execution(sagemaker_client): assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"] +def test_start_pipeline_execution_contains_client_request_token(sagemaker_client): + if settings.TEST_SERVER_MODE: + raise SkipTest( + "Skipping test in server mode due to lack of access to sagemaker_backends." + ) + + fake_pipeline_names = ["APipelineName"] + pipelines = [ + { + "PipelineName": fake_pipeline_names[0], + "RoleArn": FAKE_ROLE_ARN, + "PipelineDefinition": " ", + }, + ] + _ = create_sagemaker_pipelines(sagemaker_client, pipelines) + pipeline_execution_arn = sagemaker_client.start_pipeline_execution( + PipelineName=fake_pipeline_names[0] + )["PipelineExecutionArn"] + + # Verify that client_request_token is stored in FakePipelineExecution object + assert ( + sagemaker_backends[ACCOUNT_ID][TEST_REGION_NAME] + .pipelines[fake_pipeline_names[0]] + .pipeline_executions[pipeline_execution_arn] + .client_request_token + != "" + ) + + def test_describe_pipeline_execution_not_exists(sagemaker_client): pipeline_execution_arn = arn_formatter( # random ID (execution ID)