Add sagemaker mock call: update_pipeline (#5787)
This commit is contained in:
parent
92396bce4f
commit
008d5b958e
@ -313,7 +313,7 @@ sagemaker
|
||||
- [ ] update_monitoring_schedule
|
||||
- [ ] update_notebook_instance
|
||||
- [ ] update_notebook_instance_lifecycle_config
|
||||
- [ ] update_pipeline
|
||||
- [X] update_pipeline
|
||||
- [ ] update_pipeline_execution
|
||||
- [ ] update_project
|
||||
- [ ] update_space
|
||||
|
@ -1787,6 +1787,40 @@ class SageMakerModelBackend(BaseBackend):
|
||||
del self.pipelines[pipeline_name]
|
||||
return pipeline_arn
|
||||
|
||||
def update_pipeline(
|
||||
self,
|
||||
pipeline_name,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
|
||||
except KeyError:
|
||||
raise ValidationError(
|
||||
message=f"Could not find pipeline with name {pipeline_name}."
|
||||
)
|
||||
|
||||
provided_kwargs = set(kwargs.keys())
|
||||
allowed_kwargs = {
|
||||
"pipeline_display_name",
|
||||
"pipeline_definition",
|
||||
"pipeline_definition_s3_location",
|
||||
"pipeline_description",
|
||||
"role_arn",
|
||||
"parallelism_configuration",
|
||||
}
|
||||
invalid_kwargs = provided_kwargs - allowed_kwargs
|
||||
|
||||
if invalid_kwargs:
|
||||
raise TypeError(
|
||||
f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'"
|
||||
)
|
||||
|
||||
for attr_key, attr_value in kwargs.items():
|
||||
if attr_value:
|
||||
setattr(self.pipelines[pipeline_name], attr_key, attr_value)
|
||||
|
||||
return pipeline_arn
|
||||
|
||||
def list_pipelines(
|
||||
self,
|
||||
pipeline_name_prefix,
|
||||
|
@ -493,6 +493,23 @@ class SageMakerResponse(BaseResponse):
|
||||
response = {"PipelineArn": pipeline_arn}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def update_pipeline(self):
|
||||
pipeline_arn = self.sagemaker_backend.update_pipeline(
|
||||
pipeline_name=self._get_param("PipelineName"),
|
||||
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
||||
pipeline_definition=self._get_param("PipelineDefinition"),
|
||||
pipeline_definition_s3_location=self._get_param(
|
||||
"PipelineDefinitionS3Location"
|
||||
),
|
||||
pipeline_description=self._get_param("PipelineDescription"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
parallelism_configuration=self._get_param("ParallelismConfiguration"),
|
||||
)
|
||||
|
||||
response = {"PipelineArn": pipeline_arn}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_pipelines(self):
|
||||
max_results_range = range(1, 101)
|
||||
|
@ -204,3 +204,52 @@ def test_delete_pipeline_exists(sagemaker_client):
|
||||
def test_delete_pipeline_not_exists(sagemaker_client):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
|
||||
|
||||
|
||||
def test_update_pipeline(sagemaker_client):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
|
||||
|
||||
|
||||
def test_update_pipeline_no_update(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
|
||||
response["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
|
||||
|
||||
|
||||
def test_update_pipeline_add_attribute(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
pipeline_display_name_update = "APipelineDisplayName"
|
||||
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
response = sagemaker_client.list_pipelines()
|
||||
assert "PipelineDisplayName" not in response["PipelineSummaries"][0]
|
||||
|
||||
_ = sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
PipelineDisplayName=pipeline_display_name_update,
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
|
||||
pipeline_display_name_update
|
||||
)
|
||||
response["PipelineSummaries"][0].should.have.length_of(7)
|
||||
|
||||
|
||||
def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
role_arn_update = f"{FAKE_ROLE_ARN}Test"
|
||||
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
_ = sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
RoleArn=role_arn_update,
|
||||
)
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
|
||||
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||
|
Loading…
Reference in New Issue
Block a user