Add support for SageMaker Processing job (#4533)

This commit is contained in:
Przemysław Dąbek 2021-11-06 13:47:42 +01:00 committed by GitHub
parent 1e85e16f0f
commit aab2a25dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 603 additions and 0 deletions

View File

@ -39,6 +39,62 @@ class BaseObject(BaseModel):
return self.gen_response_object()
class FakeProcessingJob(BaseObject):
def __init__(
self,
app_specification,
experiment_config,
network_config,
processing_inputs,
processing_job_name,
processing_output_config,
processing_resources,
region_name,
role_arn,
stopping_condition,
):
self.processing_job_name = processing_job_name
self.processing_job_arn = FakeProcessingJob.arn_formatter(
processing_job_name, region_name
)
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.processing_end_time = now_string
self.role_arn = role_arn
self.app_specification = app_specification
self.experiment_config = experiment_config
self.network_config = network_config
self.processing_inputs = processing_inputs
self.processing_job_status = "Completed"
self.processing_output_config = processing_output_config
self.stopping_condition = stopping_condition
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"ProcessingJobArn": self.processing_job_arn}
@staticmethod
def arn_formatter(endpoint_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":processing-job/"
+ endpoint_name
)
class FakeTrainingJob(BaseObject):
def __init__(
self,
@ -898,6 +954,7 @@ class SageMakerModelBackend(BaseBackend):
self.endpoint_configs = {}
self.endpoints = {}
self.experiments = {}
self.processing_jobs = {}
self.trials = {}
self.trial_components = {}
self.training_jobs = {}
@ -1672,6 +1729,129 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
def create_processing_job(
self,
app_specification,
experiment_config,
network_config,
processing_inputs,
processing_job_name,
processing_output_config,
processing_resources,
role_arn,
stopping_condition,
):
processing_job = FakeProcessingJob(
app_specification=app_specification,
experiment_config=experiment_config,
network_config=network_config,
processing_inputs=processing_inputs,
processing_job_name=processing_job_name,
processing_output_config=processing_output_config,
processing_resources=processing_resources,
region_name=self.region_name,
role_arn=role_arn,
stopping_condition=stopping_condition,
)
self.processing_jobs[processing_job_name] = processing_job
return processing_job
def describe_processing_job(self, processing_job_name):
try:
return self.processing_jobs[processing_job_name].response_object
except KeyError:
message = "Could not find processing job '{}'.".format(
FakeProcessingJob.arn_formatter(processing_job_name, self.region_name)
)
raise ValidationError(message=message)
def list_processing_jobs(
self,
next_token,
max_results,
creation_time_after,
creation_time_before,
last_modified_time_after,
last_modified_time_before,
name_contains,
status_equals,
sort_by,
sort_order,
):
if next_token:
try:
starting_index = int(next_token)
if starting_index > len(self.processing_jobs):
raise ValueError # invalid next_token
except ValueError:
raise AWSValidationException('Invalid pagination token because "{0}".')
else:
starting_index = 0
if max_results:
end_index = max_results + starting_index
processing_jobs_fetched = list(self.processing_jobs.values())[
starting_index:end_index
]
if end_index >= len(self.processing_jobs):
next_index = None
else:
next_index = end_index
else:
processing_jobs_fetched = list(self.processing_jobs.values())
next_index = None
if name_contains is not None:
processing_jobs_fetched = filter(
lambda x: name_contains in x.processing_job_name,
processing_jobs_fetched,
)
if creation_time_after is not None:
processing_jobs_fetched = filter(
lambda x: x.creation_time > creation_time_after, processing_jobs_fetched
)
if creation_time_before is not None:
processing_jobs_fetched = filter(
lambda x: x.creation_time < creation_time_before,
processing_jobs_fetched,
)
if last_modified_time_after is not None:
processing_jobs_fetched = filter(
lambda x: x.last_modified_time > last_modified_time_after,
processing_jobs_fetched,
)
if last_modified_time_before is not None:
processing_jobs_fetched = filter(
lambda x: x.last_modified_time < last_modified_time_before,
processing_jobs_fetched,
)
if status_equals is not None:
processing_jobs_fetched = filter(
lambda x: x.training_job_status == status_equals,
processing_jobs_fetched,
)
processing_job_summaries = [
{
"ProcessingJobName": processing_job_data.processing_job_name,
"ProcessingJobArn": processing_job_data.processing_job_arn,
"CreationTime": processing_job_data.creation_time,
"ProcessingEndTime": processing_job_data.processing_end_time,
"LastModifiedTime": processing_job_data.last_modified_time,
"ProcessingJobStatus": processing_job_data.processing_job_status,
}
for processing_job_data in processing_jobs_fetched
]
return {
"ProcessingJobSummaries": processing_job_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def create_training_job(
self,
training_job_name,

View File

@ -225,6 +225,33 @@ class SageMakerResponse(BaseResponse):
self.sagemaker_backend.delete_endpoint(endpoint_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_processing_job(self):
try:
processing_job = self.sagemaker_backend.create_processing_job(
app_specification=self._get_param("AppSpecification"),
experiment_config=self._get_param("ExperimentConfig"),
network_config=self._get_param("NetworkConfig"),
processing_inputs=self._get_param("ProcessingInputs"),
processing_job_name=self._get_param("ProcessingJobName"),
processing_output_config=self._get_param("ProcessingOutputConfig"),
processing_resources=self._get_param("ProcessingResources"),
role_arn=self._get_param("RoleArn"),
stopping_condition=self._get_param("StoppingCondition"),
)
response = {
"ProcessingJobArn": processing_job.processing_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_processing_job(self):
processing_job_name = self._get_param("ProcessingJobName")
response = self.sagemaker_backend.describe_processing_job(processing_job_name)
return json.dumps(response)
@amzn_request_id
def create_training_job(self):
try:
@ -420,6 +447,68 @@ class SageMakerResponse(BaseResponse):
response = self.sagemaker_backend.list_associations(self.request_params)
return 200, {}, json.dumps(response)
@amzn_request_id
def list_processing_jobs(self):
max_results_range = range(1, 101)
allowed_sort_by = ["Name", "CreationTime", "Status"]
allowed_sort_order = ["Ascending", "Descending"]
allowed_status_equals = [
"Completed",
"Stopped",
"InProgress",
"Stopping",
"Failed",
]
try:
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
"Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format(
max_results, max_results_range[-1]
)
)
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(
status_equals, "statusEquals", allowed_status_equals
)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_processing_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
sort_by=sort_by,
sort_order=sort_order,
)
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def list_training_jobs(self):
max_results_range = range(1, 101)

View File

@ -0,0 +1,334 @@
import boto3
from botocore.exceptions import ClientError
import datetime
import pytest
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
TEST_REGION_NAME = "us-east-1"
class MyProcessingJobModel(object):
def __init__(
self,
processing_job_name,
role_arn,
container=None,
bucket=None,
prefix=None,
app_specification=None,
network_config=None,
processing_inputs=None,
processing_output_config=None,
processing_resources=None,
stopping_condition=None,
):
self.processing_job_name = processing_job_name
self.role_arn = role_arn
self.container = (
container
or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3"
)
self.bucket = bucket or "my-bucket"
self.prefix = prefix or "sagemaker"
self.app_specification = app_specification or {
"ImageUri": self.container,
"ContainerEntrypoint": ["python3",],
}
self.network_config = network_config or {
"EnableInterContainerTrafficEncryption": False,
"EnableNetworkIsolation": False,
}
self.processing_inputs = processing_inputs or [
{
"InputName": "input",
"AppManaged": False,
"S3Input": {
"S3Uri": "s3://{}/{}/processing/".format(self.bucket, self.prefix),
"LocalPath": "/opt/ml/processing/input",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3DataDistributionType": "FullyReplicated",
"S3CompressionType": "None",
},
}
]
self.processing_output_config = processing_output_config or {
"Outputs": [
{
"OutputName": "output",
"S3Output": {
"S3Uri": "s3://{}/{}/processing/".format(
self.bucket, self.prefix
),
"LocalPath": "/opt/ml/processing/output",
"S3UploadMode": "EndOfJob",
},
"AppManaged": False,
}
]
}
self.processing_resources = processing_resources or {
"ClusterConfig": {
"InstanceCount": 1,
"InstanceType": "ml.m5.large",
"VolumeSizeInGB": 10,
},
}
self.stopping_condition = stopping_condition or {
"MaxRuntimeInSeconds": 3600,
}
def save(self):
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
params = {
"AppSpecification": self.app_specification,
"NetworkConfig": self.network_config,
"ProcessingInputs": self.processing_inputs,
"ProcessingJobName": self.processing_job_name,
"ProcessingOutputConfig": self.processing_output_config,
"ProcessingResources": self.processing_resources,
"RoleArn": self.role_arn,
"StoppingCondition": self.stopping_condition,
}
return sagemaker.create_processing_job(**params)
@mock_sagemaker
def test_create_processing_job():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
processing_job_name = "MyProcessingJob"
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
bucket = "my-bucket"
prefix = "my-prefix"
app_specification = {
"ImageUri": container,
"ContainerEntrypoint": ["python3", "app.py"],
}
processing_resources = {
"ClusterConfig": {
"InstanceCount": 2,
"InstanceType": "ml.m5.xlarge",
"VolumeSizeInGB": 20,
},
}
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
job = MyProcessingJobModel(
processing_job_name,
role_arn,
container=container,
bucket=bucket,
prefix=prefix,
app_specification=app_specification,
processing_resources=processing_resources,
stopping_condition=stopping_condition,
)
resp = job.save()
resp["ProcessingJobArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
)
resp = sagemaker.describe_processing_job(ProcessingJobName=processing_job_name)
resp["ProcessingJobName"].should.equal(processing_job_name)
resp["ProcessingJobArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
)
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
assert resp["RoleArn"] == role_arn
assert resp["ProcessingJobStatus"] == "Completed"
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
@mock_sagemaker
def test_list_processing_jobs():
client = boto3.client("sagemaker", region_name="us-east-1")
name = "blah"
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
test_processing_job = MyProcessingJobModel(processing_job_name=name, role_arn=arn)
test_processing_job.save()
processing_jobs = client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
assert processing_jobs["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal(name)
assert processing_jobs["ProcessingJobSummaries"][0][
"ProcessingJobArn"
].should.match(r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(name))
assert processing_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_processing_jobs_multiple():
client = boto3.client("sagemaker", region_name="us-east-1")
name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
test_processing_job_1 = MyProcessingJobModel(
processing_job_name=name_job_1, role_arn=arn_job_1
)
test_processing_job_1.save()
name_job_2 = "blah2"
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
test_processing_job_2 = MyProcessingJobModel(
processing_job_name=name_job_2, role_arn=arn_job_2
)
test_processing_job_2.save()
processing_jobs_limit = client.list_processing_jobs(MaxResults=1)
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
processing_jobs = client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
assert processing_jobs.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_none():
client = boto3.client("sagemaker", region_name="us-east-1")
processing_jobs = client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
@mock_sagemaker
def test_list_processing_jobs_should_validate_input():
client = boto3.client("sagemaker", region_name="us-east-1")
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
client.list_processing_jobs(StatusEquals=junk_status_equals)
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
junk_next_token = "asdf"
with pytest.raises(ClientError) as ex:
client.list_processing_jobs(NextToken=junk_next_token)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert (
ex.value.response["Error"]["Message"]
== 'Invalid pagination token because "{0}".'
)
@mock_sagemaker
def test_list_processing_jobs_with_name_filters():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = "xgboost-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
for i in range(5):
name = "vgg-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
xgboost_processing_jobs = client.list_processing_jobs(NameContains="xgboost")
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2")
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
@mock_sagemaker
def test_list_processing_jobs_paginated():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = "xgboost-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
xgboost_processing_job_1 = client.list_processing_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("xgboost-0")
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
xgboost_processing_job_next = client.list_processing_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_processing_job_1.get("NextToken"),
)
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("xgboost-1")
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_target_in_middle():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = "xgboost-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
for i in range(5):
name = "vgg-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
vgg_processing_job_1 = client.list_processing_jobs(NameContains="vgg", MaxResults=1)
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
assert vgg_processing_job_1.get("NextToken").should_not.be.none
vgg_processing_job_6 = client.list_processing_jobs(NameContains="vgg", MaxResults=6)
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
"ProcessingJobName"
].should.equal("vgg-0")
assert vgg_processing_job_6.get("NextToken").should_not.be.none
vgg_processing_job_10 = client.list_processing_jobs(
NameContains="vgg", MaxResults=10
)
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
"ProcessingJobName"
].should.equal("vgg-4")
assert vgg_processing_job_10.get("NextToken").should.be.none
@mock_sagemaker
def test_list_processing_jobs_paginated_with_fragmented_targets():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = "xgboost-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
for i in range(5):
name = "vgg-{}".format(i)
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2", MaxResults=8)
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
assert processing_jobs_with_2.get("NextToken").should_not.be.none
processing_jobs_with_2_next = client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2.get("NextToken"),
)
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
processing_jobs_with_2_next_next = client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2_next.get("NextToken"),
)
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
0
)
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none