Add sagemaker mock call: update_pipeline (#5787)

This commit is contained in:
sist 2022-12-20 00:35:37 +01:00 committed by GitHub
parent 92396bce4f
commit 008d5b958e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 1 deletions

View File

@ -313,7 +313,7 @@ sagemaker
- [ ] update_monitoring_schedule
- [ ] update_notebook_instance
- [ ] update_notebook_instance_lifecycle_config
- [ ] update_pipeline
- [X] update_pipeline
- [ ] update_pipeline_execution
- [ ] update_project
- [ ] update_space

View File

@ -1787,6 +1787,40 @@ class SageMakerModelBackend(BaseBackend):
del self.pipelines[pipeline_name]
return pipeline_arn
def update_pipeline(
self,
pipeline_name,
**kwargs,
):
try:
pipeline_arn = self.pipelines[pipeline_name].pipeline_arn
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
provided_kwargs = set(kwargs.keys())
allowed_kwargs = {
"pipeline_display_name",
"pipeline_definition",
"pipeline_definition_s3_location",
"pipeline_description",
"role_arn",
"parallelism_configuration",
}
invalid_kwargs = provided_kwargs - allowed_kwargs
if invalid_kwargs:
raise TypeError(
f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'"
)
for attr_key, attr_value in kwargs.items():
if attr_value:
setattr(self.pipelines[pipeline_name], attr_key, attr_value)
return pipeline_arn
def list_pipelines(
self,
pipeline_name_prefix,

View File

@ -493,6 +493,23 @@ class SageMakerResponse(BaseResponse):
response = {"PipelineArn": pipeline_arn}
return 200, {}, json.dumps(response)
@amzn_request_id
def update_pipeline(self):
pipeline_arn = self.sagemaker_backend.update_pipeline(
pipeline_name=self._get_param("PipelineName"),
pipeline_display_name=self._get_param("PipelineDisplayName"),
pipeline_definition=self._get_param("PipelineDefinition"),
pipeline_definition_s3_location=self._get_param(
"PipelineDefinitionS3Location"
),
pipeline_description=self._get_param("PipelineDescription"),
role_arn=self._get_param("RoleArn"),
parallelism_configuration=self._get_param("ParallelismConfiguration"),
)
response = {"PipelineArn": pipeline_arn}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_pipelines(self):
max_results_range = range(1, 101)

View File

@ -204,3 +204,52 @@ def test_delete_pipeline_exists(sagemaker_client):
def test_delete_pipeline_not_exists(sagemaker_client):
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
def test_update_pipeline(sagemaker_client):
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
def test_update_pipeline_no_update(sagemaker_client):
pipeline_name = "APipelineName"
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
response["PipelineArn"].should.equal(
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["PipelineName"].should.equal(pipeline_name)
def test_update_pipeline_add_attribute(sagemaker_client):
pipeline_name = "APipelineName"
pipeline_display_name_update = "APipelineDisplayName"
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
response = sagemaker_client.list_pipelines()
assert "PipelineDisplayName" not in response["PipelineSummaries"][0]
_ = sagemaker_client.update_pipeline(
PipelineName=pipeline_name,
PipelineDisplayName=pipeline_display_name_update,
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
pipeline_display_name_update
)
response["PipelineSummaries"][0].should.have.length_of(7)
def test_update_pipeline_update_change_attribute(sagemaker_client):
pipeline_name = "APipelineName"
role_arn_update = f"{FAKE_ROLE_ARN}Test"
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
_ = sagemaker_client.update_pipeline(
PipelineName=pipeline_name,
RoleArn=role_arn_update,
)
response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
response["PipelineSummaries"][0].should.have.length_of(6)