Fix Race Condition in batch:SubmitJob (#3480)

* Extract Duplicate Code into Helper Method

DRY up the tests and replace the arbitrary `sleep()` calls with a more
explicit check before progressing.

* Improve Testing of batch:TerminateJob

The test now confirms that the job was terminated by sandwiching a `sleep`
command between two `echo` commands.  In addition to the original checks
of the terminated job status/reason, the test now asserts that only the
first echo command succeeded, confirming that the job was indeed terminated
while in progress.

* Fix Race Condition in batch:SubmitJob

The `test_submit_job` in `test_batch.py` kicks off a job, calls `describe_jobs`
in a loop until the job status returned is SUCCEEDED, and then asserts against
the logged events.

The backend code that runs the submitted job does so in a separate thread. If
the job was successful, the job status was being set to SUCCEEDED *before* the
event logs had been written to the logging backend.

As a result, it was possible for the primary thread running the test to detect
that the job was successful immediately after the secondary thread had updated
the job status but before the secondary thread had written the logs to the
logging backend.  Under the right conditions, this could cause the subsequent
logging assertions in the primary thread to fail.

Additionally, the code that collected the logs from the container was using
a "dodgy hack" of time.sleep() and a modulo-based conditional that was
ultimately non-deterministic and could result in log messages being dropped
or duplicated in certain scenarios.

In order to address these issues, this commit does the following:

* Carefully re-orders any code that sets a job status or timestamp
  to avoid any obvious race conditions.
* Removes the "dodgy hack" in favor of a much more straightforward
  (and less error-prone) method of collecting logs from the container.
* Removes arbitrary and unnecessary calls to time.sleep()

Before applying any changes, the flaky test was failing about 12% of the
time.  Putting a sleep() call between setting the `job_status` to SUCCEEDED
and collecting the logs, resulted in a 100% failure rate.  Simply moving
the code that sets the job status to SUCCEEDED to the end of the code block,
dropped the failure rate to ~2%.  Finally, removing the log collection
hack allowed the test suite to run ~1000 times without a single failure.

Taken in aggregate, these changes make the batch backend more deterministic
and should put the nail in the coffin of this flaky test.

Closes #3475
This commit is contained in:
Brian Pandola 2020-11-18 02:49:25 -08:00 committed by GitHub
parent 83507fbc37
commit f7467164e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 75 deletions

View File

@ -392,7 +392,6 @@ class Job(threading.Thread, BaseModel, DockerModel):
"""
try:
self.job_state = "PENDING"
time.sleep(1)
image = self.job_definition.container_properties.get(
"image", "alpine:latest"
@ -425,8 +424,8 @@ class Job(threading.Thread, BaseModel, DockerModel):
self.job_state = "RUNNABLE"
# TODO setup ecs container instance
time.sleep(1)
self.job_started_at = datetime.datetime.now()
self.job_state = "STARTING"
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
container = self.docker_client.containers.run(
@ -440,58 +439,24 @@ class Job(threading.Thread, BaseModel, DockerModel):
privileged=privileged,
)
self.job_state = "RUNNING"
self.job_started_at = datetime.datetime.now()
try:
# Log collection
logs_stdout = []
logs_stderr = []
container.reload()
# Dodgy hack, we can only check docker logs once a second, but we want to loop more
# so we can stop if asked to in a quick manner, should all go away if we go async
# There also be some dodgyness when sending an integer to docker logs and some
# events seem to be duplicated.
now = datetime.datetime.now()
i = 1
while container.status == "running" and not self.stop:
time.sleep(0.2)
if i % 5 == 0:
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
logs_stdout.extend(
container.logs(
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
now = datetime.datetime.now()
container.reload()
i += 1
container.reload()
# Container should be stopped by this point... unless asked to stop
if container.status == "running":
container.kill()
self.job_stopped_at = datetime.datetime.now()
# Get final logs
# Log collection
logs_stdout = []
logs_stderr = []
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
since=datetime2int(self.job_started_at),
)
.decode()
.split("\n")
@ -501,14 +466,12 @@ class Job(threading.Thread, BaseModel, DockerModel):
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
since=datetime2int(self.job_started_at),
)
.decode()
.split("\n")
)
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
# Process logs
logs_stdout = [x for x in logs_stdout if len(x) > 0]
logs_stderr = [x for x in logs_stderr if len(x) > 0]
@ -532,6 +495,8 @@ class Job(threading.Thread, BaseModel, DockerModel):
self._log_backend.create_log_stream(log_group, stream_name)
self._log_backend.put_log_events(log_group, stream_name, logs, None)
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
except Exception as err:
logger.error(
"Failed to run AWS Batch container {0}. Error {1}".format(

View File

@ -725,18 +725,7 @@ def test_submit_job():
)
job_id = resp["jobId"]
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
while datetime.datetime.now() < future:
time.sleep(1)
resp = batch_client.describe_jobs(jobs=[job_id])
if resp["jobs"][0]["status"] == "FAILED":
raise RuntimeError("Batch job failed")
if resp["jobs"][0]["status"] == "SUCCEEDED":
break
else:
raise RuntimeError("Batch job timed out")
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
resp = logs_client.describe_log_streams(
logGroupName="/aws/batch/job", logStreamNamePrefix="sayhellotomylittlefriend"
@ -798,26 +787,13 @@ def test_list_jobs():
)
job_id2 = resp["jobId"]
future = datetime.datetime.now() + datetime.timedelta(seconds=30)
resp_finished_jobs = batch_client.list_jobs(
jobQueue=queue_arn, jobStatus="SUCCEEDED"
)
# Wait only as long as it takes to run the jobs
while datetime.datetime.now() < future:
resp = batch_client.describe_jobs(jobs=[job_id1, job_id2])
any_failed_jobs = any([job["status"] == "FAILED" for job in resp["jobs"]])
succeeded_jobs = all([job["status"] == "SUCCEEDED" for job in resp["jobs"]])
if any_failed_jobs:
raise RuntimeError("A Batch job failed")
if succeeded_jobs:
break
time.sleep(0.5)
else:
raise RuntimeError("Batch jobs timed out")
for job_id in [job_id1, job_id2]:
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
resp_finished_jobs2 = batch_client.list_jobs(
jobQueue=queue_arn, jobStatus="SUCCEEDED"
@ -854,13 +830,13 @@ def test_terminate_job():
queue_arn = resp["jobQueueArn"]
resp = batch_client.register_job_definition(
jobDefinitionName="sleep10",
jobDefinitionName="echo-sleep-echo",
type="container",
containerProperties={
"image": "busybox:latest",
"vcpus": 1,
"memory": 128,
"command": ["sleep", "10"],
"command": ["sh", "-c", "echo start && sleep 30 && echo stop"],
},
)
job_def_arn = resp["jobDefinitionArn"]
@ -870,13 +846,43 @@ def test_terminate_job():
)
job_id = resp["jobId"]
time.sleep(2)
_wait_for_job_status(batch_client, job_id, "RUNNING")
batch_client.terminate_job(jobId=job_id, reason="test_terminate")
time.sleep(2)
_wait_for_job_status(batch_client, job_id, "FAILED")
resp = batch_client.describe_jobs(jobs=[job_id])
resp["jobs"][0]["jobName"].should.equal("test1")
resp["jobs"][0]["status"].should.equal("FAILED")
resp["jobs"][0]["statusReason"].should.equal("test_terminate")
resp = logs_client.describe_log_streams(
logGroupName="/aws/batch/job", logStreamNamePrefix="echo-sleep-echo"
)
len(resp["logStreams"]).should.equal(1)
ls_name = resp["logStreams"][0]["logStreamName"]
resp = logs_client.get_log_events(
logGroupName="/aws/batch/job", logStreamName=ls_name
)
# Events should only contain 'start' because we interrupted
# the job before 'stop' was written to the logs.
resp["events"].should.have.length_of(1)
resp["events"][0]["message"].should.equal("start")
def _wait_for_job_status(client, job_id, status, seconds_to_wait=30):
wait_time = datetime.datetime.now() + datetime.timedelta(seconds=seconds_to_wait)
last_job_status = None
while datetime.datetime.now() < wait_time:
resp = client.describe_jobs(jobs=[job_id])
last_job_status = resp["jobs"][0]["status"]
if last_job_status == status:
break
else:
raise RuntimeError(
"Time out waiting for job status {status}!\n Last status: {last_status}".format(
status=status, last_status=last_job_status
)
)