diff --git a/moto/batch/models.py b/moto/batch/models.py index 783c55fc2..4807fd2d5 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -16,7 +16,7 @@ from moto.ec2 import ec2_backends from moto.ecs import ecs_backends from moto.logs import logs_backends -from .exceptions import InvalidParameterValueException, ClientException +from .exceptions import InvalidParameterValueException, ClientException, ValidationError from .utils import ( make_arn_for_compute_env, make_arn_for_job_queue, @@ -28,6 +28,7 @@ from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES from moto.iam.exceptions import IAMNotFoundException from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.utilities.docker_utilities import DockerModel, parse_image_ref +from ..utilities.tagging_service import TaggingService logger = logging.getLogger(__name__) COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( @@ -187,6 +188,7 @@ class JobDefinition(CloudFormationModel): _type, container_properties, region_name, + tags={}, revision=0, retry_strategy=0, ): @@ -198,7 +200,7 @@ class JobDefinition(CloudFormationModel): self.container_properties = container_properties self.arn = None self.status = "ACTIVE" - + self.tagger = TaggingService() if parameters is None: parameters = {} self.parameters = parameters @@ -206,6 +208,17 @@ class JobDefinition(CloudFormationModel): self._validate() self._update_arn() + tags = self._format_tags(tags) + # Validate the tags before proceeding. + errmsg = self.tagger.validate_tags(tags or []) + if errmsg: + raise ValidationError(errmsg) + + self.tagger.tag_resource(self.arn, tags or []) + + def _format_tags(self, tags): + return [{"Key": k, "Value": v} for k, v in tags.items()] + def _update_arn(self): self.revision += 1 self.arn = make_arn_for_task_def( @@ -267,6 +280,7 @@ class JobDefinition(CloudFormationModel): "revision": self.revision, "status": self.status, "type": self.type, + "tags": self.tagger.get_tag_dict_for_resource(self.arn), } if self.container_properties is not None: result["containerProperties"] = self.container_properties @@ -294,15 +308,14 @@ class JobDefinition(CloudFormationModel): ): backend = batch_backends[region_name] properties = cloudformation_json["Properties"] - res = backend.register_job_definition( def_name=resource_name, parameters=lowercase_first_key(properties.get("Parameters", {})), _type="container", + tags=lowercase_first_key(properties.get("Tags", {})), retry_strategy=lowercase_first_key(properties["RetryStrategy"]), container_properties=lowercase_first_key(properties["ContainerProperties"]), ) - arn = res[1] return backend.get_job_definition_by_arn(arn) @@ -1209,7 +1222,7 @@ class BatchBackend(BaseBackend): del self._job_queues[job_queue.arn] def register_job_definition( - self, def_name, parameters, _type, retry_strategy, container_properties + self, def_name, parameters, _type, tags, retry_strategy, container_properties ): if def_name is None: raise ClientException("jobDefinitionName must be provided") @@ -1220,13 +1233,15 @@ class BatchBackend(BaseBackend): retry_strategy = retry_strategy["attempts"] except Exception: raise ClientException("retryStrategy is malformed") - if job_def is None: + if not tags: + tags = {} job_def = JobDefinition( def_name, parameters, _type, container_properties, + tags=tags, region_name=self.region_name, retry_strategy=retry_strategy, ) @@ -1275,6 +1290,8 @@ class BatchBackend(BaseBackend): # Got all the job defs were after, filter then by status if status is not None: return [job for job in jobs if job.status == status] + for job in jobs: + job.describe() return jobs def submit_job( diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 25afbc365..cbe48d6c9 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -177,14 +177,15 @@ class BatchResponse(BaseResponse): container_properties = self._get_param("containerProperties") def_name = self._get_param("jobDefinitionName") parameters = self._get_param("parameters") + tags = self._get_param("tags") retry_strategy = self._get_param("retryStrategy") _type = self._get_param("type") - try: name, arn, revision = self.batch_backend.register_job_definition( def_name=def_name, parameters=parameters, _type=_type, + tags=tags, retry_strategy=retry_strategy, container_properties=container_properties, ) diff --git a/tests/test_batch/test_batch_task_definition.py b/tests/test_batch/test_batch_task_definition.py index 07c2c4690..11944c534 100644 --- a/tests/test_batch/test_batch_task_definition.py +++ b/tests/test_batch/test_batch_task_definition.py @@ -23,6 +23,25 @@ def test_register_task_definition(): ) +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_register_task_definition_with_tags(): + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + _setup(ec2_client, iam_client) + + resp = register_job_def_with_tags(batch_client) + + resp.should.contain("jobDefinitionArn") + resp.should.contain("jobDefinitionName") + resp.should.contain("revision") + + assert resp["jobDefinitionArn"].endswith( + "{0}:{1}".format(resp["jobDefinitionName"], resp["revision"]) + ) + + @mock_ec2 @mock_ecs @mock_iam @@ -89,15 +108,22 @@ def test_describe_task_definition(): register_job_def(batch_client, definition_name="sleep10") register_job_def(batch_client, definition_name="sleep10") register_job_def(batch_client, definition_name="test1") + register_job_def_with_tags(batch_client, definition_name="tagged_def") resp = batch_client.describe_job_definitions(jobDefinitionName="sleep10") len(resp["jobDefinitions"]).should.equal(2) resp = batch_client.describe_job_definitions() - len(resp["jobDefinitions"]).should.equal(3) + len(resp["jobDefinitions"]).should.equal(4) resp = batch_client.describe_job_definitions(jobDefinitions=["sleep10", "test1"]) len(resp["jobDefinitions"]).should.equal(3) + resp["jobDefinitions"][0]["tags"].should.equal({}) + + resp = batch_client.describe_job_definitions(jobDefinitionName="tagged_def") + resp["jobDefinitions"][0]["tags"].should.equal( + {"foo": "123", "bar": "456",} + ) for job_definition in resp["jobDefinitions"]: job_definition["status"].should.equal("ACTIVE") @@ -114,3 +140,17 @@ def register_job_def(batch_client, definition_name="sleep10"): "command": ["sleep", "10"], }, ) + + +def register_job_def_with_tags(batch_client, definition_name="sleep10"): + return batch_client.register_job_definition( + jobDefinitionName=definition_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": random.randint(4, 128), + "command": ["sleep", "10"], + }, + tags={"foo": "123", "bar": "456",}, + )