diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 850516ae4..863275ab4 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -129,6 +129,8 @@ class TaskDefinition(BaseObject, CloudFormationModel): requires_compatibilities=None, cpu=None, memory=None, + task_role_arn=None, + execution_role_arn=None, ): self.family = family self.revision = revision @@ -169,6 +171,11 @@ class TaskDefinition(BaseObject, CloudFormationModel): else: self.network_mode = network_mode + if task_role_arn is not None: + self.task_role_arn = task_role_arn + if execution_role_arn is not None: + self.execution_role_arn = execution_role_arn + self.placement_constraints = ( placement_constraints if placement_constraints is not None else [] ) @@ -737,6 +744,8 @@ class EC2ContainerServiceBackend(BaseBackend): requires_compatibilities=None, cpu=None, memory=None, + task_role_arn=None, + execution_role_arn=None, ): if family in self.task_definitions: last_id = self._get_last_task_definition_revision_id(family) @@ -756,6 +765,8 @@ class EC2ContainerServiceBackend(BaseBackend): requires_compatibilities=requires_compatibilities, cpu=cpu, memory=memory, + task_role_arn=task_role_arn, + execution_role_arn=execution_role_arn, ) self.task_definitions[family][revision] = task_definition diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index 03b8b2618..3fed25ecd 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -67,6 +67,9 @@ class EC2ContainerServiceResponse(BaseResponse): requires_compatibilities = self._get_param("requiresCompatibilities") cpu = self._get_param("cpu") memory = self._get_param("memory") + task_role_arn = self._get_param("taskRoleArn") + execution_role_arn = self._get_param("executionRoleArn") + task_definition = self.ecs_backend.register_task_definition( family, container_definitions, @@ -77,6 +80,8 @@ class EC2ContainerServiceResponse(BaseResponse): requires_compatibilities=requires_compatibilities, cpu=cpu, memory=memory, + task_role_arn=task_role_arn, + execution_role_arn=execution_role_arn, ) return json.dumps({"taskDefinition": task_definition.response_object}) diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index 47a5b00bd..9d130ba83 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -151,10 +151,16 @@ def test_register_task_definition(): # Registering with optional top-level params definition["requiresCompatibilities"] = ["FARGATE"] + definition["taskRoleArn"] = "my-custom-task-role-arn" + definition["executionRoleArn"] = "my-custom-execution-role-arn" response = client.register_task_definition(**definition) response["taskDefinition"]["requiresCompatibilities"].should.equal(["FARGATE"]) response["taskDefinition"]["compatibilities"].should.equal(["EC2", "FARGATE"]) response["taskDefinition"]["networkMode"].should.equal("awsvpc") + response["taskDefinition"]["taskRoleArn"].should.equal("my-custom-task-role-arn") + response["taskDefinition"]["executionRoleArn"].should.equal( + "my-custom-execution-role-arn" + ) definition["requiresCompatibilities"] = ["EC2", "FARGATE"] response = client.register_task_definition(**definition) @@ -333,6 +339,8 @@ def test_describe_task_definitions(): ) _ = client.register_task_definition( family="test_ecs_task", + taskRoleArn="my-task-role-arn", + executionRoleArn="my-execution-role-arn", containerDefinitions=[ { "name": "hello_world2", @@ -372,6 +380,8 @@ def test_describe_task_definitions(): response["taskDefinition"]["taskDefinitionArn"].should.equal( "arn:aws:ecs:us-east-1:{}:task-definition/test_ecs_task:2".format(ACCOUNT_ID) ) + response["taskDefinition"]["taskRoleArn"].should.equal("my-task-role-arn") + response["taskDefinition"]["executionRoleArn"].should.equal("my-execution-role-arn") response = client.describe_task_definition( taskDefinition="test_ecs_task:1", include=["TAGS"]