From cf0bcbce919310362f579046a2da2d543b616997 Mon Sep 17 00:00:00 2001 From: shanishiri Date: Wed, 11 Jan 2023 21:28:07 +0200 Subject: [PATCH] Support tags in ECS start_task and describe_tasks (#5817) --- moto/ecs/models.py | 12 +- moto/ecs/responses.py | 11 +- tests/test_ecs/test_ecs_boto3.py | 217 ++++++++++++++++++++++--------- 3 files changed, 172 insertions(+), 68 deletions(-) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index c94887ad4..7374ce403 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -1143,6 +1143,7 @@ class EC2ContainerServiceBackend(BaseBackend): container_instances, overrides, started_by, + tags=None, ): cluster = self._get_cluster(cluster_str) @@ -1169,6 +1170,7 @@ class EC2ContainerServiceBackend(BaseBackend): backend=self, overrides=overrides or {}, started_by=started_by or "", + tags=tags, ) tasks.append(task) self.update_container_instance_resources( @@ -1177,7 +1179,10 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks[cluster.name][task.task_arn] = task return tasks - def describe_tasks(self, cluster_str, tasks): + def describe_tasks(self, cluster_str, tasks, include=None): + """ + Only include=TAGS is currently supported. + """ self._get_cluster(cluster_str) if not tasks: @@ -1192,6 +1197,11 @@ class EC2ContainerServiceBackend(BaseBackend): or any(task_id in task for task in tasks) ): response.append(task) + if "TAGS" in (include or []): + return response + + for task in response: + task.tags = [] return response def list_tasks( diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index 40c1f86e3..5a4c38988 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -163,7 +163,8 @@ class EC2ContainerServiceResponse(BaseResponse): def describe_tasks(self): cluster = self._get_param("cluster", "default") tasks = self._get_param("tasks") - data = self.ecs_backend.describe_tasks(cluster, tasks) + include = self._get_param("include") + data = self.ecs_backend.describe_tasks(cluster, tasks, include) return json.dumps( {"tasks": [task.response_object for task in data], "failures": []} ) @@ -174,8 +175,14 @@ class EC2ContainerServiceResponse(BaseResponse): task_definition_str = self._get_param("taskDefinition") container_instances = self._get_param("containerInstances") started_by = self._get_param("startedBy") + tags = self._get_param("tags") tasks = self.ecs_backend.start_task( - cluster_str, task_definition_str, container_instances, overrides, started_by + cluster_str, + task_definition_str, + container_instances, + overrides, + started_by, + tags, ) return json.dumps( {"tasks": [task.response_object for task in tasks], "failures": []} diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index dc7127112..161212b01 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -1889,46 +1889,14 @@ def test_run_task_exceptions(): @mock_ecs def test_start_task(): client = boto3.client("ecs", region_name="us-east-1") - ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = "test_ecs_cluster" - - _ = client.create_cluster(clusterName=test_cluster_name) - - test_instance = ec2.create_instances( - ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1 - )[0] - - instance_id_document = json.dumps( - ec2_utils.generate_instance_identity_document(test_instance) - ) - - client.register_container_instance( - cluster=test_cluster_name, instanceIdentityDocument=instance_id_document - ) + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) container_instances = client.list_container_instances(cluster=test_cluster_name) container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ -1 ] - _ = client.register_task_definition( - family="test_ecs_task", - containerDefinitions=[ - { - "name": "hello_world", - "image": "docker/hello-world:latest", - "cpu": 1024, - "memory": 400, - "essential": True, - "environment": [ - {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} - ], - "logConfiguration": {"logDriver": "json-file"}, - } - ], - ) - response = client.start_task( cluster="test_ecs_cluster", taskDefinition="test_ecs_task", @@ -1950,6 +1918,54 @@ def test_start_task(): response["tasks"][0]["containerInstanceArn"].should.equal( f"arn:aws:ecs:us-east-1:{ACCOUNT_ID}:container-instance/test_ecs_cluster/{container_instance_id}" ) + response["tasks"][0]["tags"].should.equal( + [], + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") + + +@mock_ec2 +@mock_ecs +def test_start_task_with_tags(): + client = boto3.client("ecs", region_name="us-east-1") + test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) + + container_instances = client.list_container_instances(cluster=test_cluster_name) + container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ + -1 + ] + + task_tags = [{"key": "Name", "value": "test_ecs_start_task"}] + response = client.start_task( + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", + overrides={}, + containerInstances=[container_instance_id], + startedBy="moto", + tags=task_tags, + ) + + len(response["tasks"]).should.equal(1) + response["tasks"][0]["taskArn"].should.contain( + f"arn:aws:ecs:us-east-1:{ACCOUNT_ID}:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + f"arn:aws:ecs:us-east-1:{ACCOUNT_ID}:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + f"arn:aws:ecs:us-east-1:{ACCOUNT_ID}:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.equal( + f"arn:aws:ecs:us-east-1:{ACCOUNT_ID}:container-instance/test_ecs_cluster/{container_instance_id}" + ) + response["tasks"][0]["tags"].should.equal( + task_tags, + ) response["tasks"][0]["overrides"].should.equal({}) response["tasks"][0]["lastStatus"].should.equal("RUNNING") response["tasks"][0]["desiredStatus"].should.equal("RUNNING") @@ -2056,40 +2072,9 @@ def test_list_tasks_exceptions(): @mock_ecs def test_describe_tasks(): client = boto3.client("ecs", region_name="us-east-1") - ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) - _ = client.create_cluster(clusterName=test_cluster_name) - - test_instance = ec2.create_instances( - ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1 - )[0] - - instance_id_document = json.dumps( - ec2_utils.generate_instance_identity_document(test_instance) - ) - - client.register_container_instance( - cluster=test_cluster_name, instanceIdentityDocument=instance_id_document - ) - - _ = client.register_task_definition( - family="test_ecs_task", - containerDefinitions=[ - { - "name": "hello_world", - "image": "docker/hello-world:latest", - "cpu": 1024, - "memory": 400, - "essential": True, - "environment": [ - {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} - ], - "logConfiguration": {"logDriver": "json-file"}, - } - ], - ) tasks_arns = [ task["taskArn"] for task in client.run_task( @@ -2114,6 +2099,76 @@ def test_describe_tasks(): len(response["tasks"]).should.equal(1) +@mock_ec2 +@mock_ecs +def test_describe_tasks_empty_tags(): + client = boto3.client("ecs", region_name="us-east-1") + test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) + + tasks_arns = [ + task["taskArn"] + for task in client.run_task( + cluster="test_ecs_cluster", + overrides={}, + taskDefinition="test_ecs_task", + count=2, + startedBy="moto", + )["tasks"] + ] + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"] + ) + + len(response["tasks"]).should.equal(2) + set( + [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] + ).should.equal(set(tasks_arns)) + response["tasks"][0]["tags"].should.equal([]) + + # Test we can pass task ids instead of ARNs + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] + ) + len(response["tasks"]).should.equal(1) + + +@mock_ec2 +@mock_ecs +def test_describe_tasks_include_tags(): + client = boto3.client("ecs", region_name="us-east-1") + test_cluster_name = "test_ecs_cluster" + setup_ecs_cluster_with_ec2_instance(client, test_cluster_name) + + task_tags = [{"key": "Name", "value": "test_ecs_task"}] + tasks_arns = [ + task["taskArn"] + for task in client.run_task( + cluster="test_ecs_cluster", + overrides={}, + taskDefinition="test_ecs_task", + count=2, + startedBy="moto", + tags=task_tags, + )["tasks"] + ] + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"] + ) + + len(response["tasks"]).should.equal(2) + set( + [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] + ).should.equal(set(tasks_arns)) + response["tasks"][0]["tags"].should.equal(task_tags) + + # Test we can pass task ids instead of ARNs + response = client.describe_tasks( + cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] + ) + len(response["tasks"]).should.equal(1) + + @mock_ecs def test_describe_tasks_exceptions(): client = boto3.client("ecs", region_name="us-east-1") @@ -3721,3 +3776,35 @@ def setup_ecs(client, ec2): ) return subnet, sg + + +def setup_ecs_cluster_with_ec2_instance(client, test_cluster_name): + ec2 = boto3.resource("ec2", region_name="us-east-1") + + _ = client.create_cluster(clusterName=test_cluster_name) + test_instance = ec2.create_instances( + ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1 + )[0] + instance_id_document = json.dumps( + ec2_utils.generate_instance_identity_document(test_instance) + ) + client.register_container_instance( + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) + + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + )