Add sagemaker mock call: update_pipeline (#5787)
This commit is contained in:
parent
92396bce4f
commit
008d5b958e
@ -313,7 +313,7 @@ sagemaker
|
|||||||
- [ ] update_monitoring_schedule
|
- [ ] update_monitoring_schedule
|
||||||
- [ ] update_notebook_instance
|
- [ ] update_notebook_instance
|
||||||
- [ ] update_notebook_instance_lifecycle_config
|
- [ ] update_notebook_instance_lifecycle_config
|
||||||
- [ ] update_pipeline
|
- [X] update_pipeline
|
||||||
- [ ] update_pipeline_execution
|
- [ ] update_pipeline_execution
|
||||||
- [ ] update_project
|
- [ ] update_project
|
||||||
- [ ] update_space
|
- [ ] update_space
|
||||||
|
@ -1787,6 +1787,40 @@ class SageMakerModelBackend(BaseBackend):
|
|||||||
del self.pipelines[pipeline_name]
|
del self.pipelines[pipeline_name]
|
||||||
return pipeline_arn
|
return pipeline_arn
|
||||||
|
|
||||||
|
def update_pipeline(
|
||||||
|
self,
|
||||||
|
pipeline_name,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
|
||||||
|
except KeyError:
|
||||||
|
raise ValidationError(
|
||||||
|
message=f"Could not find pipeline with name {pipeline_name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
provided_kwargs = set(kwargs.keys())
|
||||||
|
allowed_kwargs = {
|
||||||
|
"pipeline_display_name",
|
||||||
|
"pipeline_definition",
|
||||||
|
"pipeline_definition_s3_location",
|
||||||
|
"pipeline_description",
|
||||||
|
"role_arn",
|
||||||
|
"parallelism_configuration",
|
||||||
|
}
|
||||||
|
invalid_kwargs = provided_kwargs - allowed_kwargs
|
||||||
|
|
||||||
|
if invalid_kwargs:
|
||||||
|
raise TypeError(
|
||||||
|
f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
for attr_key, attr_value in kwargs.items():
|
||||||
|
if attr_value:
|
||||||
|
setattr(self.pipelines[pipeline_name], attr_key, attr_value)
|
||||||
|
|
||||||
|
return pipeline_arn
|
||||||
|
|
||||||
def list_pipelines(
|
def list_pipelines(
|
||||||
self,
|
self,
|
||||||
pipeline_name_prefix,
|
pipeline_name_prefix,
|
||||||
|
@ -493,6 +493,23 @@ class SageMakerResponse(BaseResponse):
|
|||||||
response = {"PipelineArn": pipeline_arn}
|
response = {"PipelineArn": pipeline_arn}
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
|
@amzn_request_id
|
||||||
|
def update_pipeline(self):
|
||||||
|
pipeline_arn = self.sagemaker_backend.update_pipeline(
|
||||||
|
pipeline_name=self._get_param("PipelineName"),
|
||||||
|
pipeline_display_name=self._get_param("PipelineDisplayName"),
|
||||||
|
pipeline_definition=self._get_param("PipelineDefinition"),
|
||||||
|
pipeline_definition_s3_location=self._get_param(
|
||||||
|
"PipelineDefinitionS3Location"
|
||||||
|
),
|
||||||
|
pipeline_description=self._get_param("PipelineDescription"),
|
||||||
|
role_arn=self._get_param("RoleArn"),
|
||||||
|
parallelism_configuration=self._get_param("ParallelismConfiguration"),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = {"PipelineArn": pipeline_arn}
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_pipelines(self):
|
def list_pipelines(self):
|
||||||
max_results_range = range(1, 101)
|
max_results_range = range(1, 101)
|
||||||
|
@ -204,3 +204,52 @@ def test_delete_pipeline_exists(sagemaker_client):
|
|||||||
def test_delete_pipeline_not_exists(sagemaker_client):
|
def test_delete_pipeline_not_exists(sagemaker_client):
|
||||||
with pytest.raises(botocore.exceptions.ClientError):
|
with pytest.raises(botocore.exceptions.ClientError):
|
||||||
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
|
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline(sagemaker_client):
|
||||||
|
with pytest.raises(botocore.exceptions.ClientError):
|
||||||
|
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_no_update(sagemaker_client):
|
||||||
|
pipeline_name = "APipelineName"
|
||||||
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||||
|
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
|
||||||
|
response["PipelineArn"].should.equal(
|
||||||
|
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||||
|
)
|
||||||
|
response = sagemaker_client.list_pipelines()
|
||||||
|
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_add_attribute(sagemaker_client):
|
||||||
|
pipeline_name = "APipelineName"
|
||||||
|
pipeline_display_name_update = "APipelineDisplayName"
|
||||||
|
|
||||||
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||||
|
response = sagemaker_client.list_pipelines()
|
||||||
|
assert "PipelineDisplayName" not in response["PipelineSummaries"][0]
|
||||||
|
|
||||||
|
_ = sagemaker_client.update_pipeline(
|
||||||
|
PipelineName=pipeline_name,
|
||||||
|
PipelineDisplayName=pipeline_display_name_update,
|
||||||
|
)
|
||||||
|
response = sagemaker_client.list_pipelines()
|
||||||
|
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
|
||||||
|
pipeline_display_name_update
|
||||||
|
)
|
||||||
|
response["PipelineSummaries"][0].should.have.length_of(7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||||
|
pipeline_name = "APipelineName"
|
||||||
|
role_arn_update = f"{FAKE_ROLE_ARN}Test"
|
||||||
|
|
||||||
|
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||||
|
_ = sagemaker_client.update_pipeline(
|
||||||
|
PipelineName=pipeline_name,
|
||||||
|
RoleArn=role_arn_update,
|
||||||
|
)
|
||||||
|
response = sagemaker_client.list_pipelines()
|
||||||
|
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
|
||||||
|
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||||
|
Loading…
Reference in New Issue
Block a user