Batch: add multinode support (#5840)

This commit is contained in:
Tristan Rice 2023-01-14 08:02:32 -08:00 committed by GitHub
parent 173e1549c0
commit a17956927f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 410 additions and 145 deletions

View File

@ -21,6 +21,7 @@ from .exceptions import InvalidParameterValueException, ClientException, Validat
from .utils import ( from .utils import (
make_arn_for_compute_env, make_arn_for_compute_env,
make_arn_for_job_queue, make_arn_for_job_queue,
make_arn_for_job,
make_arn_for_task_def, make_arn_for_task_def,
lowercase_first_key, lowercase_first_key,
) )
@ -221,6 +222,7 @@ class JobDefinition(CloudFormationModel):
parameters: Optional[Dict[str, Any]], parameters: Optional[Dict[str, Any]],
_type: str, _type: str,
container_properties: Dict[str, Any], container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
tags: Dict[str, str], tags: Dict[str, str],
retry_strategy: Dict[str, str], retry_strategy: Dict[str, str],
timeout: Dict[str, int], timeout: Dict[str, int],
@ -235,6 +237,7 @@ class JobDefinition(CloudFormationModel):
self.revision = revision or 0 self.revision = revision or 0
self._region = backend.region_name self._region = backend.region_name
self.container_properties = container_properties self.container_properties = container_properties
self.node_properties = node_properties
self.status = "ACTIVE" self.status = "ACTIVE"
self.parameters = parameters or {} self.parameters = parameters or {}
self.timeout = timeout self.timeout = timeout
@ -242,10 +245,11 @@ class JobDefinition(CloudFormationModel):
self.platform_capabilities = platform_capabilities self.platform_capabilities = platform_capabilities
self.propagate_tags = propagate_tags self.propagate_tags = propagate_tags
if "resourceRequirements" not in self.container_properties: if self.container_properties is not None:
self.container_properties["resourceRequirements"] = [] if "resourceRequirements" not in self.container_properties:
if "secrets" not in self.container_properties: self.container_properties["resourceRequirements"] = []
self.container_properties["secrets"] = [] if "secrets" not in self.container_properties:
self.container_properties["secrets"] = []
self._validate() self._validate()
self.revision += 1 self.revision += 1
@ -306,26 +310,28 @@ class JobDefinition(CloudFormationModel):
def _validate(self) -> None: def _validate(self) -> None:
# For future use when containers arnt the only thing in batch # For future use when containers arnt the only thing in batch
if self.type not in ("container",): VALID_TYPES = ("container", "multinode")
raise ClientException('type must be one of "container"') if self.type not in VALID_TYPES:
raise ClientException(f"type must be one of {VALID_TYPES}")
if not isinstance(self.parameters, dict): if not isinstance(self.parameters, dict):
raise ClientException("parameters must be a string to string map") raise ClientException("parameters must be a string to string map")
if "image" not in self.container_properties: if self.type == "container":
raise ClientException("containerProperties must contain image") if "image" not in self.container_properties:
raise ClientException("containerProperties must contain image")
memory = self._get_resource_requirement("memory") memory = self._get_resource_requirement("memory")
if memory is None: if memory is None:
raise ClientException("containerProperties must contain memory") raise ClientException("containerProperties must contain memory")
if memory < 4: if memory < 4:
raise ClientException("container memory limit must be greater than 4") raise ClientException("container memory limit must be greater than 4")
vcpus = self._get_resource_requirement("vcpus") vcpus = self._get_resource_requirement("vcpus")
if vcpus is None: if vcpus is None:
raise ClientException("containerProperties must contain vcpus") raise ClientException("containerProperties must contain vcpus")
if vcpus <= 0: if vcpus <= 0:
raise ClientException("container vcpus limit must be greater than 0") raise ClientException("container vcpus limit must be greater than 0")
def deregister(self) -> None: def deregister(self) -> None:
self.status = "INACTIVE" self.status = "INACTIVE"
@ -335,6 +341,7 @@ class JobDefinition(CloudFormationModel):
parameters: Optional[Dict[str, Any]], parameters: Optional[Dict[str, Any]],
_type: str, _type: str,
container_properties: Dict[str, Any], container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
retry_strategy: Dict[str, Any], retry_strategy: Dict[str, Any],
tags: Dict[str, str], tags: Dict[str, str],
timeout: Dict[str, int], timeout: Dict[str, int],
@ -357,6 +364,7 @@ class JobDefinition(CloudFormationModel):
parameters, parameters,
_type, _type,
container_properties, container_properties,
node_properties=node_properties,
revision=self.revision, revision=self.revision,
retry_strategy=retry_strategy, retry_strategy=retry_strategy,
tags=tags, tags=tags,
@ -416,7 +424,16 @@ class JobDefinition(CloudFormationModel):
_type="container", _type="container",
tags=lowercase_first_key(properties.get("Tags", {})), tags=lowercase_first_key(properties.get("Tags", {})),
retry_strategy=lowercase_first_key(properties["RetryStrategy"]), retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
container_properties=lowercase_first_key(properties["ContainerProperties"]), container_properties=(
lowercase_first_key(properties["ContainerProperties"])
if "ContainerProperties" in properties
else None
),
node_properties=(
lowercase_first_key(properties["NodeProperties"])
if "NodeProperties" in properties
else None
),
timeout=lowercase_first_key(properties.get("timeout", {})), timeout=lowercase_first_key(properties.get("timeout", {})),
platform_capabilities=None, platform_capabilities=None,
propagate_tags=None, propagate_tags=None,
@ -466,6 +483,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.timeout = timeout self.timeout = timeout
self.all_jobs = all_jobs self.all_jobs = all_jobs
self.arn = make_arn_for_job(
job_def.backend.account_id, self.job_id, job_def._region
)
self.stop = False self.stop = False
self.exit_code: Optional[int] = None self.exit_code: Optional[int] = None
@ -483,6 +504,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
def describe_short(self) -> Dict[str, Any]: def describe_short(self) -> Dict[str, Any]:
result = { result = {
"jobId": self.job_id, "jobId": self.job_id,
"jobArn": self.arn,
"jobName": self.job_name, "jobName": self.job_name,
"createdAt": datetime2int_milliseconds(self.job_created_at), "createdAt": datetime2int_milliseconds(self.job_created_at),
"status": self.status, "status": self.status,
@ -502,7 +524,13 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result = self.describe_short() result = self.describe_short()
result["jobQueue"] = self.job_queue.arn result["jobQueue"] = self.job_queue.arn
result["dependsOn"] = self.depends_on or [] result["dependsOn"] = self.depends_on or []
result["container"] = self._container_details() if self.job_definition.type == "container":
result["container"] = self._container_details()
elif self.job_definition.type == "multinode":
result["container"] = {
"logStreamName": self.log_stream_name,
}
result["nodeProperties"] = self.job_definition.node_properties
if self.job_stopped: if self.job_stopped:
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at) result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
if self.timeout: if self.timeout:
@ -575,6 +603,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
try: try:
import docker import docker
containers: List[docker.models.containers.Container] = []
self.advance() self.advance()
while self.status == "SUBMITTED": while self.status == "SUBMITTED":
# Wait until we've moved onto state 'PENDING' # Wait until we've moved onto state 'PENDING'
@ -585,34 +615,87 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self.depends_on and not self._wait_for_dependencies(): if self.depends_on and not self._wait_for_dependencies():
return return
image = self.job_definition.container_properties.get( container_kwargs = []
"image", "alpine:latest" if self.job_definition.container_properties:
) volumes = {
privileged = self.job_definition.container_properties.get( v["name"]: v["host"]
"privileged", False for v in self._get_container_property("volumes", [])
) }
cmd = self._get_container_property( container_kwargs.append(
"command", {
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', "image": self.job_definition.container_properties.get(
) "image", "alpine:latest"
environment = { ),
e["name"]: e["value"] "privileged": self.job_definition.container_properties.get(
for e in self._get_container_property("environment", []) "privileged", False
} ),
volumes = { "command": self._get_container_property(
v["name"]: v["host"] "command",
for v in self._get_container_property("volumes", []) '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
} ),
mounts = [ "environment": {
docker.types.Mount( e["name"]: e["value"]
m["containerPath"], for e in self._get_container_property("environment", [])
volumes[m["sourceVolume"]]["sourcePath"], },
type="bind", "mounts": [
read_only=m["readOnly"], docker.types.Mount(
m["containerPath"],
volumes[m["sourceVolume"]]["sourcePath"],
type="bind",
read_only=m["readOnly"],
)
for m in self._get_container_property("mountPoints", [])
],
"name": f"{self.job_name}-{self.job_id}",
}
) )
for m in self._get_container_property("mountPoints", []) else:
] node_properties = self.job_definition.node_properties
name = f"{self.job_name}-{self.job_id}" num_nodes = node_properties["numNodes"]
node_containers = {}
for node_range in node_properties["nodeRangeProperties"]:
start, sep, end = node_range["targetNodes"].partition(":")
if sep == "":
start = end = int(start)
else:
if start == "":
start = 0
else:
start = int(start)
if end == "":
end = num_nodes - 1
else:
end = int(end)
for i in range(start, end + 1):
node_containers[i] = node_range["container"]
for i in range(num_nodes):
spec = node_containers[i]
volumes = {v["name"]: v["host"] for v in spec.get("volumes", [])}
container_kwargs.append(
{
"image": spec.get("image", "alpine:latest"),
"privileged": spec.get("privileged", False),
"command": spec.get(
"command",
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
),
"environment": {
e["name"]: e["value"]
for e in spec.get("environment", [])
},
"mounts": [
docker.types.Mount(
m["containerPath"],
volumes[m["sourceVolume"]]["sourcePath"],
type="bind",
read_only=m["readOnly"],
)
for m in spec.get("mountPoints", [])
],
"name": f"{self.job_name}-{self.job_id}-{i}",
}
)
self.advance() self.advance()
while self.status == "PENDING": while self.status == "PENDING":
@ -632,19 +715,20 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
else {} else {}
) )
environment["MOTO_HOST"] = settings.moto_server_host() for kwargs in container_kwargs:
environment["MOTO_PORT"] = settings.moto_server_port() environment = kwargs["environment"]
environment[ environment["MOTO_HOST"] = settings.moto_server_host()
"MOTO_HTTP_ENDPOINT" environment["MOTO_PORT"] = settings.moto_server_port()
] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}' environment[
"MOTO_HTTP_ENDPOINT"
] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}'
run_kwargs = dict() network_name = settings.moto_network_name()
network_name = settings.moto_network_name() network_mode = settings.moto_network_mode()
network_mode = settings.moto_network_mode() if network_name:
if network_name: kwargs["network"] = network_name
run_kwargs["network"] = network_name elif network_mode:
elif network_mode: kwargs["network_mode"] = network_mode
run_kwargs["network_mode"] = network_mode
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
self.advance() self.advance()
@ -656,107 +740,122 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
while self.status == "STARTING": while self.status == "STARTING":
# Wait until the state is no longer runnable, but 'RUNNING' # Wait until the state is no longer runnable, but 'RUNNING'
sleep(0.5) sleep(0.5)
container = self.docker_client.containers.run(
image, for kwargs in container_kwargs:
cmd, if len(containers) > 0:
detach=True, env = kwargs["environment"]
name=name, ip = containers[0].attrs["NetworkSettings"]["IPAddress"]
log_config=log_config, env["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"] = ip
environment=environment, container = self.docker_client.containers.run(
mounts=mounts, detach=True,
privileged=privileged, log_config=log_config,
extra_hosts=extra_hosts, extra_hosts=extra_hosts,
**run_kwargs, **kwargs,
) )
try:
container.reload() container.reload()
containers.append(container)
max_time = None for i, container in enumerate(containers):
if self._get_attempt_duration(): try:
attempt_duration = self._get_attempt_duration()
max_time = self.job_started_at + datetime.timedelta(
seconds=attempt_duration # type: ignore[arg-type]
)
while container.status == "running" and not self.stop:
container.reload() container.reload()
time.sleep(0.5)
if max_time and datetime.datetime.now() > max_time: max_time = None
raise Exception( if self._get_attempt_duration():
"Job time exceeded the configured attemptDurationSeconds" attempt_duration = self._get_attempt_duration()
max_time = self.job_started_at + datetime.timedelta(
seconds=attempt_duration # type: ignore[arg-type]
) )
# Container should be stopped by this point... unless asked to stop while container.status == "running" and not self.stop:
if container.status == "running": container.reload()
container.kill() time.sleep(0.5)
# Log collection if max_time and datetime.datetime.now() > max_time:
logs_stdout = [] raise Exception(
logs_stderr = [] "Job time exceeded the configured attemptDurationSeconds"
logs_stderr.extend( )
container.logs(
stdout=False, # Container should be stopped by this point... unless asked to stop
stderr=True, if container.status == "running":
timestamps=True, container.kill()
since=datetime2int(self.job_started_at),
# Log collection
logs_stdout = []
logs_stderr = []
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(self.job_started_at),
)
.decode()
.split("\n")
) )
.decode() logs_stdout.extend(
.split("\n") container.logs(
) stdout=True,
logs_stdout.extend( stderr=False,
container.logs( timestamps=True,
stdout=True, since=datetime2int(self.job_started_at),
stderr=False, )
timestamps=True, .decode()
since=datetime2int(self.job_started_at), .split("\n")
) )
.decode()
.split("\n")
)
# Process logs # Process logs
logs_stdout = [x for x in logs_stdout if len(x) > 0] logs_stdout = [x for x in logs_stdout if len(x) > 0]
logs_stderr = [x for x in logs_stderr if len(x) > 0] logs_stderr = [x for x in logs_stderr if len(x) > 0]
logs = [] logs = []
for line in logs_stdout + logs_stderr: for line in logs_stdout + logs_stderr:
date, line = line.split(" ", 1) date, line = line.split(" ", 1)
date_obj = ( date_obj = (
dateutil.parser.parse(date) dateutil.parser.parse(date)
.astimezone(datetime.timezone.utc) .astimezone(datetime.timezone.utc)
.replace(tzinfo=None) .replace(tzinfo=None)
)
date = unix_time_millis(date_obj)
logs.append({"timestamp": date, "message": line.strip()})
logs = sorted(logs, key=lambda log: log["timestamp"])
# Send to cloudwatch
self.log_stream_name = self._stream_name
self._log_backend.ensure_log_group(self._log_group, None)
self._log_backend.ensure_log_stream(
self._log_group, self.log_stream_name
)
self._log_backend.put_log_events(
self._log_group, self.log_stream_name, logs
) )
date = unix_time_millis(date_obj)
logs.append({"timestamp": date, "message": line.strip()})
logs = sorted(logs, key=lambda log: log["timestamp"])
# Send to cloudwatch result = container.wait() or {}
self.log_stream_name = self._stream_name exit_code = result.get("StatusCode", 0)
self._log_backend.ensure_log_group(self._log_group, None) self.exit_code = exit_code
self._log_backend.create_log_stream(self._log_group, self._stream_name) job_failed = self.stop or exit_code > 0
self._log_backend.put_log_events( if job_failed:
self._log_group, self._stream_name, logs self._mark_stopped(success=False)
) break
result = container.wait() or {} except Exception as err:
exit_code = result.get("StatusCode", 0) logger.error(
self.exit_code = exit_code f"Failed to run AWS Batch container {self.name}. Error {err}"
job_failed = self.stop or exit_code > 0 )
self._mark_stopped(success=not job_failed) self._mark_stopped(success=False)
except Exception as err: self._mark_stopped(success=True)
logger.error(
f"Failed to run AWS Batch container {self.name}. Error {err}"
)
self._mark_stopped(success=False)
container.kill()
finally:
container.remove()
except Exception as err: except Exception as err:
logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}") logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}")
self._mark_stopped(success=False) self._mark_stopped(success=False)
finally:
for container in containers:
container.reload()
if container.status == "running":
container.kill()
container.remove()
def _mark_stopped(self, success: bool = True) -> None: def _mark_stopped(self, success: bool = True) -> None:
if self.job_stopped:
return
# Ensure that job_stopped/job_stopped_at-attributes are set first # Ensure that job_stopped/job_stopped_at-attributes are set first
# The describe-method needs them immediately when status is set # The describe-method needs them immediately when status is set
self.job_stopped = True self.job_stopped = True
@ -1437,6 +1536,7 @@ class BatchBackend(BaseBackend):
tags: Dict[str, str], tags: Dict[str, str],
retry_strategy: Dict[str, Any], retry_strategy: Dict[str, Any],
container_properties: Dict[str, Any], container_properties: Dict[str, Any],
node_properties: Dict[str, Any],
timeout: Dict[str, int], timeout: Dict[str, int],
platform_capabilities: List[str], platform_capabilities: List[str],
propagate_tags: bool, propagate_tags: bool,
@ -1457,6 +1557,7 @@ class BatchBackend(BaseBackend):
parameters, parameters,
_type, _type,
container_properties, container_properties,
node_properties=node_properties,
tags=tags, tags=tags,
retry_strategy=retry_strategy, retry_strategy=retry_strategy,
timeout=timeout, timeout=timeout,
@ -1467,7 +1568,13 @@ class BatchBackend(BaseBackend):
else: else:
# Make new jobdef # Make new jobdef
job_def = job_def.update( job_def = job_def.update(
parameters, _type, container_properties, retry_strategy, tags, timeout parameters,
_type,
container_properties,
node_properties,
retry_strategy,
tags,
timeout,
) )
self._job_definitions[job_def.arn] = job_def self._job_definitions[job_def.arn] = job_def
@ -1562,7 +1669,11 @@ class BatchBackend(BaseBackend):
result = [] result = []
for key, job in self._jobs.items(): for key, job in self._jobs.items():
if len(job_filter) > 0 and key not in job_filter: if (
len(job_filter) > 0
and key not in job_filter
and job.arn not in job_filter
):
continue continue
result.append(job.describe()) result.append(job.describe())
@ -1570,7 +1681,10 @@ class BatchBackend(BaseBackend):
return result return result
def list_jobs( def list_jobs(
self, job_queue_name: str, job_status: Optional[str] = None self,
job_queue_name: str,
job_status: Optional[str] = None,
filters: Optional[List[Dict[str, Any]]] = None,
) -> List[Job]: ) -> List[Job]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
@ -1598,6 +1712,18 @@ class BatchBackend(BaseBackend):
if job_status is not None and job.status != job_status: if job_status is not None and job.status != job_status:
continue continue
if filters is not None:
matches = True
for filt in filters:
name = filt["name"]
values = filt["values"]
if name == "JOB_NAME":
if job.job_name not in values:
matches = False
break
if not matches:
continue
jobs.append(job) jobs.append(job)
return jobs return jobs

