diff --git a/moto/batch/models.py b/moto/batch/models.py index f729144d8..1338beb0c 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -392,7 +392,6 @@ class Job(threading.Thread, BaseModel, DockerModel): """ try: self.job_state = "PENDING" - time.sleep(1) image = self.job_definition.container_properties.get( "image", "alpine:latest" @@ -425,8 +424,8 @@ class Job(threading.Thread, BaseModel, DockerModel): self.job_state = "RUNNABLE" # TODO setup ecs container instance - time.sleep(1) + self.job_started_at = datetime.datetime.now() self.job_state = "STARTING" log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) container = self.docker_client.containers.run( @@ -440,58 +439,24 @@ class Job(threading.Thread, BaseModel, DockerModel): privileged=privileged, ) self.job_state = "RUNNING" - self.job_started_at = datetime.datetime.now() try: - # Log collection - logs_stdout = [] - logs_stderr = [] container.reload() - - # Dodgy hack, we can only check docker logs once a second, but we want to loop more - # so we can stop if asked to in a quick manner, should all go away if we go async - # There also be some dodgyness when sending an integer to docker logs and some - # events seem to be duplicated. - now = datetime.datetime.now() - i = 1 while container.status == "running" and not self.stop: - time.sleep(0.2) - if i % 5 == 0: - logs_stderr.extend( - container.logs( - stdout=False, - stderr=True, - timestamps=True, - since=datetime2int(now), - ) - .decode() - .split("\n") - ) - logs_stdout.extend( - container.logs( - stdout=True, - stderr=False, - timestamps=True, - since=datetime2int(now), - ) - .decode() - .split("\n") - ) - now = datetime.datetime.now() - container.reload() - i += 1 + container.reload() # Container should be stopped by this point... unless asked to stop if container.status == "running": container.kill() - self.job_stopped_at = datetime.datetime.now() - # Get final logs + # Log collection + logs_stdout = [] + logs_stderr = [] logs_stderr.extend( container.logs( stdout=False, stderr=True, timestamps=True, - since=datetime2int(now), + since=datetime2int(self.job_started_at), ) .decode() .split("\n") @@ -501,14 +466,12 @@ class Job(threading.Thread, BaseModel, DockerModel): stdout=True, stderr=False, timestamps=True, - since=datetime2int(now), + since=datetime2int(self.job_started_at), ) .decode() .split("\n") ) - self.job_state = "SUCCEEDED" if not self.stop else "FAILED" - # Process logs logs_stdout = [x for x in logs_stdout if len(x) > 0] logs_stderr = [x for x in logs_stderr if len(x) > 0] @@ -532,6 +495,8 @@ class Job(threading.Thread, BaseModel, DockerModel): self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.put_log_events(log_group, stream_name, logs, None) + self.job_state = "SUCCEEDED" if not self.stop else "FAILED" + except Exception as err: logger.error( "Failed to run AWS Batch container {0}. Error {1}".format( diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 5a7757777..67f24bebc 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -725,18 +725,7 @@ def test_submit_job(): ) job_id = resp["jobId"] - future = datetime.datetime.now() + datetime.timedelta(seconds=30) - - while datetime.datetime.now() < future: - time.sleep(1) - resp = batch_client.describe_jobs(jobs=[job_id]) - - if resp["jobs"][0]["status"] == "FAILED": - raise RuntimeError("Batch job failed") - if resp["jobs"][0]["status"] == "SUCCEEDED": - break - else: - raise RuntimeError("Batch job timed out") + _wait_for_job_status(batch_client, job_id, "SUCCEEDED") resp = logs_client.describe_log_streams( logGroupName="/aws/batch/job", logStreamNamePrefix="sayhellotomylittlefriend" @@ -798,26 +787,13 @@ def test_list_jobs(): ) job_id2 = resp["jobId"] - future = datetime.datetime.now() + datetime.timedelta(seconds=30) - resp_finished_jobs = batch_client.list_jobs( jobQueue=queue_arn, jobStatus="SUCCEEDED" ) # Wait only as long as it takes to run the jobs - while datetime.datetime.now() < future: - resp = batch_client.describe_jobs(jobs=[job_id1, job_id2]) - - any_failed_jobs = any([job["status"] == "FAILED" for job in resp["jobs"]]) - succeeded_jobs = all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]) - - if any_failed_jobs: - raise RuntimeError("A Batch job failed") - if succeeded_jobs: - break - time.sleep(0.5) - else: - raise RuntimeError("Batch jobs timed out") + for job_id in [job_id1, job_id2]: + _wait_for_job_status(batch_client, job_id, "SUCCEEDED") resp_finished_jobs2 = batch_client.list_jobs( jobQueue=queue_arn, jobStatus="SUCCEEDED" @@ -854,13 +830,13 @@ def test_terminate_job(): queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName="sleep10", + jobDefinitionName="echo-sleep-echo", type="container", containerProperties={ "image": "busybox:latest", "vcpus": 1, "memory": 128, - "command": ["sleep", "10"], + "command": ["sh", "-c", "echo start && sleep 30 && echo stop"], }, ) job_def_arn = resp["jobDefinitionArn"] @@ -870,13 +846,43 @@ def test_terminate_job(): ) job_id = resp["jobId"] - time.sleep(2) + _wait_for_job_status(batch_client, job_id, "RUNNING") batch_client.terminate_job(jobId=job_id, reason="test_terminate") - time.sleep(2) + _wait_for_job_status(batch_client, job_id, "FAILED") resp = batch_client.describe_jobs(jobs=[job_id]) resp["jobs"][0]["jobName"].should.equal("test1") resp["jobs"][0]["status"].should.equal("FAILED") resp["jobs"][0]["statusReason"].should.equal("test_terminate") + + resp = logs_client.describe_log_streams( + logGroupName="/aws/batch/job", logStreamNamePrefix="echo-sleep-echo" + ) + len(resp["logStreams"]).should.equal(1) + ls_name = resp["logStreams"][0]["logStreamName"] + + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + # Events should only contain 'start' because we interrupted + # the job before 'stop' was written to the logs. + resp["events"].should.have.length_of(1) + resp["events"][0]["message"].should.equal("start") + + +def _wait_for_job_status(client, job_id, status, seconds_to_wait=30): + wait_time = datetime.datetime.now() + datetime.timedelta(seconds=seconds_to_wait) + last_job_status = None + while datetime.datetime.now() < wait_time: + resp = client.describe_jobs(jobs=[job_id]) + last_job_status = resp["jobs"][0]["status"] + if last_job_status == status: + break + else: + raise RuntimeError( + "Time out waiting for job status {status}!\n Last status: {last_status}".format( + status=status, last_status=last_job_status + ) + )