From 008d5b958ebd4d2d4fc5c4d2e24055073dab1415 Mon Sep 17 00:00:00 2001 From: sist <15859432+stiebels@users.noreply.github.com> Date: Tue, 20 Dec 2022 00:35:37 +0100 Subject: [PATCH] Add sagemaker mock call: update_pipeline (#5787) --- docs/docs/services/sagemaker.rst | 2 +- moto/sagemaker/models.py | 34 +++++++++++++ moto/sagemaker/responses.py | 17 +++++++ .../test_sagemaker/test_sagemaker_pipeline.py | 49 +++++++++++++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) diff --git a/docs/docs/services/sagemaker.rst b/docs/docs/services/sagemaker.rst index 7156caf3d..2202239a0 100644 --- a/docs/docs/services/sagemaker.rst +++ b/docs/docs/services/sagemaker.rst @@ -313,7 +313,7 @@ sagemaker - [ ] update_monitoring_schedule - [ ] update_notebook_instance - [ ] update_notebook_instance_lifecycle_config -- [ ] update_pipeline +- [X] update_pipeline - [ ] update_pipeline_execution - [ ] update_project - [ ] update_space diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index dae4167a1..1fbf54bc5 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1787,6 +1787,40 @@ class SageMakerModelBackend(BaseBackend): del self.pipelines[pipeline_name] return pipeline_arn + def update_pipeline( + self, + pipeline_name, + **kwargs, + ): + try: + pipeline_arn = self.pipelines[pipeline_name].pipeline_arn + except KeyError: + raise ValidationError( + message=f"Could not find pipeline with name {pipeline_name}." + ) + + provided_kwargs = set(kwargs.keys()) + allowed_kwargs = { + "pipeline_display_name", + "pipeline_definition", + "pipeline_definition_s3_location", + "pipeline_description", + "role_arn", + "parallelism_configuration", + } + invalid_kwargs = provided_kwargs - allowed_kwargs + + if invalid_kwargs: + raise TypeError( + f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'" + ) + + for attr_key, attr_value in kwargs.items(): + if attr_value: + setattr(self.pipelines[pipeline_name], attr_key, attr_value) + + return pipeline_arn + def list_pipelines( self, pipeline_name_prefix, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index e40060a15..34b3154aa 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -493,6 +493,23 @@ class SageMakerResponse(BaseResponse): response = {"PipelineArn": pipeline_arn} return 200, {}, json.dumps(response) + @amzn_request_id + def update_pipeline(self): + pipeline_arn = self.sagemaker_backend.update_pipeline( + pipeline_name=self._get_param("PipelineName"), + pipeline_display_name=self._get_param("PipelineDisplayName"), + pipeline_definition=self._get_param("PipelineDefinition"), + pipeline_definition_s3_location=self._get_param( + "PipelineDefinitionS3Location" + ), + pipeline_description=self._get_param("PipelineDescription"), + role_arn=self._get_param("RoleArn"), + parallelism_configuration=self._get_param("ParallelismConfiguration"), + ) + + response = {"PipelineArn": pipeline_arn} + return 200, {}, json.dumps(response) + @amzn_request_id def list_pipelines(self): max_results_range = range(1, 101) diff --git a/tests/test_sagemaker/test_sagemaker_pipeline.py b/tests/test_sagemaker/test_sagemaker_pipeline.py index 582eeadf9..ea562646e 100644 --- a/tests/test_sagemaker/test_sagemaker_pipeline.py +++ b/tests/test_sagemaker/test_sagemaker_pipeline.py @@ -204,3 +204,52 @@ def test_delete_pipeline_exists(sagemaker_client): def test_delete_pipeline_not_exists(sagemaker_client): with pytest.raises(botocore.exceptions.ClientError): _ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name") + + +def test_update_pipeline(sagemaker_client): + with pytest.raises(botocore.exceptions.ClientError): + _ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name") + + +def test_update_pipeline_no_update(sagemaker_client): + pipeline_name = "APipelineName" + _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) + response = sagemaker_client.update_pipeline(PipelineName=pipeline_name) + response["PipelineArn"].should.equal( + arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) + ) + response = sagemaker_client.list_pipelines() + response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name) + + +def test_update_pipeline_add_attribute(sagemaker_client): + pipeline_name = "APipelineName" + pipeline_display_name_update = "APipelineDisplayName" + + _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) + response = sagemaker_client.list_pipelines() + assert "PipelineDisplayName" not in response["PipelineSummaries"][0] + + _ = sagemaker_client.update_pipeline( + PipelineName=pipeline_name, + PipelineDisplayName=pipeline_display_name_update, + ) + response = sagemaker_client.list_pipelines() + response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal( + pipeline_display_name_update + ) + response["PipelineSummaries"][0].should.have.length_of(7) + + +def test_update_pipeline_update_change_attribute(sagemaker_client): + pipeline_name = "APipelineName" + role_arn_update = f"{FAKE_ROLE_ARN}Test" + + _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) + _ = sagemaker_client.update_pipeline( + PipelineName=pipeline_name, + RoleArn=role_arn_update, + ) + response = sagemaker_client.list_pipelines() + response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update) + response["PipelineSummaries"][0].should.have.length_of(6)