Batch: add multinode support (#5840)

This commit is contained in:
Tristan Rice 2023-01-14 08:02:32 -08:00 committed by GitHub
parent 173e1549c0
commit a17956927f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 410 additions and 145 deletions

View File

@ -21,6 +21,7 @@ from .exceptions import InvalidParameterValueException, ClientException, Validat
from .utils import (
make_arn_for_compute_env,
make_arn_for_job_queue,
make_arn_for_job,
make_arn_for_task_def,
lowercase_first_key,
)
@ -221,6 +222,7 @@ class JobDefinition(CloudFormationModel):
parameters: Optional[Dict[str, Any]],
_type: str,
container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
tags: Dict[str, str],
retry_strategy: Dict[str, str],
timeout: Dict[str, int],
@ -235,6 +237,7 @@ class JobDefinition(CloudFormationModel):
self.revision = revision or 0
self._region = backend.region_name
self.container_properties = container_properties
self.node_properties = node_properties
self.status = "ACTIVE"
self.parameters = parameters or {}
self.timeout = timeout
@ -242,6 +245,7 @@ class JobDefinition(CloudFormationModel):
self.platform_capabilities = platform_capabilities
self.propagate_tags = propagate_tags
if self.container_properties is not None:
if "resourceRequirements" not in self.container_properties:
self.container_properties["resourceRequirements"] = []
if "secrets" not in self.container_properties:
@ -306,12 +310,14 @@ class JobDefinition(CloudFormationModel):
def _validate(self) -> None:
# For future use when containers arnt the only thing in batch
if self.type not in ("container",):
raise ClientException('type must be one of "container"')
VALID_TYPES = ("container", "multinode")
if self.type not in VALID_TYPES:
raise ClientException(f"type must be one of {VALID_TYPES}")
if not isinstance(self.parameters, dict):
raise ClientException("parameters must be a string to string map")
if self.type == "container":
if "image" not in self.container_properties:
raise ClientException("containerProperties must contain image")
@ -335,6 +341,7 @@ class JobDefinition(CloudFormationModel):
parameters: Optional[Dict[str, Any]],
_type: str,
container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
retry_strategy: Dict[str, Any],
tags: Dict[str, str],
timeout: Dict[str, int],
@ -357,6 +364,7 @@ class JobDefinition(CloudFormationModel):
parameters,
_type,
container_properties,
node_properties=node_properties,
revision=self.revision,
retry_strategy=retry_strategy,
tags=tags,
@ -416,7 +424,16 @@ class JobDefinition(CloudFormationModel):
_type="container",
tags=lowercase_first_key(properties.get("Tags", {})),
retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
container_properties=lowercase_first_key(properties["ContainerProperties"]),
container_properties=(
lowercase_first_key(properties["ContainerProperties"])
if "ContainerProperties" in properties
else None
),
node_properties=(
lowercase_first_key(properties["NodeProperties"])
if "NodeProperties" in properties
else None
),
timeout=lowercase_first_key(properties.get("timeout", {})),
platform_capabilities=None,
propagate_tags=None,
@ -466,6 +483,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.timeout = timeout
self.all_jobs = all_jobs
self.arn = make_arn_for_job(
job_def.backend.account_id, self.job_id, job_def._region
)
self.stop = False
self.exit_code: Optional[int] = None
@ -483,6 +504,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
def describe_short(self) -> Dict[str, Any]:
result = {
"jobId": self.job_id,
"jobArn": self.arn,
"jobName": self.job_name,
"createdAt": datetime2int_milliseconds(self.job_created_at),
"status": self.status,
@ -502,7 +524,13 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result = self.describe_short()
result["jobQueue"] = self.job_queue.arn
result["dependsOn"] = self.depends_on or []
if self.job_definition.type == "container":
result["container"] = self._container_details()
elif self.job_definition.type == "multinode":
result["container"] = {
"logStreamName": self.log_stream_name,
}
result["nodeProperties"] = self.job_definition.node_properties
if self.job_stopped:
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
if self.timeout:
@ -575,6 +603,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
try:
import docker
containers: List[docker.models.containers.Container] = []
self.advance()
while self.status == "SUBMITTED":
# Wait until we've moved onto state 'PENDING'
@ -585,25 +615,29 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self.depends_on and not self._wait_for_dependencies():
return
image = self.job_definition.container_properties.get(
"image", "alpine:latest"
)
privileged = self.job_definition.container_properties.get(
"privileged", False
)
cmd = self._get_container_property(
"command",
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
)
environment = {
e["name"]: e["value"]
for e in self._get_container_property("environment", [])
}
container_kwargs = []
if self.job_definition.container_properties:
volumes = {
v["name"]: v["host"]
for v in self._get_container_property("volumes", [])
}
mounts = [
container_kwargs.append(
{
"image": self.job_definition.container_properties.get(
"image", "alpine:latest"
),
"privileged": self.job_definition.container_properties.get(
"privileged", False
),
"command": self._get_container_property(
"command",
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
),
"environment": {
e["name"]: e["value"]
for e in self._get_container_property("environment", [])
},
"mounts": [
docker.types.Mount(
m["containerPath"],
volumes[m["sourceVolume"]]["sourcePath"],
@ -611,8 +645,57 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
read_only=m["readOnly"],
)
for m in self._get_container_property("mountPoints", [])
]
name = f"{self.job_name}-{self.job_id}"
],
"name": f"{self.job_name}-{self.job_id}",
}
)
else:
node_properties = self.job_definition.node_properties
num_nodes = node_properties["numNodes"]
node_containers = {}
for node_range in node_properties["nodeRangeProperties"]:
start, sep, end = node_range["targetNodes"].partition(":")
if sep == "":
start = end = int(start)
else:
if start == "":
start = 0
else:
start = int(start)
if end == "":
end = num_nodes - 1
else:
end = int(end)
for i in range(start, end + 1):
node_containers[i] = node_range["container"]
for i in range(num_nodes):
spec = node_containers[i]
volumes = {v["name"]: v["host"] for v in spec.get("volumes", [])}
container_kwargs.append(
{
"image": spec.get("image", "alpine:latest"),
"privileged": spec.get("privileged", False),
"command": spec.get(
"command",
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
),
"environment": {
e["name"]: e["value"]
for e in spec.get("environment", [])
},
"mounts": [
docker.types.Mount(
m["containerPath"],
volumes[m["sourceVolume"]]["sourcePath"],
type="bind",
read_only=m["readOnly"],
)
for m in spec.get("mountPoints", [])
],
"name": f"{self.job_name}-{self.job_id}-{i}",
}
)
self.advance()
while self.status == "PENDING":
@ -632,19 +715,20 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
else {}
)
for kwargs in container_kwargs:
environment = kwargs["environment"]
environment["MOTO_HOST"] = settings.moto_server_host()
environment["MOTO_PORT"] = settings.moto_server_port()
environment[
"MOTO_HTTP_ENDPOINT"
] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}'
run_kwargs = dict()
network_name = settings.moto_network_name()
network_mode = settings.moto_network_mode()
if network_name:
run_kwargs["network"] = network_name
kwargs["network"] = network_name
elif network_mode:
run_kwargs["network_mode"] = network_mode
kwargs["network_mode"] = network_mode
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
self.advance()
@ -656,18 +740,22 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
while self.status == "STARTING":
# Wait until the state is no longer runnable, but 'RUNNING'
sleep(0.5)
for kwargs in container_kwargs:
if len(containers) > 0:
env = kwargs["environment"]
ip = containers[0].attrs["NetworkSettings"]["IPAddress"]
env["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"] = ip
container = self.docker_client.containers.run(
image,
cmd,
detach=True,
name=name,
log_config=log_config,
environment=environment,
mounts=mounts,
privileged=privileged,
extra_hosts=extra_hosts,
**run_kwargs,
**kwargs,
)
container.reload()
containers.append(container)
for i, container in enumerate(containers):
try:
container.reload()
@ -733,30 +821,41 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
# Send to cloudwatch
self.log_stream_name = self._stream_name
self._log_backend.ensure_log_group(self._log_group, None)
self._log_backend.create_log_stream(self._log_group, self._stream_name)
self._log_backend.ensure_log_stream(
self._log_group, self.log_stream_name
)
self._log_backend.put_log_events(
self._log_group, self._stream_name, logs
self._log_group, self.log_stream_name, logs
)
result = container.wait() or {}
exit_code = result.get("StatusCode", 0)
self.exit_code = exit_code
job_failed = self.stop or exit_code > 0
self._mark_stopped(success=not job_failed)
if job_failed:
self._mark_stopped(success=False)
break
except Exception as err:
logger.error(
f"Failed to run AWS Batch container {self.name}. Error {err}"
)
self._mark_stopped(success=False)
container.kill()
finally:
container.remove()
self._mark_stopped(success=True)
except Exception as err:
logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}")
self._mark_stopped(success=False)
finally:
for container in containers:
container.reload()
if container.status == "running":
container.kill()
container.remove()
def _mark_stopped(self, success: bool = True) -> None:
if self.job_stopped:
return
# Ensure that job_stopped/job_stopped_at-attributes are set first
# The describe-method needs them immediately when status is set
self.job_stopped = True
@ -1437,6 +1536,7 @@ class BatchBackend(BaseBackend):
tags: Dict[str, str],
retry_strategy: Dict[str, Any],
container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
timeout: Dict[str, int],
platform_capabilities: List[str],
propagate_tags: bool,
@ -1457,6 +1557,7 @@ class BatchBackend(BaseBackend):
parameters,
_type,
container_properties,
node_properties=node_properties,
tags=tags,
retry_strategy=retry_strategy,
timeout=timeout,
@ -1467,7 +1568,13 @@ class BatchBackend(BaseBackend):
else:
# Make new jobdef
job_def = job_def.update(
parameters, _type, container_properties, retry_strategy, tags, timeout
parameters,
_type,
container_properties,
node_properties,
retry_strategy,
tags,
timeout,
)
self._job_definitions[job_def.arn] = job_def
@ -1562,7 +1669,11 @@ class BatchBackend(BaseBackend):
result = []
for key, job in self._jobs.items():
if len(job_filter) > 0 and key not in job_filter:
if (
len(job_filter) > 0
and key not in job_filter
and job.arn not in job_filter
):
continue
result.append(job.describe())
@ -1570,7 +1681,10 @@ class BatchBackend(BaseBackend):
return result
def list_jobs(
self, job_queue_name: str, job_status: Optional[str] = None
self,
job_queue_name: str,
job_status: Optional[str] = None,
filters: Optional[List[Dict[str, Any]]] = None,
) -> List[Job]:
"""
Pagination is not yet implemented
@ -1598,6 +1712,18 @@ class BatchBackend(BaseBackend):
if job_status is not None and job.status != job_status:
continue
if filters is not None:
matches = True
for filt in filters:
name = filt["name"]
values = filt["values"]
if name == "JOB_NAME":
if job.job_name not in values:
matches = False
break
if not matches:
continue
jobs.append(job)
return jobs

View File

@ -134,6 +134,7 @@ class BatchResponse(BaseResponse):
# RegisterJobDefinition
def registerjobdefinition(self) -> str:
container_properties = self._get_param("containerProperties")
node_properties = self._get_param("nodeProperties")
def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters")
tags = self._get_param("tags")
@ -149,6 +150,7 @@ class BatchResponse(BaseResponse):
tags=tags,
retry_strategy=retry_strategy,
container_properties=container_properties,
node_properties=node_properties,
timeout=timeout,
platform_capabilities=platform_capabilities,
propagate_tags=propagate_tags,
@ -215,8 +217,9 @@ class BatchResponse(BaseResponse):
def listjobs(self) -> str:
job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus")
filters = self._get_param("filters")
jobs = self.batch_backend.list_jobs(job_queue, job_status)
jobs = self.batch_backend.list_jobs(job_queue, job_status, filters)
result = {"jobSummaryList": [job.describe_short() for job in jobs]}
return json.dumps(result)

View File

@ -9,6 +9,10 @@ def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str:
return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}"
def make_arn_for_job(account_id: str, job_id: str, region_name: str) -> str:
return f"arn:aws:batch:{region_name}:{account_id}:job/{job_id}"
def make_arn_for_task_def(
account_id: str, name: str, revision: int, region_name: str
) -> str:

View File

@ -649,6 +649,15 @@ class LogsBackend(BaseBackend):
log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name)
def ensure_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
if log_stream_name in self.groups[log_group_name].streams:
return
self.create_log_stream(log_group_name, log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name):
if log_group_name not in self.groups:
raise ResourceNotFoundException()

View File

@ -151,6 +151,72 @@ def test_submit_job():
attempt.should.have.key("stoppedAt").equals(stopped_at)
@mock_logs
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
@pytest.mark.network
def test_submit_job_multinode():
ec2_client, iam_client, _, logs_client, batch_client = _get_clients()
_, _, _, iam_arn = _setup(ec2_client, iam_client)
start_time_milliseconds = time.time() * 1000
job_def_name = str(uuid4())[0:6]
commands = ["echo", "hello"]
job_def_arn, queue_arn = prepare_multinode_job(
batch_client, commands, iam_arn, job_def_name
)
resp = batch_client.submit_job(
jobName=str(uuid4())[0:6], jobQueue=queue_arn, jobDefinition=job_def_arn
)
job_id = resp["jobId"]
# Test that describe_jobs() returns 'createdAt'
# github.com/getmoto/moto/issues/4364
resp = batch_client.describe_jobs(jobs=[job_id])
created_at = resp["jobs"][0]["createdAt"]
created_at.should.be.greater_than(start_time_milliseconds)
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
resp = logs_client.describe_log_streams(
logGroupName="/aws/batch/job", logStreamNamePrefix=job_def_name
)
resp["logStreams"].should.have.length_of(1)
ls_name = resp["logStreams"][0]["logStreamName"]
resp = logs_client.get_log_events(
logGroupName="/aws/batch/job", logStreamName=ls_name
)
[event["message"] for event in resp["events"]].should.equal(["hello", "hello"])
# Test that describe_jobs() returns timestamps in milliseconds
# github.com/getmoto/moto/issues/4364
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
created_at = job["createdAt"]
started_at = job["startedAt"]
stopped_at = job["stoppedAt"]
created_at.should.be.greater_than(start_time_milliseconds)
started_at.should.be.greater_than(start_time_milliseconds)
stopped_at.should.be.greater_than(start_time_milliseconds)
# Verify we track attempts
job.should.have.key("attempts").length_of(1)
attempt = job["attempts"][0]
attempt.should.have.key("container")
attempt["container"].should.have.key("containerInstanceArn")
attempt["container"].should.have.key("logStreamName").equals(
job["container"]["logStreamName"]
)
attempt["container"].should.have.key("networkInterfaces")
attempt["container"].should.have.key("taskArn")
attempt.should.have.key("startedAt").equals(started_at)
attempt.should.have.key("stoppedAt").equals(stopped_at)
@mock_logs
@mock_ec2
@mock_ecs
@ -205,6 +271,18 @@ def test_list_jobs():
job.should.have.key("stoppedAt")
job.should.have.key("container").should.have.key("exitCode").equals(0)
filtered_jobs = batch_client.list_jobs(
jobQueue=queue_arn,
filters=[
{
"name": "JOB_NAME",
"values": ["test2"],
}
],
)["jobSummaryList"]
filtered_jobs.should.have.length_of(1)
filtered_jobs[0]["jobName"].should.equal("test2")
@mock_logs
@mock_ec2
@ -772,6 +850,51 @@ def prepare_job(batch_client, commands, iam_arn, job_def_name):
return job_def_arn, queue_arn
def prepare_multinode_job(batch_client, commands, iam_arn, job_def_name):
compute_name = str(uuid4())[0:6]
resp = batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type="UNMANAGED",
state="ENABLED",
serviceRole=iam_arn,
)
arn = resp["computeEnvironmentArn"]
resp = batch_client.create_job_queue(
jobQueueName=str(uuid4())[0:6],
state="ENABLED",
priority=123,
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
)
queue_arn = resp["jobQueueArn"]
container = {
"image": "busybox:latest",
"vcpus": 1,
"memory": 128,
"command": commands,
}
resp = batch_client.register_job_definition(
jobDefinitionName=job_def_name,
type="multinode",
nodeProperties={
"mainNode": 0,
"numNodes": 2,
"nodeRangeProperties": [
{
"container": container,
"targetNodes": "0",
},
{
"container": container,
"targetNodes": "1",
},
],
},
)
job_def_arn = resp["jobDefinitionArn"]
return job_def_arn, queue_arn
@mock_batch
def test_update_job_definition():
_, _, _, _, batch_client = _get_clients()