Add tagging to batch job definitions (#4316)

This commit is contained in:
oakbramble 2021-09-21 18:12:18 +02:00 committed by GitHub
parent b95d8aaebc
commit 82158096d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 8 deletions

View File

@ -16,7 +16,7 @@ from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends
from moto.logs import logs_backends
from .exceptions import InvalidParameterValueException, ClientException
from .exceptions import InvalidParameterValueException, ClientException, ValidationError
from .utils import (
make_arn_for_compute_env,
make_arn_for_job_queue,
@ -28,6 +28,7 @@ from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES
from moto.iam.exceptions import IAMNotFoundException
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from moto.utilities.docker_utilities import DockerModel, parse_image_ref
from ..utilities.tagging_service import TaggingService
logger = logging.getLogger(__name__)
COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
@ -187,6 +188,7 @@ class JobDefinition(CloudFormationModel):
_type,
container_properties,
region_name,
tags={},
revision=0,
retry_strategy=0,
):
@ -198,7 +200,7 @@ class JobDefinition(CloudFormationModel):
self.container_properties = container_properties
self.arn = None
self.status = "ACTIVE"
self.tagger = TaggingService()
if parameters is None:
parameters = {}
self.parameters = parameters
@ -206,6 +208,17 @@ class JobDefinition(CloudFormationModel):
self._validate()
self._update_arn()
tags = self._format_tags(tags)
# Validate the tags before proceeding.
errmsg = self.tagger.validate_tags(tags or [])
if errmsg:
raise ValidationError(errmsg)
self.tagger.tag_resource(self.arn, tags or [])
def _format_tags(self, tags):
return [{"Key": k, "Value": v} for k, v in tags.items()]
def _update_arn(self):
self.revision += 1
self.arn = make_arn_for_task_def(
@ -267,6 +280,7 @@ class JobDefinition(CloudFormationModel):
"revision": self.revision,
"status": self.status,
"type": self.type,
"tags": self.tagger.get_tag_dict_for_resource(self.arn),
}
if self.container_properties is not None:
result["containerProperties"] = self.container_properties
@ -294,15 +308,14 @@ class JobDefinition(CloudFormationModel):
):
backend = batch_backends[region_name]
properties = cloudformation_json["Properties"]
res = backend.register_job_definition(
def_name=resource_name,
parameters=lowercase_first_key(properties.get("Parameters", {})),
_type="container",
tags=lowercase_first_key(properties.get("Tags", {})),
retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
container_properties=lowercase_first_key(properties["ContainerProperties"]),
)
arn = res[1]
return backend.get_job_definition_by_arn(arn)
@ -1209,7 +1222,7 @@ class BatchBackend(BaseBackend):
del self._job_queues[job_queue.arn]
def register_job_definition(
self, def_name, parameters, _type, retry_strategy, container_properties
self, def_name, parameters, _type, tags, retry_strategy, container_properties
):
if def_name is None:
raise ClientException("jobDefinitionName must be provided")
@ -1220,13 +1233,15 @@ class BatchBackend(BaseBackend):
retry_strategy = retry_strategy["attempts"]
except Exception:
raise ClientException("retryStrategy is malformed")
if job_def is None:
if not tags:
tags = {}
job_def = JobDefinition(
def_name,
parameters,
_type,
container_properties,
tags=tags,
region_name=self.region_name,
retry_strategy=retry_strategy,
)
@ -1275,6 +1290,8 @@ class BatchBackend(BaseBackend):
# Got all the job defs were after, filter then by status
if status is not None:
return [job for job in jobs if job.status == status]
for job in jobs:
job.describe()
return jobs
def submit_job(

View File

@ -177,14 +177,15 @@ class BatchResponse(BaseResponse):
container_properties = self._get_param("containerProperties")
def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters")
tags = self._get_param("tags")
retry_strategy = self._get_param("retryStrategy")
_type = self._get_param("type")
try:
name, arn, revision = self.batch_backend.register_job_definition(
def_name=def_name,
parameters=parameters,
_type=_type,
tags=tags,
retry_strategy=retry_strategy,
container_properties=container_properties,
)

View File

@ -23,6 +23,25 @@ def test_register_task_definition():
)
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_register_task_definition_with_tags():
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
_setup(ec2_client, iam_client)
resp = register_job_def_with_tags(batch_client)
resp.should.contain("jobDefinitionArn")
resp.should.contain("jobDefinitionName")
resp.should.contain("revision")
assert resp["jobDefinitionArn"].endswith(
"{0}:{1}".format(resp["jobDefinitionName"], resp["revision"])
)
@mock_ec2
@mock_ecs
@mock_iam
@ -89,15 +108,22 @@ def test_describe_task_definition():
register_job_def(batch_client, definition_name="sleep10")
register_job_def(batch_client, definition_name="sleep10")
register_job_def(batch_client, definition_name="test1")
register_job_def_with_tags(batch_client, definition_name="tagged_def")
resp = batch_client.describe_job_definitions(jobDefinitionName="sleep10")
len(resp["jobDefinitions"]).should.equal(2)
resp = batch_client.describe_job_definitions()
len(resp["jobDefinitions"]).should.equal(3)
len(resp["jobDefinitions"]).should.equal(4)
resp = batch_client.describe_job_definitions(jobDefinitions=["sleep10", "test1"])
len(resp["jobDefinitions"]).should.equal(3)
resp["jobDefinitions"][0]["tags"].should.equal({})
resp = batch_client.describe_job_definitions(jobDefinitionName="tagged_def")
resp["jobDefinitions"][0]["tags"].should.equal(
{"foo": "123", "bar": "456",}
)
for job_definition in resp["jobDefinitions"]:
job_definition["status"].should.equal("ACTIVE")
@ -114,3 +140,17 @@ def register_job_def(batch_client, definition_name="sleep10"):
"command": ["sleep", "10"],
},
)
def register_job_def_with_tags(batch_client, definition_name="sleep10"):
return batch_client.register_job_definition(
jobDefinitionName=definition_name,
type="container",
containerProperties={
"image": "busybox",
"vcpus": 1,
"memory": random.randint(4, 128),
"command": ["sleep", "10"],
},
tags={"foo": "123", "bar": "456",},
)