From ed56ffd48432ccb1415f34783474d5ada514d7ee Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 15 Nov 2023 21:23:41 -0100 Subject: [PATCH] ECS: Tagging is now supported for Tasks (#7029) --- moto/ecs/models.py | 36 +++++++------- moto/ecs/responses.py | 19 ++++---- tests/test_ecs/test_ecs_boto3.py | 36 -------------- tests/test_ecs/test_ecs_task_tags.py | 71 ++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 64 deletions(-) create mode 100644 tests/test_ecs/test_ecs_task_tags.py diff --git a/moto/ecs/models.py b/moto/ecs/models.py index cc10d2746..41f122be7 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -437,9 +437,10 @@ class Task(BaseObject, ManagedState): return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.cluster_name}/{self.id}" return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.id}" - @property - def response_object(self) -> Dict[str, Any]: # type: ignore[misc] + def response_object(self, include_tags: bool = True) -> Dict[str, Any]: # type: ignore response_object = self.gen_response_object() + if not include_tags: + response_object.pop("tags", None) response_object["taskArn"] = self.task_arn response_object["lastStatus"] = self.last_status return response_object @@ -1471,12 +1472,7 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks[cluster.name][task.task_arn] = task return tasks - def describe_tasks( - self, - cluster_str: str, - tasks: Optional[str], - include: Optional[List[str]] = None, - ) -> List[Task]: + def describe_tasks(self, cluster_str: str, tasks: Optional[str]) -> List[Task]: """ Only include=TAGS is currently supported. """ @@ -1495,22 +1491,18 @@ class EC2ContainerServiceBackend(BaseBackend): ): task.advance() response.append(task) - if "TAGS" in (include or []): - return response - for task in response: - task.tags = [] return response def list_tasks( self, - cluster_str: str, - container_instance: Optional[str], - family: str, - started_by: str, - service_name: str, - desiredStatus: str, - ) -> List[str]: + cluster_str: Optional[str] = None, + container_instance: Optional[str] = None, + family: Optional[str] = None, + started_by: Optional[str] = None, + service_name: Optional[str] = None, + desiredStatus: Optional[str] = None, + ) -> List[Task]: filtered_tasks = [] for tasks in self.tasks.values(): for task in tasks.values(): @@ -1554,7 +1546,7 @@ class EC2ContainerServiceBackend(BaseBackend): filter(lambda t: t.desired_status == desiredStatus, filtered_tasks) ) - return [t.task_arn for t in filtered_tasks] + return filtered_tasks def stop_task(self, cluster_str: str, task_str: str, reason: str) -> Task: cluster = self._get_cluster(cluster_str) @@ -2080,6 +2072,10 @@ class EC2ContainerServiceBackend(BaseBackend): return task_def elif parsed_arn["service"] == "capacity-provider": return self._get_provider(parsed_arn["id"]) + elif parsed_arn["service"] == "task": + for task in self.list_tasks(): + if task.task_arn == resource_arn: + return task raise NotImplementedError() def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]: diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index 402aba9d4..40fa7278f 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -193,16 +193,19 @@ class EC2ContainerServiceResponse(BaseResponse): network_configuration, ) return json.dumps( - {"tasks": [task.response_object for task in tasks], "failures": []} + {"tasks": [task.response_object() for task in tasks], "failures": []} ) def describe_tasks(self) -> str: cluster = self._get_param("cluster", "default") tasks = self._get_param("tasks") - include = self._get_param("include") - data = self.ecs_backend.describe_tasks(cluster, tasks, include) + include_tags = "TAGS" in self._get_param("include", []) + data = self.ecs_backend.describe_tasks(cluster, tasks) return json.dumps( - {"tasks": [task.response_object for task in data], "failures": []} + { + "tasks": [task.response_object(include_tags) for task in data], + "failures": [], + } ) def start_task(self) -> str: @@ -221,7 +224,7 @@ class EC2ContainerServiceResponse(BaseResponse): tags, ) return json.dumps( - {"tasks": [task.response_object for task in tasks], "failures": []} + {"tasks": [task.response_object() for task in tasks], "failures": []} ) def list_tasks(self) -> str: @@ -231,7 +234,7 @@ class EC2ContainerServiceResponse(BaseResponse): started_by = self._get_param("startedBy") service_name = self._get_param("serviceName") desiredStatus = self._get_param("desiredStatus") - task_arns = self.ecs_backend.list_tasks( + tasks = self.ecs_backend.list_tasks( cluster_str, container_instance, family, @@ -239,14 +242,14 @@ class EC2ContainerServiceResponse(BaseResponse): service_name, desiredStatus, ) - return json.dumps({"taskArns": task_arns}) + return json.dumps({"taskArns": [t.task_arn for t in tasks]}) def stop_task(self) -> str: cluster_str = self._get_param("cluster", "default") task = self._get_param("task") reason = self._get_param("reason") task = self.ecs_backend.stop_task(cluster_str, task, reason) - return json.dumps({"task": task.response_object}) + return json.dumps({"task": task.response_object()}) def create_service(self) -> str: cluster_str = self._get_param("cluster", "default") diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index e88d3dff2..7b76b0a6b 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -2481,42 +2481,6 @@ def test_describe_tasks_empty_tags(): assert len(response["tasks"]) == 1 -@mock_ec2 -@mock_ecs -def test_describe_tasks_include_tags(): - client = boto3.client("ecs", region_name=ECS_REGION) - test_cluster_name = "test_ecs_cluster" - setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) - - task_tags = [{"key": "Name", "value": "test_ecs_task"}] - tasks_arns = [ - task["taskArn"] - for task in client.run_task( - cluster="test_ecs_cluster", - overrides={}, - taskDefinition="test_ecs_task", - count=2, - startedBy="moto", - tags=task_tags, - )["tasks"] - ] - response = client.describe_tasks( - cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"] - ) - - assert len(response["tasks"]) == 2 - assert set( - [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] - ) == set(tasks_arns) - assert response["tasks"][0]["tags"] == task_tags - - # Test we can pass task ids instead of ARNs - response = client.describe_tasks( - cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] - ) - assert len(response["tasks"]) == 1 - - @mock_ecs def test_describe_tasks_exceptions(): client = boto3.client("ecs", region_name=ECS_REGION) diff --git a/tests/test_ecs/test_ecs_task_tags.py b/tests/test_ecs/test_ecs_task_tags.py new file mode 100644 index 000000000..9b8d08c65 --- /dev/null +++ b/tests/test_ecs/test_ecs_task_tags.py @@ -0,0 +1,71 @@ +import boto3 + +from moto import mock_ec2, mock_ecs +from .test_ecs_boto3 import setup_ecs_cluster_with_ec2_instance + + +@mock_ec2 +@mock_ecs +def test_describe_tasks_include_tags(): + client = boto3.client("ecs", region_name="us-east-1") + test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) + + task_tags = [{"key": "Name", "value": "test_ecs_task"}] + tasks_arns = [ + task["taskArn"] + for task in client.run_task( + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", + count=2, + tags=task_tags, + )["tasks"] + ] + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"] + ) + + assert len(response["tasks"]) == 2 + assert set( + [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] + ) == set(tasks_arns) + assert response["tasks"][0]["tags"] == task_tags + + # Test we can pass task ids instead of ARNs + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] + ) + assert len(response["tasks"]) == 1 + + tags = client.list_tags_for_resource(resourceArn=tasks_arns[0])["tags"] + assert tags == task_tags + + +@mock_ec2 +@mock_ecs +def test_add_tags_to_task(): + client = boto3.client("ecs", region_name="us-east-1") + test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) + + task_tags = [{"key": "k1", "value": "v1"}] + task_arn = client.run_task( + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", + count=1, + tags=task_tags, + )["tasks"][0]["taskArn"] + + client.tag_resource(resourceArn=task_arn, tags=[{"key": "k2", "value": "v2"}]) + + tags = client.describe_tasks( + cluster="test_ecs_cluster", tasks=[task_arn], include=["TAGS"] + )["tasks"][0]["tags"] + assert len(tags) == 2 + assert {"key": "k1", "value": "v1"} in tags + assert {"key": "k2", "value": "v2"} in tags + + client.untag_resource(resourceArn=task_arn, tagKeys=["k2"]) + + resp = client.list_tags_for_resource(resourceArn=task_arn) + assert resp["tags"] == [{"key": "k1", "value": "v1"}]