From efb13a17a8465e269cd914f099a941e693d12344 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 19 Feb 2023 10:20:28 -0100 Subject: [PATCH] Techdebt: MyPy ECS (#5944) --- moto/ecs/exceptions.py | 16 +- moto/ecs/models.py | 766 +++++++++++++++++++++++------------------ moto/ecs/responses.py | 122 +++---- moto/settings.py | 2 +- setup.cfg | 2 +- 5 files changed, 496 insertions(+), 412 deletions(-) diff --git a/moto/ecs/exceptions.py b/moto/ecs/exceptions.py index 4f01ba258..ce18fd40f 100644 --- a/moto/ecs/exceptions.py +++ b/moto/ecs/exceptions.py @@ -4,7 +4,7 @@ from moto.core.exceptions import RESTError, JsonRESTError class ServiceNotFoundException(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( error_type="ServiceNotFoundException", message="Service not found." ) @@ -13,7 +13,7 @@ class ServiceNotFoundException(RESTError): class TaskDefinitionNotFoundException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( error_type="ClientException", message="Unable to describe task definition.", @@ -23,14 +23,14 @@ class TaskDefinitionNotFoundException(JsonRESTError): class RevisionNotFoundException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__(error_type="ClientException", message="Revision is missing.") class TaskSetNotFoundException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( error_type="ClientException", message="The specified task set does not exist.", @@ -40,7 +40,7 @@ class TaskSetNotFoundException(JsonRESTError): class ClusterNotFoundException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( error_type="ClusterNotFoundException", message="Cluster not found." ) @@ -49,19 +49,19 @@ class ClusterNotFoundException(JsonRESTError): class EcsClientException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__(error_type="ClientException", message=message) class InvalidParameterException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__(error_type="InvalidParameterException", message=message) class UnknownAccountSettingException(InvalidParameterException): - def __init__(self): + def __init__(self) -> None: super().__init__( "unknown should be one of [serviceLongArnFormat,taskLongArnFormat,containerInstanceLongArnFormat,containerLongArnFormat,awsvpcTrunking,containerInsights,dualStackIPv6]" ) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 6ea027735..1ad967e36 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -1,7 +1,7 @@ import re from copy import copy from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple from moto import settings from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel @@ -24,7 +24,7 @@ from .exceptions import ( class BaseObject(BaseModel): - def camelCase(self, key): + def camelCase(self, key: str) -> str: words = [] for i, word in enumerate(key.split("_")): if i > 0: @@ -33,7 +33,7 @@ class BaseObject(BaseModel): words.append(word) return "".join(words) - def gen_response_object(self): + def gen_response_object(self) -> Dict[str, Any]: response_object = copy(self.__dict__) for key, value in self.__dict__.items(): if key.startswith("_"): @@ -44,12 +44,12 @@ class BaseObject(BaseModel): return response_object @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] return self.gen_response_object() class AccountSetting(BaseObject): - def __init__(self, name, value): + def __init__(self, name: str, value: str): self.name = name self.value = value @@ -57,14 +57,14 @@ class AccountSetting(BaseObject): class Cluster(BaseObject, CloudFormationModel): def __init__( self, - cluster_name, - account_id, - region_name, - cluster_settings=None, - configuration=None, - capacity_providers=None, - default_capacity_provider_strategy=None, - tags=None, + cluster_name: str, + account_id: str, + region_name: str, + cluster_settings: Optional[List[Dict[str, str]]] = None, + configuration: Optional[Dict[str, Any]] = None, + capacity_providers: Optional[List[str]] = None, + default_capacity_provider_strategy: Optional[List[Dict[str, Any]]] = None, + tags: Optional[List[Dict[str, str]]] = None, ): self.active_services_count = 0 self.arn = f"arn:aws:ecs:{region_name}:{account_id}:cluster/{cluster_name}" @@ -81,11 +81,11 @@ class Cluster(BaseObject, CloudFormationModel): self.tags = tags @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["clusterArn"] = self.arn response_object["clusterName"] = self.name @@ -97,18 +97,23 @@ class Cluster(BaseObject, CloudFormationModel): return response_object @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "ClusterName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ecs-cluster.html return "AWS::ECS::Cluster" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Cluster": ecs_backend = ecs_backends[account_id][region_name] return ecs_backend.create_cluster( # ClusterName is optional in CloudFormation, thus create a random @@ -117,14 +122,14 @@ class Cluster(BaseObject, CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Cluster": if original_resource.name != new_resource_name: ecs_backend = ecs_backends[account_id][region_name] ecs_backend.delete_cluster(original_resource.arn) @@ -138,10 +143,10 @@ class Cluster(BaseObject, CloudFormationModel): return original_resource @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -152,26 +157,26 @@ class Cluster(BaseObject, CloudFormationModel): class TaskDefinition(BaseObject, CloudFormationModel): def __init__( self, - family, - revision, - container_definitions, - account_id, - region_name, - network_mode=None, - volumes=None, - tags=None, - placement_constraints=None, - requires_compatibilities=None, - cpu=None, - memory=None, - task_role_arn=None, - execution_role_arn=None, - proxy_configuration=None, - inference_accelerators=None, - runtime_platform=None, - ipc_mode=None, - pid_mode=None, - ephemeral_storage=None, + family: str, + revision: int, + container_definitions: List[Dict[str, Any]], + account_id: str, + region_name: str, + network_mode: Optional[str] = None, + volumes: Optional[List[Dict[str, Any]]] = None, + tags: Optional[List[Dict[str, str]]] = None, + placement_constraints: Optional[List[Dict[str, str]]] = None, + requires_compatibilities: Optional[List[str]] = None, + cpu: Optional[str] = None, + memory: Optional[str] = None, + task_role_arn: Optional[str] = None, + execution_role_arn: Optional[str] = None, + proxy_configuration: Optional[Dict[str, Any]] = None, + inference_accelerators: Optional[List[Dict[str, str]]] = None, + runtime_platform: Optional[Dict[str, str]] = None, + ipc_mode: Optional[str] = None, + pid_mode: Optional[str] = None, + ephemeral_storage: Optional[Dict[str, int]] = None, ): self.family = family self.revision = revision @@ -210,9 +215,9 @@ class TaskDefinition(BaseObject, CloudFormationModel): self.compatibilities = ["EC2", "FARGATE"] if network_mode is None and "FARGATE" not in self.compatibilities: - self.network_mode = "bridge" + self.network_mode: Optional[str] = "bridge" elif "FARGATE" in self.compatibilities: - self.network_mode = "awsvpc" + self.network_mode: Optional[str] = "awsvpc" # type: ignore[no-redef] else: self.network_mode = network_mode @@ -238,7 +243,7 @@ class TaskDefinition(BaseObject, CloudFormationModel): self.status = "ACTIVE" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["taskDefinitionArn"] = response_object["arn"] del response_object["arn"] @@ -254,22 +259,27 @@ class TaskDefinition(BaseObject, CloudFormationModel): return response_object @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ecs-taskdefinition.html return "AWS::ECS::TaskDefinition" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "TaskDefinition": properties = cloudformation_json["Properties"] family = properties.get( @@ -286,14 +296,14 @@ class TaskDefinition(BaseObject, CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "TaskDefinition": properties = cloudformation_json["Properties"] family = properties.get( "Family", f"task-definition-{int(mock_random.random() * 10**6)}" @@ -322,16 +332,16 @@ class TaskDefinition(BaseObject, CloudFormationModel): class Task(BaseObject): def __init__( self, - cluster, - task_definition, - container_instance_arn, - resource_requirements, - backend, - launch_type="", - overrides=None, - started_by="", - tags=None, - networking_configuration=None, + cluster: Cluster, + task_definition: TaskDefinition, + container_instance_arn: Optional[str], + resource_requirements: Optional[Dict[str, str]], + backend: "EC2ContainerServiceBackend", + launch_type: str = "", + overrides: Optional[Dict[str, Any]] = None, + started_by: str = "", + tags: Optional[List[Dict[str, str]]] = None, + networking_configuration: Optional[Dict[str, Any]] = None, ): self.id = str(mock_random.uuid4()) self.cluster_name = cluster.name @@ -341,7 +351,7 @@ class Task(BaseObject): self.desired_status = "RUNNING" self.task_definition_arn = task_definition.arn self.overrides = overrides or {} - self.containers = [] + self.containers: List[Dict[str, Any]] = [] self.started_by = started_by self.tags = tags or [] self.launch_type = launch_type @@ -354,7 +364,6 @@ class Task(BaseObject): if task_definition.network_mode == "awsvpc": if not networking_configuration: - raise InvalidParameterException( "Network Configuration must be provided when networkMode 'awsvpc' is specified." ) @@ -388,20 +397,27 @@ class Task(BaseObject): ) @property - def task_arn(self): + def task_arn(self) -> str: if self._backend.enable_long_arn_for_name(name="taskLongArnFormat"): return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.cluster_name}/{self.id}" return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.id}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["taskArn"] = self.task_arn return response_object class CapacityProvider(BaseObject): - def __init__(self, account_id, region_name, name, asg_details, tags): + def __init__( + self, + account_id: str, + region_name: str, + name: str, + asg_details: Dict[str, Any], + tags: Optional[List[Dict[str, str]]], + ): self._id = str(mock_random.uuid4()) self.capacity_provider_arn = ( f"arn:aws:ecs:{region_name}:{account_id}:capacity-provider/{name}" @@ -411,9 +427,9 @@ class CapacityProvider(BaseObject): self.auto_scaling_group_provider = self._prepare_asg_provider(asg_details) self.tags = tags - self.update_status = None + self.update_status: Optional[str] = None - def _prepare_asg_provider(self, asg_details): + def _prepare_asg_provider(self, asg_details: Dict[str, Any]) -> Dict[str, Any]: if "managedScaling" not in asg_details: asg_details["managedScaling"] = {} if not asg_details["managedScaling"].get("instanceWarmupPeriod"): @@ -430,7 +446,7 @@ class CapacityProvider(BaseObject): asg_details["managedTerminationProtection"] = "DISABLED" return asg_details - def update(self, asg_details): + def update(self, asg_details: Dict[str, Any]) -> None: if "managedTerminationProtection" in asg_details: self.auto_scaling_group_provider[ "managedTerminationProtection" @@ -455,12 +471,12 @@ class CapacityProvider(BaseObject): class CapacityProviderFailure(BaseObject): - def __init__(self, reason, name, account_id, region_name): + def __init__(self, reason: str, name: str, account_id: str, region_name: str): self.reason = reason self.arn = f"arn:aws:ecs:{region_name}:{account_id}:capacity_provider/{name}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["reason"] = self.reason response_object["arn"] = self.arn @@ -470,32 +486,29 @@ class CapacityProviderFailure(BaseObject): class Service(BaseObject, CloudFormationModel): def __init__( self, - cluster, - service_name, - desired_count, - task_definition=None, - load_balancers=None, - scheduling_strategy=None, - tags=None, - deployment_controller=None, - launch_type=None, - backend=None, - service_registries=None, - platform_version=None, + cluster: Cluster, + service_name: str, + desired_count: int, + backend: "EC2ContainerServiceBackend", + task_definition: Optional[TaskDefinition] = None, + load_balancers: Optional[List[Dict[str, Any]]] = None, + scheduling_strategy: Optional[List[Dict[str, Any]]] = None, + tags: Optional[List[Dict[str, str]]] = None, + deployment_controller: Optional[Dict[str, str]] = None, + launch_type: Optional[str] = None, + service_registries: Optional[List[Dict[str, Any]]] = None, + platform_version: Optional[str] = None, ): self.cluster_name = cluster.name self.cluster_arn = cluster.arn self.name = service_name self.status = "ACTIVE" self.running_count = 0 - if task_definition: - self.task_definition = task_definition.arn - else: - self.task_definition = None + self.task_definition = task_definition.arn if task_definition else None self.desired_count = desired_count - self.task_sets = [] + self.task_sets: List[TaskSet] = [] self.deployment_controller = deployment_controller or {"type": "ECS"} - self.events = [] + self.events: List[Dict[str, Any]] = [] self.launch_type = launch_type self.service_registries = service_registries or [] if self.deployment_controller["type"] == "ECS": @@ -526,17 +539,17 @@ class Service(BaseObject, CloudFormationModel): self._backend = backend @property - def arn(self): + def arn(self) -> str: if self._backend.enable_long_arn_for_name(name="serviceLongArnFormat"): return f"arn:aws:ecs:{self.region_name}:{self._account_id}:service/{self.cluster_name}/{self.name}" return f"arn:aws:ecs:{self.region_name}:{self._account_id}:service/{self.name}" @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() del response_object["name"], response_object["tags"] response_object["serviceName"] = self.name @@ -564,18 +577,23 @@ class Service(BaseObject, CloudFormationModel): return response_object @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "ServiceName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ecs-service.html return "AWS::ECS::Service" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Service": properties = cloudformation_json["Properties"] if isinstance(properties["Cluster"], Cluster): cluster = properties["Cluster"].name @@ -595,14 +613,14 @@ class Service(BaseObject, CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Service": properties = cloudformation_json["Properties"] if isinstance(properties["Cluster"], Cluster): cluster_name = properties["Cluster"].name @@ -635,10 +653,10 @@ class Service(BaseObject, CloudFormationModel): ) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Name"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Name": @@ -647,11 +665,18 @@ class Service(BaseObject, CloudFormationModel): class ContainerInstance(BaseObject): - def __init__(self, ec2_instance_id, account_id, region_name, cluster_name, backend): + def __init__( + self, + ec2_instance_id: str, + account_id: str, + region_name: str, + cluster_name: str, + backend: "EC2ContainerServiceBackend", + ): self.ec2_instance_id = ec2_instance_id self.agent_connected = True self.status = "ACTIVE" - self.registered_resources = [ + self.registered_resources: List[Dict[str, Any]] = [ { "doubleValue": 0.0, "integerValue": 4096, @@ -684,7 +709,7 @@ class ContainerInstance(BaseObject): }, ] self.pending_tasks_count = 0 - self.remaining_resources = [ + self.remaining_resources: List[Dict[str, Any]] = [ { "doubleValue": 0.0, "integerValue": 4096, @@ -740,7 +765,7 @@ class ContainerInstance(BaseObject): self._backend = backend @property - def container_instance_arn(self): + def container_instance_arn(self) -> str: if self._backend.enable_long_arn_for_name( name="containerInstanceLongArnFormat" ): @@ -748,7 +773,7 @@ class ContainerInstance(BaseObject): return f"arn:aws:ecs:{self.region_name}:{self._account_id}:container-instance/{self.id}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["containerInstanceArn"] = self.container_instance_arn response_object["attributes"] = [ @@ -761,7 +786,7 @@ class ContainerInstance(BaseObject): ) return response_object - def _format_attribute(self, name, value): + def _format_attribute(self, name: str, value: Optional[str]) -> Dict[str, str]: formatted_attr = {"name": name} if value is not None: formatted_attr["value"] = value @@ -769,12 +794,14 @@ class ContainerInstance(BaseObject): class ClusterFailure(BaseObject): - def __init__(self, reason, cluster_name, account_id, region_name): + def __init__( + self, reason: str, cluster_name: str, account_id: str, region_name: str + ): self.reason = reason self.arn = f"arn:aws:ecs:{region_name}:{account_id}:cluster/{cluster_name}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["reason"] = self.reason response_object["arn"] = self.arn @@ -782,12 +809,14 @@ class ClusterFailure(BaseObject): class ContainerInstanceFailure(BaseObject): - def __init__(self, reason, container_instance_id, account_id, region_name): + def __init__( + self, reason: str, container_instance_id: str, account_id: str, region_name: str + ): self.reason = reason self.arn = f"arn:aws:ecs:{region_name}:{account_id}:container-instance/{container_instance_id}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() response_object["reason"] = self.reason response_object["arn"] = self.arn @@ -797,21 +826,21 @@ class ContainerInstanceFailure(BaseObject): class TaskSet(BaseObject): def __init__( self, - service, - cluster, - task_definition, - account_id, - region_name, - external_id=None, - network_configuration=None, - load_balancers=None, - service_registries=None, - launch_type=None, - capacity_provider_strategy=None, - platform_version=None, - scale=None, - client_token=None, - tags=None, + service: str, + cluster: str, + task_definition: str, + account_id: str, + region_name: str, + external_id: Optional[str] = None, + network_configuration: Optional[Dict[str, Any]] = None, + load_balancers: Optional[List[Dict[str, Any]]] = None, + service_registries: Optional[List[Dict[str, Any]]] = None, + launch_type: Optional[str] = None, + capacity_provider_strategy: Optional[List[Dict[str, Any]]] = None, + platform_version: Optional[str] = None, + scale: Optional[Dict[str, Any]] = None, + client_token: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, ): self.service = service self.cluster = cluster @@ -841,7 +870,7 @@ class TaskSet(BaseObject): self.task_set_arn = f"arn:aws:ecs:{region_name}:{account_id}:task-set/{cluster_name}/{service_name}/{self.id}" @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() if isinstance(response_object["createdAt"], datetime): response_object["createdAt"] = unix_time( @@ -871,16 +900,16 @@ class EC2ContainerServiceBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.account_settings = dict() - self.capacity_providers = dict() + self.account_settings: Dict[str, AccountSetting] = dict() + self.capacity_providers: Dict[str, CapacityProvider] = dict() self.clusters: Dict[str, Cluster] = {} - self.task_definitions = {} - self.tasks = {} + self.task_definitions: Dict[str, Dict[int, TaskDefinition]] = {} + self.tasks: Dict[str, Dict[str, Task]] = {} self.services: Dict[str, Service] = {} - self.container_instances: Dict[str, ContainerInstance] = {} + self.container_instances: Dict[str, Dict[str, ContainerInstance]] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc] """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ecs" @@ -896,18 +925,23 @@ class EC2ContainerServiceBackend(BaseBackend): return cluster - def create_capacity_provider(self, name, asg_details, tags): + def create_capacity_provider( + self, + name: str, + asg_details: Dict[str, Any], + tags: Optional[List[Dict[str, str]]], + ) -> CapacityProvider: capacity_provider = CapacityProvider( self.account_id, self.region_name, name, asg_details, tags ) self.capacity_providers[name] = capacity_provider return capacity_provider - def describe_task_definition(self, task_definition_str): + def describe_task_definition(self, task_definition_str: str) -> TaskDefinition: task_definition_name = task_definition_str.split("/")[-1] if ":" in task_definition_name: - family, revision = task_definition_name.split(":") - revision = int(revision) + family, rev = task_definition_name.split(":") + revision = int(rev) else: family = task_definition_name revision = self._get_last_task_definition_revision_id(family) @@ -942,7 +976,12 @@ class EC2ContainerServiceBackend(BaseBackend): self.clusters[cluster_name] = cluster return cluster - def update_cluster(self, cluster_name, cluster_settings, configuration) -> Cluster: + def update_cluster( + self, + cluster_name: str, + cluster_settings: Optional[List[Dict[str, str]]], + configuration: Optional[Dict[str, Any]], + ) -> Cluster: """ The serviceConnectDefaults-parameter is not yet implemented """ @@ -954,8 +993,11 @@ class EC2ContainerServiceBackend(BaseBackend): return cluster def put_cluster_capacity_providers( - self, cluster_name, capacity_providers, default_capacity_provider_strategy - ): + self, + cluster_name: str, + capacity_providers: Optional[List[str]], + default_capacity_provider_strategy: Optional[List[Dict[str, Any]]], + ) -> Cluster: cluster = self._get_cluster(cluster_name) if capacity_providers is not None: cluster.capacity_providers = capacity_providers @@ -965,15 +1007,18 @@ class EC2ContainerServiceBackend(BaseBackend): ) return cluster - def _get_provider(self, name_or_arn) -> CapacityProvider: + def _get_provider(self, name_or_arn: str) -> Optional[CapacityProvider]: for provider in self.capacity_providers.values(): if ( provider.name == name_or_arn or provider.capacity_provider_arn == name_or_arn ): return provider + return None - def describe_capacity_providers(self, names): + def describe_capacity_providers( + self, names: List[str] + ) -> Tuple[List[CapacityProvider], List[CapacityProviderFailure]]: providers = [] failures = [] for name in names: @@ -988,23 +1033,29 @@ class EC2ContainerServiceBackend(BaseBackend): ) return providers, failures - def delete_capacity_provider(self, name_or_arn): - provider = self._get_provider(name_or_arn) + def delete_capacity_provider(self, name_or_arn: str) -> CapacityProvider: + provider: CapacityProvider = self._get_provider(name_or_arn) # type: ignore[assignment] self.capacity_providers.pop(provider.name) - return provider + return provider # type: ignore[return-value] - def update_capacity_provider(self, name_or_arn, asg_provider) -> CapacityProvider: - provider = self._get_provider(name_or_arn) + def update_capacity_provider( + self, name_or_arn: str, asg_provider: Dict[str, Any] + ) -> CapacityProvider: + provider: CapacityProvider = self._get_provider(name_or_arn) # type: ignore[assignment] provider.update(asg_provider) return provider - def list_clusters(self): + def list_clusters(self) -> List[str]: """ maxSize and pagination not implemented """ return [cluster.arn for cluster in self.clusters.values()] - def describe_clusters(self, list_clusters_name=None, include=None): + def describe_clusters( + self, + list_clusters_name: Optional[List[str]] = None, + include: Optional[List[str]] = None, + ) -> Tuple[List[Dict[str, Any]], List[ClusterFailure]]: """ Only include=TAGS is currently supported. """ @@ -1014,8 +1065,8 @@ class EC2ContainerServiceBackend(BaseBackend): if "default" in self.clusters: list_clusters.append(self.clusters["default"].response_object) else: - for cluster in list_clusters_name: - cluster_name = cluster.split("/")[-1] + for cluster_name in list_clusters_name: + cluster_name = cluster_name.split("/")[-1] if cluster_name in self.clusters: list_clusters.append(self.clusters[cluster_name].response_object) else: @@ -1043,24 +1094,24 @@ class EC2ContainerServiceBackend(BaseBackend): def register_task_definition( self, - family, - container_definitions, - volumes=None, - network_mode=None, - tags=None, - placement_constraints=None, - requires_compatibilities=None, - cpu=None, - memory=None, - task_role_arn=None, - execution_role_arn=None, - proxy_configuration=None, - inference_accelerators=None, - runtime_platform=None, - ipc_mode=None, - pid_mode=None, - ephemeral_storage=None, - ): + family: str, + container_definitions: List[Dict[str, Any]], + volumes: Optional[List[Dict[str, Any]]] = None, + network_mode: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + placement_constraints: Optional[List[Dict[str, str]]] = None, + requires_compatibilities: Optional[List[str]] = None, + cpu: Optional[str] = None, + memory: Optional[str] = None, + task_role_arn: Optional[str] = None, + execution_role_arn: Optional[str] = None, + proxy_configuration: Optional[Dict[str, Any]] = None, + inference_accelerators: Optional[List[Dict[str, str]]] = None, + runtime_platform: Optional[Dict[str, str]] = None, + ipc_mode: Optional[str] = None, + pid_mode: Optional[str] = None, + ephemeral_storage: Optional[Dict[str, int]] = None, + ) -> TaskDefinition: if family in self.task_definitions: last_id = self._get_last_task_definition_revision_id(family) revision = (last_id or 0) + 1 @@ -1093,7 +1144,7 @@ class EC2ContainerServiceBackend(BaseBackend): return task_definition - def list_task_definitions(self, family_prefix): + def list_task_definitions(self, family_prefix: str) -> List[str]: task_arns = [] for task_definition_list in self.task_definitions.values(): task_arns.extend( @@ -1105,18 +1156,16 @@ class EC2ContainerServiceBackend(BaseBackend): ) return task_arns - def deregister_task_definition(self, task_definition_str): + def deregister_task_definition(self, task_definition_str: str) -> TaskDefinition: task_definition_name = task_definition_str.split("/")[-1] try: - family, revision = task_definition_name.split(":") + family, rev = task_definition_name.split(":") except ValueError: raise RevisionNotFoundException try: - revision = int(revision) + revision = int(rev) except ValueError: - raise InvalidParameterException( - "Invalid revision number. Number: " + revision - ) + raise InvalidParameterException("Invalid revision number. Number: " + rev) if ( family in self.task_definitions and revision in self.task_definitions[family] @@ -1131,15 +1180,15 @@ class EC2ContainerServiceBackend(BaseBackend): def run_task( self, - cluster_str, - task_definition_str, - count, - overrides, - started_by, - tags, - launch_type, - networking_configuration=None, - ): + cluster_str: str, + task_definition_str: str, + count: int, + overrides: Optional[Dict[str, Any]], + started_by: str, + tags: Optional[List[Dict[str, str]]], + launch_type: Optional[str], + networking_configuration: Optional[Dict[str, Any]] = None, + ) -> List[Task]: cluster = self._get_cluster(cluster_str) task_definition = self.describe_task_definition(task_definition_str) @@ -1152,10 +1201,10 @@ class EC2ContainerServiceBackend(BaseBackend): if launch_type == "FARGATE": for _ in range(count): task = Task( - cluster, - task_definition, - None, - resource_requirements, + cluster=cluster, + task_definition=task_definition, + container_instance_arn=None, + resource_requirements=resource_requirements, backend=self, overrides=overrides or {}, started_by=started_by or "", @@ -1179,10 +1228,8 @@ class EC2ContainerServiceBackend(BaseBackend): ] # TODO: return event about unable to place task if not able to place enough tasks to meet count placed_count = 0 - for container_instance in active_container_instances: - container_instance = self.container_instances[cluster.name][ - container_instance - ] + for name in active_container_instances: + container_instance = self.container_instances[cluster.name][name] container_instance_arn = container_instance.container_instance_arn try_to_place = True while try_to_place: @@ -1215,8 +1262,13 @@ class EC2ContainerServiceBackend(BaseBackend): return tasks @staticmethod - def _calculate_task_resource_requirements(task_definition): - resource_requirements = {"CPU": 0, "MEMORY": 0, "PORTS": [], "PORTS_UDP": []} + def _calculate_task_resource_requirements(task_definition: TaskDefinition) -> Dict[str, Any]: # type: ignore[misc] + resource_requirements: Dict[str, Any] = { + "CPU": 0, + "MEMORY": 0, + "PORTS": [], + "PORTS_UDP": [], + } for container_definition in task_definition.container_definitions: # cloudformation uses capitalized properties, while boto uses all lower case @@ -1243,7 +1295,7 @@ class EC2ContainerServiceBackend(BaseBackend): if "PortMappings" in container_definition else "portMappings" ) - for port_mapping in container_definition.get(port_mapping_key, []): + for port_mapping in container_definition.get(port_mapping_key, []): # type: ignore[attr-defined] if "hostPort" in port_mapping: resource_requirements["PORTS"].append(port_mapping.get("hostPort")) elif "HostPort" in port_mapping: @@ -1252,7 +1304,7 @@ class EC2ContainerServiceBackend(BaseBackend): return resource_requirements @staticmethod - def _can_be_placed(container_instance, task_resource_requirements): + def _can_be_placed(container_instance: ContainerInstance, task_resource_requirements: Dict[str, Any]) -> bool: # type: ignore[misc] """ :param container_instance: The container instance trying to be placed onto @@ -1264,33 +1316,33 @@ class EC2ContainerServiceBackend(BaseBackend): # docs.aws.amazon.com/AmazonECS/latest/developerguide/task-placement.html remaining_cpu = 0 remaining_memory = 0 - reserved_ports = [] + reserved_ports: List[str] = [] for resource in container_instance.remaining_resources: if resource.get("name") == "CPU": - remaining_cpu = resource.get("integerValue") + remaining_cpu = resource.get("integerValue") # type: ignore[assignment] elif resource.get("name") == "MEMORY": - remaining_memory = resource.get("integerValue") + remaining_memory = resource.get("integerValue") # type: ignore[assignment] elif resource.get("name") == "PORTS": - reserved_ports = resource.get("stringSetValue") - if task_resource_requirements.get("CPU") > remaining_cpu: + reserved_ports = resource.get("stringSetValue") # type: ignore[assignment] + if task_resource_requirements.get("CPU") > remaining_cpu: # type: ignore[operator] return False - if task_resource_requirements.get("MEMORY") > remaining_memory: + if task_resource_requirements.get("MEMORY") > remaining_memory: # type: ignore[operator] return False ports_needed = task_resource_requirements.get("PORTS") - for port in ports_needed: + for port in ports_needed: # type: ignore[union-attr] if str(port) in reserved_ports: return False return True def start_task( self, - cluster_str, - task_definition_str, - container_instances, - overrides, - started_by, - tags=None, - ): + cluster_str: str, + task_definition_str: str, + container_instances: List[str], + overrides: Dict[str, Any], + started_by: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> List[Task]: cluster = self._get_cluster(cluster_str) task_definition = self.describe_task_definition(task_definition_str) @@ -1325,7 +1377,12 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks[cluster.name][task.task_arn] = task return tasks - def describe_tasks(self, cluster_str, tasks, include=None): + def describe_tasks( + self, + cluster_str: str, + tasks: Optional[str], + include: Optional[List[str]] = None, + ) -> List[Task]: """ Only include=TAGS is currently supported. """ @@ -1352,15 +1409,15 @@ class EC2ContainerServiceBackend(BaseBackend): def list_tasks( self, - cluster_str, - container_instance, - family, - started_by, - service_name, - desiredStatus, - ): + cluster_str: str, + container_instance: Optional[str], + family: str, + started_by: str, + service_name: str, + desiredStatus: str, + ) -> List[str]: filtered_tasks = [] - for cluster, tasks in self.tasks.items(): + for tasks in self.tasks.values(): for task in tasks.values(): filtered_tasks.append(task) if cluster_str: @@ -1373,7 +1430,7 @@ class EC2ContainerServiceBackend(BaseBackend): if container_instance: filtered_tasks = list( filter( - lambda t: container_instance in t.container_instance_arn, + lambda t: container_instance in t.container_instance_arn, # type: ignore filtered_tasks, ) ) @@ -1404,7 +1461,7 @@ class EC2ContainerServiceBackend(BaseBackend): return [t.task_arn for t in filtered_tasks] - def stop_task(self, cluster_str, task_str, reason): + def stop_task(self, cluster_str: str, task_str: str, reason: str) -> Task: cluster = self._get_cluster(cluster_str) task_id = task_str.split("/")[-1] @@ -1420,7 +1477,7 @@ class EC2ContainerServiceBackend(BaseBackend): ] self.update_container_instance_resources( container_instance, - tasks[task].resource_requirements, + tasks[task].resource_requirements, # type: ignore[arg-type] removing=True, ) tasks[task].last_status = "STOPPED" @@ -1429,7 +1486,7 @@ class EC2ContainerServiceBackend(BaseBackend): return tasks[task] raise Exception(f"Could not find task {task_str} on cluster {cluster.name}") - def _get_service(self, cluster_str, service_str): + def _get_service(self, cluster_str: str, service_str: str) -> Service: cluster = self._get_cluster(cluster_str) for service in self.services.values(): if service.cluster_name == cluster.name and ( @@ -1440,18 +1497,18 @@ class EC2ContainerServiceBackend(BaseBackend): def create_service( self, - cluster_str, - service_name, - desired_count, - task_definition_str=None, - load_balancers=None, - scheduling_strategy=None, - tags=None, - deployment_controller=None, - launch_type=None, - service_registries=None, - platform_version=None, - ): + cluster_str: str, + service_name: str, + desired_count: int, + task_definition_str: Optional[str] = None, + load_balancers: Optional[List[Dict[str, Any]]] = None, + scheduling_strategy: Optional[List[Dict[str, Any]]] = None, + tags: Optional[List[Dict[str, str]]] = None, + deployment_controller: Optional[Dict[str, str]] = None, + launch_type: Optional[str] = None, + service_registries: Optional[List[Dict[str, Any]]] = None, + platform_version: Optional[str] = None, + ) -> Service: cluster = self._get_cluster(cluster_str) if task_definition_str: @@ -1465,15 +1522,15 @@ class EC2ContainerServiceBackend(BaseBackend): raise EcsClientException("launch type should be one of [EC2,FARGATE]") service = Service( - cluster, - service_name, - desired_count, - task_definition, - load_balancers, - scheduling_strategy, - tags, - deployment_controller, - launch_type, + cluster=cluster, + service_name=service_name, + desired_count=desired_count, + task_definition=task_definition, + load_balancers=load_balancers, + scheduling_strategy=scheduling_strategy, + tags=tags, + deployment_controller=deployment_controller, + launch_type=launch_type, backend=self, service_registries=service_registries, platform_version=platform_version, @@ -1483,7 +1540,12 @@ class EC2ContainerServiceBackend(BaseBackend): return service - def list_services(self, cluster_str, scheduling_strategy=None, launch_type=None): + def list_services( + self, + cluster_str: str, + scheduling_strategy: Optional[str] = None, + launch_type: Optional[str] = None, + ) -> List[str]: cluster = self._get_cluster(cluster_str) service_arns = [] for key, service in self.services.items(): @@ -1503,7 +1565,9 @@ class EC2ContainerServiceBackend(BaseBackend): return sorted(service_arns) - def describe_services(self, cluster_str, service_names_or_arns): + def describe_services( + self, cluster_str: str, service_names_or_arns: List[str] + ) -> Tuple[List[Service], List[Dict[str, str]]]: cluster = self._get_cluster(cluster_str) result = [] @@ -1523,8 +1587,12 @@ class EC2ContainerServiceBackend(BaseBackend): return result, failures def update_service( - self, cluster_str, service_str, task_definition_str, desired_count - ): + self, + cluster_str: str, + service_str: str, + task_definition_str: str, + desired_count: Optional[int], + ) -> Service: cluster = self._get_cluster(cluster_str) service_name = service_str.split("/")[-1] @@ -1541,7 +1609,9 @@ class EC2ContainerServiceBackend(BaseBackend): else: raise ServiceNotFoundException - def delete_service(self, cluster_name, service_name, force): + def delete_service( + self, cluster_name: str, service_name: str, force: bool + ) -> Service: cluster = self._get_cluster(cluster_name) service = self._get_service(cluster_name, service_name) @@ -1559,7 +1629,9 @@ class EC2ContainerServiceBackend(BaseBackend): service.status = "INACTIVE" return service - def register_container_instance(self, cluster_str, ec2_instance_id): + def register_container_instance( + self, cluster_str: str, ec2_instance_id: str + ) -> ContainerInstance: cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception(f"{cluster_name} is not a cluster") @@ -1579,7 +1651,7 @@ class EC2ContainerServiceBackend(BaseBackend): self.clusters[cluster_name].registered_container_instances_count += 1 return container_instance - def list_container_instances(self, cluster_str): + def list_container_instances(self, cluster_str: str) -> List[str]: cluster_name = cluster_str.split("/")[-1] container_instances_values = self.container_instances.get( cluster_name, {} @@ -1589,7 +1661,9 @@ class EC2ContainerServiceBackend(BaseBackend): ] return sorted(container_instances) - def describe_container_instances(self, cluster_str, list_container_instance_ids): + def describe_container_instances( + self, cluster_str: str, list_container_instance_ids: List[str] + ) -> Tuple[List[ContainerInstance], List[ContainerInstanceFailure]]: cluster = self._get_cluster(cluster_str) if not list_container_instance_ids: @@ -1616,8 +1690,8 @@ class EC2ContainerServiceBackend(BaseBackend): return container_instance_objects, failures def update_container_instances_state( - self, cluster_str, list_container_instance_ids, status - ): + self, cluster_str: str, list_container_instance_ids: List[str], status: str + ) -> Tuple[List[ContainerInstance], List[ContainerInstanceFailure]]: cluster = self._get_cluster(cluster_str) status = status.upper() @@ -1650,29 +1724,34 @@ class EC2ContainerServiceBackend(BaseBackend): return container_instance_objects, failures def update_container_instance_resources( - self, container_instance, task_resources, removing=False - ): + self, + container_instance: ContainerInstance, + task_resources: Dict[str, Any], + removing: bool = False, + ) -> None: resource_multiplier = 1 if removing: resource_multiplier = -1 for resource in container_instance.remaining_resources: if resource.get("name") == "CPU": resource["integerValue"] -= ( - task_resources.get("CPU") * resource_multiplier + task_resources.get("CPU") * resource_multiplier # type: ignore[operator] ) elif resource.get("name") == "MEMORY": resource["integerValue"] -= ( - task_resources.get("MEMORY") * resource_multiplier + task_resources.get("MEMORY") * resource_multiplier # type: ignore[operator] ) elif resource.get("name") == "PORTS": - for port in task_resources.get("PORTS"): + for port in task_resources.get("PORTS"): # type: ignore[union-attr] if removing: resource["stringSetValue"].remove(str(port)) else: resource["stringSetValue"].append(str(port)) container_instance.running_tasks_count += resource_multiplier * 1 - def deregister_container_instance(self, cluster_str, container_instance_str, force): + def deregister_container_instance( + self, cluster_str: str, container_instance_str: str, force: bool + ) -> ContainerInstance: cluster = self._get_cluster(cluster_str) container_instance_id = container_instance_str.split("/")[-1] @@ -1695,12 +1774,14 @@ class EC2ContainerServiceBackend(BaseBackend): self._respond_to_cluster_state_update(cluster_str) return container_instance - def _respond_to_cluster_state_update(self, cluster_str): + def _respond_to_cluster_state_update(self, cluster_str: str) -> None: self._get_cluster(cluster_str) pass - def put_attributes(self, cluster_name, attributes=None): + def put_attributes( + self, cluster_name: str, attributes: Optional[List[Dict[str, Any]]] = None + ) -> None: cluster = self._get_cluster(cluster_name) if attributes is None: @@ -1716,15 +1797,20 @@ class EC2ContainerServiceBackend(BaseBackend): ) def _put_attribute( - self, cluster_name, name, value=None, target_id=None, target_type=None - ): + self, + cluster_name: str, + name: str, + value: Optional[str] = None, + target_id: Optional[str] = None, + target_type: Optional[str] = None, + ) -> None: if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): instance.attributes[name] = value elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit("/", 1)[-1] + arn = target_id.rsplit("/", 1)[-1] # type: ignore[union-attr] self.container_instances[cluster_name][arn].attributes[name] = value except KeyError: raise JsonRESTError( @@ -1738,7 +1824,7 @@ class EC2ContainerServiceBackend(BaseBackend): "TargetNotFoundException", f"Could not find {target_id}" ) - self.container_instances[cluster_name][target_id].attributes[ + self.container_instances[cluster_name][target_id].attributes[ # type: ignore[index] name ] = value except KeyError: @@ -1748,11 +1834,11 @@ class EC2ContainerServiceBackend(BaseBackend): def list_attributes( self, - target_type, - cluster_name=None, - attr_name=None, - attr_value=None, - ): + target_type: str, + cluster_name: Optional[str] = None, + attr_name: Optional[str] = None, + attr_value: Optional[str] = None, + ) -> Any: """ Pagination is not yet implemented """ @@ -1784,9 +1870,11 @@ class EC2ContainerServiceBackend(BaseBackend): ) ) - return filter(lambda x: all(f(x) for f in filters), all_attrs) + return filter(lambda x: all(f(x) for f in filters), all_attrs) # type: ignore - def delete_attributes(self, cluster_name, attributes=None): + def delete_attributes( + self, cluster_name: str, attributes: Optional[List[Dict[str, Any]]] = None + ) -> None: cluster = self._get_cluster(cluster_name) if attributes is None: @@ -1804,8 +1892,13 @@ class EC2ContainerServiceBackend(BaseBackend): ) def _delete_attribute( - self, cluster_name, name, value=None, target_id=None, target_type=None - ): + self, + cluster_name: str, + name: str, + value: Optional[str] = None, + target_id: Optional[str] = None, + target_type: Optional[str] = None, + ) -> None: if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): if name in instance.attributes and instance.attributes[name] == value: @@ -1813,7 +1906,7 @@ class EC2ContainerServiceBackend(BaseBackend): elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit("/", 1)[-1] + arn = target_id.rsplit("/", 1)[-1] # type: ignore[union-attr] instance = self.container_instances[cluster_name][arn] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] @@ -1829,7 +1922,7 @@ class EC2ContainerServiceBackend(BaseBackend): "TargetNotFoundException", f"Could not find {target_id}" ) - instance = self.container_instances[cluster_name][target_id] + instance = self.container_instances[cluster_name][target_id] # type: ignore[index] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] except KeyError: @@ -1837,7 +1930,9 @@ class EC2ContainerServiceBackend(BaseBackend): "TargetNotFoundException", f"Could not find {target_id}" ) - def list_task_definition_families(self, family_prefix=None): + def list_task_definition_families( + self, family_prefix: Optional[str] = None + ) -> Iterator[str]: """ The Status and pagination parameters are not yet implemented """ @@ -1848,7 +1943,7 @@ class EC2ContainerServiceBackend(BaseBackend): yield task_fam @staticmethod - def _parse_resource_arn(resource_arn): + def _parse_resource_arn(resource_arn: str) -> Dict[str, str]: # type: ignore[misc] regexes = [ "^arn:aws:ecs:(?P[^:]+):(?P[^:]+):(?P[^:]+)/(?P[^:]+)/(?P[^:]+)/ecs-svc/(?P.*)$", "^arn:aws:ecs:(?P[^:]+):(?P[^:]+):(?P[^:]+)/(?P[^:]+)/(?P.*)$", @@ -1860,7 +1955,7 @@ class EC2ContainerServiceBackend(BaseBackend): return match.groupdict() raise JsonRESTError("InvalidParameterException", "The ARN provided is invalid.") - def _get_resource(self, resource_arn, parsed_arn): + def _get_resource(self, resource_arn: str, parsed_arn: Dict[str, str]) -> Any: if parsed_arn["service"] == "cluster": return self._get_cluster(parsed_arn["id"]) if parsed_arn["service"] == "service": @@ -1888,23 +1983,25 @@ class EC2ContainerServiceBackend(BaseBackend): return self._get_provider(parsed_arn["id"]) raise NotImplementedError() - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]: """Currently implemented only for task definitions and services""" parsed_arn = self._parse_resource_arn(resource_arn) resource = self._get_resource(resource_arn, parsed_arn) return resource.tags - def _get_last_task_definition_revision_id(self, family): - definitions = self.task_definitions.get(family, {}) + def _get_last_task_definition_revision_id(self, family: str) -> int: # type: ignore[return] + definitions = self.task_definitions.get(family) if definitions: return max(definitions.keys()) - def tag_resource(self, resource_arn, tags) -> None: + def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: parsed_arn = self._parse_resource_arn(resource_arn) resource = self._get_resource(resource_arn, parsed_arn) resource.tags = self._merge_tags(resource.tags, tags) - def _merge_tags(self, existing_tags, new_tags): + def _merge_tags( + self, existing_tags: List[Dict[str, str]], new_tags: List[Dict[str, str]] + ) -> List[Dict[str, str]]: merged_tags = new_tags new_keys = self._get_keys(new_tags) for existing_tag in existing_tags: @@ -1913,30 +2010,30 @@ class EC2ContainerServiceBackend(BaseBackend): return merged_tags @staticmethod - def _get_keys(tags): + def _get_keys(tags: List[Dict[str, str]]) -> List[str]: return [tag["key"] for tag in tags] - def untag_resource(self, resource_arn, tag_keys) -> None: + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: parsed_arn = self._parse_resource_arn(resource_arn) resource = self._get_resource(resource_arn, parsed_arn) resource.tags = [tag for tag in resource.tags if tag["key"] not in tag_keys] def create_task_set( self, - service, - cluster_str, - task_definition, - external_id=None, - network_configuration=None, - load_balancers=None, - service_registries=None, - launch_type=None, - capacity_provider_strategy=None, - platform_version=None, - scale=None, - client_token=None, - tags=None, - ): + service: str, + cluster_str: str, + task_definition: str, + external_id: Optional[str] = None, + network_configuration: Optional[Dict[str, Any]] = None, + load_balancers: Optional[List[Dict[str, Any]]] = None, + service_registries: Optional[List[Dict[str, Any]]] = None, + launch_type: Optional[str] = None, + capacity_provider_strategy: Optional[List[Dict[str, Any]]] = None, + platform_version: Optional[str] = None, + scale: Optional[Dict[str, Any]] = None, + client_token: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + ) -> TaskSet: launch_type = launch_type if launch_type is not None else "EC2" if launch_type not in ["EC2", "FARGATE"]: raise EcsClientException("launch type should be one of [EC2,FARGATE]") @@ -1975,9 +2072,10 @@ class EC2ContainerServiceBackend(BaseBackend): return task_set - def describe_task_sets(self, cluster_str, service, task_sets=None, include=None): + def describe_task_sets( + self, cluster_str: str, service: str, task_sets: Optional[List[str]] = None + ) -> List[TaskSet]: task_sets = task_sets or [] - include = include or [] cluster_obj = self._get_cluster(cluster_str) @@ -2002,7 +2100,7 @@ class EC2ContainerServiceBackend(BaseBackend): return task_set_results - def delete_task_set(self, cluster, service, task_set): + def delete_task_set(self, cluster: str, service: str, task_set: str) -> TaskSet: """ The Force-parameter is not yet implemented """ @@ -2028,7 +2126,9 @@ class EC2ContainerServiceBackend(BaseBackend): return deleted_task_set - def update_task_set(self, cluster, service, task_set, scale): + def update_task_set( + self, cluster: str, service: str, task_set: str, scale: Dict[str, Any] + ) -> TaskSet: cluster_name = cluster.split("/")[-1] service_name = service.split("/")[-1] task_set_obj = self.describe_task_sets( @@ -2037,7 +2137,9 @@ class EC2ContainerServiceBackend(BaseBackend): task_set_obj.scale = scale return task_set_obj - def update_service_primary_task_set(self, cluster, service, primary_task_set): + def update_service_primary_task_set( + self, cluster: str, service: str, primary_task_set: str + ) -> TaskSet: """Updates task sets be PRIMARY or ACTIVE for given cluster:service task sets""" cluster_name = cluster.split("/")[-1] service_name = service.split("/")[-1] @@ -2057,7 +2159,9 @@ class EC2ContainerServiceBackend(BaseBackend): task_set.status = "ACTIVE" return task_set_obj - def list_account_settings(self, name=None, value=None): + def list_account_settings( + self, name: Optional[str] = None, value: Optional[str] = None + ) -> List[AccountSetting]: expected_names = [ "serviceLongArnFormat", "taskLongArnFormat", @@ -2076,15 +2180,15 @@ class EC2ContainerServiceBackend(BaseBackend): if (not name or s.name == name) and (not value or s.value == value) ] - def put_account_setting(self, name, value): + def put_account_setting(self, name: str, value: str) -> AccountSetting: account_setting = AccountSetting(name, value) self.account_settings[name] = account_setting return account_setting - def delete_account_setting(self, name): + def delete_account_setting(self, name: str) -> None: self.account_settings.pop(name, None) - def enable_long_arn_for_name(self, name): + def enable_long_arn_for_name(self, name: str) -> bool: account = self.account_settings.get(name, None) if account and account.value == "disabled": return False diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index d8d64d6fb..27d42098d 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -1,35 +1,26 @@ import json +from typing import Any, Dict from moto.core.responses import BaseResponse from .models import ecs_backends, EC2ContainerServiceBackend class EC2ContainerServiceResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="ecs") @property def ecs_backend(self) -> EC2ContainerServiceBackend: return ecs_backends[self.current_account][self.region] - @property - def request_params(self): - try: - return json.loads(self.body) - except ValueError: - return {} - - def _get_param(self, param_name, if_none=None): - return self.request_params.get(param_name, if_none) - - def create_capacity_provider(self): + def create_capacity_provider(self) -> str: name = self._get_param("name") asg_provider = self._get_param("autoScalingGroupProvider") tags = self._get_param("tags") provider = self.ecs_backend.create_capacity_provider(name, asg_provider, tags) return json.dumps({"capacityProvider": provider.response_object}) - def create_cluster(self): + def create_cluster(self) -> str: cluster_name = self._get_param("clusterName") tags = self._get_param("tags") settings = self._get_param("settings") @@ -50,23 +41,18 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"cluster": cluster.response_object}) - def list_clusters(self): + def list_clusters(self) -> str: cluster_arns = self.ecs_backend.list_clusters() - return json.dumps( - { - "clusterArns": cluster_arns - # 'nextToken': str(uuid.uuid4()) - } - ) + return json.dumps({"clusterArns": cluster_arns}) - def update_cluster(self): + def update_cluster(self) -> str: cluster_name = self._get_param("cluster") settings = self._get_param("settings") configuration = self._get_param("configuration") cluster = self.ecs_backend.update_cluster(cluster_name, settings, configuration) return json.dumps({"cluster": cluster.response_object}) - def put_cluster_capacity_providers(self): + def put_cluster_capacity_providers(self) -> str: cluster_name = self._get_param("cluster") capacity_providers = self._get_param("capacityProviders") default_capacity_provider_strategy = self._get_param( @@ -77,18 +63,18 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"cluster": cluster.response_object}) - def delete_capacity_provider(self): + def delete_capacity_provider(self) -> str: name = self._get_param("capacityProvider") provider = self.ecs_backend.delete_capacity_provider(name) return json.dumps({"capacityProvider": provider.response_object}) - def update_capacity_provider(self): + def update_capacity_provider(self) -> str: name = self._get_param("name") asg_provider = self._get_param("autoScalingGroupProvider") provider = self.ecs_backend.update_capacity_provider(name, asg_provider) return json.dumps({"capacityProvider": provider.response_object}) - def describe_capacity_providers(self): + def describe_capacity_providers(self) -> str: names = self._get_param("capacityProviders") providers, failures = self.ecs_backend.describe_capacity_providers(names) return json.dumps( @@ -98,7 +84,7 @@ class EC2ContainerServiceResponse(BaseResponse): } ) - def describe_clusters(self): + def describe_clusters(self) -> str: names = self._get_param("clusters") include = self._get_param("include") clusters, failures = self.ecs_backend.describe_clusters(names, include) @@ -109,12 +95,12 @@ class EC2ContainerServiceResponse(BaseResponse): } ) - def delete_cluster(self): + def delete_cluster(self) -> str: cluster_str = self._get_param("cluster") cluster = self.ecs_backend.delete_cluster(cluster_str) return json.dumps({"cluster": cluster.response_object}) - def register_task_definition(self): + def register_task_definition(self) -> str: family = self._get_param("family") container_definitions = self._get_param("containerDefinitions") volumes = self._get_param("volumes") @@ -154,7 +140,7 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"taskDefinition": task_definition.response_object}) - def list_task_definitions(self): + def list_task_definitions(self) -> str: family_prefix = self._get_param("familyPrefix") task_definition_arns = self.ecs_backend.list_task_definitions(family_prefix) return json.dumps( @@ -164,22 +150,22 @@ class EC2ContainerServiceResponse(BaseResponse): } ) - def describe_task_definition(self): + def describe_task_definition(self) -> str: task_definition_str = self._get_param("taskDefinition") data = self.ecs_backend.describe_task_definition(task_definition_str) - resp = {"taskDefinition": data.response_object, "failures": []} + resp: Dict[str, Any] = {"taskDefinition": data.response_object, "failures": []} if "TAGS" in self._get_param("include", []): resp["tags"] = self.ecs_backend.list_tags_for_resource(data.arn) return json.dumps(resp) - def deregister_task_definition(self): + def deregister_task_definition(self) -> str: task_definition_str = self._get_param("taskDefinition") task_definition = self.ecs_backend.deregister_task_definition( task_definition_str ) return json.dumps({"taskDefinition": task_definition.response_object}) - def run_task(self): + def run_task(self) -> str: cluster_str = self._get_param("cluster", "default") overrides = self._get_param("overrides") task_definition_str = self._get_param("taskDefinition") @@ -202,7 +188,7 @@ class EC2ContainerServiceResponse(BaseResponse): {"tasks": [task.response_object for task in tasks], "failures": []} ) - def describe_tasks(self): + def describe_tasks(self) -> str: cluster = self._get_param("cluster", "default") tasks = self._get_param("tasks") include = self._get_param("include") @@ -211,7 +197,7 @@ class EC2ContainerServiceResponse(BaseResponse): {"tasks": [task.response_object for task in data], "failures": []} ) - def start_task(self): + def start_task(self) -> str: cluster_str = self._get_param("cluster", "default") overrides = self._get_param("overrides") task_definition_str = self._get_param("taskDefinition") @@ -230,7 +216,7 @@ class EC2ContainerServiceResponse(BaseResponse): {"tasks": [task.response_object for task in tasks], "failures": []} ) - def list_tasks(self): + def list_tasks(self) -> str: cluster_str = self._get_param("cluster", "default") container_instance = self._get_param("containerInstance") family = self._get_param("family") @@ -247,14 +233,14 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"taskArns": task_arns}) - def stop_task(self): + def stop_task(self) -> str: cluster_str = self._get_param("cluster", "default") task = self._get_param("task") reason = self._get_param("reason") task = self.ecs_backend.stop_task(cluster_str, task, reason) return json.dumps({"task": task.response_object}) - def create_service(self): + def create_service(self) -> str: cluster_str = self._get_param("cluster", "default") service_name = self._get_param("serviceName") task_definition_str = self._get_param("taskDefinition") @@ -281,22 +267,16 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"service": service.response_object}) - def list_services(self): + def list_services(self) -> str: cluster_str = self._get_param("cluster", "default") scheduling_strategy = self._get_param("schedulingStrategy") launch_type = self._get_param("launchType") service_arns = self.ecs_backend.list_services( cluster_str, scheduling_strategy, launch_type=launch_type ) - return json.dumps( - { - "serviceArns": service_arns - # , - # 'nextToken': str(uuid.uuid4()) - } - ) + return json.dumps({"serviceArns": service_arns}) - def describe_services(self): + def describe_services(self) -> str: cluster_str = self._get_param("cluster", "default") service_names = self._get_param("services") services, failures = self.ecs_backend.describe_services( @@ -313,7 +293,7 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps(resp) - def update_service(self): + def update_service(self) -> str: cluster_str = self._get_param("cluster", "default") service_name = self._get_param("service") task_definition = self._get_param("taskDefinition") @@ -323,14 +303,14 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"service": service.response_object}) - def delete_service(self): + def delete_service(self) -> str: service_name = self._get_param("service") cluster_name = self._get_param("cluster", "default") force = self._get_param("force", False) service = self.ecs_backend.delete_service(cluster_name, service_name, force) return json.dumps({"service": service.response_object}) - def register_container_instance(self): + def register_container_instance(self) -> str: cluster_str = self._get_param("cluster", "default") instance_identity_document_str = self._get_param("instanceIdentityDocument") instance_identity_document = json.loads(instance_identity_document_str) @@ -340,7 +320,7 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"containerInstance": container_instance.response_object}) - def deregister_container_instance(self): + def deregister_container_instance(self) -> str: cluster_str = self._get_param("cluster", "default") container_instance_str = self._get_param("containerInstance") force = self._get_param("force") @@ -349,12 +329,12 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"containerInstance": container_instance.response_object}) - def list_container_instances(self): + def list_container_instances(self) -> str: cluster_str = self._get_param("cluster", "default") container_instance_arns = self.ecs_backend.list_container_instances(cluster_str) return json.dumps({"containerInstanceArns": container_instance_arns}) - def describe_container_instances(self): + def describe_container_instances(self) -> str: cluster_str = self._get_param("cluster", "default") list_container_instance_arns = self._get_param("containerInstances") container_instances, failures = self.ecs_backend.describe_container_instances( @@ -369,7 +349,7 @@ class EC2ContainerServiceResponse(BaseResponse): } ) - def update_container_instances_state(self): + def update_container_instances_state(self) -> str: cluster_str = self._get_param("cluster", "default") list_container_instance_arns = self._get_param("containerInstances") status_str = self._get_param("status") @@ -388,7 +368,7 @@ class EC2ContainerServiceResponse(BaseResponse): } ) - def put_attributes(self): + def put_attributes(self) -> str: cluster_name = self._get_param("cluster") attributes = self._get_param("attributes") @@ -396,7 +376,7 @@ class EC2ContainerServiceResponse(BaseResponse): return json.dumps({"attributes": attributes}) - def list_attributes(self): + def list_attributes(self) -> str: cluster_name = self._get_param("cluster") attr_name = self._get_param("attributeName") attr_value = self._get_param("attributeValue") @@ -416,7 +396,7 @@ class EC2ContainerServiceResponse(BaseResponse): return json.dumps({"attributes": formatted_results}) - def delete_attributes(self): + def delete_attributes(self) -> str: cluster_name = self._get_param("cluster", "default") attributes = self._get_param("attributes") @@ -424,7 +404,7 @@ class EC2ContainerServiceResponse(BaseResponse): return json.dumps({"attributes": attributes}) - def discover_poll_endpoint(self): + def discover_poll_endpoint(self) -> str: # Here are the arguments, this api is used by the ecs client so obviously no decent # documentation. Hence I've responded with valid but useless data # cluster_name = self._get_param('cluster') @@ -433,30 +413,30 @@ class EC2ContainerServiceResponse(BaseResponse): {"endpoint": "http://localhost", "telemetryEndpoint": "http://localhost"} ) - def list_task_definition_families(self): + def list_task_definition_families(self) -> str: family_prefix = self._get_param("familyPrefix") results = self.ecs_backend.list_task_definition_families(family_prefix) return json.dumps({"families": list(results)}) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: resource_arn = self._get_param("resourceArn") tags = self.ecs_backend.list_tags_for_resource(resource_arn) return json.dumps({"tags": tags}) - def tag_resource(self): + def tag_resource(self) -> str: resource_arn = self._get_param("resourceArn") tags = self._get_param("tags") self.ecs_backend.tag_resource(resource_arn, tags) return json.dumps({}) - def untag_resource(self): + def untag_resource(self) -> str: resource_arn = self._get_param("resourceArn") tag_keys = self._get_param("tagKeys") self.ecs_backend.untag_resource(resource_arn, tag_keys) return json.dumps({}) - def create_task_set(self): + def create_task_set(self) -> str: service_str = self._get_param("service") cluster_str = self._get_param("cluster", "default") task_definition = self._get_param("taskDefinition") @@ -487,13 +467,13 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"taskSet": task_set.response_object}) - def describe_task_sets(self): + def describe_task_sets(self) -> str: cluster_str = self._get_param("cluster", "default") service_str = self._get_param("service") task_sets = self._get_param("taskSets") include = self._get_param("include", []) task_set_objs = self.ecs_backend.describe_task_sets( - cluster_str, service_str, task_sets, include + cluster_str, service_str, task_sets ) response_objs = [t.response_object for t in task_set_objs] @@ -502,14 +482,14 @@ class EC2ContainerServiceResponse(BaseResponse): del ro["tags"] return json.dumps({"taskSets": response_objs}) - def delete_task_set(self): + def delete_task_set(self) -> str: cluster_str = self._get_param("cluster") service_str = self._get_param("service") task_set = self._get_param("taskSet") task_set = self.ecs_backend.delete_task_set(cluster_str, service_str, task_set) return json.dumps({"taskSet": task_set.response_object}) - def update_task_set(self): + def update_task_set(self) -> str: cluster_str = self._get_param("cluster", "default") service_str = self._get_param("service") task_set = self._get_param("taskSet") @@ -520,7 +500,7 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"taskSet": task_set.response_object}) - def update_service_primary_task_set(self): + def update_service_primary_task_set(self) -> str: cluster_str = self._get_param("cluster", "default") service_str = self._get_param("service") primary_task_set = self._get_param("primaryTaskSet") @@ -530,19 +510,19 @@ class EC2ContainerServiceResponse(BaseResponse): ) return json.dumps({"taskSet": task_set.response_object}) - def put_account_setting(self): + def put_account_setting(self) -> str: name = self._get_param("name") value = self._get_param("value") account_setting = self.ecs_backend.put_account_setting(name, value) return json.dumps({"setting": account_setting.response_object}) - def list_account_settings(self): + def list_account_settings(self) -> str: name = self._get_param("name") value = self._get_param("value") account_settings = self.ecs_backend.list_account_settings(name, value) return json.dumps({"settings": [s.response_object for s in account_settings]}) - def delete_account_setting(self): + def delete_account_setting(self) -> str: name = self._get_param("name") self.ecs_backend.delete_account_setting(name) return "{}" diff --git a/moto/settings.py b/moto/settings.py index 652b74163..9467f63c5 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -60,7 +60,7 @@ def get_s3_default_key_buffer_size(): ) -def ecs_new_arn_format(): +def ecs_new_arn_format() -> bool: # True by default - only the value 'false' will return false return os.environ.get("MOTO_ECS_NEW_ARN", "true").lower() != "false" diff --git a/setup.cfg b/setup.cfg index dc38dd9e4..741d0ddf1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -229,7 +229,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/es,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/ecs,moto/es,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract