From f7467164e4a1a2874952a46d56d7673c2dd27208 Mon Sep 17 00:00:00 2001 From: Brian Pandola Date: Wed, 18 Nov 2020 02:49:25 -0800 Subject: [PATCH] Fix Race Condition in batch:SubmitJob (#3480) * Extract Duplicate Code into Helper Method DRY up the tests and replace the arbitrary `sleep()` calls with a more explicit check before progressing. * Improve Testing of batch:TerminateJob The test now confirms that the job was terminated by sandwiching a `sleep` command between two `echo` commands. In addition to the original checks of the terminated job status/reason, the test now asserts that only the first echo command succeeded, confirming that the job was indeed terminated while in progress. * Fix Race Condition in batch:SubmitJob The `test_submit_job` in `test_batch.py` kicks off a job, calls `describe_jobs` in a loop until the job status returned is SUCCEEDED, and then asserts against the logged events. The backend code that runs the submitted job does so in a separate thread. If the job was successful, the job status was being set to SUCCEEDED *before* the event logs had been written to the logging backend. As a result, it was possible for the primary thread running the test to detect that the job was successful immediately after the secondary thread had updated the job status but before the secondary thread had written the logs to the logging backend. Under the right conditions, this could cause the subsequent logging assertions in the primary thread to fail. Additionally, the code that collected the logs from the container was using a "dodgy hack" of time.sleep() and a modulo-based conditional that was ultimately non-deterministic and could result in log messages being dropped or duplicated in certain scenarios. In order to address these issues, this commit does the following: * Carefully re-orders any code that sets a job status or timestamp to avoid any obvious race conditions. * Removes the "dodgy hack" in favor of a much more straightforward (and less error-prone) method of collecting logs from the container. * Removes arbitrary and unnecessary calls to time.sleep() Before applying any changes, the flaky test was failing about 12% of the time. Putting a sleep() call between setting the `job_status` to SUCCEEDED and collecting the logs, resulted in a 100% failure rate. Simply moving the code that sets the job status to SUCCEEDED to the end of the code block, dropped the failure rate to ~2%. Finally, removing the log collection hack allowed the test suite to run ~1000 times without a single failure. Taken in aggregate, these changes make the batch backend more deterministic and should put the nail in the coffin of this flaky test. Closes #3475 --- moto/batch/models.py | 53 +++++--------------------- tests/test_batch/test_batch.py | 68 ++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 75 deletions(-) diff --git a/moto/batch/models.py b/moto/batch/models.py index f729144d8..1338beb0c 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -392,7 +392,6 @@ class Job(threading.Thread, BaseModel, DockerModel): """ try: self.job_state = "PENDING" - time.sleep(1) image = self.job_definition.container_properties.get( "image", "alpine:latest" @@ -425,8 +424,8 @@ class Job(threading.Thread, BaseModel, DockerModel): self.job_state = "RUNNABLE" # TODO setup ecs container instance - time.sleep(1) + self.job_started_at = datetime.datetime.now() self.job_state = "STARTING" log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) container = self.docker_client.containers.run( @@ -440,58 +439,24 @@ class Job(threading.Thread, BaseModel, DockerModel): privileged=privileged, ) self.job_state = "RUNNING" - self.job_started_at = datetime.datetime.now() try: - # Log collection - logs_stdout = [] - logs_stderr = [] container.reload() - - # Dodgy hack, we can only check docker logs once a second, but we want to loop more - # so we can stop if asked to in a quick manner, should all go away if we go async - # There also be some dodgyness when sending an integer to docker logs and some - # events seem to be duplicated. - now = datetime.datetime.now() - i = 1 while container.status == "running" and not self.stop: - time.sleep(0.2) - if i % 5 == 0: - logs_stderr.extend( - container.logs( - stdout=False, - stderr=True, - timestamps=True, - since=datetime2int(now), - ) - .decode() - .split("\n") - ) - logs_stdout.extend( - container.logs( - stdout=True, - stderr=False, - timestamps=True, - since=datetime2int(now), - ) - .decode() - .split("\n") - ) - now = datetime.datetime.now() - container.reload() - i += 1 + container.reload() # Container should be stopped by this point... unless asked to stop if container.status == "running": container.kill() - self.job_stopped_at = datetime.datetime.now() - # Get final logs + # Log collection + logs_stdout = [] + logs_stderr = [] logs_stderr.extend( container.logs( stdout=False, stderr=True, timestamps=True, - since=datetime2int(now), + since=datetime2int(self.job_started_at), ) .decode() .split("\n") @@ -501,14 +466,12 @@ class Job(threading.Thread, BaseModel, DockerModel): stdout=True, stderr=False, timestamps=True, - since=datetime2int(now), + since=datetime2int(self.job_started_at), ) .decode() .split("\n") ) - self.job_state = "SUCCEEDED" if not self.stop else "FAILED" - # Process logs logs_stdout = [x for x in logs_stdout if len(x) > 0] logs_stderr = [x for x in logs_stderr if len(x) > 0] @@ -532,6 +495,8 @@ class Job(threading.Thread, BaseModel, DockerModel): self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.put_log_events(log_group, stream_name, logs, None) + self.job_state = "SUCCEEDED" if not self.stop else "FAILED" + except Exception as err: logger.error( "Failed to run AWS Batch container {0}. Error {1}".format( diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 5a7757777..67f24bebc 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -725,18 +725,7 @@ def test_submit_job(): ) job_id = resp["jobId"] - future = datetime.datetime.now() + datetime.timedelta(seconds=30) - - while datetime.datetime.now() < future: - time.sleep(1) - resp = batch_client.describe_jobs(jobs=[job_id]) - - if resp["jobs"][0]["status"] == "FAILED": - raise RuntimeError("Batch job failed") - if resp["jobs"][0]["status"] == "SUCCEEDED": - break - else: - raise RuntimeError("Batch job timed out") + _wait_for_job_status(batch_client, job_id, "SUCCEEDED") resp = logs_client.describe_log_streams( logGroupName="/aws/batch/job", logStreamNamePrefix="sayhellotomylittlefriend" @@ -798,26 +787,13 @@ def test_list_jobs(): ) job_id2 = resp["jobId"] - future = datetime.datetime.now() + datetime.timedelta(seconds=30) - resp_finished_jobs = batch_client.list_jobs( jobQueue=queue_arn, jobStatus="SUCCEEDED" ) # Wait only as long as it takes to run the jobs - while datetime.datetime.now() < future: - resp = batch_client.describe_jobs(jobs=[job_id1, job_id2]) - - any_failed_jobs = any([job["status"] == "FAILED" for job in resp["jobs"]]) - succeeded_jobs = all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]) - - if any_failed_jobs: - raise RuntimeError("A Batch job failed") - if succeeded_jobs: - break - time.sleep(0.5) - else: - raise RuntimeError("Batch jobs timed out") + for job_id in [job_id1, job_id2]: + _wait_for_job_status(batch_client, job_id, "SUCCEEDED") resp_finished_jobs2 = batch_client.list_jobs( jobQueue=queue_arn, jobStatus="SUCCEEDED" @@ -854,13 +830,13 @@ def test_terminate_job(): queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName="sleep10", + jobDefinitionName="echo-sleep-echo", type="container", containerProperties={ "image": "busybox:latest", "vcpus": 1, "memory": 128, - "command": ["sleep", "10"], + "command": ["sh", "-c", "echo start && sleep 30 && echo stop"], }, ) job_def_arn = resp["jobDefinitionArn"] @@ -870,13 +846,43 @@ def test_terminate_job(): ) job_id = resp["jobId"] - time.sleep(2) + _wait_for_job_status(batch_client, job_id, "RUNNING") batch_client.terminate_job(jobId=job_id, reason="test_terminate") - time.sleep(2) + _wait_for_job_status(batch_client, job_id, "FAILED") resp = batch_client.describe_jobs(jobs=[job_id]) resp["jobs"][0]["jobName"].should.equal("test1") resp["jobs"][0]["status"].should.equal("FAILED") resp["jobs"][0]["statusReason"].should.equal("test_terminate") + + resp = logs_client.describe_log_streams( + logGroupName="/aws/batch/job", logStreamNamePrefix="echo-sleep-echo" + ) + len(resp["logStreams"]).should.equal(1) + ls_name = resp["logStreams"][0]["logStreamName"] + + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + # Events should only contain 'start' because we interrupted + # the job before 'stop' was written to the logs. + resp["events"].should.have.length_of(1) + resp["events"][0]["message"].should.equal("start") + + +def _wait_for_job_status(client, job_id, status, seconds_to_wait=30): + wait_time = datetime.datetime.now() + datetime.timedelta(seconds=seconds_to_wait) + last_job_status = None + while datetime.datetime.now() < wait_time: + resp = client.describe_jobs(jobs=[job_id]) + last_job_status = resp["jobs"][0]["status"] + if last_job_status == status: + break + else: + raise RuntimeError( + "Time out waiting for job status {status}!\n Last status: {last_status}".format( + status=status, last_status=last_job_status + ) + )