Batch: arraySize and child jobs (#6541)

This commit is contained in:
rafcio19 2023-10-12 16:06:57 +02:00 committed by GitHub
parent 7e891f7880
commit e5944307fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 19 deletions

View File

@ -1,13 +1,14 @@
import datetime
import dateutil.parser
import logging
import re import re
import threading
import time
from sys import platform
from itertools import cycle from itertools import cycle
from time import sleep from time import sleep
from typing import Any, Dict, List, Tuple, Optional, Set from typing import Any, Dict, List, Tuple, Optional, Set
import datetime
import time
import logging
import threading
import dateutil.parser
from sys import platform
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.iam.models import iam_backends, IAMBackend from moto.iam.models import iam_backends, IAMBackend
@ -482,6 +483,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
depends_on: Optional[List[Dict[str, str]]], depends_on: Optional[List[Dict[str, str]]],
all_jobs: Dict[str, "Job"], all_jobs: Dict[str, "Job"],
timeout: Optional[Dict[str, int]], timeout: Optional[Dict[str, int]],
array_properties: Dict[str, Any],
provided_job_id: Optional[str] = None,
): ):
threading.Thread.__init__(self) threading.Thread.__init__(self)
DockerModel.__init__(self) DockerModel.__init__(self)
@ -492,7 +495,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
) )
self.job_name = name self.job_name = name
self.job_id = str(mock_random.uuid4()) self.job_id = provided_job_id or str(mock_random.uuid4())
self.job_definition = job_def self.job_definition = job_def
self.container_overrides: Dict[str, Any] = container_overrides or {} self.container_overrides: Dict[str, Any] = container_overrides or {}
self.job_queue = job_queue self.job_queue = job_queue
@ -505,6 +508,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.depends_on = depends_on self.depends_on = depends_on
self.timeout = timeout self.timeout = timeout
self.all_jobs = all_jobs self.all_jobs = all_jobs
self.array_properties: Dict[str, Any] = array_properties
self.arn = make_arn_for_job( self.arn = make_arn_for_job(
job_def.backend.account_id, self.job_id, job_def._region job_def.backend.account_id, self.job_id, job_def._region
@ -514,6 +518,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.exit_code: Optional[int] = None self.exit_code: Optional[int] = None
self.daemon = True self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id self.name = "MOTO-BATCH-" + self.job_id
self._log_backend = log_backend self._log_backend = log_backend
@ -523,6 +528,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.attempts: List[Dict[str, Any]] = [] self.attempts: List[Dict[str, Any]] = []
self.latest_attempt: Optional[Dict[str, Any]] = None self.latest_attempt: Optional[Dict[str, Any]] = None
self._child_jobs: Optional[List[Job]] = None
def describe_short(self) -> Dict[str, Any]: def describe_short(self) -> Dict[str, Any]:
result = { result = {
@ -560,6 +566,26 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self.timeout: if self.timeout:
result["timeout"] = self.timeout result["timeout"] = self.timeout
result["attempts"] = self.attempts result["attempts"] = self.attempts
if self._child_jobs:
child_statuses = {
"STARTING": 0,
"FAILED": 0,
"RUNNING": 0,
"SUCCEEDED": 0,
"RUNNABLE": 0,
"SUBMITTED": 0,
"PENDING": 0,
}
for child_job in self._child_jobs:
if child_job.status is not None:
child_statuses[child_job.status] += 1
result["arrayProperties"] = {
"statusSummary": child_statuses,
"size": len(self._child_jobs),
}
if len(self._child_jobs) == child_statuses["SUCCEEDED"]:
self.status = "SUCCEEDED"
result["status"] = self.status
return result return result
def _container_details(self) -> Dict[str, Any]: def _container_details(self) -> Dict[str, Any]:
@ -675,7 +701,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
) )
for m in self._get_container_property("mountPoints", []) for m in self._get_container_property("mountPoints", [])
], ],
"name": f"{self.job_name}-{self.job_id}", "name": f"{self.job_name}-{self.job_id.replace(':', '-')}",
} }
) )
else: else:
@ -1704,6 +1730,7 @@ class BatchBackend(BaseBackend):
job_name: str, job_name: str,
job_def_id: str, job_def_id: str,
job_queue: str, job_queue: str,
array_properties: Dict[str, int],
depends_on: Optional[List[Dict[str, str]]] = None, depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None, container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None, timeout: Optional[Dict[str, int]] = None,
@ -1732,12 +1759,36 @@ class BatchBackend(BaseBackend):
depends_on=depends_on, depends_on=depends_on,
all_jobs=self._jobs, all_jobs=self._jobs,
timeout=timeout, timeout=timeout,
array_properties=array_properties or {},
) )
self._jobs[job.job_id] = job self._jobs[job.job_id] = job
# Here comes the fun if "size" in array_properties:
job.start() child_jobs = []
for array_index in range(array_properties["size"]):
provided_job_id = f"{job.job_id}:{array_index}"
child_job = Job(
job_name,
job_def,
queue,
log_backend=self.logs_backend,
container_overrides=container_overrides,
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties={"statusSummary": {}, "index": array_index},
provided_job_id=provided_job_id,
)
child_jobs.append(child_job)
self._jobs[child_job.job_id] = child_job
child_job.start()
# The 'parent' job doesn't need to be executed
# it just needs to keep track of it's children
job._child_jobs = child_jobs
else:
# Here comes the fun
job.start()
return job_name, job.job_id return job_name, job.job_id
def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]: def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]:

View File

@ -210,6 +210,7 @@ class BatchResponse(BaseResponse):
job_name = self._get_param("jobName") job_name = self._get_param("jobName")
job_queue = self._get_param("jobQueue") job_queue = self._get_param("jobQueue")
timeout = self._get_param("timeout") timeout = self._get_param("timeout")
array_properties = self._get_param("arrayProperties", {})
name, job_id = self.batch_backend.submit_job( name, job_id = self.batch_backend.submit_job(
job_name, job_name,
@ -218,6 +219,7 @@ class BatchResponse(BaseResponse):
depends_on=depends_on, depends_on=depends_on,
container_overrides=container_overrides, container_overrides=container_overrides,
timeout=timeout, timeout=timeout,
array_properties=array_properties,
) )
result = {"jobId": job_id, "jobName": name} result = {"jobId": job_id, "jobName": name}

View File

@ -1,9 +1,9 @@
from ..batch.models import ( from ..batch.models import (
batch_backends, batch_backends,
BaseBackend, BaseBackend,
Job,
ClientException,
BatchBackend, BatchBackend,
ClientException,
Job,
) )
from ..core import BackendDict from ..core import BackendDict
@ -42,7 +42,7 @@ class BatchSimpleBackend(BaseBackend):
"url_bases", "url_bases",
]: ]:
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
if name in ["submit_job"]: if name in ["submit_job", "_mark_job_as_finished"]:
def newfunc(*args: Any, **kwargs: Any) -> Any: def newfunc(*args: Any, **kwargs: Any) -> Any:
attr = object.__getattribute__(self, name) attr = object.__getattribute__(self, name)
@ -57,6 +57,7 @@ class BatchSimpleBackend(BaseBackend):
job_name: str, job_name: str,
job_def_id: str, job_def_id: str,
job_queue: str, job_queue: str,
array_properties: Dict[str, Any],
depends_on: Optional[List[Dict[str, str]]] = None, depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None, container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None, timeout: Optional[Dict[str, int]] = None,
@ -79,13 +80,40 @@ class BatchSimpleBackend(BaseBackend):
depends_on=depends_on, depends_on=depends_on,
all_jobs=self._jobs, all_jobs=self._jobs,
timeout=timeout, timeout=timeout,
array_properties=array_properties,
) )
self.backend._jobs[job.job_id] = job
if "size" in array_properties:
child_jobs: List[Job] = []
for array_index in range(array_properties.get("size", 0)):
provided_job_id = f"{job.job_id}:{array_index}"
child_job = Job(
job_name,
job_def,
queue,
log_backend=self.logs_backend,
container_overrides=container_overrides,
depends_on=depends_on,
all_jobs=self._jobs,
timeout=timeout,
array_properties={"statusSummary": {}, "index": array_index},
provided_job_id=provided_job_id,
)
self._mark_job_as_finished(include_start_attempt=True, job=child_job)
child_jobs.append(child_job)
self._mark_job_as_finished(include_start_attempt=False, job=job)
job._child_jobs = child_jobs
else:
self._mark_job_as_finished(include_start_attempt=True, job=job)
return job_name, job.job_id
def _mark_job_as_finished(self, include_start_attempt: bool, job: Job) -> None:
self.backend._jobs[job.job_id] = job
job.job_started_at = datetime.datetime.now() job.job_started_at = datetime.datetime.now()
job.log_stream_name = job._stream_name job.log_stream_name = job._stream_name
job._start_attempt() if include_start_attempt:
job._start_attempt()
# We don't want to actually run the job - just mark it as succeeded or failed # We don't want to actually run the job - just mark it as succeeded or failed
# depending on whether env var MOTO_SIMPLE_BATCH_FAIL_AFTER is set # depending on whether env var MOTO_SIMPLE_BATCH_FAIL_AFTER is set
# if MOTO_SIMPLE_BATCH_FAIL_AFTER is set to an integer then batch will # if MOTO_SIMPLE_BATCH_FAIL_AFTER is set to an integer then batch will
@ -104,7 +132,5 @@ class BatchSimpleBackend(BaseBackend):
else: else:
job._mark_stopped(success=True) job._mark_stopped(success=True)
return job_name, job.job_id
batch_simple_backends = BackendDict(BatchSimpleBackend, "batch") batch_simple_backends = BackendDict(BatchSimpleBackend, "batch")

