diff --git a/moto/batch/models.py b/moto/batch/models.py index 422608543..c7fea1fda 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -310,7 +310,16 @@ class JobDefinition(CloudFormationModel): class Job(threading.Thread, BaseModel, DockerModel): - def __init__(self, name, job_def, job_queue, log_backend, container_overrides): + def __init__( + self, + name, + job_def, + job_queue, + log_backend, + container_overrides, + depends_on, + all_jobs, + ): """ Docker Job @@ -335,6 +344,8 @@ class Job(threading.Thread, BaseModel, DockerModel): self.job_stopped_at = datetime.datetime(1970, 1, 1) self.job_stopped = False self.job_stopped_reason = None + self.depends_on = depends_on + self.all_jobs = all_jobs self.stop = False @@ -351,27 +362,48 @@ class Job(threading.Thread, BaseModel, DockerModel): "jobName": self.job_name, "jobQueue": self.job_queue.arn, "status": self.job_state, - "dependsOn": [], + "dependsOn": self.depends_on if self.depends_on else [], } if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]: result["startedAt"] = datetime2int(self.job_started_at) if self.job_stopped: result["stoppedAt"] = datetime2int(self.job_stopped_at) result["container"] = {} - result["container"]["command"] = [ - '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' - ] - result["container"]["privileged"] = False - result["container"]["readonlyRootFilesystem"] = False - result["container"]["ulimits"] = {} - result["container"]["vcpus"] = 1 - result["container"]["volumes"] = "" + result["container"]["command"] = self._get_container_property("command", []) + result["container"]["privileged"] = self._get_container_property( + "privileged", False + ) + result["container"][ + "readonlyRootFilesystem" + ] = self._get_container_property("readonlyRootFilesystem", False) + result["container"]["ulimits"] = self._get_container_property("ulimits", {}) + result["container"]["vcpus"] = self._get_container_property("vcpus", 1) + result["container"]["memory"] = self._get_container_property("memory", 512) + result["container"]["volumes"] = self._get_container_property("volumes", []) + result["container"]["environment"] = self._get_container_property( + "environment", [] + ) result["container"]["logStreamName"] = self.log_stream_name if self.job_stopped_reason is not None: result["statusReason"] = self.job_stopped_reason return result def _get_container_property(self, p, default): + if p == "environment": + job_env = self.container_overrides.get(p, default) + jd_env = self.job_definition.container_properties.get(p, default) + + job_env_dict = {_env["name"]: _env["value"] for _env in job_env} + jd_env_dict = {_env["name"]: _env["value"] for _env in jd_env} + + for key in jd_env_dict.keys(): + if key not in job_env_dict.keys(): + job_env.append({"name": key, "value": jd_env_dict[key]}) + + job_env.append({"name": "AWS_BATCH_JOB_ID", "value": self.job_id}) + + return job_env + return self.container_overrides.get( p, self.job_definition.container_properties.get(p, default) ) @@ -393,6 +425,9 @@ class Job(threading.Thread, BaseModel, DockerModel): try: self.job_state = "PENDING" + if self.depends_on and not self._wait_for_dependencies(): + return + image = self.job_definition.container_properties.get( "image", "alpine:latest" ) @@ -497,7 +532,11 @@ class Job(threading.Thread, BaseModel, DockerModel): self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.put_log_events(log_group, stream_name, logs, None) - self.job_state = "SUCCEEDED" if not self.stop else "FAILED" + result = container.wait() + if self.stop or result["StatusCode"] != 0: + self.job_state = "FAILED" + else: + self.job_state = "SUCCEEDED" except Exception as err: logger.error( @@ -525,6 +564,30 @@ class Job(threading.Thread, BaseModel, DockerModel): self.stop = True self.job_stopped_reason = reason + def _wait_for_dependencies(self): + dependent_ids = [dependency["jobId"] for dependency in self.depends_on] + successful_dependencies = set() + while len(successful_dependencies) != len(dependent_ids): + for dependent_id in dependent_ids: + if dependent_id in self.all_jobs: + dependent_job = self.all_jobs[dependent_id] + if dependent_job.job_state == "SUCCEEDED": + successful_dependencies.add(dependent_id) + if dependent_job.job_state == "FAILED": + logger.error( + "Terminating job {0} due to failed dependency {1}".format( + self.name, dependent_job.name + ) + ) + self.job_state = "FAILED" + self.job_stopped = True + self.job_stopped_at = datetime.datetime.now() + return False + + time.sleep(1) + + return True + class BatchBackend(BaseBackend): def __init__(self, region_name=None): @@ -1241,6 +1304,8 @@ class BatchBackend(BaseBackend): queue, log_backend=self.logs_backend, container_overrides=container_overrides, + depends_on=depends_on, + all_jobs=self._jobs, ) self._jobs[job.job_id] = job diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 67f24bebc..04554383d 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -886,3 +886,367 @@ def _wait_for_job_status(client, job_id, status, seconds_to_wait=30): status=status, last_status=last_job_status ) ) + + +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_failed_job(): + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = "test_compute_env" + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + resp = batch_client.create_job_queue( + jobQueueName="test_job_queue", + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + + resp = batch_client.register_job_definition( + jobDefinitionName="sayhellotomylittlefriend", + type="container", + containerProperties={ + "image": "busybox:latest", + "vcpus": 1, + "memory": 128, + "command": ["exit", "1"], + }, + ) + job_def_arn = resp["jobDefinitionArn"] + + resp = batch_client.submit_job( + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn + ) + job_id = resp["jobId"] + + future = datetime.datetime.now() + datetime.timedelta(seconds=30) + + while datetime.datetime.now() < future: + resp = batch_client.describe_jobs(jobs=[job_id]) + + if resp["jobs"][0]["status"] == "FAILED": + break + if resp["jobs"][0]["status"] == "SUCCEEDED": + raise RuntimeError("Batch job succeeded even though it had exit code 1") + time.sleep(0.5) + else: + raise RuntimeError("Batch job timed out") + + +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_dependencies(): + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = "test_compute_env" + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + resp = batch_client.create_job_queue( + jobQueueName="test_job_queue", + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + + resp = batch_client.register_job_definition( + jobDefinitionName="sayhellotomylittlefriend", + type="container", + containerProperties={ + "image": "busybox:latest", + "vcpus": 1, + "memory": 128, + "command": ["echo", "hello"], + }, + ) + job_def_arn = resp["jobDefinitionArn"] + + resp = batch_client.submit_job( + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn + ) + job_id1 = resp["jobId"] + + resp = batch_client.submit_job( + jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn + ) + job_id2 = resp["jobId"] + + depends_on = [ + {"jobId": job_id1, "type": "SEQUENTIAL"}, + {"jobId": job_id2, "type": "SEQUENTIAL"}, + ] + resp = batch_client.submit_job( + jobName="test3", + jobQueue=queue_arn, + jobDefinition=job_def_arn, + dependsOn=depends_on, + ) + job_id3 = resp["jobId"] + + future = datetime.datetime.now() + datetime.timedelta(seconds=30) + + while datetime.datetime.now() < future: + resp = batch_client.describe_jobs(jobs=[job_id1, job_id2, job_id3]) + + if any([job["status"] == "FAILED" for job in resp["jobs"]]): + raise RuntimeError("Batch job failed") + if all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]): + break + time.sleep(0.5) + else: + raise RuntimeError("Batch job timed out") + + resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job") + len(resp["logStreams"]).should.equal(3) + for log_stream in resp["logStreams"]: + ls_name = log_stream["logStreamName"] + + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + [event["message"] for event in resp["events"]].should.equal(["hello"]) + + +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_failed_dependencies(): + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = "test_compute_env" + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + resp = batch_client.create_job_queue( + jobQueueName="test_job_queue", + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + + resp = batch_client.register_job_definition( + jobDefinitionName="sayhellotomylittlefriend", + type="container", + containerProperties={ + "image": "busybox:latest", + "vcpus": 1, + "memory": 128, + "command": ["echo", "hello"], + }, + ) + job_def_arn_success = resp["jobDefinitionArn"] + + resp = batch_client.register_job_definition( + jobDefinitionName="sayhellotomylittlefriend_failed", + type="container", + containerProperties={ + "image": "busybox:latest", + "vcpus": 1, + "memory": 128, + "command": ["exi1", "1"], + }, + ) + job_def_arn_failure = resp["jobDefinitionArn"] + + resp = batch_client.submit_job( + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn_success + ) + + job_id1 = resp["jobId"] + + resp = batch_client.submit_job( + jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn_failure + ) + job_id2 = resp["jobId"] + + depends_on = [ + {"jobId": job_id1, "type": "SEQUENTIAL"}, + {"jobId": job_id2, "type": "SEQUENTIAL"}, + ] + resp = batch_client.submit_job( + jobName="test3", + jobQueue=queue_arn, + jobDefinition=job_def_arn_success, + dependsOn=depends_on, + ) + job_id3 = resp["jobId"] + + future = datetime.datetime.now() + datetime.timedelta(seconds=30) + + # Query batch jobs until all jobs have run. + # Job 2 is supposed to fail and in consequence Job 3 should never run + # and status should change directly from PENDING to FAILED + while datetime.datetime.now() < future: + resp = batch_client.describe_jobs(jobs=[job_id2, job_id3]) + + assert resp["jobs"][0]["status"] != "SUCCEEDED", "Job 2 cannot succeed" + assert resp["jobs"][1]["status"] != "SUCCEEDED", "Job 3 cannot succeed" + + if resp["jobs"][1]["status"] == "FAILED": + break + + time.sleep(0.5) + else: + raise RuntimeError("Batch job timed out") + + +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_container_overrides(): + """ + Test if container overrides have any effect. + Overwrites should be reflected in container description. + Environment variables should be accessible inside docker container + """ + + # Set up environment + + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = "test_compute_env" + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + resp = batch_client.create_job_queue( + jobQueueName="test_job_queue", + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + + job_definition_name = "sleep10" + + # Set up Job Definition + # We will then override the container properties in the actual job + resp = batch_client.register_job_definition( + jobDefinitionName=job_definition_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 512, + "command": ["sleep", "10"], + "environment": [ + {"name": "TEST0", "value": "from job definition"}, + {"name": "TEST1", "value": "from job definition"}, + ], + }, + ) + + job_definition_arn = resp["jobDefinitionArn"] + + # The Job to run, including container overrides + resp = batch_client.submit_job( + jobName="test1", + jobQueue=queue_arn, + jobDefinition=job_definition_name, + containerOverrides={ + "vcpus": 2, + "memory": 1024, + "command": ["printenv"], + "environment": [ + {"name": "TEST0", "value": "from job"}, + {"name": "TEST2", "value": "from job"}, + ], + }, + ) + + job_id = resp["jobId"] + + # Wait until Job finishes + future = datetime.datetime.now() + datetime.timedelta(seconds=30) + + while datetime.datetime.now() < future: + resp_jobs = batch_client.describe_jobs(jobs=[job_id]) + + if resp_jobs["jobs"][0]["status"] == "FAILED": + raise RuntimeError("Batch job failed") + if resp_jobs["jobs"][0]["status"] == "SUCCEEDED": + break + time.sleep(0.5) + else: + raise RuntimeError("Batch job timed out") + + # Getting the log stream to read out env variables inside container + resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job") + + env_var = list() + for stream in resp["logStreams"]: + ls_name = stream["logStreamName"] + + stream_resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + + for event in stream_resp["events"]: + if "TEST" in event["message"] or "AWS" in event["message"]: + key, value = tuple(event["message"].split("=")) + env_var.append({"name": key, "value": value}) + + len(resp_jobs["jobs"]).should.equal(1) + resp_jobs["jobs"][0]["jobId"].should.equal(job_id) + resp_jobs["jobs"][0]["jobQueue"].should.equal(queue_arn) + resp_jobs["jobs"][0]["jobDefinition"].should.equal(job_definition_arn) + resp_jobs["jobs"][0]["container"]["vcpus"].should.equal(2) + resp_jobs["jobs"][0]["container"]["memory"].should.equal(1024) + resp_jobs["jobs"][0]["container"]["command"].should.equal(["printenv"]) + + sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain( + {"name": "TEST0", "value": "from job"} + ) + sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain( + {"name": "TEST1", "value": "from job definition"} + ) + sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain( + {"name": "TEST2", "value": "from job"} + ) + sure.expect(resp_jobs["jobs"][0]["container"]["environment"]).to.contain( + {"name": "AWS_BATCH_JOB_ID", "value": job_id} + ) + + sure.expect(env_var).to.contain({"name": "TEST0", "value": "from job"}) + sure.expect(env_var).to.contain({"name": "TEST1", "value": "from job definition"}) + sure.expect(env_var).to.contain({"name": "TEST2", "value": "from job"}) + + sure.expect(env_var).to.contain({"name": "AWS_BATCH_JOB_ID", "value": job_id})