diff --git a/moto/batch/models.py b/moto/batch/models.py index 01c4a71fd..09d3fd3c9 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -21,6 +21,7 @@ from .exceptions import InvalidParameterValueException, ClientException, Validat from .utils import ( make_arn_for_compute_env, make_arn_for_job_queue, + make_arn_for_job, make_arn_for_task_def, lowercase_first_key, ) @@ -221,6 +222,7 @@ class JobDefinition(CloudFormationModel): parameters: Optional[Dict[str, Any]], _type: str, container_properties: Dict[str, Any], + node_properties: Dict[str, Any], tags: Dict[str, str], retry_strategy: Dict[str, str], timeout: Dict[str, int], @@ -235,6 +237,7 @@ class JobDefinition(CloudFormationModel): self.revision = revision or 0 self._region = backend.region_name self.container_properties = container_properties + self.node_properties = node_properties self.status = "ACTIVE" self.parameters = parameters or {} self.timeout = timeout @@ -242,10 +245,11 @@ class JobDefinition(CloudFormationModel): self.platform_capabilities = platform_capabilities self.propagate_tags = propagate_tags - if "resourceRequirements" not in self.container_properties: - self.container_properties["resourceRequirements"] = [] - if "secrets" not in self.container_properties: - self.container_properties["secrets"] = [] + if self.container_properties is not None: + if "resourceRequirements" not in self.container_properties: + self.container_properties["resourceRequirements"] = [] + if "secrets" not in self.container_properties: + self.container_properties["secrets"] = [] self._validate() self.revision += 1 @@ -306,26 +310,28 @@ class JobDefinition(CloudFormationModel): def _validate(self) -> None: # For future use when containers arnt the only thing in batch - if self.type not in ("container",): - raise ClientException('type must be one of "container"') + VALID_TYPES = ("container", "multinode") + if self.type not in VALID_TYPES: + raise ClientException(f"type must be one of {VALID_TYPES}") if not isinstance(self.parameters, dict): raise ClientException("parameters must be a string to string map") - if "image" not in self.container_properties: - raise ClientException("containerProperties must contain image") + if self.type == "container": + if "image" not in self.container_properties: + raise ClientException("containerProperties must contain image") - memory = self._get_resource_requirement("memory") - if memory is None: - raise ClientException("containerProperties must contain memory") - if memory < 4: - raise ClientException("container memory limit must be greater than 4") + memory = self._get_resource_requirement("memory") + if memory is None: + raise ClientException("containerProperties must contain memory") + if memory < 4: + raise ClientException("container memory limit must be greater than 4") - vcpus = self._get_resource_requirement("vcpus") - if vcpus is None: - raise ClientException("containerProperties must contain vcpus") - if vcpus <= 0: - raise ClientException("container vcpus limit must be greater than 0") + vcpus = self._get_resource_requirement("vcpus") + if vcpus is None: + raise ClientException("containerProperties must contain vcpus") + if vcpus <= 0: + raise ClientException("container vcpus limit must be greater than 0") def deregister(self) -> None: self.status = "INACTIVE" @@ -335,6 +341,7 @@ class JobDefinition(CloudFormationModel): parameters: Optional[Dict[str, Any]], _type: str, container_properties: Dict[str, Any], + node_properties: Dict[str, Any], retry_strategy: Dict[str, Any], tags: Dict[str, str], timeout: Dict[str, int], @@ -357,6 +364,7 @@ class JobDefinition(CloudFormationModel): parameters, _type, container_properties, + node_properties=node_properties, revision=self.revision, retry_strategy=retry_strategy, tags=tags, @@ -416,7 +424,16 @@ class JobDefinition(CloudFormationModel): _type="container", tags=lowercase_first_key(properties.get("Tags", {})), retry_strategy=lowercase_first_key(properties["RetryStrategy"]), - container_properties=lowercase_first_key(properties["ContainerProperties"]), + container_properties=( + lowercase_first_key(properties["ContainerProperties"]) + if "ContainerProperties" in properties + else None + ), + node_properties=( + lowercase_first_key(properties["NodeProperties"]) + if "NodeProperties" in properties + else None + ), timeout=lowercase_first_key(properties.get("timeout", {})), platform_capabilities=None, propagate_tags=None, @@ -466,6 +483,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): self.timeout = timeout self.all_jobs = all_jobs + self.arn = make_arn_for_job( + job_def.backend.account_id, self.job_id, job_def._region + ) + self.stop = False self.exit_code: Optional[int] = None @@ -483,6 +504,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): def describe_short(self) -> Dict[str, Any]: result = { "jobId": self.job_id, + "jobArn": self.arn, "jobName": self.job_name, "createdAt": datetime2int_milliseconds(self.job_created_at), "status": self.status, @@ -502,7 +524,13 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): result = self.describe_short() result["jobQueue"] = self.job_queue.arn result["dependsOn"] = self.depends_on or [] - result["container"] = self._container_details() + if self.job_definition.type == "container": + result["container"] = self._container_details() + elif self.job_definition.type == "multinode": + result["container"] = { + "logStreamName": self.log_stream_name, + } + result["nodeProperties"] = self.job_definition.node_properties if self.job_stopped: result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at) if self.timeout: @@ -575,6 +603,8 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): try: import docker + containers: List[docker.models.containers.Container] = [] + self.advance() while self.status == "SUBMITTED": # Wait until we've moved onto state 'PENDING' @@ -585,34 +615,87 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): if self.depends_on and not self._wait_for_dependencies(): return - image = self.job_definition.container_properties.get( - "image", "alpine:latest" - ) - privileged = self.job_definition.container_properties.get( - "privileged", False - ) - cmd = self._get_container_property( - "command", - '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', - ) - environment = { - e["name"]: e["value"] - for e in self._get_container_property("environment", []) - } - volumes = { - v["name"]: v["host"] - for v in self._get_container_property("volumes", []) - } - mounts = [ - docker.types.Mount( - m["containerPath"], - volumes[m["sourceVolume"]]["sourcePath"], - type="bind", - read_only=m["readOnly"], + container_kwargs = [] + if self.job_definition.container_properties: + volumes = { + v["name"]: v["host"] + for v in self._get_container_property("volumes", []) + } + container_kwargs.append( + { + "image": self.job_definition.container_properties.get( + "image", "alpine:latest" + ), + "privileged": self.job_definition.container_properties.get( + "privileged", False + ), + "command": self._get_container_property( + "command", + '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', + ), + "environment": { + e["name"]: e["value"] + for e in self._get_container_property("environment", []) + }, + "mounts": [ + docker.types.Mount( + m["containerPath"], + volumes[m["sourceVolume"]]["sourcePath"], + type="bind", + read_only=m["readOnly"], + ) + for m in self._get_container_property("mountPoints", []) + ], + "name": f"{self.job_name}-{self.job_id}", + } ) - for m in self._get_container_property("mountPoints", []) - ] - name = f"{self.job_name}-{self.job_id}" + else: + node_properties = self.job_definition.node_properties + num_nodes = node_properties["numNodes"] + node_containers = {} + for node_range in node_properties["nodeRangeProperties"]: + start, sep, end = node_range["targetNodes"].partition(":") + if sep == "": + start = end = int(start) + else: + if start == "": + start = 0 + else: + start = int(start) + if end == "": + end = num_nodes - 1 + else: + end = int(end) + for i in range(start, end + 1): + node_containers[i] = node_range["container"] + + for i in range(num_nodes): + spec = node_containers[i] + volumes = {v["name"]: v["host"] for v in spec.get("volumes", [])} + container_kwargs.append( + { + "image": spec.get("image", "alpine:latest"), + "privileged": spec.get("privileged", False), + "command": spec.get( + "command", + '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', + ), + "environment": { + e["name"]: e["value"] + for e in spec.get("environment", []) + }, + "mounts": [ + docker.types.Mount( + m["containerPath"], + volumes[m["sourceVolume"]]["sourcePath"], + type="bind", + read_only=m["readOnly"], + ) + for m in spec.get("mountPoints", []) + ], + "name": f"{self.job_name}-{self.job_id}-{i}", + } + ) self.advance() while self.status == "PENDING": @@ -632,19 +715,20 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): else {} ) - environment["MOTO_HOST"] = settings.moto_server_host() - environment["MOTO_PORT"] = settings.moto_server_port() - environment[ - "MOTO_HTTP_ENDPOINT" - ] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}' + for kwargs in container_kwargs: + environment = kwargs["environment"] + environment["MOTO_HOST"] = settings.moto_server_host() + environment["MOTO_PORT"] = settings.moto_server_port() + environment[ + "MOTO_HTTP_ENDPOINT" + ] = f'{environment["MOTO_HOST"]}:{environment["MOTO_PORT"]}' - run_kwargs = dict() - network_name = settings.moto_network_name() - network_mode = settings.moto_network_mode() - if network_name: - run_kwargs["network"] = network_name - elif network_mode: - run_kwargs["network_mode"] = network_mode + network_name = settings.moto_network_name() + network_mode = settings.moto_network_mode() + if network_name: + kwargs["network"] = network_name + elif network_mode: + kwargs["network_mode"] = network_mode log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) self.advance() @@ -656,107 +740,122 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): while self.status == "STARTING": # Wait until the state is no longer runnable, but 'RUNNING' sleep(0.5) - container = self.docker_client.containers.run( - image, - cmd, - detach=True, - name=name, - log_config=log_config, - environment=environment, - mounts=mounts, - privileged=privileged, - extra_hosts=extra_hosts, - **run_kwargs, - ) - try: + + for kwargs in container_kwargs: + if len(containers) > 0: + env = kwargs["environment"] + ip = containers[0].attrs["NetworkSettings"]["IPAddress"] + env["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"] = ip + container = self.docker_client.containers.run( + detach=True, + log_config=log_config, + extra_hosts=extra_hosts, + **kwargs, + ) container.reload() + containers.append(container) - max_time = None - if self._get_attempt_duration(): - attempt_duration = self._get_attempt_duration() - max_time = self.job_started_at + datetime.timedelta( - seconds=attempt_duration # type: ignore[arg-type] - ) - - while container.status == "running" and not self.stop: + for i, container in enumerate(containers): + try: container.reload() - time.sleep(0.5) - if max_time and datetime.datetime.now() > max_time: - raise Exception( - "Job time exceeded the configured attemptDurationSeconds" + max_time = None + if self._get_attempt_duration(): + attempt_duration = self._get_attempt_duration() + max_time = self.job_started_at + datetime.timedelta( + seconds=attempt_duration # type: ignore[arg-type] ) - # Container should be stopped by this point... unless asked to stop - if container.status == "running": - container.kill() + while container.status == "running" and not self.stop: + container.reload() + time.sleep(0.5) - # Log collection - logs_stdout = [] - logs_stderr = [] - logs_stderr.extend( - container.logs( - stdout=False, - stderr=True, - timestamps=True, - since=datetime2int(self.job_started_at), + if max_time and datetime.datetime.now() > max_time: + raise Exception( + "Job time exceeded the configured attemptDurationSeconds" + ) + + # Container should be stopped by this point... unless asked to stop + if container.status == "running": + container.kill() + + # Log collection + logs_stdout = [] + logs_stderr = [] + logs_stderr.extend( + container.logs( + stdout=False, + stderr=True, + timestamps=True, + since=datetime2int(self.job_started_at), + ) + .decode() + .split("\n") ) - .decode() - .split("\n") - ) - logs_stdout.extend( - container.logs( - stdout=True, - stderr=False, - timestamps=True, - since=datetime2int(self.job_started_at), + logs_stdout.extend( + container.logs( + stdout=True, + stderr=False, + timestamps=True, + since=datetime2int(self.job_started_at), + ) + .decode() + .split("\n") ) - .decode() - .split("\n") - ) - # Process logs - logs_stdout = [x for x in logs_stdout if len(x) > 0] - logs_stderr = [x for x in logs_stderr if len(x) > 0] - logs = [] - for line in logs_stdout + logs_stderr: - date, line = line.split(" ", 1) - date_obj = ( - dateutil.parser.parse(date) - .astimezone(datetime.timezone.utc) - .replace(tzinfo=None) + # Process logs + logs_stdout = [x for x in logs_stdout if len(x) > 0] + logs_stderr = [x for x in logs_stderr if len(x) > 0] + logs = [] + for line in logs_stdout + logs_stderr: + date, line = line.split(" ", 1) + date_obj = ( + dateutil.parser.parse(date) + .astimezone(datetime.timezone.utc) + .replace(tzinfo=None) + ) + date = unix_time_millis(date_obj) + logs.append({"timestamp": date, "message": line.strip()}) + logs = sorted(logs, key=lambda log: log["timestamp"]) + + # Send to cloudwatch + self.log_stream_name = self._stream_name + self._log_backend.ensure_log_group(self._log_group, None) + self._log_backend.ensure_log_stream( + self._log_group, self.log_stream_name + ) + self._log_backend.put_log_events( + self._log_group, self.log_stream_name, logs ) - date = unix_time_millis(date_obj) - logs.append({"timestamp": date, "message": line.strip()}) - logs = sorted(logs, key=lambda log: log["timestamp"]) - # Send to cloudwatch - self.log_stream_name = self._stream_name - self._log_backend.ensure_log_group(self._log_group, None) - self._log_backend.create_log_stream(self._log_group, self._stream_name) - self._log_backend.put_log_events( - self._log_group, self._stream_name, logs - ) + result = container.wait() or {} + exit_code = result.get("StatusCode", 0) + self.exit_code = exit_code + job_failed = self.stop or exit_code > 0 + if job_failed: + self._mark_stopped(success=False) + break - result = container.wait() or {} - exit_code = result.get("StatusCode", 0) - self.exit_code = exit_code - job_failed = self.stop or exit_code > 0 - self._mark_stopped(success=not job_failed) + except Exception as err: + logger.error( + f"Failed to run AWS Batch container {self.name}. Error {err}" + ) + self._mark_stopped(success=False) - except Exception as err: - logger.error( - f"Failed to run AWS Batch container {self.name}. Error {err}" - ) - self._mark_stopped(success=False) - container.kill() - finally: - container.remove() + self._mark_stopped(success=True) except Exception as err: logger.error(f"Failed to run AWS Batch container {self.name}. Error {err}") self._mark_stopped(success=False) + finally: + for container in containers: + container.reload() + if container.status == "running": + container.kill() + container.remove() def _mark_stopped(self, success: bool = True) -> None: + if self.job_stopped: + return # Ensure that job_stopped/job_stopped_at-attributes are set first # The describe-method needs them immediately when status is set self.job_stopped = True @@ -1437,6 +1536,7 @@ class BatchBackend(BaseBackend): tags: Dict[str, str], retry_strategy: Dict[str, Any], container_properties: Dict[str, Any], + node_properties: Dict[str, Any], timeout: Dict[str, int], platform_capabilities: List[str], propagate_tags: bool, @@ -1457,6 +1557,7 @@ class BatchBackend(BaseBackend): parameters, _type, container_properties, + node_properties=node_properties, tags=tags, retry_strategy=retry_strategy, timeout=timeout, @@ -1467,7 +1568,13 @@ class BatchBackend(BaseBackend): else: # Make new jobdef job_def = job_def.update( - parameters, _type, container_properties, retry_strategy, tags, timeout + parameters, + _type, + container_properties, + node_properties, + retry_strategy, + tags, + timeout, ) self._job_definitions[job_def.arn] = job_def @@ -1562,7 +1669,11 @@ class BatchBackend(BaseBackend): result = [] for key, job in self._jobs.items(): - if len(job_filter) > 0 and key not in job_filter: + if ( + len(job_filter) > 0 + and key not in job_filter + and job.arn not in job_filter + ): continue result.append(job.describe()) @@ -1570,7 +1681,10 @@ class BatchBackend(BaseBackend): return result def list_jobs( - self, job_queue_name: str, job_status: Optional[str] = None + self, + job_queue_name: str, + job_status: Optional[str] = None, + filters: Optional[List[Dict[str, Any]]] = None, ) -> List[Job]: """ Pagination is not yet implemented @@ -1598,6 +1712,18 @@ class BatchBackend(BaseBackend): if job_status is not None and job.status != job_status: continue + if filters is not None: + matches = True + for filt in filters: + name = filt["name"] + values = filt["values"] + if name == "JOB_NAME": + if job.job_name not in values: + matches = False + break + if not matches: + continue + jobs.append(job) return jobs diff --git a/moto/batch/responses.py b/moto/batch/responses.py index be4375dd5..1e42b6c76 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -134,6 +134,7 @@ class BatchResponse(BaseResponse): # RegisterJobDefinition def registerjobdefinition(self) -> str: container_properties = self._get_param("containerProperties") + node_properties = self._get_param("nodeProperties") def_name = self._get_param("jobDefinitionName") parameters = self._get_param("parameters") tags = self._get_param("tags") @@ -149,6 +150,7 @@ class BatchResponse(BaseResponse): tags=tags, retry_strategy=retry_strategy, container_properties=container_properties, + node_properties=node_properties, timeout=timeout, platform_capabilities=platform_capabilities, propagate_tags=propagate_tags, @@ -215,8 +217,9 @@ class BatchResponse(BaseResponse): def listjobs(self) -> str: job_queue = self._get_param("jobQueue") job_status = self._get_param("jobStatus") + filters = self._get_param("filters") - jobs = self.batch_backend.list_jobs(job_queue, job_status) + jobs = self.batch_backend.list_jobs(job_queue, job_status, filters) result = {"jobSummaryList": [job.describe_short() for job in jobs]} return json.dumps(result) diff --git a/moto/batch/utils.py b/moto/batch/utils.py index c565f5636..6df059dca 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -9,6 +9,10 @@ def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str: return f"arn:aws:batch:{region_name}:{account_id}:job-queue/{name}" +def make_arn_for_job(account_id: str, job_id: str, region_name: str) -> str: + return f"arn:aws:batch:{region_name}:{account_id}:job/{job_id}" + + def make_arn_for_task_def( account_id: str, name: str, revision: int, region_name: str ) -> str: diff --git a/moto/logs/models.py b/moto/logs/models.py index 581899503..5e241c9c4 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -649,6 +649,15 @@ class LogsBackend(BaseBackend): log_group = self.groups[log_group_name] return log_group.create_log_stream(log_stream_name) + def ensure_log_stream(self, log_group_name: str, log_stream_name: str) -> None: + if log_group_name not in self.groups: + raise ResourceNotFoundException() + + if log_stream_name in self.groups[log_group_name].streams: + return + + self.create_log_stream(log_group_name, log_stream_name) + def delete_log_stream(self, log_group_name, log_stream_name): if log_group_name not in self.groups: raise ResourceNotFoundException() diff --git a/tests/test_batch/test_batch_jobs.py b/tests/test_batch/test_batch_jobs.py index 05d93898f..a521810c2 100644 --- a/tests/test_batch/test_batch_jobs.py +++ b/tests/test_batch/test_batch_jobs.py @@ -151,6 +151,72 @@ def test_submit_job(): attempt.should.have.key("stoppedAt").equals(stopped_at) +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +@pytest.mark.network +def test_submit_job_multinode(): + ec2_client, iam_client, _, logs_client, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + start_time_milliseconds = time.time() * 1000 + + job_def_name = str(uuid4())[0:6] + commands = ["echo", "hello"] + job_def_arn, queue_arn = prepare_multinode_job( + batch_client, commands, iam_arn, job_def_name + ) + + resp = batch_client.submit_job( + jobName=str(uuid4())[0:6], jobQueue=queue_arn, jobDefinition=job_def_arn + ) + job_id = resp["jobId"] + + # Test that describe_jobs() returns 'createdAt' + # github.com/getmoto/moto/issues/4364 + resp = batch_client.describe_jobs(jobs=[job_id]) + created_at = resp["jobs"][0]["createdAt"] + created_at.should.be.greater_than(start_time_milliseconds) + + _wait_for_job_status(batch_client, job_id, "SUCCEEDED") + + resp = logs_client.describe_log_streams( + logGroupName="/aws/batch/job", logStreamNamePrefix=job_def_name + ) + resp["logStreams"].should.have.length_of(1) + ls_name = resp["logStreams"][0]["logStreamName"] + + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + [event["message"] for event in resp["events"]].should.equal(["hello", "hello"]) + + # Test that describe_jobs() returns timestamps in milliseconds + # github.com/getmoto/moto/issues/4364 + job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0] + created_at = job["createdAt"] + started_at = job["startedAt"] + stopped_at = job["stoppedAt"] + + created_at.should.be.greater_than(start_time_milliseconds) + started_at.should.be.greater_than(start_time_milliseconds) + stopped_at.should.be.greater_than(start_time_milliseconds) + + # Verify we track attempts + job.should.have.key("attempts").length_of(1) + attempt = job["attempts"][0] + attempt.should.have.key("container") + attempt["container"].should.have.key("containerInstanceArn") + attempt["container"].should.have.key("logStreamName").equals( + job["container"]["logStreamName"] + ) + attempt["container"].should.have.key("networkInterfaces") + attempt["container"].should.have.key("taskArn") + attempt.should.have.key("startedAt").equals(started_at) + attempt.should.have.key("stoppedAt").equals(stopped_at) + + @mock_logs @mock_ec2 @mock_ecs @@ -205,6 +271,18 @@ def test_list_jobs(): job.should.have.key("stoppedAt") job.should.have.key("container").should.have.key("exitCode").equals(0) + filtered_jobs = batch_client.list_jobs( + jobQueue=queue_arn, + filters=[ + { + "name": "JOB_NAME", + "values": ["test2"], + } + ], + )["jobSummaryList"] + filtered_jobs.should.have.length_of(1) + filtered_jobs[0]["jobName"].should.equal("test2") + @mock_logs @mock_ec2 @@ -772,6 +850,51 @@ def prepare_job(batch_client, commands, iam_arn, job_def_name): return job_def_arn, queue_arn +def prepare_multinode_job(batch_client, commands, iam_arn, job_def_name): + compute_name = str(uuid4())[0:6] + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + resp = batch_client.create_job_queue( + jobQueueName=str(uuid4())[0:6], + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + container = { + "image": "busybox:latest", + "vcpus": 1, + "memory": 128, + "command": commands, + } + resp = batch_client.register_job_definition( + jobDefinitionName=job_def_name, + type="multinode", + nodeProperties={ + "mainNode": 0, + "numNodes": 2, + "nodeRangeProperties": [ + { + "container": container, + "targetNodes": "0", + }, + { + "container": container, + "targetNodes": "1", + }, + ], + }, + ) + job_def_arn = resp["jobDefinitionArn"] + return job_def_arn, queue_arn + + @mock_batch def test_update_job_definition(): _, _, _, _, batch_client = _get_clients()