View File

@ -134,6 +134,7 @@ class BatchResponse(BaseResponse):
# RegisterJobDefinition # RegisterJobDefinition
def registerjobdefinition(self) -> str: def registerjobdefinition(self) -> str:
container_properties = self._get_param("containerProperties") container_properties = self._get_param("containerProperties")
node_properties = self._get_param("nodeProperties")
def_name = self._get_param("jobDefinitionName") def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters") parameters = self._get_param("parameters")
tags = self._get_param("tags") tags = self._get_param("tags")
@ -149,6 +150,7 @@ class BatchResponse(BaseResponse):
tags=tags, tags=tags,
retry_strategy=retry_strategy, retry_strategy=retry_strategy,
container_properties=container_properties, container_properties=container_properties,
node_properties=node_properties,
timeout=timeout, timeout=timeout,
platform_capabilities=platform_capabilities, platform_capabilities=platform_capabilities,
propagate_tags=propagate_tags, propagate_tags=propagate_tags,
@ -215,8 +217,9 @@ class BatchResponse(BaseResponse):
def listjobs(self) -> str: def listjobs(self) -> str:
job_queue = self._get_param("jobQueue") job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus") job_status = self._get_param("jobStatus")
filters = self._get_param("filters")
jobs = self.batch_backend.list_jobs(job_queue, job_status) jobs = self.batch_backend.list_jobs(job_queue, job_status, filters)
result = {"jobSummaryList": [job.describe_short() for job in jobs]} result = {"jobSummaryList": [job.describe_short() for job in jobs]}
return json.dumps(result) return json.dumps(result)

