Add sagemaker mock call: describe_pipeline (#5797)
* Added mock for sagemaker describe-pipeline call * Added NotImplementedError for PipelineDefinitionS3Location * Added support for PipelineDefinitionS3Location * Extended unit tests * Moved arn_formatter into utils * Import arn_formatter in test_sagemaker_pipeline * Adding uniqueness check for PipelineName * Removed unused import * Swapped client for s3_backend * Corrected kwarg names * From direct s3_backend to mocked boto call due to strange error * Changed to using s3_backends from mocked boto3 call * Remove unused argument * Black formatting * Delete object and bucket to avoid duplicate bucket names error * Try to fix bucket collisions * Remove unused lines * Switched to mock * SkipTest in server mode * Switched to handling inside to-be tested method * added s3 mock * mock s3 * Change mocking s3 * Removed unnecessary tests * Switch to only s3_backend * Adding skiptest to load from s3
This commit is contained in:
parent
cf0bcbce91
commit
4b117c4884
@ -170,7 +170,7 @@ sagemaker
|
||||
- [ ] describe_monitoring_schedule
|
||||
- [ ] describe_notebook_instance
|
||||
- [X] describe_notebook_instance_lifecycle_config
|
||||
- [ ] describe_pipeline
|
||||
- [X] describe_pipeline
|
||||
- [ ] describe_pipeline_definition_for_execution
|
||||
- [ ] describe_pipeline_execution
|
||||
- [X] describe_processing_job
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||
from moto.sagemaker import validators
|
||||
from moto.utilities.paginator import paginate
|
||||
@ -10,6 +11,7 @@ from .exceptions import (
|
||||
AWSValidationException,
|
||||
ResourceNotFound,
|
||||
)
|
||||
from .utils import load_pipeline_definition_from_s3, arn_formatter
|
||||
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
@ -44,10 +46,6 @@ PAGINATION_MODEL = {
|
||||
}
|
||||
|
||||
|
||||
def arn_formatter(_type, _id, account_id, region_name):
|
||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
def camelCase(self, key):
|
||||
words = []
|
||||
@ -86,24 +84,42 @@ class FakePipeline(BaseObject):
|
||||
account_id,
|
||||
region_name,
|
||||
parallelism_configuration,
|
||||
pipeline_definition_s3_location,
|
||||
):
|
||||
self.pipeline_name = pipeline_name
|
||||
self.pipeline_arn = arn_formatter(
|
||||
"pipeline", pipeline_name, account_id, region_name
|
||||
)
|
||||
self.pipeline_display_name = pipeline_display_name
|
||||
self.pipeline_display_name = pipeline_display_name or pipeline_name
|
||||
self.pipeline_definition = pipeline_definition
|
||||
self.pipeline_description = pipeline_description
|
||||
self.role_arn = role_arn
|
||||
self.tags = tags or []
|
||||
self.parallelism_configuration = parallelism_configuration
|
||||
self.pipeline_definition_s3_location = pipeline_definition_s3_location
|
||||
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.last_execution_time = now_string
|
||||
self.last_execution_time = None
|
||||
|
||||
self.pipeline_status = "Active"
|
||||
fake_user_profile_name = "fake-user-profile-name"
|
||||
fake_domain_id = "fake-domain-id"
|
||||
fake_user_profile_arn = arn_formatter(
|
||||
"user-profile",
|
||||
f"{fake_domain_id}/{fake_user_profile_name}",
|
||||
account_id,
|
||||
region_name,
|
||||
)
|
||||
self.created_by = {
|
||||
"UserProfileArn": fake_user_profile_arn,
|
||||
"UserProfileName": fake_user_profile_name,
|
||||
"DomainId": fake_domain_id,
|
||||
}
|
||||
self.last_modified_by = {
|
||||
"UserProfileArn": fake_user_profile_arn,
|
||||
"UserProfileName": fake_user_profile_name,
|
||||
"DomainId": fake_domain_id,
|
||||
}
|
||||
|
||||
|
||||
class FakeProcessingJob(BaseObject):
|
||||
@ -1758,6 +1774,28 @@ class SageMakerModelBackend(BaseBackend):
|
||||
tags,
|
||||
parallelism_configuration,
|
||||
):
|
||||
if not any([pipeline_definition, pipeline_definition_s3_location]):
|
||||
raise ValidationError(
|
||||
"An error occurred (ValidationException) when calling the CreatePipeline operation: Either "
|
||||
"Pipeline Definition or Pipeline Definition S3 location should be provided"
|
||||
)
|
||||
if all([pipeline_definition, pipeline_definition_s3_location]):
|
||||
raise ValidationError(
|
||||
"An error occurred (ValidationException) when calling the CreatePipeline operation: "
|
||||
"Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present"
|
||||
)
|
||||
|
||||
if pipeline_name in self.pipelines:
|
||||
raise ValidationError(
|
||||
f"An error occurred (ValidationException) when calling the CreatePipeline operation: Pipeline names "
|
||||
f"must be unique within an AWS account and region. Pipeline with name ({pipeline_name}) already exists."
|
||||
)
|
||||
|
||||
if pipeline_definition_s3_location:
|
||||
pipeline_definition = load_pipeline_definition_from_s3(
|
||||
pipeline_definition_s3_location, self.account_id
|
||||
)
|
||||
|
||||
pipeline = FakePipeline(
|
||||
pipeline_name,
|
||||
pipeline_display_name,
|
||||
@ -1767,7 +1805,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
tags,
|
||||
self.account_id,
|
||||
self.region_name,
|
||||
pipeline_definition_s3_location,
|
||||
parallelism_configuration,
|
||||
)
|
||||
|
||||
@ -1799,28 +1836,59 @@ class SageMakerModelBackend(BaseBackend):
|
||||
message=f"Could not find pipeline with name {pipeline_name}."
|
||||
)
|
||||
|
||||
provided_kwargs = set(kwargs.keys())
|
||||
allowed_kwargs = {
|
||||
"pipeline_display_name",
|
||||
"pipeline_definition",
|
||||
"pipeline_definition_s3_location",
|
||||
"pipeline_description",
|
||||
"role_arn",
|
||||
"parallelism_configuration",
|
||||
}
|
||||
invalid_kwargs = provided_kwargs - allowed_kwargs
|
||||
|
||||
if invalid_kwargs:
|
||||
raise TypeError(
|
||||
f"update_pipeline got unexpected keyword arguments '{invalid_kwargs}'"
|
||||
if all(
|
||||
[
|
||||
kwargs.get("pipeline_definition"),
|
||||
kwargs.get("pipeline_definition_s3_location"),
|
||||
]
|
||||
):
|
||||
raise ValidationError(
|
||||
"An error occurred (ValidationException) when calling the UpdatePipeline operation: "
|
||||
"Both Pipeline Definition and Pipeline Definition S3 Location shouldn't be present"
|
||||
)
|
||||
|
||||
for attr_key, attr_value in kwargs.items():
|
||||
if attr_value:
|
||||
if attr_key == "pipeline_definition_s3_location":
|
||||
self.pipelines[
|
||||
pipeline_name
|
||||
].pipeline_definition = load_pipeline_definition_from_s3(
|
||||
attr_value, self.account_id
|
||||
)
|
||||
continue
|
||||
setattr(self.pipelines[pipeline_name], attr_key, attr_value)
|
||||
|
||||
return pipeline_arn
|
||||
|
||||
def describe_pipeline(
|
||||
self,
|
||||
pipeline_name,
|
||||
):
|
||||
try:
|
||||
pipeline = self.pipelines[pipeline_name]
|
||||
except KeyError:
|
||||
raise ValidationError(
|
||||
message=f"Could not find pipeline with name {pipeline_name}."
|
||||
)
|
||||
|
||||
response = {
|
||||
"PipelineArn": pipeline.pipeline_arn,
|
||||
"PipelineName": pipeline.pipeline_name,
|
||||
"PipelineDisplayName": pipeline.pipeline_display_name,
|
||||
"PipelineDescription": pipeline.pipeline_description,
|
||||
"PipelineDefinition": pipeline.pipeline_definition,
|
||||
"RoleArn": pipeline.role_arn,
|
||||
"PipelineStatus": pipeline.pipeline_status,
|
||||
"CreationTime": pipeline.creation_time,
|
||||
"LastModifiedTime": pipeline.last_modified_time,
|
||||
"LastRunTime": pipeline.last_execution_time,
|
||||
"CreatedBy": pipeline.created_by,
|
||||
"LastModifiedBy": pipeline.last_modified_by,
|
||||
"ParallelismConfiguration": pipeline.parallelism_configuration,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
def list_pipelines(
|
||||
self,
|
||||
pipeline_name_prefix,
|
||||
|
@ -465,6 +465,13 @@ class SageMakerResponse(BaseResponse):
|
||||
response = self.sagemaker_backend.list_associations(self.request_params)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def describe_pipeline(self):
|
||||
response = self.sagemaker_backend.describe_pipeline(
|
||||
self._get_param("PipelineName")
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_pipeline(self):
|
||||
pipeline = self.sagemaker_backend.create_pipeline(
|
||||
|
15
moto/sagemaker/utils.py
Normal file
15
moto/sagemaker/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
from moto.s3.models import s3_backends
|
||||
import json
|
||||
|
||||
|
||||
def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id):
|
||||
s3_backend = s3_backends[account_id]["global"]
|
||||
result = s3_backend.get_object(
|
||||
bucket_name=pipeline_definition_s3_location["Bucket"],
|
||||
key_name=pipeline_definition_s3_location["ObjectKey"],
|
||||
)
|
||||
return json.loads(result.value)
|
||||
|
||||
|
||||
def arn_formatter(_type, _id, account_id, region_name):
|
||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
@ -1,18 +1,37 @@
|
||||
from moto import mock_sagemaker
|
||||
from contextlib import contextmanager
|
||||
from moto import mock_sagemaker, settings
|
||||
from time import sleep
|
||||
from datetime import datetime
|
||||
import boto3
|
||||
import botocore
|
||||
import json
|
||||
import pytest
|
||||
from unittest import SkipTest
|
||||
|
||||
from moto.s3 import mock_s3
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
from moto.sagemaker.utils import arn_formatter, load_pipeline_definition_from_s3
|
||||
|
||||
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
TEST_REGION_NAME = "us-west-1"
|
||||
|
||||
|
||||
def arn_formatter(_type, _id, account_id, region_name):
|
||||
return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"
|
||||
@contextmanager
|
||||
def setup_s3_pipeline_definition(bucket_name, object_key, pipeline_definition):
|
||||
client = boto3.client("s3")
|
||||
client.create_bucket(
|
||||
Bucket=bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": TEST_REGION_NAME},
|
||||
)
|
||||
client.put_object(
|
||||
Body=json.dumps(pipeline_definition),
|
||||
Bucket=bucket_name,
|
||||
Key=object_key,
|
||||
)
|
||||
yield
|
||||
|
||||
client.delete_object(Bucket=bucket_name, Key=object_key)
|
||||
client.delete_bucket(Bucket=bucket_name)
|
||||
|
||||
|
||||
@pytest.fixture(name="sagemaker_client")
|
||||
@ -21,22 +40,46 @@ def fixture_sagemaker_client():
|
||||
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
def create_sagemaker_pipelines(sagemaker_client, pipeline_names, wait_seconds=0.0):
|
||||
def create_sagemaker_pipelines(sagemaker_client, pipelines, wait_seconds=0.0):
|
||||
responses = []
|
||||
for pipeline_name in pipeline_names:
|
||||
responses += sagemaker_client.create_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
)
|
||||
for pipeline in pipelines:
|
||||
responses += sagemaker_client.create_pipeline(**pipeline)
|
||||
sleep(wait_seconds)
|
||||
return responses
|
||||
|
||||
|
||||
def test_load_pipeline_definition_from_s3():
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest(
|
||||
"Skipping test in server mode due to lack of access to s3_backend."
|
||||
)
|
||||
|
||||
bucket_name = "some-bucket-1"
|
||||
object_key = "some/object/key.json"
|
||||
pipeline_definition = {"key": "value"}
|
||||
|
||||
with mock_s3():
|
||||
with setup_s3_pipeline_definition(
|
||||
bucket_name,
|
||||
object_key,
|
||||
pipeline_definition,
|
||||
):
|
||||
observed_pipeline_definition = load_pipeline_definition_from_s3(
|
||||
pipeline_definition_s3_location={
|
||||
"Bucket": bucket_name,
|
||||
"ObjectKey": object_key,
|
||||
},
|
||||
account_id=ACCOUNT_ID,
|
||||
)
|
||||
observed_pipeline_definition.should.equal(pipeline_definition)
|
||||
|
||||
|
||||
def test_create_pipeline(sagemaker_client):
|
||||
fake_pipeline_name = "MyPipelineName"
|
||||
response = sagemaker_client.create_pipeline(
|
||||
PipelineName=fake_pipeline_name,
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
PipelineDefinition=" ",
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
response["PipelineArn"].should.equal(
|
||||
@ -44,6 +87,48 @@ def test_create_pipeline(sagemaker_client):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"create_pipeline_kwargs",
|
||||
[
|
||||
{"PipelineName": "MyPipelineName", "RoleArn": FAKE_ROLE_ARN},
|
||||
{"RoleArn": FAKE_ROLE_ARN, "PipelineDefinition": " "},
|
||||
{"PipelineName": "MyPipelineName", "PipelineDefinition": " "},
|
||||
{
|
||||
"PipelineName": "MyPipelineName",
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
"PipelineDefinitionS3Location": {"key": "value"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_create_pipeline_invalid_required_kwargs(
|
||||
sagemaker_client, create_pipeline_kwargs
|
||||
):
|
||||
with pytest.raises(
|
||||
(
|
||||
botocore.exceptions.ParamValidationError,
|
||||
botocore.exceptions.ClientError,
|
||||
)
|
||||
):
|
||||
_ = sagemaker_client.create_pipeline(
|
||||
**create_pipeline_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_create_pipeline_duplicate_pipeline_name(sagemaker_client):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.create_pipeline(
|
||||
PipelineName="APipelineName",
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
PipelineDefinition=" ",
|
||||
)
|
||||
_ = sagemaker_client.create_pipeline(
|
||||
PipelineName="APipelineName",
|
||||
RoleArn=FAKE_ROLE_ARN,
|
||||
PipelineDefinition=" ",
|
||||
)
|
||||
|
||||
|
||||
def test_list_pipelines_none(sagemaker_client):
|
||||
response = sagemaker_client.list_pipelines()
|
||||
assert isinstance(response, dict)
|
||||
@ -52,7 +137,15 @@ def test_list_pipelines_none(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_single(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_names[0],
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
},
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineArn"].should.equal(
|
||||
@ -62,7 +155,16 @@ def test_list_pipelines_single(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_multiple(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
@ -72,7 +174,16 @@ def test_list_pipelines_multiple(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="Name",
|
||||
SortOrder="Ascending",
|
||||
@ -90,7 +201,16 @@ def test_list_pipelines_sort_name_ascending(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 1)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines, 1.0)
|
||||
|
||||
response = sagemaker_client.list_pipelines(
|
||||
SortBy="CreationTime",
|
||||
SortOrder="Descending",
|
||||
@ -108,14 +228,30 @@ def test_list_pipelines_sort_creation_time_descending(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_max_results(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(MaxResults=2)
|
||||
response["PipelineSummaries"].should.have.length_of(2)
|
||||
|
||||
|
||||
def test_list_pipelines_next_token(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_names[0],
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
},
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(NextToken="0")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
@ -123,7 +259,16 @@ def test_list_pipelines_next_token(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
response = sagemaker_client.list_pipelines(PipelineNamePrefix="APipe")
|
||||
response["PipelineSummaries"].should.have.length_of(1)
|
||||
response["PipelineSummaries"][0]["PipelineName"].should.equal("APipelineName")
|
||||
@ -134,7 +279,15 @@ def test_list_pipelines_pipeline_name_prefix(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_created_after(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
created_after_str = "2099-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedAfter=created_after_str)
|
||||
@ -151,7 +304,15 @@ def test_list_pipelines_created_after(sagemaker_client):
|
||||
|
||||
def test_list_pipelines_created_before(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
|
||||
created_before_str = "2000-12-31 23:59:59"
|
||||
response = sagemaker_client.list_pipelines(CreatedBefore=created_before_str)
|
||||
@ -182,7 +343,15 @@ def test_list_pipelines_invalid_values(sagemaker_client, list_pipelines_kwargs):
|
||||
|
||||
def test_delete_pipeline_exists(sagemaker_client):
|
||||
fake_pipeline_names = ["APipelineName", "BPipelineName", "CPipelineName"]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, fake_pipeline_names, 0.0)
|
||||
pipelines = [
|
||||
{
|
||||
"PipelineName": fake_pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
for fake_pipeline_name in fake_pipeline_names
|
||||
]
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, pipelines)
|
||||
pipeline_name_delete, pipeline_names_remain = (
|
||||
fake_pipeline_names[0],
|
||||
fake_pipeline_names[1:],
|
||||
@ -198,7 +367,7 @@ def test_delete_pipeline_exists(sagemaker_client):
|
||||
pipeline_names_exist = [
|
||||
pipeline["PipelineName"] for pipeline in response["PipelineSummaries"]
|
||||
]
|
||||
assert pipeline_names_remain == pipeline_names_exist
|
||||
assert set(pipeline_names_remain) == set(pipeline_names_exist)
|
||||
|
||||
|
||||
def test_delete_pipeline_not_exists(sagemaker_client):
|
||||
@ -206,14 +375,36 @@ def test_delete_pipeline_not_exists(sagemaker_client):
|
||||
_ = sagemaker_client.delete_pipeline(PipelineName="some-pipeline-name")
|
||||
|
||||
|
||||
def test_update_pipeline(sagemaker_client):
|
||||
def test_update_pipeline_not_exists(sagemaker_client):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.update_pipeline(PipelineName="some-pipeline-name")
|
||||
|
||||
|
||||
def test_update_pipeline_invalid_kwargs(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
pipeline = {
|
||||
"PipelineName": pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
|
||||
with pytest.raises(botocore.exceptions.ParamValidationError):
|
||||
sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
**{"InvalidKwarg": "some-value"},
|
||||
)
|
||||
|
||||
|
||||
def test_update_pipeline_no_update(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
pipeline = {
|
||||
"PipelineName": pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
|
||||
response = sagemaker_client.update_pipeline(PipelineName=pipeline_name)
|
||||
response["PipelineArn"].should.equal(
|
||||
arn_formatter("pipeline", pipeline_name, ACCOUNT_ID, TEST_REGION_NAME)
|
||||
@ -226,9 +417,14 @@ def test_update_pipeline_add_attribute(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
pipeline_display_name_update = "APipelineDisplayName"
|
||||
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
pipeline = {
|
||||
"PipelineName": pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
response = sagemaker_client.list_pipelines()
|
||||
assert "PipelineDisplayName" not in response["PipelineSummaries"][0]
|
||||
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(pipeline_name)
|
||||
|
||||
_ = sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
@ -238,14 +434,19 @@ def test_update_pipeline_add_attribute(sagemaker_client):
|
||||
response["PipelineSummaries"][0]["PipelineDisplayName"].should.equal(
|
||||
pipeline_display_name_update
|
||||
)
|
||||
response["PipelineSummaries"][0].should.have.length_of(7)
|
||||
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||
|
||||
|
||||
def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||
pipeline_name = "APipelineName"
|
||||
role_arn_update = f"{FAKE_ROLE_ARN}Test"
|
||||
pipeline = {
|
||||
"PipelineName": pipeline_name,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
}
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline_name])
|
||||
_ = sagemaker_client.update_pipeline(
|
||||
PipelineName=pipeline_name,
|
||||
RoleArn=role_arn_update,
|
||||
@ -253,3 +454,36 @@ def test_update_pipeline_update_change_attribute(sagemaker_client):
|
||||
response = sagemaker_client.list_pipelines()
|
||||
response["PipelineSummaries"][0]["RoleArn"].should.equal(role_arn_update)
|
||||
response["PipelineSummaries"][0].should.have.length_of(6)
|
||||
|
||||
|
||||
def test_describe_pipeline_not_exists(sagemaker_client):
|
||||
with pytest.raises(botocore.exceptions.ClientError):
|
||||
_ = sagemaker_client.describe_pipeline(PipelineName="some-pipeline-name")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pipeline,expected_response_length",
|
||||
[
|
||||
(
|
||||
{
|
||||
"PipelineName": "APipelineName",
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
},
|
||||
11,
|
||||
),
|
||||
(
|
||||
{
|
||||
"PipelineName": "BPipelineName",
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"PipelineDefinition": " ",
|
||||
"PipelineDescription": "some pipeline description",
|
||||
},
|
||||
12,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_describe_pipeline_exists(sagemaker_client, pipeline, expected_response_length):
|
||||
_ = create_sagemaker_pipelines(sagemaker_client, [pipeline])
|
||||
response = sagemaker_client.describe_pipeline(PipelineName=pipeline["PipelineName"])
|
||||
response.should.have.length_of(expected_response_length)
|
||||
|
Loading…
Reference in New Issue
Block a user