Add sagemaker mock call: update_pipeline (#5787)
This commit is contained in:
		
							parent
							
								
									92396bce4f
								
							
						
					
					
						commit
						008d5b958e
					
				| @ -313,7 +313,7 @@ sagemaker | |||||||
| - [ ] update_monitoring_schedule | - [ ] update_monitoring_schedule | ||||||
| - [ ] update_notebook_instance | - [ ] update_notebook_instance | ||||||
| - [ ] update_notebook_instance_lifecycle_config | - [ ] update_notebook_instance_lifecycle_config | ||||||
| - [ ] update_pipeline | - [X] update_pipeline | ||||||
| - [ ] update_pipeline_execution | - [ ] update_pipeline_execution | ||||||
| - [ ] update_project | - [ ] update_project | ||||||
| - [ ] update_space | - [ ] update_space | ||||||
|  | |||||||
| @ -1787,6 +1787,40 @@ class SageMakerModelBackend(BaseBackend): | |||||||
|         del self.pipelines[pipeline_name] |         del self.pipelines[pipeline_name] | ||||||
|         return pipeline_arn |         return pipeline_arn | ||||||
| 
 | 
 | ||||||
|  |     def update_pipeline( | ||||||
|  |         self, | ||||||
|  |         pipeline_name, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         try: | ||||||
|  |             pipeline_arn = self.pipelines[pipeline_name].pipeline_arn | ||||||
|  |         except KeyError: | ||||||
|  |             raise ValidationError( | ||||||
|  |                 message=f"Could not find pipeline with name {pipeline_name}." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         provided_kwargs = set(kwargs.keys()) | ||||||
|  |         allowed_kwargs = { | ||||||
|  |             "pipeline_display_name", | ||||||
|  |             "pipeline_definition", | ||||||
|  |             "pipeline_definition_s3_location", | ||||||
|  |             "pipeline_description", | ||||||
|  |             "role_arn", | ||||||
|  |             "parallelism_configuration", | ||||||
|  |         } | ||||||
|  |         invalid_kwargs = provided_kwargs - allowed_kwargs | ||||||
|  | 
 | ||||||
|  |         if invalid_kwargs: | ||||||
|  |             raise TypeError( | ||||||
|  |                 f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         for attr_key, attr_value in kwargs.items(): | ||||||
|  |             if attr_value: | ||||||
|  |                 setattr(self.pipelines[pipeline_name], attr_key, attr_value) | ||||||
|  | 
 | ||||||
|  |         return pipeline_arn | ||||||
|  | 
 | ||||||
|     def list_pipelines( |     def list_pipelines( | ||||||
|         self, |         self, | ||||||
|         pipeline_name_prefix, |         pipeline_name_prefix, | ||||||
|  | |||||||
| @ -493,6 +493,23 @@ class SageMakerResponse(BaseResponse): | |||||||
|         response = {"PipelineArn": pipeline_arn} |         response = {"PipelineArn": pipeline_arn} | ||||||
|         return 200, {}, json.dumps(response) |         return 200, {}, json.dumps(response) | ||||||
| 
 | 
 | ||||||
|  |     @amzn_request_id | ||||||
|  |     def update_pipeline(self): | ||||||
|  |         pipeline_arn = self.sagemaker_backend.update_pipeline( | ||||||
|  |             pipeline_name=self._get_param("PipelineName"), | ||||||
|  |             pipeline_display_name=self._get_param("PipelineDisplayName"), | ||||||
|  |             pipeline_definition=self._get_param("PipelineDefinition"), | ||||||
|  |             pipeline_definition_s3_location=self._get_param( | ||||||
|  |                 "PipelineDefinitionS3Location" | ||||||
|  |             ), | ||||||
|  |             pipeline_description=self._get_param("PipelineDescription"), | ||||||
|  |             role_arn=self._get_param("RoleArn"), | ||||||
|  |             parallelism_configuration=self._get_param("ParallelismConfiguration"), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         response = {"PipelineArn": pipeline_arn} | ||||||
|  |         return 200, {}, json.dumps(response) | ||||||
|  | 
 | ||||||
|     @amzn_request_id |     @amzn_request_id | ||||||
|     def list_pipelines(self): |     def list_pipelines(self): | ||||||
|         max_results_range = range(1, 101) |         max_results_range = range(1, 101) | ||||||
|  | |||||||
| @ -204,3 +204,52 @@ def test_delete_pipeline_exists(sagemaker_client): | |||||||
| def test_delete_pipeline_not_exists(sagemaker_client): | def test_delete_pipeline_not_exists(sagemaker_client): | ||||||
|     with pytest.raises(botocore.exceptions.ClientError): |     with pytest.raises(botocore.exceptions.ClientError): | ||||||
|         _ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name") |         _ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_update_pipeline(sagemaker_client): | ||||||
|  |     with pytest.raises(botocore.exceptions.ClientError): | ||||||
|  |         _ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_update_pipeline_no_update(sagemaker_client): | ||||||
|  |     pipeline_name = "APipelineName" | ||||||
|  |     _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) | ||||||
|  |     response = sagemaker_client.update_pipeline(PipelineName=pipeline_name) | ||||||
|  |     response["PipelineArn"].should.equal( | ||||||
|  |         arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) | ||||||
|  |     ) | ||||||
|  |     response = sagemaker_client.list_pipelines() | ||||||
|  |     response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_update_pipeline_add_attribute(sagemaker_client): | ||||||
|  |     pipeline_name = "APipelineName" | ||||||
|  |     pipeline_display_name_update = "APipelineDisplayName" | ||||||
|  | 
 | ||||||
|  |     _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) | ||||||
|  |     response = sagemaker_client.list_pipelines() | ||||||
|  |     assert "PipelineDisplayName" not in response["PipelineSummaries"][0] | ||||||
|  | 
 | ||||||
|  |     _ = sagemaker_client.update_pipeline( | ||||||
|  |         PipelineName=pipeline_name, | ||||||
|  |         PipelineDisplayName=pipeline_display_name_update, | ||||||
|  |     ) | ||||||
|  |     response = sagemaker_client.list_pipelines() | ||||||
|  |     response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal( | ||||||
|  |         pipeline_display_name_update | ||||||
|  |     ) | ||||||
|  |     response["PipelineSummaries"][0].should.have.length_of(7) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_update_pipeline_update_change_attribute(sagemaker_client): | ||||||
|  |     pipeline_name = "APipelineName" | ||||||
|  |     role_arn_update = f"{FAKE_ROLE_ARN}Test" | ||||||
|  | 
 | ||||||
|  |     _ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) | ||||||
|  |     _ = sagemaker_client.update_pipeline( | ||||||
|  |         PipelineName=pipeline_name, | ||||||
|  |         RoleArn=role_arn_update, | ||||||
|  |     ) | ||||||
|  |     response = sagemaker_client.list_pipelines() | ||||||
|  |     response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update) | ||||||
|  |     response["PipelineSummaries"][0].should.have.length_of(6) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user