From f4e62f0dfde91f6fd5a4a5f2b29e1e2bd20fa739 Mon Sep 17 00:00:00 2001 From: Adam Richie-Halford Date: Mon, 1 Nov 2021 03:31:22 -0700 Subject: [PATCH] ENH: Add resource_requirements to batch job definition (#4506) --- moto/batch/models.py | 50 +++++++++++- tests/test_batch/test_batch_cloudformation.py | 6 +- .../test_batch/test_batch_task_definition.py | 79 +++++++++++++++---- 3 files changed, 113 insertions(+), 22 deletions(-) diff --git a/moto/batch/models.py b/moto/batch/models.py index 200fb7806..686e85347 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -233,6 +233,41 @@ class JobDefinition(CloudFormationModel): DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region ) + def _get_resource_requirement(self, req_type, default=None): + """ + Get resource requirement from container properties. + + Resource requirements like "memory" and "vcpus" are now specified in + "resourceRequirements". This function retrieves a resource requirement + from either container_properties.resourceRequirements (preferred) or + directly from container_properties (deprecated). + + :param req_type: The type of resource requirement to retrieve. + :type req_type: ["gpu", "memory", "vcpus"] + + :param default: The default value to return if the resource requirement is not found. + :type default: any, default=None + + :return: The value of the resource requirement, or None. + :rtype: any + """ + resource_reqs = self.container_properties.get("resourceRequirements", []) + + # Filter the resource requirements by the specified type. + # Note that VCPUS are specified in resourceRequirements without the + # trailing "s", so we strip that off in the comparison below. + required_resource = list( + filter( + lambda req: req["type"].lower() == req_type.lower().rstrip("s"), + resource_reqs, + ) + ) + + if required_resource: + return required_resource[0]["value"] + else: + return self.container_properties.get(req_type, default) + def _validate(self): if self.type not in ("container",): raise ClientException('type must be one of "container"') @@ -247,14 +282,16 @@ class JobDefinition(CloudFormationModel): if "image" not in self.container_properties: raise ClientException("containerProperties must contain image") - if "memory" not in self.container_properties: + memory = self._get_resource_requirement("memory") + if memory is None: raise ClientException("containerProperties must contain memory") - if self.container_properties["memory"] < 4: + if memory < 4: raise ClientException("container memory limit must be greater than 4") - if "vcpus" not in self.container_properties: + vcpus = self._get_resource_requirement("vcpus") + if vcpus is None: raise ClientException("containerProperties must contain vcpus") - if self.container_properties["vcpus"] < 1: + if vcpus < 1: raise ClientException("container vcpus limit must be greater than 0") def update(self, parameters, _type, container_properties, retry_strategy): @@ -426,6 +463,11 @@ class Job(threading.Thread, BaseModel, DockerModel): return job_env + if p in ["vcpus", "memory"]: + return self.container_overrides.get( + p, self.job_definition._get_resource_requirement(p, default) + ) + return self.container_overrides.get( p, self.job_definition.container_properties.get(p, default) ) diff --git a/tests/test_batch/test_batch_cloudformation.py b/tests/test_batch/test_batch_cloudformation.py index d6f16c074..5618812b4 100644 --- a/tests/test_batch/test_batch_cloudformation.py +++ b/tests/test_batch/test_batch_cloudformation.py @@ -227,8 +227,10 @@ def test_create_job_def_cf(): ], ] }, - "Vcpus": 2, - "Memory": 2000, + "ResourceRequirements": [ + {"Type": "MEMORY", "Value": 2000}, + {"Type": "VCPU", "Value": 2}, + ], "Command": ["echo", "Hello world"], "LinuxParameters": {"Devices": [{"HostPath": "test-path"}]}, }, diff --git a/tests/test_batch/test_batch_task_definition.py b/tests/test_batch/test_batch_task_definition.py index 092f6082e..d60c82e37 100644 --- a/tests/test_batch/test_batch_task_definition.py +++ b/tests/test_batch/test_batch_task_definition.py @@ -1,5 +1,6 @@ from . import _get_clients, _setup import random +import pytest import sure # noqa # pylint: disable=unused-import from moto import mock_batch, mock_iam, mock_ec2, mock_ecs from uuid import uuid4 @@ -9,11 +10,12 @@ from uuid import uuid4 @mock_ecs @mock_iam @mock_batch -def test_register_task_definition(): +@pytest.mark.parametrize("use_resource_reqs", [True, False]) +def test_register_task_definition(use_resource_reqs): ec2_client, iam_client, _, _, batch_client = _get_clients() _setup(ec2_client, iam_client) - resp = register_job_def(batch_client) + resp = register_job_def(batch_client, use_resource_reqs=use_resource_reqs) resp.should.contain("jobDefinitionArn") resp.should.contain("jobDefinitionName") @@ -47,13 +49,16 @@ def test_register_task_definition_with_tags(): @mock_ecs @mock_iam @mock_batch -def test_reregister_task_definition(): +@pytest.mark.parametrize("use_resource_reqs", [True, False]) +def test_reregister_task_definition(use_resource_reqs): # Reregistering task with the same name bumps the revision number ec2_client, iam_client, _, _, batch_client = _get_clients() _setup(ec2_client, iam_client) job_def_name = str(uuid4())[0:6] - resp1 = register_job_def(batch_client, definition_name=job_def_name) + resp1 = register_job_def( + batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs + ) resp1.should.contain("jobDefinitionArn") resp1.should.have.key("jobDefinitionName").equals(job_def_name) @@ -64,18 +69,24 @@ def test_reregister_task_definition(): ) resp1["revision"].should.equal(1) - resp2 = register_job_def(batch_client, definition_name=job_def_name) + resp2 = register_job_def( + batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs + ) resp2["revision"].should.equal(2) resp2["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) - resp3 = register_job_def(batch_client, definition_name=job_def_name) + resp3 = register_job_def( + batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs + ) resp3["revision"].should.equal(3) resp3["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) resp3["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"]) - resp4 = register_job_def(batch_client, definition_name=job_def_name) + resp4 = register_job_def( + batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs + ) resp4["revision"].should.equal(4) resp4["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) @@ -87,11 +98,14 @@ def test_reregister_task_definition(): @mock_ecs @mock_iam @mock_batch -def test_delete_task_definition(): +@pytest.mark.parametrize("use_resource_reqs", [True, False]) +def test_delete_task_definition(use_resource_reqs): ec2_client, iam_client, _, _, batch_client = _get_clients() _setup(ec2_client, iam_client) - resp = register_job_def(batch_client, definition_name=str(uuid4())) + resp = register_job_def( + batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs + ) name = resp["jobDefinitionName"] batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"]) @@ -104,11 +118,14 @@ def test_delete_task_definition(): @mock_ecs @mock_iam @mock_batch -def test_delete_task_definition_by_name(): +@pytest.mark.parametrize("use_resource_reqs", [True, False]) +def test_delete_task_definition_by_name(use_resource_reqs): ec2_client, iam_client, _, _, batch_client = _get_clients() _setup(ec2_client, iam_client) - resp = register_job_def(batch_client, definition_name=str(uuid4())) + resp = register_job_def( + batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs + ) name = resp["jobDefinitionName"] batch_client.deregister_job_definition(jobDefinition=f"{name}:{resp['revision']}") @@ -121,16 +138,27 @@ def test_delete_task_definition_by_name(): @mock_ecs @mock_iam @mock_batch -def test_describe_task_definition(): +@pytest.mark.parametrize("use_resource_reqs", [True, False]) +def test_describe_task_definition(use_resource_reqs): ec2_client, iam_client, _, _, batch_client = _get_clients() _setup(ec2_client, iam_client) sleep_def_name = f"sleep10_{str(uuid4())[0:6]}" other_name = str(uuid4())[0:6] tagged_name = str(uuid4())[0:6] - register_job_def(batch_client, definition_name=sleep_def_name) - register_job_def(batch_client, definition_name=sleep_def_name) - register_job_def(batch_client, definition_name=other_name) + register_job_def( + batch_client, + definition_name=sleep_def_name, + use_resource_reqs=use_resource_reqs, + ) + register_job_def( + batch_client, + definition_name=sleep_def_name, + use_resource_reqs=use_resource_reqs, + ) + register_job_def( + batch_client, definition_name=other_name, use_resource_reqs=use_resource_reqs + ) register_job_def_with_tags(batch_client, definition_name=tagged_name) resp = batch_client.describe_job_definitions(jobDefinitionName=sleep_def_name) @@ -157,7 +185,26 @@ def test_describe_task_definition(): job_definition["status"].should.equal("ACTIVE") -def register_job_def(batch_client, definition_name="sleep10"): +def register_job_def(batch_client, definition_name="sleep10", use_resource_reqs=True): + container_properties = { + "image": "busybox", + "command": ["sleep", "10"], + } + + if use_resource_reqs: + container_properties.update( + { + "resourceRequirements": [ + {"value": "1", "type": "VCPU"}, + {"value": str(random.randint(4, 128)), "type": "MEMORY"}, + ] + } + ) + else: + container_properties.update( + {"memory": random.randint(4, 128), "vcpus": 1,} + ) + return batch_client.register_job_definition( jobDefinitionName=definition_name, type="container",