Add sagemaker mock call: describe_pipeline (#5797)

* Added mock for sagemaker describe-pipeline call

* Added NotImplementedError for PipelineDefinitionS3Location

* Added support for PipelineDefinitionS3Location

* Extended unit tests

* Moved arn_formatter into utils

* Import arn_formatter in test_sagemaker_pipeline

* Adding uniqueness check for PipelineName

* Removed unused import

* Swapped client for s3_backend

* Corrected kwarg names

* From direct s3_backend to mocked boto call due to strange error

* Changed to using s3_backends from mocked boto3 call

* Remove unused argument

* Black formatting

* Delete object and bucket to avoid duplicate bucket names error

* Try to fix bucket collisions

* Remove unused lines

* Switched to mock

* SkipTest in server mode

* Switched to handling inside to-be tested method

* added s3 mock

* mock s3

* Change mocking s3

* Removed unnecessary tests

* Switch to only s3_backend

* Adding skiptest to load from s3
This commit is contained in:
stiebels 2023-01-11 20:30:07 +01:00 committed by GitHub
parent cf0bcbce91
commit 4b117c4884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 375 additions and 51 deletions

View File

@ -170,7 +170,7 @@ sagemaker
- [ ] describe_monitoring_schedule - [ ] describe_monitoring_schedule
- [ ] describe_notebook_instance - [ ] describe_notebook_instance
- [X] describe_notebook_instance_lifecycle_config - [X] describe_notebook_instance_lifecycle_config
- [ ] describe_pipeline - [X] describe_pipeline
- [ ] describe_pipeline_definition_for_execution - [ ] describe_pipeline_definition_for_execution
- [ ] describe_pipeline_execution - [ ] describe_pipeline_execution
- [X] describe_processing_job - [X] describe_processing_job

View File

@ -1,6 +1,7 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.sagemaker import validators from moto.sagemaker import validators
from moto.utilities.paginator import paginate from moto.utilities.paginator import paginate
@ -10,6 +11,7 @@ from .exceptions import (
AWSValidationException, AWSValidationException,
ResourceNotFound, ResourceNotFound,
) )
from .utils import load_pipeline_definition_from_s3, arn_formatter
PAGINATION_MODEL = { PAGINATION_MODEL = {
@ -44,10 +46,6 @@ PAGINATION_MODEL = {
} }
def arn_formatter(_type, _id, account_id, region_name):
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
class BaseObject(BaseModel): class BaseObject(BaseModel):
def camelCase(self, key): def camelCase(self, key):
words = [] words = []
@ -86,24 +84,42 @@ class FakePipeline(BaseObject):
account_id, account_id,
region_name, region_name,
parallelism_configuration, parallelism_configuration,
pipeline_definition_s3_location,
): ):
self.pipeline_name = pipeline_name self.pipeline_name = pipeline_name
self.pipeline_arn = arn_formatter( self.pipeline_arn = arn_formatter(
"pipeline", pipeline_name, account_id, region_name "pipeline", pipeline_name, account_id, region_name
) )
self.pipeline_display_name = pipeline_display_name self.pipeline_display_name = pipeline_display_name or pipeline_name
self.pipeline_definition = pipeline_definition self.pipeline_definition = pipeline_definition
self.pipeline_description = pipeline_description self.pipeline_description = pipeline_description
self.role_arn = role_arn self.role_arn = role_arn
self.tags = tags or [] self.tags = tags or []
self.parallelism_configuration = parallelism_configuration self.parallelism_configuration = parallelism_configuration
self.pipeline_definition_s3_location = pipeline_definition_s3_location
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string self.creation_time = now_string
self.last_modified_time = now_string self.last_modified_time = now_string
self.last_execution_time = now_string self.last_execution_time = None
self.pipeline_status = "Active"
fake_user_profile_name = "fake-user-profile-name"
fake_domain_id = "fake-domain-id"
fake_user_profile_arn = arn_formatter(
"user-profile",
f"{fake_domain_id}/{fake_user_profile_name}",
account_id,
region_name,
)
self.created_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
self.last_modified_by = {
"UserProfileArn": fake_user_profile_arn,
"UserProfileName": fake_user_profile_name,
"DomainId": fake_domain_id,
}
class FakeProcessingJob(BaseObject): class FakeProcessingJob(BaseObject):
@ -1758,6 +1774,28 @@ class SageMakerModelBackend(BaseBackend):
tags, tags,
parallelism_configuration, parallelism_configuration,
): ):
if not any([pipeline_definition, pipeline_definition_s3_location]):
raise ValidationError(
"An error occurred (ValidationException) when calling the CreatePipeline operation: Either "
"Pipeline Definition or Pipeline Definition S3 location should be provided"
)
if all([pipeline_definition, pipeline_definition_s3_location]):
raise ValidationError(
"An error occurred (ValidationException) when calling the CreatePipeline operation: "
"Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present"
)
if pipeline_name in self.pipelines:
raise ValidationError(
f"An error occurred (ValidationException) when calling the CreatePipeline operation: Pipeline names "
f"must be unique within an AWS account and region. Pipeline with name ({pipeline_name}) already exists."
)
if pipeline_definition_s3_location:
pipeline_definition = load_pipeline_definition_from_s3(
pipeline_definition_s3_location, self.account_id
)
pipeline = FakePipeline( pipeline = FakePipeline(
pipeline_name, pipeline_name,
pipeline_display_name, pipeline_display_name,
@ -1767,7 +1805,6 @@ class SageMakerModelBackend(BaseBackend):
tags, tags,
self.account_id, self.account_id,
self.region_name, self.region_name,
pipeline_definition_s3_location,
parallelism_configuration, parallelism_configuration,
) )
@ -1799,28 +1836,59 @@ class SageMakerModelBackend(BaseBackend):
message=f"Could not find pipeline with name {pipeline_name}." message=f"Could not find pipeline with name {pipeline_name}."
) )
provided_kwargs = set(kwargs.keys()) if all(
allowed_kwargs = { [
"pipeline_display_name", kwargs.get("pipeline_definition"),
"pipeline_definition", kwargs.get("pipeline_definition_s3_location"),
"pipeline_definition_s3_location", ]
"pipeline_description", ):
"role_arn", raise ValidationError(
"parallelism_configuration", "An error occurred (ValidationException) when calling the UpdatePipeline operation: "
} "Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present"
invalid_kwargs = provided_kwargs - allowed_kwargs
if invalid_kwargs:
raise TypeError(
f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'"
) )
for attr_key, attr_value in kwargs.items(): for attr_key, attr_value in kwargs.items():
if attr_value: if attr_value:
if attr_key == "pipeline_definition_s3_location":
self.pipelines[
pipeline_name
].pipeline_definition = load_pipeline_definition_from_s3(
attr_value, self.account_id
)
continue
setattr(self.pipelines[pipeline_name], attr_key, attr_value) setattr(self.pipelines[pipeline_name], attr_key, attr_value)
return pipeline_arn return pipeline_arn
def describe_pipeline(
self,
pipeline_name,
):
try:
pipeline = self.pipelines[pipeline_name]
except KeyError:
raise ValidationError(
message=f"Could not find pipeline with name {pipeline_name}."
)
response = {
"PipelineArn": pipeline.pipeline_arn,
"PipelineName": pipeline.pipeline_name,
"PipelineDisplayName": pipeline.pipeline_display_name,
"PipelineDescription": pipeline.pipeline_description,
"PipelineDefinition": pipeline.pipeline_definition,
"RoleArn": pipeline.role_arn,
"PipelineStatus": pipeline.pipeline_status,
"CreationTime": pipeline.creation_time,
"LastModifiedTime": pipeline.last_modified_time,
"LastRunTime": pipeline.last_execution_time,
"CreatedBy": pipeline.created_by,
"LastModifiedBy": pipeline.last_modified_by,
"ParallelismConfiguration": pipeline.parallelism_configuration,
}
return response
def list_pipelines( def list_pipelines(
self, self,
pipeline_name_prefix, pipeline_name_prefix,

View File

@ -465,6 +465,13 @@ class SageMakerResponse(BaseResponse):
response = self.sagemaker_backend.list_associations(self.request_params) response = self.sagemaker_backend.list_associations(self.request_params)
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id
def describe_pipeline(self):
response = self.sagemaker_backend.describe_pipeline(
self._get_param("PipelineName")
)
return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def create_pipeline(self): def create_pipeline(self):
pipeline = self.sagemaker_backend.create_pipeline( pipeline = self.sagemaker_backend.create_pipeline(

15
moto/sagemaker/utils.py Normal file
View File

@ -0,0 +1,15 @@
from moto.s3.models import s3_backends
import json
def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id):
s3_backend = s3_backends[account_id]["global"]
result = s3_backend.get_object(
bucket_name=pipeline_definition_s3_location["Bucket"],
key_name=pipeline_definition_s3_location["ObjectKey"],
)
return json.loads(result.value)
def arn_formatter(_type, _id, account_id, region_name):
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"

View File

@ -1,18 +1,37 @@
from moto import mock_sagemaker from contextlib import contextmanager
from moto import mock_sagemaker, settings
from time import sleep from time import sleep
from datetime import datetime from datetime import datetime
import boto3 import boto3
import botocore import botocore
import json
import pytest import pytest
from unittest import SkipTest
from moto.s3 import mock_s3
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1" TEST_REGION_NAME = "us-west-1"
def arn_formatter(_type, _id, account_id, region_name): @contextmanager
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}" def setup_s3_pipeline_definition(bucket_name, object_key, pipeline_definition):
client = boto3.client("s3")
client.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": TEST_REGION_NAME},
)
client.put_object(
Body=json.dumps(pipeline_definition),
Bucket=bucket_name,
Key=object_key,
)
yield
client.delete_object(Bucket=bucket_name, Key=object_key)
client.delete_bucket(Bucket=bucket_name)
@pytest.fixture(name="sagemaker_client") @pytest.fixture(name="sagemaker_client")
@ -21,22 +40,46 @@ def fixture_sagemaker_client():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
def create_sagemaker_pipelines(sagemaker_client, pipeline_names, wait_seconds=0.0): def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0):
responses = [] responses = []
for pipeline_name in pipeline_names: for pipeline in pipelines:
responses += sagemaker_client.create_pipeline( responses += sagemaker_client.create_pipeline(**pipeline)
PipelineName=pipeline_name,
RoleArn=FAKE_ROLE_ARN,
)
sleep(wait_seconds) sleep(wait_seconds)
return responses return responses
def test_load_pipeline_definition_from_s3():
if settings.TEST_SERVER_MODE:
raise SkipTest(
"Skipping test in server mode due to lack of access to s3_backend."
)
bucket_name = "some-bucket-1"
object_key = "some/object/key.json"
pipeline_definition = {"key": "value"}
with mock_s3():
with setup_s3_pipeline_definition(
bucket_name,
object_key,
pipeline_definition,
):
observed_pipeline_definition = load_pipeline_definition_from_s3(
pipeline_definition_s3_location={
"Bucket": bucket_name,
"ObjectKey": object_key,
},
account_id=ACCOUNT_ID,
)
observed_pipeline_definition.should.equal(pipeline_definition)
def test_create_pipeline(sagemaker_client): def test_create_pipeline(sagemaker_client):
fake_pipeline_name = "MyPipelineName" fake_pipeline_name = "MyPipelineName"
response = sagemaker_client.create_pipeline( response = sagemaker_client.create_pipeline(
PipelineName=fake_pipeline_name, PipelineName=fake_pipeline_name,
RoleArn=FAKE_ROLE_ARN, RoleArn=FAKE_ROLE_ARN,
PipelineDefinition=" ",
) )
assert isinstance(response, dict) assert isinstance(response, dict)
response["PipelineArn"].should.equal( response["PipelineArn"].should.equal(
@ -44,6 +87,48 @@ def test_create_pipeline(sagemaker_client):
) )
@pytest.mark.parametrize(
"create_pipeline_kwargs",
[
{"PipelineName": "MyPipelineName", "RoleArn": FAKE_ROLE_ARN},
{"RoleArn": FAKE_ROLE_ARN, "PipelineDefinition": " "},
{"PipelineName": "MyPipelineName", "PipelineDefinition": " "},
{
"PipelineName": "MyPipelineName",
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
"PipelineDefinitionS3Location": {"key": "value"},
},
],
)
def test_create_pipeline_invalid_required_kwargs(
sagemaker_client, create_pipeline_kwargs
):
with pytest.raises(
(
botocore.exceptions.ParamValidationError,
botocore.exceptions.ClientError,
)
):
_ = sagemaker_client.create_pipeline(
**create_pipeline_kwargs,
)
def test_create_pipeline_duplicate_pipeline_name(sagemaker_client):
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.create_pipeline(
PipelineName="APipelineName",
RoleArn=FAKE_ROLE_ARN,
PipelineDefinition=" ",
)
_ = sagemaker_client.create_pipeline(
PipelineName="APipelineName",
RoleArn=FAKE_ROLE_ARN,
PipelineDefinition=" ",
)
def test_list_pipelines_none(sagemaker_client): def test_list_pipelines_none(sagemaker_client):
response = sagemaker_client.list_pipelines() response = sagemaker_client.list_pipelines()
assert isinstance(response, dict) assert isinstance(response, dict)
@ -52,7 +137,15 @@ def test_list_pipelines_none(sagemaker_client):
def test_list_pipelines_single(sagemaker_client): def test_list_pipelines_single(sagemaker_client):
fake_pipeline_names = ["APipelineName"] fake_pipeline_names = ["APipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines() response = sagemaker_client.list_pipelines()
response["PipelineSummaries"].should.have.length_of(1) response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineArn"].should.equal( response["PipelineSummaries"][0]["PipelineArn"].should.equal(
@ -62,7 +155,16 @@ def test_list_pipelines_single(sagemaker_client):
def test_list_pipelines_multiple(sagemaker_client): def test_list_pipelines_multiple(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines( response = sagemaker_client.list_pipelines(
SortBy="Name", SortBy="Name",
SortOrder="Ascending", SortOrder="Ascending",
@ -72,7 +174,16 @@ def test_list_pipelines_multiple(sagemaker_client):
def test_list_pipelines_sort_name_ascending(sagemaker_client): def test_list_pipelines_sort_name_ascending(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines( response = sagemaker_client.list_pipelines(
SortBy="Name", SortBy="Name",
SortOrder="Ascending", SortOrder="Ascending",
@ -90,7 +201,16 @@ def test_list_pipelines_sort_name_ascending(sagemaker_client):
def test_list_pipelines_sort_creation_time_descending(sagemaker_client): def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 1) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines, 1.0)
response = sagemaker_client.list_pipelines( response = sagemaker_client.list_pipelines(
SortBy="CreationTime", SortBy="CreationTime",
SortOrder="Descending", SortOrder="Descending",
@ -108,14 +228,30 @@ def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
def test_list_pipelines_max_results(sagemaker_client): def test_list_pipelines_max_results(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(MaxResults=2) response = sagemaker_client.list_pipelines(MaxResults=2)
response["PipelineSummaries"].should.have.length_of(2) response["PipelineSummaries"].should.have.length_of(2)
def test_list_pipelines_next_token(sagemaker_client): def test_list_pipelines_next_token(sagemaker_client):
fake_pipeline_names = ["APipelineName"] fake_pipeline_names = ["APipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_names[0],
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(NextToken="0") response = sagemaker_client.list_pipelines(NextToken="0")
response["PipelineSummaries"].should.have.length_of(1) response["PipelineSummaries"].should.have.length_of(1)
@ -123,7 +259,16 @@ def test_list_pipelines_next_token(sagemaker_client):
def test_list_pipelines_pipeline_name_prefix(sagemaker_client): def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe") response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
response["PipelineSummaries"].should.have.length_of(1) response["PipelineSummaries"].should.have.length_of(1)
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName") response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
@ -134,7 +279,15 @@ def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
def test_list_pipelines_created_after(sagemaker_client): def test_list_pipelines_created_after(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
created_after_str = "2099-12-31 23:59:59" created_after_str = "2099-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str) response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
@ -151,7 +304,15 @@ def test_list_pipelines_created_after(sagemaker_client):
def test_list_pipelines_created_before(sagemaker_client): def test_list_pipelines_created_before(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
created_before_str = "2000-12-31 23:59:59" created_before_str = "2000-12-31 23:59:59"
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str) response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
@ -182,7 +343,15 @@ def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs):
def test_delete_pipeline_exists(sagemaker_client): def test_delete_pipeline_exists(sagemaker_client):
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"] fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0) pipelines = [
{
"PipelineName": fake_pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
for fake_pipeline_name in fake_pipeline_names
]
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
pipeline_name_delete, pipeline_names_remain = ( pipeline_name_delete, pipeline_names_remain = (
fake_pipeline_names[0], fake_pipeline_names[0],
fake_pipeline_names[1:], fake_pipeline_names[1:],
@ -198,7 +367,7 @@ def test_delete_pipeline_exists(sagemaker_client):
pipeline_names_exist = [ pipeline_names_exist = [
pipeline["PipelineName"] for pipeline in response["PipelineSummaries"] pipeline["PipelineName"] for pipeline in response["PipelineSummaries"]
] ]
assert pipeline_names_remain == pipeline_names_exist assert set(pipeline_names_remain) == set(pipeline_names_exist)
def test_delete_pipeline_not_exists(sagemaker_client): def test_delete_pipeline_not_exists(sagemaker_client):
@ -206,14 +375,36 @@ def test_delete_pipeline_not_exists(sagemaker_client):
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name") _ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
def test_update_pipeline(sagemaker_client): def test_update_pipeline_not_exists(sagemaker_client):
with pytest.raises(botocore.exceptions.ClientError): with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name") _ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
def test_update_pipeline_invalid_kwargs(sagemaker_client):
pipeline_name = "APipelineName"
pipeline = {
"PipelineName": pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
with pytest.raises(botocore.exceptions.ParamValidationError):
sagemaker_client.update_pipeline(
PipelineName=pipeline_name,
**{"InvalidKwarg": "some-value"},
)
def test_update_pipeline_no_update(sagemaker_client): def test_update_pipeline_no_update(sagemaker_client):
pipeline_name = "APipelineName" pipeline_name = "APipelineName"
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) pipeline = {
"PipelineName": pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name) response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
response["PipelineArn"].should.equal( response["PipelineArn"].should.equal(
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME) arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
@ -226,9 +417,14 @@ def test_update_pipeline_add_attribute(sagemaker_client):
pipeline_name = "APipelineName" pipeline_name = "APipelineName"
pipeline_display_name_update = "APipelineDisplayName" pipeline_display_name_update = "APipelineDisplayName"
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name]) pipeline = {
"PipelineName": pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.list_pipelines() response = sagemaker_client.list_pipelines()
assert "PipelineDisplayName" not in response["PipelineSummaries"][0] response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name)
_ = sagemaker_client.update_pipeline( _ = sagemaker_client.update_pipeline(
PipelineName=pipeline_name, PipelineName=pipeline_name,
@ -238,14 +434,19 @@ def test_update_pipeline_add_attribute(sagemaker_client):
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal( response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
pipeline_display_name_update pipeline_display_name_update
) )
response["PipelineSummaries"][0].should.have.length_of(7) response["PipelineSummaries"][0].should.have.length_of(6)
def test_update_pipeline_update_change_attribute(sagemaker_client): def test_update_pipeline_update_change_attribute(sagemaker_client):
pipeline_name = "APipelineName" pipeline_name = "APipelineName"
role_arn_update = f"{FAKE_ROLE_ARN}Test" role_arn_update = f"{FAKE_ROLE_ARN}Test"
pipeline = {
"PipelineName": pipeline_name,
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
}
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
_ = sagemaker_client.update_pipeline( _ = sagemaker_client.update_pipeline(
PipelineName=pipeline_name, PipelineName=pipeline_name,
RoleArn=role_arn_update, RoleArn=role_arn_update,
@ -253,3 +454,36 @@ def test_update_pipeline_update_change_attribute(sagemaker_client):
response = sagemaker_client.list_pipelines() response = sagemaker_client.list_pipelines()
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update) response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
response["PipelineSummaries"][0].should.have.length_of(6) response["PipelineSummaries"][0].should.have.length_of(6)
def test_describe_pipeline_not_exists(sagemaker_client):
with pytest.raises(botocore.exceptions.ClientError):
_ = sagemaker_client.describe_pipeline(PipelineName="some-pipeline-name")
@pytest.mark.parametrize(
"pipeline,expected_response_length",
[
(
{
"PipelineName": "APipelineName",
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
},
11,
),
(
{
"PipelineName": "BPipelineName",
"RoleArn": FAKE_ROLE_ARN,
"PipelineDefinition": " ",
"PipelineDescription": "some pipeline description",
},
12,
),
],
)
def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length):
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"])
response.should.have.length_of(expected_response_length)