View File

@ -86,6 +86,50 @@ def test_submit_job_by_name():
assert resp_jobs["jobs"][0]["jobDefinition"] == job_definition_arn assert resp_jobs["jobs"][0]["jobDefinition"] == job_definition_arn
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_submit_job_array_size():
# Setup
job_definition_name = f"sleep10_{str(uuid4())[0:6]}"
ec2_client, iam_client, _, _, batch_client = _get_clients()
commands = ["echo", "hello"]
_, _, _, iam_arn = _setup(ec2_client, iam_client)
_, queue_arn = prepare_job(batch_client, commands, iam_arn, job_definition_name)
# Execute
resp = batch_client.submit_job(
jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
arrayProperties={"size": 2},
)
# Verify
job_id = resp["jobId"]
child_job_1_id = f"{job_id}:0"
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
assert job["arrayProperties"]["size"] == 2
assert job["attempts"] == []
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
# If the main job is successful, that means that all child jobs are successful
assert job["arrayProperties"]["size"] == 2
assert job["arrayProperties"]["statusSummary"]["SUCCEEDED"] == 2
# Main job still has no attempts - because only the child jobs are executed
assert job["attempts"] == []
child_job_1 = batch_client.describe_jobs(jobs=[child_job_1_id])["jobs"][0]
assert child_job_1["status"] == "SUCCEEDED"
# Child job was executed
assert len(child_job_1["attempts"]) == 1
# SLOW TESTS # SLOW TESTS

View File

@ -23,7 +23,9 @@ def test_submit_job_by_name():
) )
resp = batch_client.submit_job( resp = batch_client.submit_job(
jobName="test1", jobQueue=queue_arn, jobDefinition=job_definition_name jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
) )
job_id = resp["jobId"] job_id = resp["jobId"]
@ -41,6 +43,44 @@ def test_submit_job_by_name():
assert "logStreamName" in job["container"] assert "logStreamName" in job["container"]
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch_simple
def test_submit_job_array_size():
# Setup
job_definition_name = f"sleep10_{str(uuid4())[0:6]}"
batch_client, _, queue_arn = setup_common_batch_simple(job_definition_name)
# Execute
resp = batch_client.submit_job(
jobName="test1",
jobQueue=queue_arn,
jobDefinition=job_definition_name,
arrayProperties={"size": 2},
)
# Verify
job_id = resp["jobId"]
child_job_1_id = f"{job_id}:0"
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
assert job["arrayProperties"]["size"] == 2
assert job["attempts"] == []
# If the main job is successful, that means that all child jobs are successful
assert job["arrayProperties"]["size"] == 2
assert job["arrayProperties"]["statusSummary"]["SUCCEEDED"] == 2
# Main job still has no attempts - because only the child jobs are executed
assert job["attempts"] == []
child_job_1 = batch_client.describe_jobs(jobs=[child_job_1_id])["jobs"][0]
assert child_job_1["status"] == "SUCCEEDED"
# Child job was executed
assert len(child_job_1["attempts"]) == 1
@mock_batch_simple @mock_batch_simple
def test_update_job_definition(): def test_update_job_definition():
_, _, _, _, batch_client = _get_clients() _, _, _, _, batch_client = _get_clients()