Sagemaker: Store client_request_token in FakePipelineExecution object (#6000)

This commit is contained in:
Bogdan Girman 2023-03-01 12:19:05 +01:00 committed by GitHub
parent 98c8cd3b5d
commit 490e631245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 1 deletions

View File

@ -88,6 +88,7 @@ class FakePipelineExecution(BaseObject):
pipeline_execution_description,
parallelism_configuration,
pipeline_definition,
client_request_token,
):
self.pipeline_execution_arn = pipeline_execution_arn
self.pipeline_execution_display_name = pipeline_execution_display_name
@ -97,6 +98,7 @@ class FakePipelineExecution(BaseObject):
self.pipeline_execution_failure_reason = None
self.parallelism_configuration = parallelism_configuration
self.pipeline_definition_for_execution = pipeline_definition
self.client_request_token = client_request_token
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
@ -1909,6 +1911,7 @@ class SageMakerModelBackend(BaseBackend):
pipeline_parameters,
pipeline_execution_description,
parallelism_configuration,
client_request_token,
):
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
execution_id = "".join(
@ -1929,6 +1932,7 @@ class SageMakerModelBackend(BaseBackend):
pipeline_definition=pipeline.pipeline_definition,
parallelism_configuration=parallelism_configuration
or pipeline.parallelism_configuration,
client_request_token=client_request_token,
)
self.pipelines[pipeline_name].pipeline_executions[

View File

@ -480,6 +480,7 @@ class SageMakerResponse(BaseResponse):
self._get_param("PipelineParameters"),
self._get_param("PipelineExecutionDescription"),
self._get_param("ParallelismConfiguration"),
self._get_param("ClientRequestToken"),
)
return 200, {}, json.dumps(response)

View File

@ -16,7 +16,7 @@ from moto.sagemaker.utils import (
get_pipeline_execution_from_arn,
get_pipeline_name_from_execution_arn,
)
from moto.sagemaker.models import FakePipeline
from moto.sagemaker.models import FakePipeline, sagemaker_backends
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
@ -203,6 +203,35 @@ def test_start_pipeline_execution(sagemaker_client):
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
def test_start_pipeline_execution_contains_client_request_token(sagemaker_client):
if settings.TEST_SERVER_MODE:
raise SkipTest(
"Skipping test in server mode due to lack of access to sagemaker_backends."
)
fake_pipeline_names = ["APipelineName"]
pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
PipelineName=fake_pipeline_names[0]
)["PipelineExecutionArn"]
# Verify that client_request_token is stored in FakePipelineExecution object
assert (
sagemaker_backends[ACCOUNT_ID][TEST_REGION_NAME]
.pipelines[fake_pipeline_names[0]]
.pipeline_executions[pipeline_execution_arn]
.client_request_token
!= ""
)
def test_describe_pipeline_execution_not_exists(sagemaker_client):
pipeline_execution_arn = arn_formatter(
# random ID (execution ID)