Batch: add multinode support (#5840)
This commit is contained in:
parent
173e1549c0
commit
a17956927f
@ -21,6 +21,7 @@ from .exceptions import InvalidParameterValueException, ClientException, Validat
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
make_arn_for_compute_env,
|
make_arn_for_compute_env,
|
||||||
make_arn_for_job_queue,
|
make_arn_for_job_queue,
|
||||||
|
make_arn_for_job,
|
||||||
make_arn_for_task_def,
|
make_arn_for_task_def,
|
||||||
lowercase_first_key,
|
lowercase_first_key,
|
||||||
)
|
)
|
||||||
@ -221,6 +222,7 @@ class JobDefinition(CloudFormationModel):
|
|||||||
parameters: Optional[Dict[str, Any]],
|
parameters: Optional[Dict[str, Any]],
|
||||||
_type: str,
|
_type: str,
|
||||||
container_properties: Dict[str, Any],
|
container_properties: Dict[str, Any],
|
||||||
|
node_properties: Dict[str, Any],
|
||||||
tags: Dict[str, str],
|
tags: Dict[str, str],
|
||||||
retry_strategy: Dict[str, str],
|
retry_strategy: Dict[str, str],
|
||||||
timeout: Dict[str, int],
|
timeout: Dict[str, int],
|
||||||
@ -235,6 +237,7 @@ class JobDefinition(CloudFormationModel):
|
|||||||
self.revision = revision or 0
|
self.revision = revision or 0
|
||||||
self._region = backend.region_name
|
self._region = backend.region_name
|
||||||
self.container_properties = container_properties
|
self.container_properties = container_properties
|
||||||
|
self.node_properties = node_properties
|
||||||
self.status = "ACTIVE"
|
self.status = "ACTIVE"
|
||||||
self.parameters = parameters or {}
|
self.parameters = parameters or {}
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@ -242,10 +245,11 @@ class JobDefinition(CloudFormationModel):
|
|||||||
self.platform_capabilities = platform_capabilities
|
self.platform_capabilities = platform_capabilities
|
||||||
self.propagate_tags = propagate_tags
|
self.propagate_tags = propagate_tags
|
||||||
|
|
||||||
if "resourceRequirements" not in self.container_properties:
|
if self.container_properties is not None:
|
||||||
self.container_properties["resourceRequirements"] = []
|
if "resourceRequirements" not in self.container_properties:
|
||||||
if "secrets" not in self.container_properties:
|
self.container_properties["resourceRequirements"] = []
|
||||||
self.container_properties["secrets"] = []
|
if "secrets" not in self.container_properties:
|
||||||
|
self.container_properties["secrets"] = []
|
||||||
|
|
||||||
self._validate()
|
self._validate()
|
||||||
self.revision += 1
|
self.revision += 1
|
||||||
@ -306,26 +310,28 @@ class JobDefinition(CloudFormationModel):
|
|||||||
|
|
||||||
def _validate(self) -> None:
|
def _validate(self) -> None:
|
||||||
# For future use when containers arnt the only thing in batch
|
# For future use when containers arnt the only thing in batch
|
||||||
if self.type not in ("container",):
|
VALID_TYPES = ("container", "multinode")
|
||||||
raise ClientException('type must be one of "container"')
|
if self.type not in VALID_TYPES:
|
||||||
|
raise ClientException(f"type must be one of {VALID_TYPES}")
|
||||||
|
|
||||||
if not isinstance(self.parameters, dict):
|
if not isinstance(self.parameters, dict):
|
||||||
raise ClientException("parameters must be a string to string map")
|
raise ClientException("parameters must be a string to string map")
|
||||||
|
|
||||||
if "image" not in self.container_properties:
|
if self.type == "container":
|
||||||
raise ClientException("containerProperties must contain image")
|
if "image" not in self.container_properties:
|
||||||
|
raise ClientException("containerProperties must contain image")
|
||||||
|
|
||||||
memory = self._get_resource_requirement("memory")
|
memory = self._get_resource_requirement("memory")
|
||||||
if memory is None:
|
if memory is None:
|
||||||
raise ClientException("containerProperties must contain memory")
|
raise ClientException("containerProperties must contain memory")
|
||||||
if memory < 4:
|
if memory < 4:
|
||||||
raise ClientException("container memory limit must be greater than 4")
|
raise ClientException("container memory limit must be greater than 4")
|
||||||
|
|
||||||
vcpus = self._get_resource_requirement("vcpus")
|
vcpus = self._get_resource_requirement("vcpus")
|
||||||
if vcpus is None:
|
if vcpus is None:
|
||||||
raise ClientException("containerProperties must contain vcpus")
|
raise ClientException("containerProperties must contain vcpus")
|
||||||
if vcpus <= 0:
|
if vcpus <= 0:
|
||||||
raise ClientException("container vcpus limit must be greater than 0")
|
raise ClientException("container vcpus limit must be greater than 0")
|
||||||
|
|
||||||
def deregister(self) -> None:
|
def deregister(self) -> None:
|
||||||
self.status = "INACTIVE"
|
self.status = "INACTIVE"
|
||||||
@ -335,6 +341,7 @@ class JobDefinition(CloudFormationModel):
|
|||||||
parameters: Optional[Dict[str, Any]],
|
parameters: Optional[Dict[str, Any]],
|
||||||
_type: str,
|
_type: str,
|
||||||
container_properties: Dict[str, Any],
|
container_properties: Dict[str, Any],
|
||||||
|
node_properties: Dict[str, Any],
|
||||||
retry_strategy: Dict[str, Any],
|
retry_strategy: Dict[str, Any],
|
||||||
tags: Dict[str, str],
|
tags: Dict[str, str],
|
||||||
timeout: Dict[str, int],
|
timeout: Dict[str, int],
|
||||||
@ -357,6 +364,7 @@ class JobDefinition(CloudFormationModel):
|
|||||||
parameters,
|
parameters,
|
||||||
_type,
|
_type,
|
||||||
container_properties,
|
container_properties,
|
||||||
|
node_properties=node_properties,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
retry_strategy=retry_strategy,
|
retry_strategy=retry_strategy,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@ -416,7 +424,16 @@ class JobDefinition(CloudFormationModel):
|
|||||||
_type="container",
|
_type="container",
|
||||||
tags=lowercase_first_key(properties.get("Tags", {})),
|
tags=lowercase_first_key(properties.get("Tags", {})),
|
||||||
retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
|
retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
|
||||||
container_properties=lowercase_first_key(properties["ContainerProperties"]),
|
container_properties=(
|
||||||
|
lowercase_first_key(properties["ContainerProperties"])
|
||||||
|
if "ContainerProperties" in properties
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
node_properties=(
|
||||||
|
lowercase_first_key(properties["NodeProperties"])
|
||||||
|
if "NodeProperties" in properties
|
||||||
|
else None
|
||||||
|
),
|
||||||
timeout=lowercase_first_key(properties.get("timeout", {})),
|
timeout=lowercase_first_key(properties.get("timeout", {})),
|
||||||
platform_capabilities=None,
|
platform_capabilities=None,
|
||||||
propagate_tags=None,
|
propagate_tags=None,
|
||||||
@ -466,6 +483,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.all_jobs = all_jobs
|
self.all_jobs = all_jobs
|
||||||
|
|
||||||
|
self.arn = make_arn_for_job(
|
||||||
|
job_def.backend.account_id, self.job_id, job_def._region
|
||||||
|
)
|
||||||
|
|
||||||
self.stop = False
|
self.stop = False
|
||||||
self.exit_code: Optional[int] = None
|
self.exit_code: Optional[int] = None
|
||||||
|
|
||||||
@ -483,6 +504,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
def describe_short(self) -> Dict[str, Any]:
|
def describe_short(self) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"jobId": self.job_id,
|
"jobId": self.job_id,
|
||||||
|
"jobArn": self.arn,
|
||||||
"jobName": self.job_name,
|
"jobName": self.job_name,
|
||||||
"createdAt": datetime2int_milliseconds(self.job_created_at),
|
"createdAt": datetime2int_milliseconds(self.job_created_at),
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
@ -502,7 +524,13 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
result = self.describe_short()
|
result = self.describe_short()
|
||||||
result["jobQueue"] = self.job_queue.arn
|
result["jobQueue"] = self.job_queue.arn
|
||||||
result["dependsOn"] = self.depends_on or []
|
result["dependsOn"] = self.depends_on or []
|
||||||
result["container"] = self._container_details()
|
if self.job_definition.type == "container":
|
||||||
|
result["container"] = self._container_details()
|
||||||
|
elif self.job_definition.type == "multinode":
|
||||||
|
result["container"] = {
|
||||||
|
"logStreamName": self.log_stream_name,
|
||||||
|
}
|
||||||
|
result["nodeProperties"] = self.job_definition.node_properties
|
||||||
if self.job_stopped:
|
if self.job_stopped:
|
||||||
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
|
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
|
||||||
if self.timeout:
|
if self.timeout:
|
||||||
@ -575,6 +603,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
try:
|
try:
|
||||||
import docker
|
import docker
|
||||||
|
|
||||||
|
containers: List[docker.models.containers.Container] = []
|
||||||
|
|
||||||
self.advance()
|
self.advance()
|
||||||
while self.status == "SUBMITTED":
|
while self.status == "SUBMITTED":
|
||||||
# Wait until we've moved onto state 'PENDING'
|
# Wait until we've moved onto state 'PENDING'
|
||||||
@ -585,34 +615,87 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
if self.depends_on and not self._wait_for_dependencies():
|
if self.depends_on and not self._wait_for_dependencies():
|
||||||
return
|
return
|
||||||
|
|
||||||
image = self.job_definition.container_properties.get(
|
container_kwargs = []
|
||||||
"image", "alpine:latest"
|
if self.job_definition.container_properties:
|
||||||
)
|
volumes = {
|
||||||
privileged = self.job_definition.container_properties.get(
|
v["name"]: v["host"]
|
||||||
"privileged", False
|
for v in self._get_container_property("volumes", [])
|
||||||
)
|
}
|
||||||
cmd = self._get_container_property(
|
container_kwargs.append(
|
||||||
"command",
|
{
|
||||||
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
|
"image": self.job_definition.container_properties.get(
|
||||||
)
|
"image", "alpine:latest"
|
||||||
environment = {
|
),
|
||||||
e["name"]: e["value"]
|
"privileged": self.job_definition.container_properties.get(
|
||||||
for e in self._get_container_property("environment", [])
|
"privileged", False
|
||||||
}
|
),
|
||||||
volumes = {
|
"command": self._get_container_property(
|
||||||
v["name"]: v["host"]
|
"command",
|
||||||
for v in self._get_container_property("volumes", [])
|
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
|
||||||
}
|
),
|
||||||
mounts = [
|
"environment": {
|
||||||
docker.types.Mount(
|
e["name"]: e["value"]
|
||||||
m["containerPath"],
|
for e in self._get_container_property("environment", [])
|
||||||
volumes[m["sourceVolume"]]["sourcePath"],
|
},
|
||||||
type="bind",
|
"mounts": [
|
||||||
read_only=m["readOnly"],
|
docker.types.Mount(
|
||||||
|
m["containerPath"],
|
||||||
|
volumes[m["sourceVolume"]]["sourcePath"],
|
||||||
|
type="bind",
|
||||||
|
read_only=m["readOnly"],
|
||||||
|
)
|
||||||
|
for m in self._get_container_property("mountPoints", [])
|
||||||
|
],
|
||||||
|
"name": f"{self.job_name}-{self.job_id}",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
for m in self._get_container_property("mountPoints", [])
|
else:
|
||||||
]
|
node_properties = self.job_definition.node_properties
|
||||||
name = f"{self.job_name}-{self.job_id}"
|
num_nodes = node_properties["numNodes"]
|
||||||
|
node_containers = {}
|
||||||
|
for node_range in node_properties["nodeRangeProperties"]:
|
||||||
|
start, sep, end = node_range["targetNodes"].partition(":")
|
||||||
|
if sep == "":
|
||||||
|
start = end = int(start)
|
||||||
|
else:
|
||||||
|
if start == "":
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = int(start)
|
||||||
|
if end == "":
|
||||||
|
end = num_nodes - 1
|
||||||
|
else:
|
||||||
|
end = int(end)
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
node_containers[i] = node_range["container"]
|
||||||
|
|
||||||
|
for i in range(num_nodes):
|
||||||
|
spec = node_containers[i]
|
||||||
|
volumes = {v["name"]: v["host"] for v in spec.get("volumes", [])}
|
||||||
|
container_kwargs.append(
|
||||||
|
{
|
||||||
|
"image": spec.get("image", "alpine:latest"),
|
||||||
|
"privileged": spec.get("privileged", False),
|
||||||
|
"command": spec.get(
|
||||||
|
"command",
|
||||||
|
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
|
||||||
|
),
|
||||||
|
"environment": {
|
||||||
|
e["name"]: e["value"]
|
||||||
|
for e in spec.get("environment", [])
|
||||||
|
},
|
||||||
|
"mounts": [
|
||||||
|
docker.types.Mount(
|
||||||
|
m["containerPath"],
|
||||||
|
volumes[m["sourceVolume"]]["sourcePath"],
|
||||||
|
type="bind",
|
||||||
|
read_only=m["readOnly"],
|
||||||
|
)
|
||||||
|
for m in spec.get("mountPoints", [])
|
||||||
|
],
|
||||||
|
"name": f"{self.job_name}-{self.job_id}-{i}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.advance()
|
self.advance()
|
||||||
while self.status == "PENDING":
|
while self.status == "PENDING":
|
||||||
@ -632,19 +715,20 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
environment["MOTO_HOST"] = settings.moto_server_host()
|
for kwargs in container_kwargs:
|
||||||
environment["MOTO_PORT"] = settings.moto_server_port()
|
environment = kwargs["environment"]
|
||||||
environment[
|
environment["MOTO_HOST"] = settings.moto_server_host()
|
||||||
"MOTO_HTTP_ENDPOINT"
|
environment["MOTO_PORT"] = settings.moto_server_port()
|
||||||
] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}'
|
environment[
|
||||||
|
"MOTO_HTTP_ENDPOINT"
|
||||||
|
] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}'
|
||||||
|
|
||||||
run_kwargs = dict()
|
network_name = settings.moto_network_name()
|
||||||
network_name = settings.moto_network_name()
|
network_mode = settings.moto_network_mode()
|
||||||
network_mode = settings.moto_network_mode()
|
if network_name:
|
||||||
if network_name:
|
kwargs["network"] = network_name
|
||||||
run_kwargs["network"] = network_name
|
elif network_mode:
|
||||||
elif network_mode:
|
kwargs["network_mode"] = network_mode
|
||||||
run_kwargs["network_mode"] = network_mode
|
|
||||||
|
|
||||||
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
|
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
|
||||||
self.advance()
|
self.advance()
|
||||||
@ -656,107 +740,122 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
|||||||
while self.status == "STARTING":
|
while self.status == "STARTING":
|
||||||
# Wait until the state is no longer runnable, but 'RUNNING'
|
# Wait until the state is no longer runnable, but 'RUNNING'
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
container = self.docker_client.containers.run(
|
|
||||||
image,
|
for kwargs in container_kwargs:
|
||||||
cmd,
|
if len(containers) > 0:
|
||||||
detach=True,
|
env = kwargs["environment"]
|
||||||
name=name,
|
ip = containers[0].attrs["NetworkSettings"]["IPAddress"]
|
||||||
log_config=log_config,
|
env["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"] = ip
|
||||||
environment=environment,
|
container = self.docker_client.containers.run(
|
||||||
mounts=mounts,
|
detach=True,
|
||||||
privileged=privileged,
|
log_config=log_config,
|
||||||
extra_hosts=extra_hosts,
|
extra_hosts=extra_hosts,
|
||||||
**run_kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
container.reload()
|
container.reload()
|
||||||
|
containers.append(container)
|
||||||
|
|
||||||
max_time = None
|
for i, container in enumerate(containers):
|
||||||
if self._get_attempt_duration():
|
try:
|
||||||
attempt_duration = self._get_attempt_duration()
|
|
||||||
max_time = self.job_started_at + datetime.timedelta(
|
|
||||||
seconds=attempt_duration # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
while container.status == "running" and not self.stop:
|
|
||||||
container.reload()
|
container.reload()
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
if max_time and datetime.datetime.now() > max_time:
|
max_time = None
|
||||||
raise Exception(
|
if self._get_attempt_duration():
|
||||||
"Job time exceeded the configured attemptDurationSeconds"
|
attempt_duration = self._get_attempt_duration()
|
||||||
|
max_time = self.job_started_at + datetime.timedelta(
|
||||||
|
seconds=attempt_duration # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Container should be stopped by this point... unless asked to stop
|
while container.status == "running" and not self.stop:
|
||||||
if container.status == "running":
|
container.reload()
|
||||||
container.kill()
|
time.sleep(0.5)
|
||||||
|
|
||||||
# Log collection
|
if max_time and datetime.datetime.now() > max_time:
|
||||||
logs_stdout = []
|
raise Exception(
|
||||||
logs_stderr = []
|
"Job time exceeded the configured attemptDurationSeconds"
|
||||||
logs_stderr.extend(
|
)
|
||||||
container.logs(
|
|
||||||
stdout=False,
|
# Container should be stopped by this point... unless asked to stop
|
||||||
stderr=True,
|
if container.status == "running":
|
||||||
timestamps=True,
|
container.kill()
|
||||||
since=datetime2int(self.job_started_at),
|
|
||||||
|
# Log collection
|
||||||
|
logs_stdout = []
|
||||||
|
logs_stderr = []
|
||||||
|
logs_stderr.extend(
|
||||||
|
container.logs(
|
||||||
|
stdout=False,
|
||||||
|
stderr=True,
|
||||||
|
timestamps=True,
|
||||||
|
since=datetime2int(self.job_started_at),
|
||||||
|
)
|
||||||
|
.decode()
|
||||||
|
.split("\n")
|
||||||
)
|
)
|
||||||
.decode()
|
logs_stdout.extend(
|
||||||
.split("\n")
|
container.logs(
|
||||||
)
|
stdout=True,
|
||||||
logs_stdout.extend(
|
stderr=False,
|
||||||
container.logs(
|
timestamps=True,
|
||||||
stdout=True,
|
since=datetime2int(self.job_started_at),
|
||||||
stderr=False,
|
)
|
||||||
timestamps=True,
|
.decode()
|
||||||
since=datetime2int(self.job_started_at),
|
.split("\n")
|
||||||
)
|
)
|
||||||
.decode()
|
|
||||||
.split("\n")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process logs
|
# Process logs
|
||||||
logs_stdout = [x for x in logs_stdout if len(x) > 0]
|
logs_stdout = [x for x in logs_stdout if len(x) > 0]
|
||||||
logs_stderr = [x for x in logs_stderr if len(x) > 0]
|
logs_stderr = [x for x in logs_stderr if len(x) > 0]
|
||||||
logs = []
|
logs = []
|
||||||
for line in logs_stdout + logs_stderr:
|
for line in logs_stdout + logs_stderr:
|
||||||
date, line = line.split(" ", 1)
|
date, line = line.split(" ", 1)
|
||||||
date_obj = (
|
date_obj = (
|
||||||
dateutil.parser.parse(date)
|
dateutil.parser.parse(date)
|
||||||
.astimezone(datetime.timezone.utc)
|
.astimezone(datetime.timezone.utc)
|
||||||
.replace(tzinfo=None)
|
.replace(tzinfo=None)
|
||||||
|
)
|
||||||
|
date = unix_time_millis(date_obj)
|
||||||
|
logs.append({"timestamp": date, "message": line.strip()})
|
||||||
|
logs = sorted(logs, key=lambda log: log["timestamp"])
|
||||||
|
|
||||||
|
# Send to cloudwatch
|
||||||
|
self.log_stream_name = self._stream_name
|
||||||
|
self._log_backend.ensure_log_group(self._log_group, None)
|
||||||
|
self._log_backend.ensure_log_stream(
|
||||||
|
self._log_group, self.log_stream_name
|
||||||
|
)
|
||||||
|
self._log_backend.put_log_events(
|
||||||
|
self._log_group, self.log_stream_name, logs
|
||||||
)
|
)
|
||||||
date = unix_time_millis(date_obj)
|
|
||||||
logs.append({"timestamp": date, "message": line.strip()})
|
|
||||||
logs = sorted(logs, key=lambda log: log["timestamp"])
|
|
||||||
|
|
||||||
# Send to cloudwatch
|
result = container.wait() or {}
|
||||||
self.log_stream_name = self._stream_name
|
exit_code = result.get("StatusCode", 0)
|
||||||
self._log_backend.ensure_log_group(self._log_group, None)
|
self.exit_code = exit_code
|
||||||
self._log_backend.create_log_stream(self._log_group, self._stream_name)
|
job_failed = self.stop or exit_code > 0
|
||||||
self._log_backend.put_log_events(
|
if job_failed:
|
||||||
self._log_group, self._stream_name, logs
|
self._mark_stopped(success=False)
|
||||||
)
|
break
|
||||||
|
|
||||||
result = container.wait() or {}
|
except Exception as err:
|
||||||
exit_code = result.get("StatusCode", 0)
|
logger.error(
|
||||||
self.exit_code = exit_code
|
f"Failed to run AWS Batch container {self.name}. Error {err}"
|
||||||
job_failed = self.stop or exit_code > 0
|
)
|
||||||
self._mark_stopped(success=not job_failed)
|
self._mark_stopped(success=False)
|
||||||
|
|
||||||
except Exception as err:
|
self._mark_stopped(success=True)
|
||||||
logger.error(
|
|
||||||
f"Failed to run AWS Batch container {self.name}. Error {err}"
|
|
||||||
)
|
|
||||||
self._mark_stopped(success=False)
|
|
||||||
container.kill()
|
|
||||||
finally:
|
|
||||||
container.remove()
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}")
|
logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}")
|
||||||
self._mark_stopped(success=False)
|
self._mark_stopped(success=False)
|
||||||
|
finally:
|
||||||
|
for container in containers:
|
||||||
|
container.reload()
|
||||||
|
if container.status == "running":
|
||||||
|
container.kill()
|
||||||
|
container.remove()
|
||||||
|
|
||||||
def _mark_stopped(self, success: bool = True) -> None:
|
def _mark_stopped(self, success: bool = True) -> None:
|
||||||
|
if self.job_stopped:
|
||||||
|
return
|
||||||
# Ensure that job_stopped/job_stopped_at-attributes are set first
|
# Ensure that job_stopped/job_stopped_at-attributes are set first
|
||||||
# The describe-method needs them immediately when status is set
|
# The describe-method needs them immediately when status is set
|
||||||
self.job_stopped = True
|
self.job_stopped = True
|
||||||
@ -1437,6 +1536,7 @@ class BatchBackend(BaseBackend):
|
|||||||
tags: Dict[str, str],
|
tags: Dict[str, str],
|
||||||
retry_strategy: Dict[str, Any],
|
retry_strategy: Dict[str, Any],
|
||||||
container_properties: Dict[str, Any],
|
container_properties: Dict[str, Any],
|
||||||
|
node_properties: Dict[str, Any],
|
||||||
timeout: Dict[str, int],
|
timeout: Dict[str, int],
|
||||||
platform_capabilities: List[str],
|
platform_capabilities: List[str],
|
||||||
propagate_tags: bool,
|
propagate_tags: bool,
|
||||||
@ -1457,6 +1557,7 @@ class BatchBackend(BaseBackend):
|
|||||||
parameters,
|
parameters,
|
||||||
_type,
|
_type,
|
||||||
container_properties,
|
container_properties,
|
||||||
|
node_properties=node_properties,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
retry_strategy=retry_strategy,
|
retry_strategy=retry_strategy,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@ -1467,7 +1568,13 @@ class BatchBackend(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
# Make new jobdef
|
# Make new jobdef
|
||||||
job_def = job_def.update(
|
job_def = job_def.update(
|
||||||
parameters, _type, container_properties, retry_strategy, tags, timeout
|
parameters,
|
||||||
|
_type,
|
||||||
|
container_properties,
|
||||||
|
node_properties,
|
||||||
|
retry_strategy,
|
||||||
|
tags,
|
||||||
|
timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._job_definitions[job_def.arn] = job_def
|
self._job_definitions[job_def.arn] = job_def
|
||||||
@ -1562,7 +1669,11 @@ class BatchBackend(BaseBackend):
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
for key, job in self._jobs.items():
|
for key, job in self._jobs.items():
|
||||||
if len(job_filter) > 0 and key not in job_filter:
|
if (
|
||||||
|
len(job_filter) > 0
|
||||||
|
and key not in job_filter
|
||||||
|
and job.arn not in job_filter
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result.append(job.describe())
|
result.append(job.describe())
|
||||||
@ -1570,7 +1681,10 @@ class BatchBackend(BaseBackend):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def list_jobs(
|
def list_jobs(
|
||||||
self, job_queue_name: str, job_status: Optional[str] = None
|
self,
|
||||||
|
job_queue_name: str,
|
||||||
|
job_status: Optional[str] = None,
|
||||||
|
filters: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> List[Job]:
|
) -> List[Job]:
|
||||||
"""
|
"""
|
||||||
Pagination is not yet implemented
|
Pagination is not yet implemented
|
||||||
@ -1598,6 +1712,18 @@ class BatchBackend(BaseBackend):
|
|||||||
if job_status is not None and job.status != job_status:
|
if job_status is not None and job.status != job_status:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if filters is not None:
|
||||||
|
matches = True
|
||||||
|
for filt in filters:
|
||||||
|
name = filt["name"]
|
||||||
|
values = filt["values"]
|
||||||
|
if name == "JOB_NAME":
|
||||||
|
if job.job_name not in values:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
if not matches:
|
||||||
|
continue
|
||||||
|
|
||||||
jobs.append(job)
|
jobs.append(job)
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
@ -134,6 +134,7 @@ class BatchResponse(BaseResponse):
|
|||||||
# RegisterJobDefinition
|
# RegisterJobDefinition
|
||||||
def registerjobdefinition(self) -> str:
|
def registerjobdefinition(self) -> str:
|
||||||
container_properties = self._get_param("containerProperties")
|
container_properties = self._get_param("containerProperties")
|
||||||
|
node_properties = self._get_param("nodeProperties")
|
||||||
def_name = self._get_param("jobDefinitionName")
|
def_name = self._get_param("jobDefinitionName")
|
||||||
parameters = self._get_param("parameters")
|
parameters = self._get_param("parameters")
|
||||||
tags = self._get_param("tags")
|
tags = self._get_param("tags")
|
||||||
@ -149,6 +150,7 @@ class BatchResponse(BaseResponse):
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
retry_strategy=retry_strategy,
|
retry_strategy=retry_strategy,
|
||||||
container_properties=container_properties,
|
container_properties=container_properties,
|
||||||
|
node_properties=node_properties,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
platform_capabilities=platform_capabilities,
|
platform_capabilities=platform_capabilities,
|
||||||
propagate_tags=propagate_tags,
|
propagate_tags=propagate_tags,
|
||||||
@ -215,8 +217,9 @@ class BatchResponse(BaseResponse):
|
|||||||
def listjobs(self) -> str:
|
def listjobs(self) -> str:
|
||||||
job_queue = self._get_param("jobQueue")
|
job_queue = self._get_param("jobQueue")
|
||||||
job_status = self._get_param("jobStatus")
|
job_status = self._get_param("jobStatus")
|
||||||
|
filters = self._get_param("filters")
|
||||||
|
|
||||||
jobs = self.batch_backend.list_jobs(job_queue, job_status)
|
jobs = self.batch_backend.list_jobs(job_queue, job_status, filters)
|
||||||
|
|
||||||
result = {"jobSummaryList": [job.describe_short() for job in jobs]}
|
result = {"jobSummaryList": [job.describe_short() for job in jobs]}
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
@ -9,6 +9,10 @@ def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str:
|
|||||||
return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}"
|
return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}"
|
||||||
|
|
||||||
|
|
||||||
|
def make_arn_for_job(account_id: str, job_id: str, region_name: str) -> str:
|
||||||
|
return f"arn:aws:batch:{region_name}:{account_id}:job/{job_id}"
|
||||||
|
|
||||||
|
|
||||||
def make_arn_for_task_def(
|
def make_arn_for_task_def(
|
||||||
account_id: str, name: str, revision: int, region_name: str
|
account_id: str, name: str, revision: int, region_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -649,6 +649,15 @@ class LogsBackend(BaseBackend):
|
|||||||
log_group = self.groups[log_group_name]
|
log_group = self.groups[log_group_name]
|
||||||
return log_group.create_log_stream(log_stream_name)
|
return log_group.create_log_stream(log_stream_name)
|
||||||
|
|
||||||
|
def ensure_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
|
||||||
|
if log_group_name not in self.groups:
|
||||||
|
raise ResourceNotFoundException()
|
||||||
|
|
||||||
|
if log_stream_name in self.groups[log_group_name].streams:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.create_log_stream(log_group_name, log_stream_name)
|
||||||
|
|
||||||
def delete_log_stream(self, log_group_name, log_stream_name):
|
def delete_log_stream(self, log_group_name, log_stream_name):
|
||||||
if log_group_name not in self.groups:
|
if log_group_name not in self.groups:
|
||||||
raise ResourceNotFoundException()
|
raise ResourceNotFoundException()
|
||||||
|
@ -151,6 +151,72 @@ def test_submit_job():
|
|||||||
attempt.should.have.key("stoppedAt").equals(stopped_at)
|
attempt.should.have.key("stoppedAt").equals(stopped_at)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_logs
|
||||||
|
@mock_ec2
|
||||||
|
@mock_ecs
|
||||||
|
@mock_iam
|
||||||
|
@mock_batch
|
||||||
|
@pytest.mark.network
|
||||||
|
def test_submit_job_multinode():
|
||||||
|
ec2_client, iam_client, _, logs_client, batch_client = _get_clients()
|
||||||
|
_, _, _, iam_arn = _setup(ec2_client, iam_client)
|
||||||
|
start_time_milliseconds = time.time() * 1000
|
||||||
|
|
||||||
|
job_def_name = str(uuid4())[0:6]
|
||||||
|
commands = ["echo", "hello"]
|
||||||
|
job_def_arn, queue_arn = prepare_multinode_job(
|
||||||
|
batch_client, commands, iam_arn, job_def_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = batch_client.submit_job(
|
||||||
|
jobName=str(uuid4())[0:6], jobQueue=queue_arn, jobDefinition=job_def_arn
|
||||||
|
)
|
||||||
|
job_id = resp["jobId"]
|
||||||
|
|
||||||
|
# Test that describe_jobs() returns 'createdAt'
|
||||||
|
# github.com/getmoto/moto/issues/4364
|
||||||
|
resp = batch_client.describe_jobs(jobs=[job_id])
|
||||||
|
created_at = resp["jobs"][0]["createdAt"]
|
||||||
|
created_at.should.be.greater_than(start_time_milliseconds)
|
||||||
|
|
||||||
|
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
|
||||||
|
|
||||||
|
resp = logs_client.describe_log_streams(
|
||||||
|
logGroupName="/aws/batch/job", logStreamNamePrefix=job_def_name
|
||||||
|
)
|
||||||
|
resp["logStreams"].should.have.length_of(1)
|
||||||
|
ls_name = resp["logStreams"][0]["logStreamName"]
|
||||||
|
|
||||||
|
resp = logs_client.get_log_events(
|
||||||
|
logGroupName="/aws/batch/job", logStreamName=ls_name
|
||||||
|
)
|
||||||
|
[event["message"] for event in resp["events"]].should.equal(["hello", "hello"])
|
||||||
|
|
||||||
|
# Test that describe_jobs() returns timestamps in milliseconds
|
||||||
|
# github.com/getmoto/moto/issues/4364
|
||||||
|
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
|
||||||
|
created_at = job["createdAt"]
|
||||||
|
started_at = job["startedAt"]
|
||||||
|
stopped_at = job["stoppedAt"]
|
||||||
|
|
||||||
|
created_at.should.be.greater_than(start_time_milliseconds)
|
||||||
|
started_at.should.be.greater_than(start_time_milliseconds)
|
||||||
|
stopped_at.should.be.greater_than(start_time_milliseconds)
|
||||||
|
|
||||||
|
# Verify we track attempts
|
||||||
|
job.should.have.key("attempts").length_of(1)
|
||||||
|
attempt = job["attempts"][0]
|
||||||
|
attempt.should.have.key("container")
|
||||||
|
attempt["container"].should.have.key("containerInstanceArn")
|
||||||
|
attempt["container"].should.have.key("logStreamName").equals(
|
||||||
|
job["container"]["logStreamName"]
|
||||||
|
)
|
||||||
|
attempt["container"].should.have.key("networkInterfaces")
|
||||||
|
attempt["container"].should.have.key("taskArn")
|
||||||
|
attempt.should.have.key("startedAt").equals(started_at)
|
||||||
|
attempt.should.have.key("stoppedAt").equals(stopped_at)
|
||||||
|
|
||||||
|
|
||||||
@mock_logs
|
@mock_logs
|
||||||
@mock_ec2
|
@mock_ec2
|
||||||
@mock_ecs
|
@mock_ecs
|
||||||
@ -205,6 +271,18 @@ def test_list_jobs():
|
|||||||
job.should.have.key("stoppedAt")
|
job.should.have.key("stoppedAt")
|
||||||
job.should.have.key("container").should.have.key("exitCode").equals(0)
|
job.should.have.key("container").should.have.key("exitCode").equals(0)
|
||||||
|
|
||||||
|
filtered_jobs = batch_client.list_jobs(
|
||||||
|
jobQueue=queue_arn,
|
||||||
|
filters=[
|
||||||
|
{
|
||||||
|
"name": "JOB_NAME",
|
||||||
|
"values": ["test2"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)["jobSummaryList"]
|
||||||
|
filtered_jobs.should.have.length_of(1)
|
||||||
|
filtered_jobs[0]["jobName"].should.equal("test2")
|
||||||
|
|
||||||
|
|
||||||
@mock_logs
|
@mock_logs
|
||||||
@mock_ec2
|
@mock_ec2
|
||||||
@ -772,6 +850,51 @@ def prepare_job(batch_client, commands, iam_arn, job_def_name):
|
|||||||
return job_def_arn, queue_arn
|
return job_def_arn, queue_arn
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_multinode_job(batch_client, commands, iam_arn, job_def_name):
|
||||||
|
compute_name = str(uuid4())[0:6]
|
||||||
|
resp = batch_client.create_compute_environment(
|
||||||
|
computeEnvironmentName=compute_name,
|
||||||
|
type="UNMANAGED",
|
||||||
|
state="ENABLED",
|
||||||
|
serviceRole=iam_arn,
|
||||||
|
)
|
||||||
|
arn = resp["computeEnvironmentArn"]
|
||||||
|
|
||||||
|
resp = batch_client.create_job_queue(
|
||||||
|
jobQueueName=str(uuid4())[0:6],
|
||||||
|
state="ENABLED",
|
||||||
|
priority=123,
|
||||||
|
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
|
||||||
|
)
|
||||||
|
queue_arn = resp["jobQueueArn"]
|
||||||
|
container = {
|
||||||
|
"image": "busybox:latest",
|
||||||
|
"vcpus": 1,
|
||||||
|
"memory": 128,
|
||||||
|
"command": commands,
|
||||||
|
}
|
||||||
|
resp = batch_client.register_job_definition(
|
||||||
|
jobDefinitionName=job_def_name,
|
||||||
|
type="multinode",
|
||||||
|
nodeProperties={
|
||||||
|
"mainNode": 0,
|
||||||
|
"numNodes": 2,
|
||||||
|
"nodeRangeProperties": [
|
||||||
|
{
|
||||||
|
"container": container,
|
||||||
|
"targetNodes": "0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"container": container,
|
||||||
|
"targetNodes": "1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
job_def_arn = resp["jobDefinitionArn"]
|
||||||
|
return job_def_arn, queue_arn
|
||||||
|
|
||||||
|
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_update_job_definition():
|
def test_update_job_definition():
|
||||||
_, _, _, _, batch_client = _get_clients()
|
_, _, _, _, batch_client = _get_clients()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user