AWS Batch enhancements (#3956)
* Check exit status of container * Added support for job dependencies * batch container overrides * add AWS_BATCH_JOB_ID to container env variables * lint with black * refactor batch dependency test * refactor batch dependency test * fix index Co-authored-by: jterry64 <justin.terry@wri.org> Co-authored-by: Daniel Mannarino <daniel.mannarino@gmail.com>
This commit is contained in:
parent
fbbc8fc472
commit
d635c78bd1
@ -310,7 +310,16 @@ class JobDefinition(CloudFormationModel):
|
||||
|
||||
|
||||
class Job(threading.Thread, BaseModel, DockerModel):
|
||||
def __init__(self, name, job_def, job_queue, log_backend, container_overrides):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
job_def,
|
||||
job_queue,
|
||||
log_backend,
|
||||
container_overrides,
|
||||
depends_on,
|
||||
all_jobs,
|
||||
):
|
||||
"""
|
||||
Docker Job
|
||||
|
||||
@ -335,6 +344,8 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
||||
self.job_stopped_at = datetime.datetime(1970, 1, 1)
|
||||
self.job_stopped = False
|
||||
self.job_stopped_reason = None
|
||||
self.depends_on = depends_on
|
||||
self.all_jobs = all_jobs
|
||||
|
||||
self.stop = False
|
||||
|
||||
@ -351,27 +362,48 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
||||
"jobName": self.job_name,
|
||||
"jobQueue": self.job_queue.arn,
|
||||
"status": self.job_state,
|
||||
"dependsOn": [],
|
||||
"dependsOn": self.depends_on if self.depends_on else [],
|
||||
}
|
||||
if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]:
|
||||
result["startedAt"] = datetime2int(self.job_started_at)
|
||||
if self.job_stopped:
|
||||
result["stoppedAt"] = datetime2int(self.job_stopped_at)
|
||||
result["container"] = {}
|
||||
result["container"]["command"] = [
|
||||
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"'
|
||||
]
|
||||
result["container"]["privileged"] = False
|
||||
result["container"]["readonlyRootFilesystem"] = False
|
||||
result["container"]["ulimits"] = {}
|
||||
result["container"]["vcpus"] = 1
|
||||
result["container"]["volumes"] = ""
|
||||
result["container"]["command"] = self._get_container_property("command", [])
|
||||
result["container"]["privileged"] = self._get_container_property(
|
||||
"privileged", False
|
||||
)
|
||||
result["container"][
|
||||
"readonlyRootFilesystem"
|
||||
] = self._get_container_property("readonlyRootFilesystem", False)
|
||||
result["container"]["ulimits"] = self._get_container_property("ulimits", {})
|
||||
result["container"]["vcpus"] = self._get_container_property("vcpus", 1)
|
||||
result["container"]["memory"] = self._get_container_property("memory", 512)
|
||||
result["container"]["volumes"] = self._get_container_property("volumes", [])
|
||||
result["container"]["environment"] = self._get_container_property(
|
||||
"environment", []
|
||||
)
|
||||
result["container"]["logStreamName"] = self.log_stream_name
|
||||
if self.job_stopped_reason is not None:
|
||||
result["statusReason"] = self.job_stopped_reason
|
||||
return result
|
||||
|
||||
def _get_container_property(self, p, default):
|
||||
if p == "environment":
|
||||
job_env = self.container_overrides.get(p, default)
|
||||
jd_env = self.job_definition.container_properties.get(p, default)
|
||||
|
||||
job_env_dict = {_env["name"]: _env["value"] for _env in job_env}
|
||||
jd_env_dict = {_env["name"]: _env["value"] for _env in jd_env}
|
||||
|
||||
for key in jd_env_dict.keys():
|
||||
if key not in job_env_dict.keys():
|
||||
job_env.append({"name": key, "value": jd_env_dict[key]})
|
||||
|
||||
job_env.append({"name": "AWS_BATCH_JOB_ID", "value": self.job_id})
|
||||
|
||||
return job_env
|
||||
|
||||
return self.container_overrides.get(
|
||||
p, self.job_definition.container_properties.get(p, default)
|
||||
)
|
||||
@ -393,6 +425,9 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
||||
try:
|
||||
self.job_state = "PENDING"
|
||||
|
||||
if self.depends_on and not self._wait_for_dependencies():
|
||||
return
|
||||
|
||||
image = self.job_definition.container_properties.get(
|
||||
"image", "alpine:latest"
|
||||
)
|
||||
@ -497,7 +532,11 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
||||
self._log_backend.create_log_stream(log_group, stream_name)
|
||||
self._log_backend.put_log_events(log_group, stream_name, logs, None)
|
||||
|
||||
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
|
||||
result = container.wait()
|
||||
if self.stop or result["StatusCode"] != 0:
|
||||
self.job_state = "FAILED"
|
||||
else:
|
||||
self.job_state = "SUCCEEDED"
|
||||
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
@ -525,6 +564,30 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
||||
self.stop = True
|
||||
self.job_stopped_reason = reason
|
||||
|
||||
def _wait_for_dependencies(self):
|
||||
dependent_ids = [dependency["jobId"] for dependency in self.depends_on]
|
||||
successful_dependencies = set()
|
||||
while len(successful_dependencies) != len(dependent_ids):
|
||||
for dependent_id in dependent_ids:
|
||||
if dependent_id in self.all_jobs:
|
||||
dependent_job = self.all_jobs[dependent_id]
|
||||
if dependent_job.job_state == "SUCCEEDED":
|
||||
successful_dependencies.add(dependent_id)
|
||||
if dependent_job.job_state == "FAILED":
|
||||
logger.error(
|
||||
"Terminating job {0} due to failed dependency {1}".format(
|
||||
self.name, dependent_job.name
|
||||
)
|
||||
)
|
||||
self.job_state = "FAILED"
|
||||
self.job_stopped = True
|
||||
self.job_stopped_at = datetime.datetime.now()
|
||||
return False
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class BatchBackend(BaseBackend):
|
||||
def __init__(self, region_name=None):
|
||||
@ -1241,6 +1304,8 @@ class BatchBackend(BaseBackend):
|
||||
queue,
|
||||
log_backend=self.logs_backend,
|
||||
container_overrides=container_overrides,
|
||||
depends_on=depends_on,
|
||||
all_jobs=self._jobs,
|
||||
)
|
||||
self._jobs[job.job_id] = job
|
||||
|
||||
|
@ -886,3 +886,367 @@ def _wait_for_job_status(client, job_id, status, seconds_to_wait=30):
|
||||
status=status, last_status=last_job_status
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@mock_logs
|
||||
@mock_ec2
|
||||
@mock_ecs
|
||||
@mock_iam
|
||||
@mock_batch
|
||||
def test_failed_job():
|
||||
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
|
||||
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
|
||||
|
||||
compute_name = "test_compute_env"
|
||||
resp = batch_client.create_compute_environment(
|
||||
computeEnvironmentName=compute_name,
|
||||
type="UNMANAGED",
|
||||
state="ENABLED",
|
||||
serviceRole=iam_arn,
|
||||
)
|
||||
arn = resp["computeEnvironmentArn"]
|
||||
|
||||
resp = batch_client.create_job_queue(
|
||||
jobQueueName="test_job_queue",
|
||||
state="ENABLED",
|
||||
priority=123,
|
||||
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
|
||||
)
|
||||
queue_arn = resp["jobQueueArn"]
|
||||
|
||||
resp = batch_client.register_job_definition(
|
||||
jobDefinitionName="sayhellotomylittlefriend",
|
||||
type="container",
|
||||
containerProperties={
|
||||
"image": "busybox:latest",
|
||||
"vcpus": 1,
|
||||
"memory": 128,
|
||||
"command": ["exit", "1"],
|
||||
},
|
||||
)
|
||||
job_def_arn = resp["jobDefinitionArn"]
|
||||
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn
|
||||
)
|
||||
job_id = resp["jobId"]
|
||||
|
||||
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
|
||||
|
||||
while datetime.datetime.now() < future:
|
||||
resp = batch_client.describe_jobs(jobs=[job_id])
|
||||
|
||||
if resp["jobs"][0]["status"] == "FAILED":
|
||||
break
|
||||
if resp["jobs"][0]["status"] == "SUCCEEDED":
|
||||
raise RuntimeError("Batch job succeeded even though it had exit code 1")
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise RuntimeError("Batch job timed out")
|
||||
|
||||
|
||||
@mock_logs
|
||||
@mock_ec2
|
||||
@mock_ecs
|
||||
@mock_iam
|
||||
@mock_batch
|
||||
def test_dependencies():
|
||||
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
|
||||
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
|
||||
|
||||
compute_name = "test_compute_env"
|
||||
resp = batch_client.create_compute_environment(
|
||||
computeEnvironmentName=compute_name,
|
||||
type="UNMANAGED",
|
||||
state="ENABLED",
|
||||
serviceRole=iam_arn,
|
||||
)
|
||||
arn = resp["computeEnvironmentArn"]
|
||||
|
||||
resp = batch_client.create_job_queue(
|
||||
jobQueueName="test_job_queue",
|
||||
state="ENABLED",
|
||||
priority=123,
|
||||
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
|
||||
)
|
||||
queue_arn = resp["jobQueueArn"]
|
||||
|
||||
resp = batch_client.register_job_definition(
|
||||
jobDefinitionName="sayhellotomylittlefriend",
|
||||
type="container",
|
||||
containerProperties={
|
||||
"image": "busybox:latest",
|
||||
"vcpus": 1,
|
||||
"memory": 128,
|
||||
"command": ["echo", "hello"],
|
||||
},
|
||||
)
|
||||
job_def_arn = resp["jobDefinitionArn"]
|
||||
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn
|
||||
)
|
||||
job_id1 = resp["jobId"]
|
||||
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn
|
||||
)
|
||||
job_id2 = resp["jobId"]
|
||||
|
||||
depends_on = [
|
||||
{"jobId": job_id1, "type": "SEQUENTIAL"},
|
||||
{"jobId": job_id2, "type": "SEQUENTIAL"},
|
||||
]
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test3",
|
||||
jobQueue=queue_arn,
|
||||
jobDefinition=job_def_arn,
|
||||
dependsOn=depends_on,
|
||||
)
|
||||
job_id3 = resp["jobId"]
|
||||
|
||||
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
|
||||
|
||||
while datetime.datetime.now() < future:
|
||||
resp = batch_client.describe_jobs(jobs=[job_id1, job_id2, job_id3])
|
||||
|
||||
if any([job["status"] == "FAILED" for job in resp["jobs"]]):
|
||||
raise RuntimeError("Batch job failed")
|
||||
if all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]):
|
||||
break
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise RuntimeError("Batch job timed out")
|
||||
|
||||
resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job")
|
||||
len(resp["logStreams"]).should.equal(3)
|
||||
for log_stream in resp["logStreams"]:
|
||||
ls_name = log_stream["logStreamName"]
|
||||
|
||||
resp = logs_client.get_log_events(
|
||||
logGroupName="/aws/batch/job", logStreamName=ls_name
|
||||
)
|
||||
[event["message"] for event in resp["events"]].should.equal(["hello"])
|
||||
|
||||
|
||||
@mock_logs
|
||||
@mock_ec2
|
||||
@mock_ecs
|
||||
@mock_iam
|
||||
@mock_batch
|
||||
def test_failed_dependencies():
|
||||
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
|
||||
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
|
||||
|
||||
compute_name = "test_compute_env"
|
||||
resp = batch_client.create_compute_environment(
|
||||
computeEnvironmentName=compute_name,
|
||||
type="UNMANAGED",
|
||||
state="ENABLED",
|
||||
serviceRole=iam_arn,
|
||||
)
|
||||
arn = resp["computeEnvironmentArn"]
|
||||
|
||||
resp = batch_client.create_job_queue(
|
||||
jobQueueName="test_job_queue",
|
||||
state="ENABLED",
|
||||
priority=123,
|
||||
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
|
||||
)
|
||||
queue_arn = resp["jobQueueArn"]
|
||||
|
||||
resp = batch_client.register_job_definition(
|
||||
jobDefinitionName="sayhellotomylittlefriend",
|
||||
type="container",
|
||||
containerProperties={
|
||||
"image": "busybox:latest",
|
||||
"vcpus": 1,
|
||||
"memory": 128,
|
||||
"command": ["echo", "hello"],
|
||||
},
|
||||
)
|
||||
job_def_arn_success = resp["jobDefinitionArn"]
|
||||
|
||||
resp = batch_client.register_job_definition(
|
||||
jobDefinitionName="sayhellotomylittlefriend_failed",
|
||||
type="container",
|
||||
containerProperties={
|
||||
"image": "busybox:latest",
|
||||
"vcpus": 1,
|
||||
"memory": 128,
|
||||
"command": ["exi1", "1"],
|
||||
},
|
||||
)
|
||||
job_def_arn_failure = resp["jobDefinitionArn"]
|
||||
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn_success
|
||||
)
|
||||
|
||||
job_id1 = resp["jobId"]
|
||||
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn_failure
|
||||
)
|
||||
job_id2 = resp["jobId"]
|
||||
|
||||
depends_on = [
|
||||
{"jobId": job_id1, "type": "SEQUENTIAL"},
|
||||
{"jobId": job_id2, "type": "SEQUENTIAL"},
|
||||
]
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test3",
|
||||
jobQueue=queue_arn,
|
||||
jobDefinition=job_def_arn_success,
|
||||
dependsOn=depends_on,
|
||||
)
|
||||
job_id3 = resp["jobId"]
|
||||
|
||||
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
|
||||
|
||||
# Query batch jobs until all jobs have run.
|
||||
# Job 2 is supposed to fail and in consequence Job 3 should never run
|
||||
# and status should change directly from PENDING to FAILED
|
||||
while datetime.datetime.now() < future:
|
||||
resp = batch_client.describe_jobs(jobs=[job_id2, job_id3])
|
||||
|
||||
assert resp["jobs"][0]["status"] != "SUCCEEDED", "Job 2 cannot succeed"
|
||||
assert resp["jobs"][1]["status"] != "SUCCEEDED", "Job 3 cannot succeed"
|
||||
|
||||
if resp["jobs"][1]["status"] == "FAILED":
|
||||
break
|
||||
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise RuntimeError("Batch job timed out")
|
||||
|
||||
|
||||
@mock_logs
|
||||
@mock_ec2
|
||||
@mock_ecs
|
||||
@mock_iam
|
||||
@mock_batch
|
||||
def test_container_overrides():
|
||||
"""
|
||||
Test if container overrides have any effect.
|
||||
Overwrites should be reflected in container description.
|
||||
Environment variables should be accessible inside docker container
|
||||
"""
|
||||
|
||||
# Set up environment
|
||||
|
||||
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
|
||||
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
|
||||
|
||||
compute_name = "test_compute_env"
|
||||
resp = batch_client.create_compute_environment(
|
||||
computeEnvironmentName=compute_name,
|
||||
type="UNMANAGED",
|
||||
state="ENABLED",
|
||||
serviceRole=iam_arn,
|
||||
)
|
||||
arn = resp["computeEnvironmentArn"]
|
||||
|
||||
resp = batch_client.create_job_queue(
|
||||
jobQueueName="test_job_queue",
|
||||
state="ENABLED",
|
||||
priority=123,
|
||||
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
|
||||
)
|
||||
queue_arn = resp["jobQueueArn"]
|
||||
|
||||
job_definition_name = "sleep10"
|
||||
|
||||
# Set up Job Definition
|
||||
# We will then override the container properties in the actual job
|
||||
resp = batch_client.register_job_definition(
|
||||
jobDefinitionName=job_definition_name,
|
||||
type="container",
|
||||
containerProperties={
|
||||
"image": "busybox",
|
||||
"vcpus": 1,
|
||||
"memory": 512,
|
||||
"command": ["sleep", "10"],
|
||||
"environment": [
|
||||
{"name": "TEST0", "value": "from job definition"},
|
||||
{"name": "TEST1", "value": "from job definition"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
job_definition_arn = resp["jobDefinitionArn"]
|
||||
|
||||
# The Job to run, including container overrides
|
||||
resp = batch_client.submit_job(
|
||||
jobName="test1",
|
||||
jobQueue=queue_arn,
|
||||
jobDefinition=job_definition_name,
|
||||
containerOverrides={
|
||||
"vcpus": 2,
|
||||
"memory": 1024,
|
||||
"command": ["printenv"],
|
||||
"environment": [
|
||||
{"name": "TEST0", "value": "from job"},
|
||||
{"name": "TEST2", "value": "from job"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
job_id = resp["jobId"]
|
||||
|
||||
# Wait until Job finishes
|
||||
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
|
||||
|
||||
while datetime.datetime.now() < future:
|
||||
resp_jobs = batch_client.describe_jobs(jobs=[job_id])
|
||||
|
||||
if resp_jobs["jobs"][0]["status"] == "FAILED":
|
||||
raise RuntimeError("Batch job failed")
|
||||
if resp_jobs["jobs"][0]["status"] == "SUCCEEDED":
|
||||
break
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
raise RuntimeError("Batch job timed out")
|
||||
|
||||
# Getting the log stream to read out env variables inside container
|
||||
resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job")
|
||||
|
||||
env_var = list()
|
||||
for stream in resp["logStreams"]:
|
||||
ls_name = stream["logStreamName"]
|
||||
|
||||
stream_resp = logs_client.get_log_events(
|
||||
logGroupName="/aws/batch/job", logStreamName=ls_name
|
||||
)
|
||||
|
||||
for event in stream_resp["events"]:
|
||||
if "TEST" in event["message"] or "AWS" in event["message"]:
|
||||
key, value = tuple(event["message"].split("="))
|
||||
env_var.append({"name": key, "value": value})
|
||||
|
||||
len(resp_jobs["jobs"]).should.equal(1)
|
||||
resp_jobs["jobs"][0]["jobId"].should.equal(job_id)
|
||||
resp_jobs["jobs"][0]["jobQueue"].should.equal(queue_arn)
|
||||
resp_jobs["jobs"][0]["jobDefinition"].should.equal(job_definition_arn)
|
||||
resp_jobs["jobs"][0]["container"]["vcpus"].should.equal(2)
|
||||
resp_jobs["jobs"][0]["container"]["memory"].should.equal(1024)
|
||||
resp_jobs["jobs"][0]["container"]["command"].should.equal(["printenv"])
|
||||
|
||||
sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain(
|
||||
{"name": "TEST0", "value": "from job"}
|
||||
)
|
||||
sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain(
|
||||
{"name": "TEST1", "value": "from job definition"}
|
||||
)
|
||||
sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain(
|
||||
{"name": "TEST2", "value": "from job"}
|
||||
)
|
||||
sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain(
|
||||
{"name": "AWS_BATCH_JOB_ID", "value": job_id}
|
||||
)
|
||||
|
||||
sure.expect(env_var).to.contain({"name": "TEST0", "value": "from job"})
|
||||
sure.expect(env_var).to.contain({"name": "TEST1", "value": "from job definition"})
|
||||
sure.expect(env_var).to.contain({"name": "TEST2", "value": "from job"})
|
||||
|
||||
sure.expect(env_var).to.contain({"name": "AWS_BATCH_JOB_ID", "value": job_id})
|
||||
|
Loading…
Reference in New Issue
Block a user