Sagemaker: Store client_request_token in FakePipelineExecution object (#6000)
This commit is contained in:
parent
98c8cd3b5d
commit
490e631245
@ -88,6 +88,7 @@ class FakePipelineExecution(BaseObject):
|
||||
pipeline_execution_description,
|
||||
parallelism_configuration,
|
||||
pipeline_definition,
|
||||
client_request_token,
|
||||
):
|
||||
self.pipeline_execution_arn = pipeline_execution_arn
|
||||
self.pipeline_execution_display_name = pipeline_execution_display_name
|
||||
@ -97,6 +98,7 @@ class FakePipelineExecution(BaseObject):
|
||||
self.pipeline_execution_failure_reason = None
|
||||
self.parallelism_configuration = parallelism_configuration
|
||||
self.pipeline_definition_for_execution = pipeline_definition
|
||||
self.client_request_token = client_request_token
|
||||
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
@ -1909,6 +1911,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
pipeline_parameters,
|
||||
pipeline_execution_description,
|
||||
parallelism_configuration,
|
||||
client_request_token,
|
||||
):
|
||||
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
|
||||
execution_id = "".join(
|
||||
@ -1929,6 +1932,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
pipeline_definition=pipeline.pipeline_definition,
|
||||
parallelism_configuration=parallelism_configuration
|
||||
or pipeline.parallelism_configuration,
|
||||
client_request_token=client_request_token,
|
||||
)
|
||||
|
||||
self.pipelines[pipeline_name].pipeline_executions[
|
||||
|
@ -480,6 +480,7 @@ class SageMakerResponse(BaseResponse):
|
||||
self._get_param("PipelineParameters"),
|
||||
self._get_param("PipelineExecutionDescription"),
|
||||
self._get_param("ParallelismConfiguration"),
|
||||
self._get_param("ClientRequestToken"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
|
@ -16,7 +16,7 @@ from moto.sagemaker.utils import (
|
||||
get_pipeline_execution_from_arn,
|
||||
get_pipeline_name_from_execution_arn,
|
||||
)
|
||||
from moto.sagemaker.models import FakePipeline
|
||||
from moto.sagemaker.models import FakePipeline, sagemaker_backends
|
||||
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
|
||||
|
||||
|
||||
@ -203,6 +203,35 @@ def test_start_pipeline_execution(sagemaker_client):
|
||||
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
|
||||
|
||||
|
||||
def test_start_pipeline_execution_contains_client_request_token(sagemaker_client):
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest(
|
||||
"Skipping test in server mode due to lack of access to sagemaker_backends."
|
||||
)
|
||||
|
||||
fake_pipeline_names = ["APipelineName"]
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_names[0],
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
},
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
|
||||
PipelineName=fake_pipeline_names[0]
|
||||
)["PipelineExecutionArn"]
|
||||
|
||||
# Verify that client_request_token is stored in FakePipelineExecution object
|
||||
assert (
|
||||
sagemaker_backends[ACCOUNT_ID][TEST_REGION_NAME]
|
||||
.pipelines[fake_pipeline_names[0]]
|
||||
.pipeline_executions[pipeline_execution_arn]
|
||||
.client_request_token
|
||||
!= ""
|
||||
)
|
||||
|
||||
|
||||
def test_describe_pipeline_execution_not_exists(sagemaker_client):
|
||||
pipeline_execution_arn = arn_formatter(
|
||||
# random ID (execution ID)
|
||||
|
Loading…
Reference in New Issue
Block a user