ENH: Add resource_requirements to batch job definition (#4506)
This commit is contained in:
parent
e8700bd533
commit
f4e62f0dfd
@ -233,6 +233,41 @@ class JobDefinition(CloudFormationModel):
|
|||||||
DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region
|
DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_resource_requirement(self, req_type, default=None):
|
||||||
|
"""
|
||||||
|
Get resource requirement from container properties.
|
||||||
|
|
||||||
|
Resource requirements like "memory" and "vcpus" are now specified in
|
||||||
|
"resourceRequirements". This function retrieves a resource requirement
|
||||||
|
from either container_properties.resourceRequirements (preferred) or
|
||||||
|
directly from container_properties (deprecated).
|
||||||
|
|
||||||
|
:param req_type: The type of resource requirement to retrieve.
|
||||||
|
:type req_type: ["gpu", "memory", "vcpus"]
|
||||||
|
|
||||||
|
:param default: The default value to return if the resource requirement is not found.
|
||||||
|
:type default: any, default=None
|
||||||
|
|
||||||
|
:return: The value of the resource requirement, or None.
|
||||||
|
:rtype: any
|
||||||
|
"""
|
||||||
|
resource_reqs = self.container_properties.get("resourceRequirements", [])
|
||||||
|
|
||||||
|
# Filter the resource requirements by the specified type.
|
||||||
|
# Note that VCPUS are specified in resourceRequirements without the
|
||||||
|
# trailing "s", so we strip that off in the comparison below.
|
||||||
|
required_resource = list(
|
||||||
|
filter(
|
||||||
|
lambda req: req["type"].lower() == req_type.lower().rstrip("s"),
|
||||||
|
resource_reqs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if required_resource:
|
||||||
|
return required_resource[0]["value"]
|
||||||
|
else:
|
||||||
|
return self.container_properties.get(req_type, default)
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
if self.type not in ("container",):
|
if self.type not in ("container",):
|
||||||
raise ClientException('type must be one of "container"')
|
raise ClientException('type must be one of "container"')
|
||||||
@ -247,14 +282,16 @@ class JobDefinition(CloudFormationModel):
|
|||||||
if "image" not in self.container_properties:
|
if "image" not in self.container_properties:
|
||||||
raise ClientException("containerProperties must contain image")
|
raise ClientException("containerProperties must contain image")
|
||||||
|
|
||||||
if "memory" not in self.container_properties:
|
memory = self._get_resource_requirement("memory")
|
||||||
|
if memory is None:
|
||||||
raise ClientException("containerProperties must contain memory")
|
raise ClientException("containerProperties must contain memory")
|
||||||
if self.container_properties["memory"] < 4:
|
if memory < 4:
|
||||||
raise ClientException("container memory limit must be greater than 4")
|
raise ClientException("container memory limit must be greater than 4")
|
||||||
|
|
||||||
if "vcpus" not in self.container_properties:
|
vcpus = self._get_resource_requirement("vcpus")
|
||||||
|
if vcpus is None:
|
||||||
raise ClientException("containerProperties must contain vcpus")
|
raise ClientException("containerProperties must contain vcpus")
|
||||||
if self.container_properties["vcpus"] < 1:
|
if vcpus < 1:
|
||||||
raise ClientException("container vcpus limit must be greater than 0")
|
raise ClientException("container vcpus limit must be greater than 0")
|
||||||
|
|
||||||
def update(self, parameters, _type, container_properties, retry_strategy):
|
def update(self, parameters, _type, container_properties, retry_strategy):
|
||||||
@ -426,6 +463,11 @@ class Job(threading.Thread, BaseModel, DockerModel):
|
|||||||
|
|
||||||
return job_env
|
return job_env
|
||||||
|
|
||||||
|
if p in ["vcpus", "memory"]:
|
||||||
|
return self.container_overrides.get(
|
||||||
|
p, self.job_definition._get_resource_requirement(p, default)
|
||||||
|
)
|
||||||
|
|
||||||
return self.container_overrides.get(
|
return self.container_overrides.get(
|
||||||
p, self.job_definition.container_properties.get(p, default)
|
p, self.job_definition.container_properties.get(p, default)
|
||||||
)
|
)
|
||||||
|
@ -227,8 +227,10 @@ def test_create_job_def_cf():
|
|||||||
],
|
],
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"Vcpus": 2,
|
"ResourceRequirements": [
|
||||||
"Memory": 2000,
|
{"Type": "MEMORY", "Value": 2000},
|
||||||
|
{"Type": "VCPU", "Value": 2},
|
||||||
|
],
|
||||||
"Command": ["echo", "Hello world"],
|
"Command": ["echo", "Hello world"],
|
||||||
"LinuxParameters": {"Devices": [{"HostPath": "test-path"}]},
|
"LinuxParameters": {"Devices": [{"HostPath": "test-path"}]},
|
||||||
},
|
},
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from . import _get_clients, _setup
|
from . import _get_clients, _setup
|
||||||
import random
|
import random
|
||||||
|
import pytest
|
||||||
import sure # noqa # pylint: disable=unused-import
|
import sure # noqa # pylint: disable=unused-import
|
||||||
from moto import mock_batch, mock_iam, mock_ec2, mock_ecs
|
from moto import mock_batch, mock_iam, mock_ec2, mock_ecs
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@ -9,11 +10,12 @@ from uuid import uuid4
|
|||||||
@mock_ecs
|
@mock_ecs
|
||||||
@mock_iam
|
@mock_iam
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_register_task_definition():
|
@pytest.mark.parametrize("use_resource_reqs", [True, False])
|
||||||
|
def test_register_task_definition(use_resource_reqs):
|
||||||
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
||||||
_setup(ec2_client, iam_client)
|
_setup(ec2_client, iam_client)
|
||||||
|
|
||||||
resp = register_job_def(batch_client)
|
resp = register_job_def(batch_client, use_resource_reqs=use_resource_reqs)
|
||||||
|
|
||||||
resp.should.contain("jobDefinitionArn")
|
resp.should.contain("jobDefinitionArn")
|
||||||
resp.should.contain("jobDefinitionName")
|
resp.should.contain("jobDefinitionName")
|
||||||
@ -47,13 +49,16 @@ def test_register_task_definition_with_tags():
|
|||||||
@mock_ecs
|
@mock_ecs
|
||||||
@mock_iam
|
@mock_iam
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_reregister_task_definition():
|
@pytest.mark.parametrize("use_resource_reqs", [True, False])
|
||||||
|
def test_reregister_task_definition(use_resource_reqs):
|
||||||
# Reregistering task with the same name bumps the revision number
|
# Reregistering task with the same name bumps the revision number
|
||||||
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
||||||
_setup(ec2_client, iam_client)
|
_setup(ec2_client, iam_client)
|
||||||
|
|
||||||
job_def_name = str(uuid4())[0:6]
|
job_def_name = str(uuid4())[0:6]
|
||||||
resp1 = register_job_def(batch_client, definition_name=job_def_name)
|
resp1 = register_job_def(
|
||||||
|
batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
|
|
||||||
resp1.should.contain("jobDefinitionArn")
|
resp1.should.contain("jobDefinitionArn")
|
||||||
resp1.should.have.key("jobDefinitionName").equals(job_def_name)
|
resp1.should.have.key("jobDefinitionName").equals(job_def_name)
|
||||||
@ -64,18 +69,24 @@ def test_reregister_task_definition():
|
|||||||
)
|
)
|
||||||
resp1["revision"].should.equal(1)
|
resp1["revision"].should.equal(1)
|
||||||
|
|
||||||
resp2 = register_job_def(batch_client, definition_name=job_def_name)
|
resp2 = register_job_def(
|
||||||
|
batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
resp2["revision"].should.equal(2)
|
resp2["revision"].should.equal(2)
|
||||||
|
|
||||||
resp2["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
resp2["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
||||||
|
|
||||||
resp3 = register_job_def(batch_client, definition_name=job_def_name)
|
resp3 = register_job_def(
|
||||||
|
batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
resp3["revision"].should.equal(3)
|
resp3["revision"].should.equal(3)
|
||||||
|
|
||||||
resp3["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
resp3["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
||||||
resp3["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"])
|
resp3["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"])
|
||||||
|
|
||||||
resp4 = register_job_def(batch_client, definition_name=job_def_name)
|
resp4 = register_job_def(
|
||||||
|
batch_client, definition_name=job_def_name, use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
resp4["revision"].should.equal(4)
|
resp4["revision"].should.equal(4)
|
||||||
|
|
||||||
resp4["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
resp4["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"])
|
||||||
@ -87,11 +98,14 @@ def test_reregister_task_definition():
|
|||||||
@mock_ecs
|
@mock_ecs
|
||||||
@mock_iam
|
@mock_iam
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_delete_task_definition():
|
@pytest.mark.parametrize("use_resource_reqs", [True, False])
|
||||||
|
def test_delete_task_definition(use_resource_reqs):
|
||||||
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
||||||
_setup(ec2_client, iam_client)
|
_setup(ec2_client, iam_client)
|
||||||
|
|
||||||
resp = register_job_def(batch_client, definition_name=str(uuid4()))
|
resp = register_job_def(
|
||||||
|
batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
name = resp["jobDefinitionName"]
|
name = resp["jobDefinitionName"]
|
||||||
|
|
||||||
batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"])
|
batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"])
|
||||||
@ -104,11 +118,14 @@ def test_delete_task_definition():
|
|||||||
@mock_ecs
|
@mock_ecs
|
||||||
@mock_iam
|
@mock_iam
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_delete_task_definition_by_name():
|
@pytest.mark.parametrize("use_resource_reqs", [True, False])
|
||||||
|
def test_delete_task_definition_by_name(use_resource_reqs):
|
||||||
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
||||||
_setup(ec2_client, iam_client)
|
_setup(ec2_client, iam_client)
|
||||||
|
|
||||||
resp = register_job_def(batch_client, definition_name=str(uuid4()))
|
resp = register_job_def(
|
||||||
|
batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
name = resp["jobDefinitionName"]
|
name = resp["jobDefinitionName"]
|
||||||
|
|
||||||
batch_client.deregister_job_definition(jobDefinition=f"{name}:{resp['revision']}")
|
batch_client.deregister_job_definition(jobDefinition=f"{name}:{resp['revision']}")
|
||||||
@ -121,16 +138,27 @@ def test_delete_task_definition_by_name():
|
|||||||
@mock_ecs
|
@mock_ecs
|
||||||
@mock_iam
|
@mock_iam
|
||||||
@mock_batch
|
@mock_batch
|
||||||
def test_describe_task_definition():
|
@pytest.mark.parametrize("use_resource_reqs", [True, False])
|
||||||
|
def test_describe_task_definition(use_resource_reqs):
|
||||||
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
ec2_client, iam_client, _, _, batch_client = _get_clients()
|
||||||
_setup(ec2_client, iam_client)
|
_setup(ec2_client, iam_client)
|
||||||
|
|
||||||
sleep_def_name = f"sleep10_{str(uuid4())[0:6]}"
|
sleep_def_name = f"sleep10_{str(uuid4())[0:6]}"
|
||||||
other_name = str(uuid4())[0:6]
|
other_name = str(uuid4())[0:6]
|
||||||
tagged_name = str(uuid4())[0:6]
|
tagged_name = str(uuid4())[0:6]
|
||||||
register_job_def(batch_client, definition_name=sleep_def_name)
|
register_job_def(
|
||||||
register_job_def(batch_client, definition_name=sleep_def_name)
|
batch_client,
|
||||||
register_job_def(batch_client, definition_name=other_name)
|
definition_name=sleep_def_name,
|
||||||
|
use_resource_reqs=use_resource_reqs,
|
||||||
|
)
|
||||||
|
register_job_def(
|
||||||
|
batch_client,
|
||||||
|
definition_name=sleep_def_name,
|
||||||
|
use_resource_reqs=use_resource_reqs,
|
||||||
|
)
|
||||||
|
register_job_def(
|
||||||
|
batch_client, definition_name=other_name, use_resource_reqs=use_resource_reqs
|
||||||
|
)
|
||||||
register_job_def_with_tags(batch_client, definition_name=tagged_name)
|
register_job_def_with_tags(batch_client, definition_name=tagged_name)
|
||||||
|
|
||||||
resp = batch_client.describe_job_definitions(jobDefinitionName=sleep_def_name)
|
resp = batch_client.describe_job_definitions(jobDefinitionName=sleep_def_name)
|
||||||
@ -157,7 +185,26 @@ def test_describe_task_definition():
|
|||||||
job_definition["status"].should.equal("ACTIVE")
|
job_definition["status"].should.equal("ACTIVE")
|
||||||
|
|
||||||
|
|
||||||
def register_job_def(batch_client, definition_name="sleep10"):
|
def register_job_def(batch_client, definition_name="sleep10", use_resource_reqs=True):
|
||||||
|
container_properties = {
|
||||||
|
"image": "busybox",
|
||||||
|
"command": ["sleep", "10"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_resource_reqs:
|
||||||
|
container_properties.update(
|
||||||
|
{
|
||||||
|
"resourceRequirements": [
|
||||||
|
{"value": "1", "type": "VCPU"},
|
||||||
|
{"value": str(random.randint(4, 128)), "type": "MEMORY"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
container_properties.update(
|
||||||
|
{"memory": random.randint(4, 128), "vcpus": 1,}
|
||||||
|
)
|
||||||
|
|
||||||
return batch_client.register_job_definition(
|
return batch_client.register_job_definition(
|
||||||
jobDefinitionName=definition_name,
|
jobDefinitionName=definition_name,
|
||||||
type="container",
|
type="container",
|
||||||
|
Loading…
Reference in New Issue
Block a user