Sagemaker: Store client_request_token in FakePipelineExecution object (#6000)
This commit is contained in:
parent
98c8cd3b5d
commit
490e631245
@ -88,6 +88,7 @@ class FakePipelineExecution(BaseObject):
|
|||||||
pipeline_execution_description,
|
pipeline_execution_description,
|
||||||
parallelism_configuration,
|
parallelism_configuration,
|
||||||
pipeline_definition,
|
pipeline_definition,
|
||||||
|
client_request_token,
|
||||||
):
|
):
|
||||||
self.pipeline_execution_arn = pipeline_execution_arn
|
self.pipeline_execution_arn = pipeline_execution_arn
|
||||||
self.pipeline_execution_display_name = pipeline_execution_display_name
|
self.pipeline_execution_display_name = pipeline_execution_display_name
|
||||||
@ -97,6 +98,7 @@ class FakePipelineExecution(BaseObject):
|
|||||||
self.pipeline_execution_failure_reason = None
|
self.pipeline_execution_failure_reason = None
|
||||||
self.parallelism_configuration = parallelism_configuration
|
self.parallelism_configuration = parallelism_configuration
|
||||||
self.pipeline_definition_for_execution = pipeline_definition
|
self.pipeline_definition_for_execution = pipeline_definition
|
||||||
|
self.client_request_token = client_request_token
|
||||||
|
|
||||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
self.creation_time = now_string
|
self.creation_time = now_string
|
||||||
@ -1909,6 +1911,7 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
pipeline_parameters,
|
pipeline_parameters,
|
||||||
pipeline_execution_description,
|
pipeline_execution_description,
|
||||||
parallelism_configuration,
|
parallelism_configuration,
|
||||||
|
client_request_token,
|
||||||
):
|
):
|
||||||
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
|
pipeline = get_pipeline_from_name(self.pipelines, pipeline_name)
|
||||||
execution_id = "".join(
|
execution_id = "".join(
|
||||||
@ -1929,6 +1932,7 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
pipeline_definition=pipeline.pipeline_definition,
|
pipeline_definition=pipeline.pipeline_definition,
|
||||||
parallelism_configuration=parallelism_configuration
|
parallelism_configuration=parallelism_configuration
|
||||||
or pipeline.parallelism_configuration,
|
or pipeline.parallelism_configuration,
|
||||||
|
client_request_token=client_request_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pipelines[pipeline_name].pipeline_executions[
|
self.pipelines[pipeline_name].pipeline_executions[
|
||||||
|
@ -480,6 +480,7 @@ class SageMakerResponse(BaseResponse):
|
|||||||
self._get_param("PipelineParameters"),
|
self._get_param("PipelineParameters"),
|
||||||
self._get_param("PipelineExecutionDescription"),
|
self._get_param("PipelineExecutionDescription"),
|
||||||
self._get_param("ParallelismConfiguration"),
|
self._get_param("ParallelismConfiguration"),
|
||||||
|
self._get_param("ClientRequestToken"),
|
||||||
)
|
)
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from moto.sagemaker.utils import (
|
|||||||
get_pipeline_execution_from_arn,
|
get_pipeline_execution_from_arn,
|
||||||
get_pipeline_name_from_execution_arn,
|
get_pipeline_name_from_execution_arn,
|
||||||
)
|
)
|
||||||
from moto.sagemaker.models import FakePipeline
|
from moto.sagemaker.models import FakePipeline, sagemaker_backends
|
||||||
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
|
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
|
||||||
|
|
||||||
|
|
||||||
@ -203,6 +203,35 @@ def test_start_pipeline_execution(sagemaker_client):
|
|||||||
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
|
assert fake_pipeline_names[0] in pipeline_execution_arn["PipelineExecutionArn"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_pipeline_execution_contains_client_request_token(sagemaker_client):
|
||||||
|
if settings.TEST_SERVER_MODE:
|
||||||
|
raise SkipTest(
|
||||||
|
"Skipping test in server mode due to lack of access to sagemaker_backends."
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_pipeline_names = ["APipelineName"]
|
||||||
|
pipelines = [
|
||||||
|
{
|
||||||
|
"PipelineName": fake_pipeline_names[0],
|
||||||
|
"RoleArn": FAKE_ROLE_ARN,
|
||||||
|
"PipelineDefinition": " ",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||||
|
pipeline_execution_arn = sagemaker_client.start_pipeline_execution(
|
||||||
|
PipelineName=fake_pipeline_names[0]
|
||||||
|
)["PipelineExecutionArn"]
|
||||||
|
|
||||||
|
# Verify that client_request_token is stored in FakePipelineExecution object
|
||||||
|
assert (
|
||||||
|
sagemaker_backends[ACCOUNT_ID][TEST_REGION_NAME]
|
||||||
|
.pipelines[fake_pipeline_names[0]]
|
||||||
|
.pipeline_executions[pipeline_execution_arn]
|
||||||
|
.client_request_token
|
||||||
|
!= ""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_describe_pipeline_execution_not_exists(sagemaker_client):
|
def test_describe_pipeline_execution_not_exists(sagemaker_client):
|
||||||
pipeline_execution_arn = arn_formatter(
|
pipeline_execution_arn = arn_formatter(
|
||||||
# random ID (execution ID)
|
# random ID (execution ID)
|
||||||
|
Loading…
Reference in New Issue
Block a user