ECS: Tagging is now supported for Tasks (#7029)

This commit is contained in:
Bert Blommers 2023-11-15 21:23:41 -01:00 committed by GitHub
parent d3efa2afb9
commit ed56ffd484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 64 deletions

View File

@ -437,9 +437,10 @@ class Task(BaseObject, ManagedState):
return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.cluster_name}/{self.id}"
return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.id}"
@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
def response_object(self, include_tags: bool = True) -> Dict[str, Any]: # type: ignore
response_object = self.gen_response_object()
if not include_tags:
response_object.pop("tags", None)
response_object["taskArn"] = self.task_arn
response_object["lastStatus"] = self.last_status
return response_object
@ -1471,12 +1472,7 @@ class EC2ContainerServiceBackend(BaseBackend):
self.tasks[cluster.name][task.task_arn] = task
return tasks
def describe_tasks(
self,
cluster_str: str,
tasks: Optional[str],
include: Optional[List[str]] = None,
) -> List[Task]:
def describe_tasks(self, cluster_str: str, tasks: Optional[str]) -> List[Task]:
"""
Only include=TAGS is currently supported.
"""
@ -1495,22 +1491,18 @@ class EC2ContainerServiceBackend(BaseBackend):
):
task.advance()
response.append(task)
if "TAGS" in (include or []):
return response
for task in response:
task.tags = []
return response
def list_tasks(
self,
cluster_str: str,
container_instance: Optional[str],
family: str,
started_by: str,
service_name: str,
desiredStatus: str,
) -> List[str]:
cluster_str: Optional[str] = None,
container_instance: Optional[str] = None,
family: Optional[str] = None,
started_by: Optional[str] = None,
service_name: Optional[str] = None,
desiredStatus: Optional[str] = None,
) -> List[Task]:
filtered_tasks = []
for tasks in self.tasks.values():
for task in tasks.values():
@ -1554,7 +1546,7 @@ class EC2ContainerServiceBackend(BaseBackend):
filter(lambda t: t.desired_status == desiredStatus, filtered_tasks)
)
return [t.task_arn for t in filtered_tasks]
return filtered_tasks
def stop_task(self, cluster_str: str, task_str: str, reason: str) -> Task:
cluster = self._get_cluster(cluster_str)
@ -2080,6 +2072,10 @@ class EC2ContainerServiceBackend(BaseBackend):
return task_def
elif parsed_arn["service"] == "capacity-provider":
return self._get_provider(parsed_arn["id"])
elif parsed_arn["service"] == "task":
for task in self.list_tasks():
if task.task_arn == resource_arn:
return task
raise NotImplementedError()
def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:

View File

@ -193,16 +193,19 @@ class EC2ContainerServiceResponse(BaseResponse):
network_configuration,
)
return json.dumps(
{"tasks": [task.response_object for task in tasks], "failures": []}
{"tasks": [task.response_object() for task in tasks], "failures": []}
)
def describe_tasks(self) -> str:
cluster = self._get_param("cluster", "default")
tasks = self._get_param("tasks")
include = self._get_param("include")
data = self.ecs_backend.describe_tasks(cluster, tasks, include)
include_tags = "TAGS" in self._get_param("include", [])
data = self.ecs_backend.describe_tasks(cluster, tasks)
return json.dumps(
{"tasks": [task.response_object for task in data], "failures": []}
{
"tasks": [task.response_object(include_tags) for task in data],
"failures": [],
}
)
def start_task(self) -> str:
@ -221,7 +224,7 @@ class EC2ContainerServiceResponse(BaseResponse):
tags,
)
return json.dumps(
{"tasks": [task.response_object for task in tasks], "failures": []}
{"tasks": [task.response_object() for task in tasks], "failures": []}
)
def list_tasks(self) -> str:
@ -231,7 +234,7 @@ class EC2ContainerServiceResponse(BaseResponse):
started_by = self._get_param("startedBy")
service_name = self._get_param("serviceName")
desiredStatus = self._get_param("desiredStatus")
task_arns = self.ecs_backend.list_tasks(
tasks = self.ecs_backend.list_tasks(
cluster_str,
container_instance,
family,
@ -239,14 +242,14 @@ class EC2ContainerServiceResponse(BaseResponse):
service_name,
desiredStatus,
)
return json.dumps({"taskArns": task_arns})
return json.dumps({"taskArns": [t.task_arn for t in tasks]})
def stop_task(self) -> str:
cluster_str = self._get_param("cluster", "default")
task = self._get_param("task")
reason = self._get_param("reason")
task = self.ecs_backend.stop_task(cluster_str, task, reason)
return json.dumps({"task": task.response_object})
return json.dumps({"task": task.response_object()})
def create_service(self) -> str:
cluster_str = self._get_param("cluster", "default")

View File

@ -2481,42 +2481,6 @@ def test_describe_tasks_empty_tags():
assert len(response["tasks"]) == 1
@mock_ec2
@mock_ecs
def test_describe_tasks_include_tags():
client = boto3.client("ecs", region_name=ECS_REGION)
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)
task_tags = [{"key": "Name", "value": "test_ecs_task"}]
tasks_arns = [
task["taskArn"]
for task in client.run_task(
cluster="test_ecs_cluster",
overrides={},
taskDefinition="test_ecs_task",
count=2,
startedBy="moto",
tags=task_tags,
)["tasks"]
]
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"]
)
assert len(response["tasks"]) == 2
assert set(
[response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]]
) == set(tasks_arns)
assert response["tasks"][0]["tags"] == task_tags
# Test we can pass task ids instead of ARNs
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]]
)
assert len(response["tasks"]) == 1
@mock_ecs
def test_describe_tasks_exceptions():
client = boto3.client("ecs", region_name=ECS_REGION)

View File

@ -0,0 +1,71 @@
import boto3
from moto import mock_ec2, mock_ecs
from .test_ecs_boto3 import setup_ecs_cluster_with_ec2_instance
@mock_ec2
@mock_ecs
def test_describe_tasks_include_tags():
client = boto3.client("ecs", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)
task_tags = [{"key": "Name", "value": "test_ecs_task"}]
tasks_arns = [
task["taskArn"]
for task in client.run_task(
cluster="test_ecs_cluster",
taskDefinition="test_ecs_task",
count=2,
tags=task_tags,
)["tasks"]
]
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"]
)
assert len(response["tasks"]) == 2
assert set(
[response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]]
) == set(tasks_arns)
assert response["tasks"][0]["tags"] == task_tags
# Test we can pass task ids instead of ARNs
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]]
)
assert len(response["tasks"]) == 1
tags = client.list_tags_for_resource(resourceArn=tasks_arns[0])["tags"]
assert tags == task_tags
@mock_ec2
@mock_ecs
def test_add_tags_to_task():
client = boto3.client("ecs", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)
task_tags = [{"key": "k1", "value": "v1"}]
task_arn = client.run_task(
cluster="test_ecs_cluster",
taskDefinition="test_ecs_task",
count=1,
tags=task_tags,
)["tasks"][0]["taskArn"]
client.tag_resource(resourceArn=task_arn, tags=[{"key": "k2", "value": "v2"}])
tags = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[task_arn], include=["TAGS"]
)["tasks"][0]["tags"]
assert len(tags) == 2
assert {"key": "k1", "value": "v1"} in tags
assert {"key": "k2", "value": "v2"} in tags
client.untag_resource(resourceArn=task_arn, tagKeys=["k2"])
resp = client.list_tags_for_resource(resourceArn=task_arn)
assert resp["tags"] == [{"key": "k1", "value": "v1"}]