Batch: arraySize and child jobs (#6541)

This commit is contained in:
rafcio19 2023-10-12 16:06:57 +02:00 committed by GitHub
parent 7e891f7880
commit e5944307fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 19 deletions

View File

@ -1,13 +1,14 @@
import datetime
import dateutil.parser
import logging
import re
import threading
import time
from sys import platform
from itertools import cycle
from time import sleep
from typing import Any, Dict, List, Tuple, Optional, Set
import datetime
import time
import logging
import threading
import dateutil.parser
from sys import platform
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.iam.models import iam_backends, IAMBackend
@ -482,6 +483,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
depends_on: Optional[List[Dict[str, str]]],
all_jobs: Dict[str, "Job"],
timeout: Optional[Dict[str, int]],
array_properties: Dict[str, Any],
provided_job_id: Optional[str] = None,
):
threading.Thread.__init__(self)
DockerModel.__init__(self)
@ -492,7 +495,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
)
self.job_name = name
self.job_id = str(mock_random.uuid4())
self.job_id = provided_job_id or str(mock_random.uuid4())
self.job_definition = job_def
self.container_overrides: Dict[str, Any] = container_overrides or {}
self.job_queue = job_queue
@ -505,6 +508,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.depends_on = depends_on
self.timeout = timeout
self.all_jobs = all_jobs
self.array_properties: Dict[str, Any] = array_properties
self.arn = make_arn_for_job(
job_def.backend.account_id, self.job_id, job_def._region
@ -514,6 +518,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.exit_code: Optional[int] = None
self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id
self._log_backend = log_backend
@ -523,6 +528,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.attempts: List[Dict[str, Any]] = []
self.latest_attempt: Optional[Dict[str, Any]] = None
self._child_jobs: Optional[List[Job]] = None
def describe_short(self) -> Dict[str, Any]:
result = {
@ -560,6 +566,26 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self.timeout:
result["timeout"] = self.timeout
result["attempts"] = self.attempts
if self._child_jobs:
child_statuses = {
"STARTING": 0,
"FAILED": 0,
"RUNNING": 0,
"SUCCEEDED": 0,
"RUNNABLE": 0,
"SUBMITTED": 0,
"PENDING": 0,
}
for child_job in self._child_jobs:
if child_job.status is not None:
child_statuses[child_job.status] += 1
result["arrayProperties"] = {
"statusSummary": child_statuses,
"size": len(self._child_jobs),
}
if len(self._child_jobs) == child_statuses["SUCCEEDED"]:
self.status = "SUCCEEDED"
result["status"] = self.status
return result
def _container_details(self) -> Dict[str, Any]:
@ -675,7 +701,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
)
for m in self._get_container_property("mountPoints", [])
],
"name": f"{self.job_name}-{self.job_id}",
"name": f"{self.job_name}-{self.job_id.replace(':', '-')}",
}
)
else:
@ -1704,6 +1730,7 @@ class BatchBackend(BaseBackend):
job_name: str,
job_def_id: str,
job_queue: str,
array_properties: Dict[str, int],
depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None,
@ -1732,12 +1759,36 @@ class BatchBackend(BaseBackend):
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties=array_properties or {},
)
self._jobs[job.job_id] = job
# Here comes the fun
job.start()
if "size" in array_properties:
child_jobs = []
for array_index in range(array_properties["size"]):
provided_job_id = f"{job.job_id}:{array_index}"
child_job = Job(
job_name,
job_def,
queue,
log_backend=self.logs_backend,
container_overrides=container_overrides,
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties={"statusSummary": {}, "index": array_index},
provided_job_id=provided_job_id,
)
child_jobs.append(child_job)
self._jobs[child_job.job_id] = child_job
child_job.start()
# The 'parent' job doesn't need to be executed
# it just needs to keep track of it's children
job._child_jobs = child_jobs
else:
# Here comes the fun
job.start()
return job_name, job.job_id
def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]:

View File

@ -210,6 +210,7 @@ class BatchResponse(BaseResponse):
job_name = self._get_param("jobName")
job_queue = self._get_param("jobQueue")
timeout = self._get_param("timeout")
array_properties = self._get_param("arrayProperties", {})
name, job_id = self.batch_backend.submit_job(
job_name,
@ -218,6 +219,7 @@ class BatchResponse(BaseResponse):
depends_on=depends_on,
container_overrides=container_overrides,
timeout=timeout,
array_properties=array_properties,
)
result = {"jobId": job_id, "jobName": name}

View File

