Implement state transitions for ECS tasks (#6857)

This commit is contained in:
Edgar Ramírez Mondragón 2023-09-27 08:56:34 -06:00 committed by GitHub
parent fea098310a
commit 39b9c2f121
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 185 additions and 2 deletions

View File

@ -134,3 +134,16 @@ Transition type: Manual - describe the resource 1 time before the state advances
Advancement:
Call `boto3.client("transcribe").get_medical_transcription_job(..)`
Service: ECS
--------------
**Model**: `ecs::task` :raw-html:`<br />`
Available states:
"RUNNING" --> "DEACTIVATING" --> "STOPPING" --> "DEPROVISIONING" --> "STOPPED"
Transition type: Manual - describe the resource 1 time before the state advances :raw-html:`<br />`
Advancement:
Call `boto3.client("ecs").describe_tasks(..)`

View File

@ -11,7 +11,10 @@ from moto.core.utils import unix_time, pascal_to_camelcase, remap_nested_keys
from ..ec2.utils import random_private_ip
from moto.ec2 import ec2_backends
from moto.moto_api import state_manager
from moto.moto_api._internal import mock_random
from moto.moto_api._internal.managed_state_model import ManagedState
from .exceptions import (
EcsClientException,
ServiceNotFoundException,
@ -334,7 +337,7 @@ class TaskDefinition(BaseObject, CloudFormationModel):
return original_resource
class Task(BaseObject):
class Task(BaseObject, ManagedState):
def __init__(
self,
cluster: Cluster,
@ -348,11 +351,28 @@ class Task(BaseObject):
tags: Optional[List[Dict[str, str]]] = None,
networking_configuration: Optional[Dict[str, Any]] = None,
):
# Configure ManagedState
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-lifecycle.html
super().__init__(
model_name="ecs::task",
transitions=[
# We start in RUNNING state in order not to break existing tests.
# ("PROVISIONING", "PENDING"),
# ("PENDING", "ACTIVATING"),
# ("ACTIVATING", "RUNNING"),
("RUNNING", "DEACTIVATING"),
("DEACTIVATING", "STOPPING"),
("STOPPING", "DEPROVISIONING"),
("DEPROVISIONING", "STOPPED"),
# There seems to be race condition, where the waiter expects the task to be in
# STOPPED state, but it is already in DELETED state.
# ("STOPPED", "DELETED"),
],
)
self.id = str(mock_random.uuid4())
self.cluster_name = cluster.name
self.cluster_arn = cluster.arn
self.container_instance_arn = container_instance_arn
self.last_status = "RUNNING"
self.desired_status = "RUNNING"
self.task_definition_arn = task_definition.arn
self.overrides = overrides or {}
@ -401,6 +421,14 @@ class Task(BaseObject):
}
)
@property
def last_status(self) -> Optional[str]:
return self.status # managed state
@last_status.setter
def last_status(self, value: Optional[str]) -> None:
self.status = value
@property
def task_arn(self) -> str:
if self._backend.enable_long_arn_for_name(name="taskLongArnFormat"):
@ -411,6 +439,7 @@ class Task(BaseObject):
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["taskArn"] = self.task_arn
response_object["lastStatus"] = self.last_status
return response_object
@ -929,6 +958,11 @@ class EC2ContainerServiceBackend(BaseBackend):
self.services: Dict[str, Service] = {}
self.container_instances: Dict[str, Dict[str, ContainerInstance]] = {}
state_manager.register_default_transition(
model_name="ecs::task",
transition={"progression": "manual", "times": 1},
)
@staticmethod
def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service."""
@ -1438,6 +1472,7 @@ class EC2ContainerServiceBackend(BaseBackend):
or task.task_arn in tasks
or any(task_id in task for task in tasks)
):
task.advance()
response.append(task)
if "TAGS" in (include or []):
return response

View File

@ -11,6 +11,7 @@ from unittest import mock, SkipTest
from uuid import UUID
from moto import mock_ecs, mock_ec2, settings
from moto.moto_api import state_manager
from tests import EXAMPLE_AMI_ID
@ -1845,6 +1846,140 @@ def test_run_task():
assert task["tags"][0].get("value") == "tagValue0"
@mock_ec2
@mock_ecs
def test_wait_tasks_stopped():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't set transition directly in ServerMode")
state_manager.set_transition(
model_name="ecs::task",
transition={"progression": "immediate"},
)
client = boto3.client("ecs", region_name="us-east-1")
ec2 = boto3.resource("ec2", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
client.create_cluster(clusterName=test_cluster_name)
test_instance = ec2.create_instances(
ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1
)[0]
instance_id_document = json.dumps(
ec2_utils.generate_instance_identity_document(test_instance)
)
response = client.register_container_instance(
cluster=test_cluster_name,
instanceIdentityDocument=instance_id_document,
)
client.register_task_definition(
family="test_ecs_task",
containerDefinitions=[
{
"name": "hello_world",
"image": "docker/hello-world:latest",
"cpu": 1024,
"memory": 400,
"essential": True,
"environment": [
{"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"}
],
"logConfiguration": {"logDriver": "json-file"},
}
],
)
response = client.run_task(
cluster="test_ecs_cluster",
overrides={},
taskDefinition="test_ecs_task",
startedBy="moto",
)
task_arn = response["tasks"][0]["taskArn"]
assert len(response["tasks"]) == 1
client.get_waiter("tasks_stopped").wait(
cluster="test_ecs_cluster",
tasks=[task_arn],
)
response = client.describe_tasks(cluster="test_ecs_cluster", tasks=[task_arn])
assert response["tasks"][0]["lastStatus"] == "STOPPED"
state_manager.unset_transition("ecs::task")
@mock_ec2
@mock_ecs
def test_task_state_transitions():
if settings.TEST_SERVER_MODE:
raise SkipTest("Can't set transition directly in ServerMode")
state_manager.set_transition(
model_name="ecs::task",
transition={"progression": "manual", "times": 1},
)
client = boto3.client("ecs", region_name="us-east-1")
ec2 = boto3.resource("ec2", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
client.create_cluster(clusterName=test_cluster_name)
test_instance = ec2.create_instances(
ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1
)[0]
instance_id_document = json.dumps(
ec2_utils.generate_instance_identity_document(test_instance)
)
response = client.register_container_instance(
cluster=test_cluster_name,
instanceIdentityDocument=instance_id_document,
)
client.register_task_definition(
family="test_ecs_task",
containerDefinitions=[
{
"name": "hello_world",
"image": "docker/hello-world:latest",
"cpu": 1024,
"memory": 400,
"essential": True,
"environment": [
{"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"}
],
"logConfiguration": {"logDriver": "json-file"},
}
],
)
response = client.run_task(
cluster="test_ecs_cluster",
overrides={},
taskDefinition="test_ecs_task",
startedBy="moto",
)
task_arn = response["tasks"][0]["taskArn"]
assert len(response["tasks"]) == 1
task_status = response["tasks"][0]["lastStatus"]
assert task_status == "RUNNING"
for status in ("DEACTIVATING", "STOPPING", "DEPROVISIONING", "STOPPED"):
response = client.describe_tasks(cluster="test_ecs_cluster", tasks=[task_arn])
assert response["tasks"][0]["lastStatus"] == status
state_manager.unset_transition("ecs::task")
@mock_ec2
@mock_ecs
def test_run_task_awsvpc_network():