Add support for SageMaker Processing job (#4533)
This commit is contained in:
parent
1e85e16f0f
commit
aab2a25dfa
@ -39,6 +39,62 @@ class BaseObject(BaseModel):
|
||||
return self.gen_response_object()
|
||||
|
||||
|
||||
class FakeProcessingJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
app_specification,
|
||||
experiment_config,
|
||||
network_config,
|
||||
processing_inputs,
|
||||
processing_job_name,
|
||||
processing_output_config,
|
||||
processing_resources,
|
||||
region_name,
|
||||
role_arn,
|
||||
stopping_condition,
|
||||
):
|
||||
self.processing_job_name = processing_job_name
|
||||
self.processing_job_arn = FakeProcessingJob.arn_formatter(
|
||||
processing_job_name, region_name
|
||||
)
|
||||
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.processing_end_time = now_string
|
||||
|
||||
self.role_arn = role_arn
|
||||
self.app_specification = app_specification
|
||||
self.experiment_config = experiment_config
|
||||
self.network_config = network_config
|
||||
self.processing_inputs = processing_inputs
|
||||
self.processing_job_status = "Completed"
|
||||
self.processing_output_config = processing_output_config
|
||||
self.stopping_condition = stopping_condition
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"ProcessingJobArn": self.processing_job_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(endpoint_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":processing-job/"
|
||||
+ endpoint_name
|
||||
)
|
||||
|
||||
|
||||
class FakeTrainingJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
@ -898,6 +954,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.experiments = {}
|
||||
self.processing_jobs = {}
|
||||
self.trials = {}
|
||||
self.trial_components = {}
|
||||
self.training_jobs = {}
|
||||
@ -1672,6 +1729,129 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_processing_job(
|
||||
self,
|
||||
app_specification,
|
||||
experiment_config,
|
||||
network_config,
|
||||
processing_inputs,
|
||||
processing_job_name,
|
||||
processing_output_config,
|
||||
processing_resources,
|
||||
role_arn,
|
||||
stopping_condition,
|
||||
):
|
||||
processing_job = FakeProcessingJob(
|
||||
app_specification=app_specification,
|
||||
experiment_config=experiment_config,
|
||||
network_config=network_config,
|
||||
processing_inputs=processing_inputs,
|
||||
processing_job_name=processing_job_name,
|
||||
processing_output_config=processing_output_config,
|
||||
processing_resources=processing_resources,
|
||||
region_name=self.region_name,
|
||||
role_arn=role_arn,
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
self.processing_jobs[processing_job_name] = processing_job
|
||||
return processing_job
|
||||
|
||||
def describe_processing_job(self, processing_job_name):
|
||||
try:
|
||||
return self.processing_jobs[processing_job_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find processing job '{}'.".format(
|
||||
FakeProcessingJob.arn_formatter(processing_job_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def list_processing_jobs(
|
||||
self,
|
||||
next_token,
|
||||
max_results,
|
||||
creation_time_after,
|
||||
creation_time_before,
|
||||
last_modified_time_after,
|
||||
last_modified_time_before,
|
||||
name_contains,
|
||||
status_equals,
|
||||
sort_by,
|
||||
sort_order,
|
||||
):
|
||||
if next_token:
|
||||
try:
|
||||
starting_index = int(next_token)
|
||||
if starting_index > len(self.processing_jobs):
|
||||
raise ValueError # invalid next_token
|
||||
except ValueError:
|
||||
raise AWSValidationException('Invalid pagination token because "{0}".')
|
||||
else:
|
||||
starting_index = 0
|
||||
|
||||
if max_results:
|
||||
end_index = max_results + starting_index
|
||||
processing_jobs_fetched = list(self.processing_jobs.values())[
|
||||
starting_index:end_index
|
||||
]
|
||||
if end_index >= len(self.processing_jobs):
|
||||
next_index = None
|
||||
else:
|
||||
next_index = end_index
|
||||
else:
|
||||
processing_jobs_fetched = list(self.processing_jobs.values())
|
||||
next_index = None
|
||||
|
||||
if name_contains is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: name_contains in x.processing_job_name,
|
||||
processing_jobs_fetched,
|
||||
)
|
||||
|
||||
if creation_time_after is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: x.creation_time > creation_time_after, processing_jobs_fetched
|
||||
)
|
||||
|
||||
if creation_time_before is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: x.creation_time < creation_time_before,
|
||||
processing_jobs_fetched,
|
||||
)
|
||||
|
||||
if last_modified_time_after is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time > last_modified_time_after,
|
||||
processing_jobs_fetched,
|
||||
)
|
||||
|
||||
if last_modified_time_before is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time < last_modified_time_before,
|
||||
processing_jobs_fetched,
|
||||
)
|
||||
if status_equals is not None:
|
||||
processing_jobs_fetched = filter(
|
||||
lambda x: x.training_job_status == status_equals,
|
||||
processing_jobs_fetched,
|
||||
)
|
||||
|
||||
processing_job_summaries = [
|
||||
{
|
||||
"ProcessingJobName": processing_job_data.processing_job_name,
|
||||
"ProcessingJobArn": processing_job_data.processing_job_arn,
|
||||
"CreationTime": processing_job_data.creation_time,
|
||||
"ProcessingEndTime": processing_job_data.processing_end_time,
|
||||
"LastModifiedTime": processing_job_data.last_modified_time,
|
||||
"ProcessingJobStatus": processing_job_data.processing_job_status,
|
||||
}
|
||||
for processing_job_data in processing_jobs_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"ProcessingJobSummaries": processing_job_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def create_training_job(
|
||||
self,
|
||||
training_job_name,
|
||||
|
@ -225,6 +225,33 @@ class SageMakerResponse(BaseResponse):
|
||||
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_processing_job(self):
|
||||
try:
|
||||
processing_job = self.sagemaker_backend.create_processing_job(
|
||||
app_specification=self._get_param("AppSpecification"),
|
||||
experiment_config=self._get_param("ExperimentConfig"),
|
||||
network_config=self._get_param("NetworkConfig"),
|
||||
processing_inputs=self._get_param("ProcessingInputs"),
|
||||
processing_job_name=self._get_param("ProcessingJobName"),
|
||||
processing_output_config=self._get_param("ProcessingOutputConfig"),
|
||||
processing_resources=self._get_param("ProcessingResources"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
stopping_condition=self._get_param("StoppingCondition"),
|
||||
)
|
||||
response = {
|
||||
"ProcessingJobArn": processing_job.processing_job_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_processing_job(self):
|
||||
processing_job_name = self._get_param("ProcessingJobName")
|
||||
response = self.sagemaker_backend.describe_processing_job(processing_job_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_training_job(self):
|
||||
try:
|
||||
@ -420,6 +447,68 @@ class SageMakerResponse(BaseResponse):
|
||||
response = self.sagemaker_backend.list_associations(self.request_params)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_processing_jobs(self):
|
||||
max_results_range = range(1, 101)
|
||||
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
||||
allowed_sort_order = ["Ascending", "Descending"]
|
||||
allowed_status_equals = [
|
||||
"Completed",
|
||||
"Stopped",
|
||||
"InProgress",
|
||||
"Stopping",
|
||||
"Failed",
|
||||
]
|
||||
|
||||
try:
|
||||
max_results = self._get_int_param("MaxResults")
|
||||
sort_by = self._get_param("SortBy", "CreationTime")
|
||||
sort_order = self._get_param("SortOrder", "Ascending")
|
||||
status_equals = self._get_param("StatusEquals")
|
||||
next_token = self._get_param("NextToken")
|
||||
errors = []
|
||||
if max_results and max_results not in max_results_range:
|
||||
errors.append(
|
||||
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
|
||||
max_results, max_results_range[-1]
|
||||
)
|
||||
)
|
||||
|
||||
if sort_by not in allowed_sort_by:
|
||||
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
|
||||
if sort_order not in allowed_sort_order:
|
||||
errors.append(
|
||||
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
|
||||
)
|
||||
|
||||
if status_equals and status_equals not in allowed_status_equals:
|
||||
errors.append(
|
||||
format_enum_error(
|
||||
status_equals, "statusEquals", allowed_status_equals
|
||||
)
|
||||
)
|
||||
|
||||
if errors != []:
|
||||
raise AWSValidationException(
|
||||
f"{len(errors)} validation errors detected: {';'.join(errors)}"
|
||||
)
|
||||
|
||||
response = self.sagemaker_backend.list_processing_jobs(
|
||||
next_token=next_token,
|
||||
max_results=max_results,
|
||||
creation_time_after=self._get_param("CreationTimeAfter"),
|
||||
creation_time_before=self._get_param("CreationTimeBefore"),
|
||||
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
|
||||
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
|
||||
name_contains=self._get_param("NameContains"),
|
||||
status_equals=status_equals,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def list_training_jobs(self):
|
||||
max_results_range = range(1, 101)
|
||||
|
334
tests/test_sagemaker/test_sagemaker_processing.py
Normal file
334
tests/test_sagemaker/test_sagemaker_processing.py
Normal file
@ -0,0 +1,334 @@
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
class MyProcessingJobModel(object):
|
||||
def __init__(
|
||||
self,
|
||||
processing_job_name,
|
||||
role_arn,
|
||||
container=None,
|
||||
bucket=None,
|
||||
prefix=None,
|
||||
app_specification=None,
|
||||
network_config=None,
|
||||
processing_inputs=None,
|
||||
processing_output_config=None,
|
||||
processing_resources=None,
|
||||
stopping_condition=None,
|
||||
):
|
||||
self.processing_job_name = processing_job_name
|
||||
self.role_arn = role_arn
|
||||
self.container = (
|
||||
container
|
||||
or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3"
|
||||
)
|
||||
self.bucket = bucket or "my-bucket"
|
||||
self.prefix = prefix or "sagemaker"
|
||||
self.app_specification = app_specification or {
|
||||
"ImageUri": self.container,
|
||||
"ContainerEntrypoint": ["python3",],
|
||||
}
|
||||
self.network_config = network_config or {
|
||||
"EnableInterContainerTrafficEncryption": False,
|
||||
"EnableNetworkIsolation": False,
|
||||
}
|
||||
self.processing_inputs = processing_inputs or [
|
||||
{
|
||||
"InputName": "input",
|
||||
"AppManaged": False,
|
||||
"S3Input": {
|
||||
"S3Uri": "s3://{}/{}/processing/".format(self.bucket, self.prefix),
|
||||
"LocalPath": "/opt/ml/processing/input",
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3InputMode": "File",
|
||||
"S3DataDistributionType": "FullyReplicated",
|
||||
"S3CompressionType": "None",
|
||||
},
|
||||
}
|
||||
]
|
||||
self.processing_output_config = processing_output_config or {
|
||||
"Outputs": [
|
||||
{
|
||||
"OutputName": "output",
|
||||
"S3Output": {
|
||||
"S3Uri": "s3://{}/{}/processing/".format(
|
||||
self.bucket, self.prefix
|
||||
),
|
||||
"LocalPath": "/opt/ml/processing/output",
|
||||
"S3UploadMode": "EndOfJob",
|
||||
},
|
||||
"AppManaged": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
self.processing_resources = processing_resources or {
|
||||
"ClusterConfig": {
|
||||
"InstanceCount": 1,
|
||||
"InstanceType": "ml.m5.large",
|
||||
"VolumeSizeInGB": 10,
|
||||
},
|
||||
}
|
||||
self.stopping_condition = stopping_condition or {
|
||||
"MaxRuntimeInSeconds": 3600,
|
||||
}
|
||||
|
||||
def save(self):
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
params = {
|
||||
"AppSpecification": self.app_specification,
|
||||
"NetworkConfig": self.network_config,
|
||||
"ProcessingInputs": self.processing_inputs,
|
||||
"ProcessingJobName": self.processing_job_name,
|
||||
"ProcessingOutputConfig": self.processing_output_config,
|
||||
"ProcessingResources": self.processing_resources,
|
||||
"RoleArn": self.role_arn,
|
||||
"StoppingCondition": self.stopping_condition,
|
||||
}
|
||||
|
||||
return sagemaker.create_processing_job(**params)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_processing_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
processing_job_name = "MyProcessingJob"
|
||||
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
bucket = "my-bucket"
|
||||
prefix = "my-prefix"
|
||||
app_specification = {
|
||||
"ImageUri": container,
|
||||
"ContainerEntrypoint": ["python3", "app.py"],
|
||||
}
|
||||
processing_resources = {
|
||||
"ClusterConfig": {
|
||||
"InstanceCount": 2,
|
||||
"InstanceType": "ml.m5.xlarge",
|
||||
"VolumeSizeInGB": 20,
|
||||
},
|
||||
}
|
||||
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
|
||||
|
||||
job = MyProcessingJobModel(
|
||||
processing_job_name,
|
||||
role_arn,
|
||||
container=container,
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
app_specification=app_specification,
|
||||
processing_resources=processing_resources,
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
resp["ProcessingJobName"].should.equal(processing_job_name)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
|
||||
)
|
||||
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert resp["RoleArn"] == role_arn
|
||||
assert resp["ProcessingJobStatus"] == "Completed"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_processing_job = MyProcessingJobModel(processing_job_name=name, role_arn=arn)
|
||||
test_processing_job.save()
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal(name)
|
||||
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobArn"
|
||||
].should.match(r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(name))
|
||||
assert processing_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name_job_1 = "blah"
|
||||
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_processing_job_1 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_1, role_arn=arn_job_1
|
||||
)
|
||||
test_processing_job_1.save()
|
||||
|
||||
name_job_2 = "blah2"
|
||||
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
|
||||
test_processing_job_2 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_2, role_arn=arn_job_2
|
||||
)
|
||||
test_processing_job_2.save()
|
||||
processing_jobs_limit = client.list_processing_jobs(MaxResults=1)
|
||||
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
|
||||
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_should_validate_input():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_processing_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
junk_next_token = "asdf"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_processing_jobs(NextToken=junk_next_token)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"]
|
||||
== 'Invalid pagination token because "{0}".'
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_with_name_filters():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
xgboost_processing_jobs = client.list_processing_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
|
||||
|
||||
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2")
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
xgboost_processing_job_1 = client.list_processing_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
xgboost_processing_job_next = client.list_processing_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_processing_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_target_in_middle():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
|
||||
vgg_processing_job_1 = client.list_processing_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert vgg_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_6 = client.list_processing_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_processing_job_6.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_10 = client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=10
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
|
||||
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_processing_job_10.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_fragmented_targets():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
|
||||
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs_with_2.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next = client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next_next = client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
|
||||
0
|
||||
)
|
||||
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
|
Loading…
Reference in New Issue
Block a user