@ -1,9 +1,9 @@
from ..batch.models import (
batch_backends,
BaseBackend,
Job,
ClientException,
BatchBackend,
ClientException,
Job,
)
from ..core import BackendDict
@ -42,7 +42,7 @@ class BatchSimpleBackend(BaseBackend):
"url_bases",
]:
return object.__getattribute__(self, name)
if name in ["submit_job"]:
if name in ["submit_job", "_mark_job_as_finished"]:
def newfunc(*args: Any, **kwargs: Any) -> Any:
attr = object.__getattribute__(self, name)
@ -57,6 +57,7 @@ class BatchSimpleBackend(BaseBackend):
job_name: str,
job_def_id: str,
job_queue: str,
array_properties: Dict[str, Any],
depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None,
@ -79,13 +80,40 @@ class BatchSimpleBackend(BaseBackend):
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties=array_properties,
)
self.backend._jobs[job.job_id] = job
if "size" in array_properties:
child_jobs: List[Job] = []
for array_index in range(array_properties.get("size", 0)):
provided_job_id = f"{job.job_id}:{array_index}"
child_job = Job(
job_name,
job_def,
queue,
log_backend=self.logs_backend,
container_overrides=container_overrides,
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties={"statusSummary": {}, "index": array_index},
provided_job_id=provided_job_id,
)
self._mark_job_as_finished(include_start_attempt=True, job=child_job)
child_jobs.append(child_job)
self._mark_job_as_finished(include_start_attempt=False, job=job)
job._child_jobs = child_jobs
else:
self._mark_job_as_finished(include_start_attempt=True, job=job)
return job_name, job.job_id
def _mark_job_as_finished(self, include_start_attempt: bool, job: Job) -> None:
self.backend._jobs[job.job_id] = job
job.job_started_at = datetime.datetime.now()
job.log_stream_name = job._stream_name
job._start_attempt()
if include_start_attempt:
job._start_attempt()
# We don't want to actually run the job - just mark it as succeeded or failed
# depending on whether env var MOTO_SIMPLE_BATCH_FAIL_AFTER is set
# if MOTO_SIMPLE_BATCH_FAIL_AFTER is set to an integer then batch will
@ -104,7 +132,5 @@ class BatchSimpleBackend(BaseBackend):
else:
job._mark_stopped(success=True)
return job_name, job.job_id
batch_simple_backends = BackendDict(BatchSimpleBackend, "batch")

View File

@ -86,6 +86,50 @@ def test_submit_job_by_name():
assert resp_jobs["jobs"][0]["jobDefinition"] == job_definition_arn
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_submit_job_array_size():
# Setup
job_definition_name = f"sleep10_{str(uuid4())[0:6]}"
ec2_client, iam_client, _, _, batch_client = _get_clients()
commands = ["echo", "hello"]
_, _, _, iam_arn = _setup(ec2_client, iam_client)
_, queue_arn = prepare_job(batch_client, commands, iam_arn, job_definition_name)
# Execute
resp = batch_client.submit_job(
jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
arrayProperties={"size": 2},
)
# Verify
job_id = resp["jobId"]
child_job_1_id = f"{job_id}:0"
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
assert job["arrayProperties"]["size"] == 2
assert job["attempts"] == []
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
# If the main job is successful, that means that all child jobs are successful
assert job["arrayProperties"]["size"] == 2
assert job["arrayProperties"]["statusSummary"]["SUCCEEDED"] == 2
# Main job still has no attempts - because only the child jobs are executed
assert job["attempts"] == []
child_job_1 = batch_client.describe_jobs(jobs=[child_job_1_id])["jobs"][0]
assert child_job_1["status"] == "SUCCEEDED"
# Child job was executed
assert len(child_job_1["attempts"]) == 1
# SLOW TESTS

View File

@ -23,7 +23,9 @@ def test_submit_job_by_name():
)
resp = batch_client.submit_job(
jobName="test1", jobQueue=queue_arn, jobDefinition=job_definition_name
jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
)
job_id = resp["jobId"]
@ -41,6 +43,44 @@ def test_submit_job_by_name():
assert "logStreamName" in job["container"]
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch_simple
def test_submit_job_array_size():
# Setup
job_definition_name = f"sleep10_{str(uuid4())[0:6]}"
batch_client, _, queue_arn = setup_common_batch_simple(job_definition_name)
# Execute
resp = batch_client.submit_job(
jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
arrayProperties={"size": 2},
)
# Verify
job_id = resp["jobId"]
child_job_1_id = f"{job_id}:0"
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
assert job["arrayProperties"]["size"] == 2
assert job["attempts"] == []
# If the main job is successful, that means that all child jobs are successful
assert job["arrayProperties"]["size"] == 2
assert job["arrayProperties"]["statusSummary"]["SUCCEEDED"] == 2
# Main job still has no attempts - because only the child jobs are executed
assert job["attempts"] == []
child_job_1 = batch_client.describe_jobs(jobs=[child_job_1_id])["jobs"][0]
assert child_job_1["status"] == "SUCCEEDED"
# Child job was executed
assert len(child_job_1["attempts"]) == 1
@mock_batch_simple
def test_update_job_definition():
_, _, _, _, batch_client = _get_clients()