View File

@ -9,6 +9,10 @@ def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str:
return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}" return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}"
def make_arn_for_job(account_id: str, job_id: str, region_name: str) -> str:
return f"arn:aws:batch:{region_name}:{account_id}:job/{job_id}"
def make_arn_for_task_def( def make_arn_for_task_def(
account_id: str, name: str, revision: int, region_name: str account_id: str, name: str, revision: int, region_name: str
) -> str: ) -> str:

View File

@ -649,6 +649,15 @@ class LogsBackend(BaseBackend):
log_group = self.groups[log_group_name] log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name) return log_group.create_log_stream(log_stream_name)
def ensure_log_stream(self, log_group_name: str, log_stream_name: str) -> None:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
if log_stream_name in self.groups[log_group_name].streams:
return
self.create_log_stream(log_group_name, log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name): def delete_log_stream(self, log_group_name, log_stream_name):
if log_group_name not in self.groups: if log_group_name not in self.groups:
raise ResourceNotFoundException() raise ResourceNotFoundException()

View File

@ -151,6 +151,72 @@ def test_submit_job():
attempt.should.have.key("stoppedAt").equals(stopped_at) attempt.should.have.key("stoppedAt").equals(stopped_at)
@mock_logs
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
@pytest.mark.network
def test_submit_job_multinode():
ec2_client, iam_client, _, logs_client, batch_client = _get_clients()
_, _, _, iam_arn = _setup(ec2_client, iam_client)
start_time_milliseconds = time.time() * 1000
job_def_name = str(uuid4())[0:6]
commands = ["echo", "hello"]
job_def_arn, queue_arn = prepare_multinode_job(
batch_client, commands, iam_arn, job_def_name
)
resp = batch_client.submit_job(
jobName=str(uuid4())[0:6], jobQueue=queue_arn, jobDefinition=job_def_arn
)
job_id = resp["jobId"]
# Test that describe_jobs() returns 'createdAt'
# github.com/getmoto/moto/issues/4364
resp = batch_client.describe_jobs(jobs=[job_id])
created_at = resp["jobs"][0]["createdAt"]
created_at.should.be.greater_than(start_time_milliseconds)
_wait_for_job_status(batch_client, job_id, "SUCCEEDED")
resp = logs_client.describe_log_streams(
logGroupName="/aws/batch/job", logStreamNamePrefix=job_def_name
)
resp["logStreams"].should.have.length_of(1)
ls_name = resp["logStreams"][0]["logStreamName"]
resp = logs_client.get_log_events(
logGroupName="/aws/batch/job", logStreamName=ls_name
)
[event["message"] for event in resp["events"]].should.equal(["hello", "hello"])
# Test that describe_jobs() returns timestamps in milliseconds
# github.com/getmoto/moto/issues/4364
job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0]
created_at = job["createdAt"]
started_at = job["startedAt"]
stopped_at = job["stoppedAt"]
created_at.should.be.greater_than(start_time_milliseconds)
started_at.should.be.greater_than(start_time_milliseconds)
stopped_at.should.be.greater_than(start_time_milliseconds)
# Verify we track attempts
job.should.have.key("attempts").length_of(1)
attempt = job["attempts"][0]
attempt.should.have.key("container")
attempt["container"].should.have.key("containerInstanceArn")
attempt["container"].should.have.key("logStreamName").equals(
job["container"]["logStreamName"]
)
attempt["container"].should.have.key("networkInterfaces")
attempt["container"].should.have.key("taskArn")
attempt.should.have.key("startedAt").equals(started_at)
attempt.should.have.key("stoppedAt").equals(stopped_at)
@mock_logs @mock_logs
@mock_ec2 @mock_ec2
@mock_ecs @mock_ecs
@ -205,6 +271,18 @@ def test_list_jobs():
job.should.have.key("stoppedAt") job.should.have.key("stoppedAt")
job.should.have.key("container").should.have.key("exitCode").equals(0) job.should.have.key("container").should.have.key("exitCode").equals(0)
filtered_jobs = batch_client.list_jobs(
jobQueue=queue_arn,
filters=[
{
"name": "JOB_NAME",
"values": ["test2"],
}
],
)["jobSummaryList"]
filtered_jobs.should.have.length_of(1)
filtered_jobs[0]["jobName"].should.equal("test2")
@mock_logs @mock_logs
@mock_ec2 @mock_ec2
@ -772,6 +850,51 @@ def prepare_job(batch_client, commands, iam_arn, job_def_name):
return job_def_arn, queue_arn return job_def_arn, queue_arn
def prepare_multinode_job(batch_client, commands, iam_arn, job_def_name):
compute_name = str(uuid4())[0:6]
resp = batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type="UNMANAGED",
state="ENABLED",
serviceRole=iam_arn,
)
arn = resp["computeEnvironmentArn"]
resp = batch_client.create_job_queue(
jobQueueName=str(uuid4())[0:6],
state="ENABLED",
priority=123,
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
)
queue_arn = resp["jobQueueArn"]
container = {
"image": "busybox:latest",
"vcpus": 1,
"memory": 128,
"command": commands,
}
resp = batch_client.register_job_definition(
jobDefinitionName=job_def_name,
type="multinode",
nodeProperties={
"mainNode": 0,
"numNodes": 2,
"nodeRangeProperties": [
{
"container": container,
"targetNodes": "0",
},
{
"container": container,
"targetNodes": "1",
},
],
},
)
job_def_arn = resp["jobDefinitionArn"]
return job_def_arn, queue_arn
@mock_batch @mock_batch
def test_update_job_definition(): def test_update_job_definition():
_, _, _, _, batch_client = _get_clients() _, _, _, _, batch_client = _get_clients()