From 9b969f7e3fc2a5f72b1ac06458170ab3aed5e930 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 26 Apr 2023 10:56:53 +0000 Subject: [PATCH] Techdebt: MyPy SWF (#6256) --- moto/swf/exceptions.py | 26 ++-- moto/swf/models/__init__.py | 178 ++++++++++++++++--------- moto/swf/models/activity_task.py | 44 +++--- moto/swf/models/activity_type.py | 5 +- moto/swf/models/decision_task.py | 39 +++--- moto/swf/models/domain.py | 65 ++++++--- moto/swf/models/generic_type.py | 22 +-- moto/swf/models/history_event.py | 13 +- moto/swf/models/timeout.py | 5 +- moto/swf/models/timer.py | 9 +- moto/swf/models/workflow_execution.py | 185 +++++++++++++++----------- moto/swf/models/workflow_type.py | 5 +- moto/swf/responses.py | 103 +++++++------- moto/swf/utils.py | 2 +- setup.cfg | 2 +- 15 files changed, 421 insertions(+), 282 deletions(-) diff --git a/moto/swf/exceptions.py b/moto/swf/exceptions.py index cd5e3b476..4e3c9a503 100644 --- a/moto/swf/exceptions.py +++ b/moto/swf/exceptions.py @@ -1,12 +1,16 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING from moto.core.exceptions import JsonRESTError +if TYPE_CHECKING: + from .models.generic_type import GenericType + class SWFClientError(JsonRESTError): code = 400 class SWFUnknownResourceFault(SWFClientError): - def __init__(self, resource_type, resource_name=None): + def __init__(self, resource_type: str, resource_name: Optional[str] = None): if resource_name: message = f"Unknown {resource_type}: {resource_name}" else: @@ -15,21 +19,21 @@ class SWFUnknownResourceFault(SWFClientError): class SWFDomainAlreadyExistsFault(SWFClientError): - def __init__(self, domain_name): + def __init__(self, domain_name: str): super().__init__( "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name ) class SWFDomainDeprecatedFault(SWFClientError): - def __init__(self, domain_name): + def __init__(self, domain_name: str): super().__init__( "com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name ) class SWFSerializationException(SWFClientError): - def __init__(self, value): + def __init__(self, value: Any): message = "class java.lang.Foo can not be converted to an String " message += f" (not a real SWF exception ; happened on: {value})" __type = "com.amazonaws.swf.base.model#SerializationException" @@ -37,7 +41,7 @@ class SWFSerializationException(SWFClientError): class SWFTypeAlreadyExistsFault(SWFClientError): - def __init__(self, _type): + def __init__(self, _type: "GenericType"): super().__init__( "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]", @@ -45,7 +49,7 @@ class SWFTypeAlreadyExistsFault(SWFClientError): class SWFTypeDeprecatedFault(SWFClientError): - def __init__(self, _type): + def __init__(self, _type: "GenericType"): super().__init__( "com.amazonaws.swf.base.model#TypeDeprecatedFault", f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]", @@ -53,7 +57,7 @@ class SWFTypeDeprecatedFault(SWFClientError): class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", "Already Started", @@ -61,7 +65,7 @@ class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): class SWFDefaultUndefinedFault(SWFClientError): - def __init__(self, key): + def __init__(self, key: str): # TODO: move that into moto.core.utils maybe? words = key.split("_") key_camel_case = words.pop(0) @@ -73,12 +77,12 @@ class SWFDefaultUndefinedFault(SWFClientError): class SWFValidationException(SWFClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("com.amazon.coral.validate#ValidationException", message) class SWFDecisionValidationException(SWFClientError): - def __init__(self, problems): + def __init__(self, problems: List[Dict[str, Any]]): # messages messages = [] for pb in problems: @@ -106,5 +110,5 @@ class SWFDecisionValidationException(SWFClientError): class SWFWorkflowExecutionClosedError(Exception): - def __str__(self): + def __str__(self) -> str: return repr("Cannot change this object because the WorkflowExecution is closed") diff --git a/moto/swf/models/__init__.py b/moto/swf/models/__init__.py index 628904a80..4eb499e51 100644 --- a/moto/swf/models/__init__.py +++ b/moto/swf/models/__init__.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Optional from moto.core import BaseBackend, BackendDict from ..exceptions import ( @@ -12,7 +13,7 @@ from .activity_task import ActivityTask # noqa from .activity_type import ActivityType # noqa from .decision_task import DecisionTask # noqa from .domain import Domain # noqa -from .generic_type import GenericType # noqa +from .generic_type import GenericType, TGenericType # noqa from .history_event import HistoryEvent # noqa from .timeout import Timeout # noqa from .timer import Timer # noqa @@ -24,33 +25,39 @@ KNOWN_SWF_TYPES = {"activity": ActivityType, "workflow": WorkflowType} class SWFBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.domains = [] + self.domains: List[Domain] = [] - def _get_domain(self, name, ignore_empty=False): + def _get_domain(self, name: str, ignore_empty: bool = False) -> Domain: matching = [domain for domain in self.domains if domain.name == name] if not matching and not ignore_empty: raise SWFUnknownResourceFault("domain", name) if matching: return matching[0] - return None + return None # type: ignore - def _process_timeouts(self): + def _process_timeouts(self) -> None: for domain in self.domains: for wfe in domain.workflow_executions: wfe._process_timeouts() - def list_domains(self, status, reverse_order=None): + def list_domains( + self, status: str, reverse_order: Optional[bool] = None + ) -> List[Domain]: domains = [domain for domain in self.domains if domain.status == status] domains = sorted(domains, key=lambda domain: domain.name) if reverse_order: - domains = reversed(domains) + domains = reversed(domains) # type: ignore[assignment] return domains def list_open_workflow_executions( - self, domain_name, maximum_page_size, tag_filter, reverse_order - ): + self, + domain_name: str, + maximum_page_size: int, + tag_filter: Dict[str, str], + reverse_order: bool, + ) -> List[WorkflowExecution]: self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": @@ -64,17 +71,17 @@ class SWFBackend(BaseBackend): if tag_filter["tag"] not in open_wfe.tag_list: open_wfes.remove(open_wfe) if reverse_order: - open_wfes = reversed(open_wfes) + open_wfes = reversed(open_wfes) # type: ignore[assignment] return open_wfes[0:maximum_page_size] def list_closed_workflow_executions( self, - domain_name, - tag_filter, - close_status_filter, - maximum_page_size, - reverse_order, - ): + domain_name: str, + tag_filter: Dict[str, str], + close_status_filter: Dict[str, str], + maximum_page_size: int, + reverse_order: bool, + ) -> List[WorkflowExecution]: self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": @@ -90,15 +97,18 @@ class SWFBackend(BaseBackend): closed_wfes.remove(closed_wfe) if close_status_filter: for closed_wfe in closed_wfes: - if close_status_filter != closed_wfe.close_status: + if close_status_filter != closed_wfe.close_status: # type: ignore closed_wfes.remove(closed_wfe) if reverse_order: - closed_wfes = reversed(closed_wfes) + closed_wfes = reversed(closed_wfes) # type: ignore[assignment] return closed_wfes[0:maximum_page_size] def register_domain( - self, name, workflow_execution_retention_period_in_days, description=None - ): + self, + name: str, + workflow_execution_retention_period_in_days: int, + description: Optional[str] = None, + ) -> None: if self._get_domain(name, ignore_empty=True): raise SWFDomainAlreadyExistsFault(name) domain = Domain( @@ -110,68 +120,82 @@ class SWFBackend(BaseBackend): ) self.domains.append(domain) - def deprecate_domain(self, name): + def deprecate_domain(self, name: str) -> None: domain = self._get_domain(name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(name) domain.status = "DEPRECATED" - def undeprecate_domain(self, name): + def undeprecate_domain(self, name: str) -> None: domain = self._get_domain(name) if domain.status == "REGISTERED": raise SWFDomainAlreadyExistsFault(name) domain.status = "REGISTERED" - def describe_domain(self, name): + def describe_domain(self, name: str) -> Optional[Domain]: return self._get_domain(name) - def list_types(self, kind, domain_name, status, reverse_order=None): + def list_types( + self, + kind: str, + domain_name: str, + status: str, + reverse_order: Optional[bool] = None, + ) -> List[GenericType]: domain = self._get_domain(domain_name) - _types = domain.find_types(kind, status) + _types: List[GenericType] = domain.find_types(kind, status) _types = sorted(_types, key=lambda domain: domain.name) if reverse_order: - _types = reversed(_types) + _types = reversed(_types) # type: ignore return _types - def register_type(self, kind, domain_name, name, version, **kwargs): + def register_type( + self, kind: str, domain_name: str, name: str, version: str, **kwargs: Any + ) -> None: domain = self._get_domain(domain_name) - _type = domain.get_type(kind, name, version, ignore_empty=True) + _type: GenericType = domain.get_type(kind, name, version, ignore_empty=True) if _type: raise SWFTypeAlreadyExistsFault(_type) _class = KNOWN_SWF_TYPES[kind] _type = _class(name, version, **kwargs) domain.add_type(_type) - def deprecate_type(self, kind, domain_name, name, version): + def deprecate_type( + self, kind: str, domain_name: str, name: str, version: str + ) -> None: domain = self._get_domain(domain_name) - _type = domain.get_type(kind, name, version) + _type: GenericType = domain.get_type(kind, name, version) if _type.status == "DEPRECATED": raise SWFTypeDeprecatedFault(_type) _type.status = "DEPRECATED" - def undeprecate_type(self, kind, domain_name, name, version): + def undeprecate_type( + self, kind: str, domain_name: str, name: str, version: str + ) -> None: domain = self._get_domain(domain_name) - _type = domain.get_type(kind, name, version) + _type: GenericType = domain.get_type(kind, name, version) if _type.status == "REGISTERED": raise SWFTypeAlreadyExistsFault(_type) _type.status = "REGISTERED" - def describe_type(self, kind, domain_name, name, version): + def describe_type( + self, kind: str, domain_name: str, name: str, version: str + ) -> GenericType: domain = self._get_domain(domain_name) return domain.get_type(kind, name, version) def start_workflow_execution( self, - domain_name, - workflow_id, - workflow_name, - workflow_version, - tag_list=None, - workflow_input=None, - **kwargs, - ): + domain_name: str, + workflow_id: str, + workflow_name: str, + workflow_version: str, + tag_list: Optional[Dict[str, str]] = None, + workflow_input: Optional[str] = None, + **kwargs: Any, + ) -> WorkflowExecution: domain = self._get_domain(domain_name) - wf_type = domain.get_type("workflow", workflow_name, workflow_version) + wf_type: WorkflowType = domain.get_type("workflow", workflow_name, workflow_version) # type: ignore if wf_type.status == "DEPRECATED": raise SWFTypeDeprecatedFault(wf_type) wfe = WorkflowExecution( @@ -187,13 +211,17 @@ class SWFBackend(BaseBackend): return wfe - def describe_workflow_execution(self, domain_name, run_id, workflow_id): + def describe_workflow_execution( + self, domain_name: str, run_id: str, workflow_id: str + ) -> Optional[WorkflowExecution]: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) return domain.get_workflow_execution(workflow_id, run_id=run_id) - def poll_for_decision_task(self, domain_name, task_list, identity=None): + def poll_for_decision_task( + self, domain_name: str, task_list: List[str], identity: Optional[str] = None + ) -> Optional[DecisionTask]: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -245,7 +273,9 @@ class SWFBackend(BaseBackend): sleep(1) return None - def count_pending_decision_tasks(self, domain_name, task_list): + def count_pending_decision_tasks( + self, domain_name: str, task_list: List[str] + ) -> int: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -256,8 +286,11 @@ class SWFBackend(BaseBackend): return count def respond_decision_task_completed( - self, task_token, decisions=None, execution_context=None - ): + self, + task_token: str, + decisions: Optional[List[Dict[str, Any]]] = None, + execution_context: Optional[str] = None, + ) -> None: # process timeouts on all objects self._process_timeouts() # let's find decision task @@ -308,7 +341,9 @@ class SWFBackend(BaseBackend): execution_context=execution_context, ) - def poll_for_activity_task(self, domain_name, task_list, identity=None): + def poll_for_activity_task( + self, domain_name: str, task_list: List[str], identity: Optional[str] = None + ) -> Optional[ActivityTask]: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -342,7 +377,9 @@ class SWFBackend(BaseBackend): sleep(1) return None - def count_pending_activity_tasks(self, domain_name, task_list): + def count_pending_activity_tasks( + self, domain_name: str, task_list: List[str] + ) -> int: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -353,7 +390,7 @@ class SWFBackend(BaseBackend): count += len(pending) return count - def _find_activity_task_from_token(self, task_token): + def _find_activity_task_from_token(self, task_token: str) -> ActivityTask: activity_task = None for domain in self.domains: for wfe in domain.workflow_executions: @@ -389,14 +426,18 @@ class SWFBackend(BaseBackend): # everything's good return activity_task - def respond_activity_task_completed(self, task_token, result=None): + def respond_activity_task_completed( + self, task_token: str, result: Any = None + ) -> None: # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) wfe = activity_task.workflow_execution wfe.complete_activity_task(activity_task.task_token, result=result) - def respond_activity_task_failed(self, task_token, reason=None, details=None): + def respond_activity_task_failed( + self, task_token: str, reason: Optional[str] = None, details: Any = None + ) -> None: # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) @@ -405,22 +446,24 @@ class SWFBackend(BaseBackend): def terminate_workflow_execution( self, - domain_name, - workflow_id, - child_policy=None, - details=None, - reason=None, - run_id=None, - ): + domain_name: str, + workflow_id: str, + child_policy: Any = None, + details: Any = None, + reason: Optional[str] = None, + run_id: Optional[str] = None, + ) -> None: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( workflow_id, run_id=run_id, raise_if_closed=True ) - wfe.terminate(child_policy=child_policy, details=details, reason=reason) + wfe.terminate(child_policy=child_policy, details=details, reason=reason) # type: ignore[union-attr] - def record_activity_task_heartbeat(self, task_token, details=None): + def record_activity_task_heartbeat( + self, task_token: str, details: Any = None + ) -> None: # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) @@ -429,15 +472,20 @@ class SWFBackend(BaseBackend): activity_task.details = details def signal_workflow_execution( - self, domain_name, signal_name, workflow_id, workflow_input=None, run_id=None - ): + self, + domain_name: str, + signal_name: str, + workflow_id: str, + workflow_input: Any = None, + run_id: Optional[str] = None, + ) -> None: # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( workflow_id, run_id=run_id, raise_if_closed=True ) - wfe.signal(signal_name, workflow_input) + wfe.signal(signal_name, workflow_input) # type: ignore[union-attr] swf_backends = BackendDict(SWFBackend, "swf") diff --git a/moto/swf/models/activity_task.py b/moto/swf/models/activity_task.py index 2119188fa..37b42f29c 100644 --- a/moto/swf/models/activity_task.py +++ b/moto/swf/models/activity_task.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any, Dict, Optional, TYPE_CHECKING from moto.core import BaseModel from moto.core.utils import unix_time @@ -7,16 +8,20 @@ from ..exceptions import SWFWorkflowExecutionClosedError from .timeout import Timeout +if TYPE_CHECKING: + from .activity_type import ActivityType + from .workflow_execution import WorkflowExecution + class ActivityTask(BaseModel): def __init__( self, - activity_id, - activity_type, - scheduled_event_id, - workflow_execution, - timeouts, - workflow_input=None, + activity_id: str, + activity_type: "ActivityType", + scheduled_event_id: int, + workflow_execution: "WorkflowExecution", + timeouts: Dict[str, Any], + workflow_input: Any = None, ): self.activity_id = activity_id self.activity_type = activity_type @@ -24,26 +29,26 @@ class ActivityTask(BaseModel): self.input = workflow_input self.last_heartbeat_timestamp = unix_time() self.scheduled_event_id = scheduled_event_id - self.started_event_id = None + self.started_event_id: Optional[int] = None self.state = "SCHEDULED" self.task_token = str(mock_random.uuid4()) self.timeouts = timeouts - self.timeout_type = None + self.timeout_type: Optional[str] = None self.workflow_execution = workflow_execution # this is *not* necessarily coherent with workflow execution history, # but that shouldn't be a problem for tests self.scheduled_at = datetime.utcnow() - def _check_workflow_execution_open(self): + def _check_workflow_execution_open(self) -> None: if not self.workflow_execution.open: raise SWFWorkflowExecutionClosedError() @property - def open(self): + def open(self) -> bool: return self.state in ["SCHEDULED", "STARTED"] - def to_full_dict(self): - hsh = { + def to_full_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = { "activityId": self.activity_id, "activityType": self.activity_type.to_short_dict(), "taskToken": self.task_token, @@ -54,22 +59,22 @@ class ActivityTask(BaseModel): hsh["input"] = self.input return hsh - def start(self, started_event_id): + def start(self, started_event_id: int) -> None: self.state = "STARTED" self.started_event_id = started_event_id - def complete(self): + def complete(self) -> None: self._check_workflow_execution_open() self.state = "COMPLETED" - def fail(self): + def fail(self) -> None: self._check_workflow_execution_open() self.state = "FAILED" - def reset_heartbeat_clock(self): + def reset_heartbeat_clock(self) -> None: self.last_heartbeat_timestamp = unix_time() - def first_timeout(self): + def first_timeout(self) -> Optional[Timeout]: if not self.open or not self.workflow_execution.open: return None @@ -82,13 +87,14 @@ class ActivityTask(BaseModel): _timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT") if _timeout.reached: return _timeout + return None - def process_timeouts(self): + def process_timeouts(self) -> None: _timeout = self.first_timeout() if _timeout: self.timeout(_timeout) - def timeout(self, _timeout): + def timeout(self, _timeout: Timeout) -> None: self._check_workflow_execution_open() self.state = "TIMED_OUT" self.timeout_type = _timeout.kind diff --git a/moto/swf/models/activity_type.py b/moto/swf/models/activity_type.py index 95a83ca7a..fcee00fb9 100644 --- a/moto/swf/models/activity_type.py +++ b/moto/swf/models/activity_type.py @@ -1,9 +1,10 @@ +from typing import List from .generic_type import GenericType class ActivityType(GenericType): @property - def _configuration_keys(self): + def _configuration_keys(self) -> List[str]: return [ "defaultTaskHeartbeatTimeout", "defaultTaskScheduleToCloseTimeout", @@ -12,5 +13,5 @@ class ActivityType(GenericType): ] @property - def kind(self): + def kind(self) -> str: return "activity" diff --git a/moto/swf/models/decision_task.py b/moto/swf/models/decision_task.py index ead089c5f..34f7e67c1 100644 --- a/moto/swf/models/decision_task.py +++ b/moto/swf/models/decision_task.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any, Dict, Optional, TYPE_CHECKING from moto.core import BaseModel from moto.core.utils import unix_time @@ -7,16 +8,21 @@ from ..exceptions import SWFWorkflowExecutionClosedError from .timeout import Timeout +if TYPE_CHECKING: + from .workflow_execution import WorkflowExecution + class DecisionTask(BaseModel): - def __init__(self, workflow_execution, scheduled_event_id): + def __init__( + self, workflow_execution: "WorkflowExecution", scheduled_event_id: int + ): self.workflow_execution = workflow_execution self.workflow_type = workflow_execution.workflow_type self.task_token = str(mock_random.uuid4()) self.scheduled_event_id = scheduled_event_id - self.previous_started_event_id = None - self.started_event_id = None - self.started_timestamp = None + self.previous_started_event_id: Optional[int] = None + self.started_event_id: Optional[int] = None + self.started_timestamp: Optional[float] = None self.start_to_close_timeout = ( self.workflow_execution.task_start_to_close_timeout ) @@ -24,19 +30,19 @@ class DecisionTask(BaseModel): # this is *not* necessarily coherent with workflow execution history, # but that shouldn't be a problem for tests self.scheduled_at = datetime.utcnow() - self.timeout_type = None + self.timeout_type: Optional[str] = None @property - def started(self): + def started(self) -> bool: return self.state == "STARTED" - def _check_workflow_execution_open(self): + def _check_workflow_execution_open(self) -> None: if not self.workflow_execution.open: raise SWFWorkflowExecutionClosedError() - def to_full_dict(self, reverse_order=False): + def to_full_dict(self, reverse_order: bool = False) -> Dict[str, Any]: events = self.workflow_execution.events(reverse_order=reverse_order) - hsh = { + hsh: Dict[str, Any] = { "events": [evt.to_dict() for evt in events], "taskToken": self.task_token, "workflowExecution": self.workflow_execution.to_short_dict(), @@ -48,31 +54,34 @@ class DecisionTask(BaseModel): hsh["startedEventId"] = self.started_event_id return hsh - def start(self, started_event_id, previous_started_event_id=None): + def start( + self, started_event_id: int, previous_started_event_id: Optional[int] = None + ) -> None: self.state = "STARTED" self.started_timestamp = unix_time() self.started_event_id = started_event_id self.previous_started_event_id = previous_started_event_id - def complete(self): + def complete(self) -> None: self._check_workflow_execution_open() self.state = "COMPLETED" - def first_timeout(self): + def first_timeout(self) -> Optional[Timeout]: if not self.started or not self.workflow_execution.open: return None # TODO: handle the "NONE" case - start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) + start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) # type: ignore _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout + return None - def process_timeouts(self): + def process_timeouts(self) -> None: _timeout = self.first_timeout() if _timeout: self.timeout(_timeout) - def timeout(self, _timeout): + def timeout(self, _timeout: Timeout) -> None: self._check_workflow_execution_open() self.state = "TIMED_OUT" self.timeout_type = _timeout.kind diff --git a/moto/swf/models/domain.py b/moto/swf/models/domain.py index 596983a1d..bb12d7947 100644 --- a/moto/swf/models/domain.py +++ b/moto/swf/models/domain.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Any, Dict, List, Optional, TYPE_CHECKING from moto.core import BaseModel from ..exceptions import ( @@ -6,29 +7,45 @@ from ..exceptions import ( SWFWorkflowExecutionAlreadyStartedFault, ) +if TYPE_CHECKING: + from .activity_task import ActivityTask + from .decision_task import DecisionTask + from .generic_type import GenericType, TGenericType + from .workflow_execution import WorkflowExecution + class Domain(BaseModel): - def __init__(self, name, retention, account_id, region_name, description=None): + def __init__( + self, + name: str, + retention: int, + account_id: str, + region_name: str, + description: Optional[str] = None, + ): self.name = name self.retention = retention self.account_id = account_id self.region_name = region_name self.description = description self.status = "REGISTERED" - self.types = {"activity": defaultdict(dict), "workflow": defaultdict(dict)} + self.types: Dict[str, Dict[str, Dict[str, GenericType]]] = { + "activity": defaultdict(dict), + "workflow": defaultdict(dict), + } # Workflow executions have an id, which unicity is guaranteed # at domain level (not super clear in the docs, but I checked # that against SWF API) ; hence the storage method as a dict # of "workflow_id (client determined)" => WorkflowExecution() # here. - self.workflow_executions = [] - self.activity_task_lists = {} - self.decision_task_lists = {} + self.workflow_executions: List["WorkflowExecution"] = [] + self.activity_task_lists: Dict[List[str], List["ActivityTask"]] = {} + self.decision_task_lists: Dict[str, List["DecisionTask"]] = {} - def __repr__(self): + def __repr__(self) -> str: return f"Domain(name: {self.name}, status: {self.status})" - def to_short_dict(self): + def to_short_dict(self) -> Dict[str, str]: hsh = {"name": self.name, "status": self.status} if self.description: hsh["description"] = self.description @@ -37,13 +54,13 @@ class Domain(BaseModel): ] = f"arn:aws:swf:{self.region_name}:{self.account_id}:/domain/{self.name}" return hsh - def to_full_dict(self): + def to_full_dict(self) -> Dict[str, Any]: return { "domainInfo": self.to_short_dict(), "configuration": {"workflowExecutionRetentionPeriodInDays": self.retention}, } - def get_type(self, kind, name, version, ignore_empty=False): + def get_type(self, kind: str, name: str, version: str, ignore_empty: bool = False) -> "GenericType": # type: ignore try: return self.types[kind][name][version] except KeyError: @@ -53,26 +70,30 @@ class Domain(BaseModel): f"{kind.capitalize()}Type=[name={name}, version={version}]", ) - def add_type(self, _type): + def add_type(self, _type: "TGenericType") -> None: self.types[_type.kind][_type.name][_type.version] = _type - def find_types(self, kind, status): + def find_types(self, kind: str, status: str) -> List["GenericType"]: _all = [] for family in self.types[kind].values(): for _type in family.values(): - if _type.status == status: + if _type.status == status: # type: ignore _all.append(_type) return _all - def add_workflow_execution(self, workflow_execution): + def add_workflow_execution(self, workflow_execution: "WorkflowExecution") -> None: _id = workflow_execution.workflow_id if self.get_workflow_execution(_id, raise_if_none=False): raise SWFWorkflowExecutionAlreadyStartedFault() self.workflow_executions.append(workflow_execution) def get_workflow_execution( - self, workflow_id, run_id=None, raise_if_none=True, raise_if_closed=False - ): + self, + workflow_id: str, + run_id: Optional[str] = None, + raise_if_none: bool = True, + raise_if_closed: bool = False, + ) -> Optional["WorkflowExecution"]: # query if run_id: _all = [ @@ -103,26 +124,28 @@ class Domain(BaseModel): # at last return workflow execution return wfe - def add_to_activity_task_list(self, task_list, obj): + def add_to_activity_task_list( + self, task_list: List[str], obj: "ActivityTask" + ) -> None: if task_list not in self.activity_task_lists: self.activity_task_lists[task_list] = [] self.activity_task_lists[task_list].append(obj) @property - def activity_tasks(self): - _all = [] + def activity_tasks(self) -> List["ActivityTask"]: + _all: List["ActivityTask"] = [] for tasks in self.activity_task_lists.values(): _all += tasks return _all - def add_to_decision_task_list(self, task_list, obj): + def add_to_decision_task_list(self, task_list: str, obj: "DecisionTask") -> None: if task_list not in self.decision_task_lists: self.decision_task_lists[task_list] = [] self.decision_task_lists[task_list].append(obj) @property - def decision_tasks(self): - _all = [] + def decision_tasks(self) -> List["DecisionTask"]: + _all: List["DecisionTask"] = [] for tasks in self.decision_task_lists.values(): _all += tasks return _all diff --git a/moto/swf/models/generic_type.py b/moto/swf/models/generic_type.py index f5797193a..ac97b1976 100644 --- a/moto/swf/models/generic_type.py +++ b/moto/swf/models/generic_type.py @@ -1,9 +1,10 @@ +from typing import Any, Dict, List, TypeVar from moto.core import BaseModel from moto.core.utils import camelcase_to_underscores class GenericType(BaseModel): - def __init__(self, name, version, **kwargs): + def __init__(self, name: str, version: str, **kwargs: Any): self.name = name self.version = version self.status = "REGISTERED" @@ -19,24 +20,24 @@ class GenericType(BaseModel): if not hasattr(self, "task_list"): self.task_list = None - def __repr__(self): + def __repr__(self) -> str: cls = self.__class__.__name__ attrs = f"name: {self.name}, version: {self.version}, status: {self.status}" return f"{cls}({attrs})" @property - def kind(self): + def kind(self) -> str: raise NotImplementedError() @property - def _configuration_keys(self): + def _configuration_keys(self) -> List[str]: raise NotImplementedError() - def to_short_dict(self): + def to_short_dict(self) -> Dict[str, str]: return {"name": self.name, "version": self.version} - def to_medium_dict(self): - hsh = { + def to_medium_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = { f"{self.kind}Type": self.to_short_dict(), "creationDate": 1420066800, "status": self.status, @@ -47,8 +48,8 @@ class GenericType(BaseModel): hsh["description"] = self.description return hsh - def to_full_dict(self): - hsh = {"typeInfo": self.to_medium_dict(), "configuration": {}} + def to_full_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = {"typeInfo": self.to_medium_dict(), "configuration": {}} if self.task_list: hsh["configuration"]["defaultTaskList"] = {"name": self.task_list} for key in self._configuration_keys: @@ -57,3 +58,6 @@ class GenericType(BaseModel): continue hsh["configuration"][key] = getattr(self, attr) return hsh + + +TGenericType = TypeVar("TGenericType", bound=GenericType) diff --git a/moto/swf/models/history_event.py b/moto/swf/models/history_event.py index 36b13fe08..3d42c41d8 100644 --- a/moto/swf/models/history_event.py +++ b/moto/swf/models/history_event.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Optional from moto.core import BaseModel from moto.core.utils import underscores_to_camelcase, unix_time @@ -36,7 +37,13 @@ SUPPORTED_HISTORY_EVENT_TYPES = ( class HistoryEvent(BaseModel): - def __init__(self, event_id, event_type, event_timestamp=None, **kwargs): + def __init__( + self, + event_id: int, + event_type: str, + event_timestamp: Optional[float] = None, + **kwargs: Any, + ): if event_type not in SUPPORTED_HISTORY_EVENT_TYPES: raise NotImplementedError( f"HistoryEvent does not implement attributes for type '{event_type}'" @@ -60,7 +67,7 @@ class HistoryEvent(BaseModel): value = value.to_short_dict() self.event_attributes[camel_key] = value - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "eventId": self.event_id, "eventType": self.event_type, @@ -68,6 +75,6 @@ class HistoryEvent(BaseModel): self._attributes_key(): self.event_attributes, } - def _attributes_key(self): + def _attributes_key(self) -> str: key = f"{self.event_type}EventAttributes" return decapitalize(key) diff --git a/moto/swf/models/timeout.py b/moto/swf/models/timeout.py index bc576bb64..7e1c42e41 100644 --- a/moto/swf/models/timeout.py +++ b/moto/swf/models/timeout.py @@ -1,13 +1,14 @@ +from typing import Any from moto.core import BaseModel from moto.core.utils import unix_time class Timeout(BaseModel): - def __init__(self, obj, timestamp, kind): + def __init__(self, obj: Any, timestamp: float, kind: str): self.obj = obj self.timestamp = timestamp self.kind = kind @property - def reached(self): + def reached(self) -> bool: return unix_time() >= self.timestamp diff --git a/moto/swf/models/timer.py b/moto/swf/models/timer.py index 05a8fea37..780e58992 100644 --- a/moto/swf/models/timer.py +++ b/moto/swf/models/timer.py @@ -1,16 +1,17 @@ +from threading import Timer as ThreadingTimer from moto.core import BaseModel class Timer(BaseModel): - def __init__(self, background_timer, started_event_id): + def __init__(self, background_timer: ThreadingTimer, started_event_id: int): self.background_timer = background_timer self.started_event_id = started_event_id - def start(self): + def start(self) -> None: return self.background_timer.start() - def is_alive(self): + def is_alive(self) -> bool: return self.background_timer.is_alive() - def cancel(self): + def cancel(self) -> None: return self.background_timer.cancel() diff --git a/moto/swf/models/workflow_execution.py b/moto/swf/models/workflow_execution.py index 6f9f50112..ecc2b552a 100644 --- a/moto/swf/models/workflow_execution.py +++ b/moto/swf/models/workflow_execution.py @@ -1,4 +1,5 @@ from threading import Timer as ThreadingTimer, Lock +from typing import Any, Dict, Iterable, List, Optional from moto.core import BaseModel from moto.core.utils import camelcase_to_underscores, unix_time @@ -14,9 +15,11 @@ from ..utils import decapitalize from .activity_task import ActivityTask from .activity_type import ActivityType from .decision_task import DecisionTask +from .domain import Domain from .history_event import HistoryEvent from .timeout import Timeout from .timer import Timer +from .workflow_type import WorkflowType # TODO: extract decision related logic into a Decision class @@ -40,7 +43,13 @@ class WorkflowExecution(BaseModel): "CancelWorkflowExecution", ] - def __init__(self, domain, workflow_type, workflow_id, **kwargs): + def __init__( + self, + domain: Domain, + workflow_type: "WorkflowType", + workflow_id: str, + **kwargs: Any, + ): self.domain = domain self.workflow_id = workflow_id self.run_id = mock_random.uuid4().hex @@ -49,27 +58,33 @@ class WorkflowExecution(BaseModel): # TODO: check valid values among: # COMPLETED | FAILED | CANCELED | TERMINATED | CONTINUED_AS_NEW | TIMED_OUT # TODO: implement them all - self.close_cause = None - self.close_status = None - self.close_timestamp = None + self.close_cause: Optional[str] = None + self.close_status: Optional[str] = None + self.close_timestamp: Optional[float] = None self.execution_status = "OPEN" - self.latest_activity_task_timestamp = None - self.latest_execution_context = None + self.latest_activity_task_timestamp: Optional[float] = None + self.latest_execution_context: Optional[str] = None self.parent = None - self.start_timestamp = None + self.start_timestamp: Optional[float] = None self.tag_list = kwargs.get("tag_list", None) or [] - self.timeout_type = None + self.timeout_type: Optional[str] = None self.workflow_type = workflow_type # args processing # NB: the order follows boto/SWF order of exceptions appearance (if no # param is set, # SWF will raise DefaultUndefinedFault errors in the # same order as the few lines that follow) - self._set_from_kwargs_or_workflow_type( + self.execution_start_to_close_timeout = self._get_from_kwargs_or_workflow_type( kwargs, "execution_start_to_close_timeout" ) - self._set_from_kwargs_or_workflow_type(kwargs, "task_list", "task_list") - self._set_from_kwargs_or_workflow_type(kwargs, "task_start_to_close_timeout") - self._set_from_kwargs_or_workflow_type(kwargs, "child_policy") + self.task_list = self._get_from_kwargs_or_workflow_type( + kwargs, "task_list", "task_list" + ) + self.task_start_to_close_timeout = self._get_from_kwargs_or_workflow_type( + kwargs, "task_start_to_close_timeout" + ) + self.child_policy = self._get_from_kwargs_or_workflow_type( + kwargs, "child_policy" + ) self.input = kwargs.get("workflow_input") # counters self.open_counts = { @@ -80,20 +95,23 @@ class WorkflowExecution(BaseModel): "openLambdaFunctions": 0, } # events - self._events = [] + self._events: List[HistoryEvent] = [] # child workflows - self.child_workflow_executions = [] - self._previous_started_event_id = None + self.child_workflow_executions: List[WorkflowExecution] = [] + self._previous_started_event_id: Optional[int] = None # timers/thread utils self.threading_lock = Lock() - self._timers = {} + self._timers: Dict[str, Timer] = {} - def __repr__(self): + def __repr__(self) -> str: return f"WorkflowExecution(run_id: {self.run_id})" - def _set_from_kwargs_or_workflow_type( - self, kwargs, local_key, workflow_type_key=None - ): + def _get_from_kwargs_or_workflow_type( + self, + kwargs: Dict[str, Any], + local_key: str, + workflow_type_key: Optional[str] = None, + ) -> Any: if workflow_type_key is None: workflow_type_key = "default_" + local_key value = kwargs.get(local_key) @@ -101,10 +119,10 @@ class WorkflowExecution(BaseModel): value = getattr(self.workflow_type, workflow_type_key) if not value: raise SWFDefaultUndefinedFault(local_key) - setattr(self, local_key, value) + return value @property - def _configuration_keys(self): + def _configuration_keys(self) -> List[str]: return [ "executionStartToCloseTimeout", "childPolicy", @@ -112,11 +130,11 @@ class WorkflowExecution(BaseModel): "taskStartToCloseTimeout", ] - def to_short_dict(self): + def to_short_dict(self) -> Dict[str, str]: return {"workflowId": self.workflow_id, "runId": self.run_id} - def to_medium_dict(self): - hsh = { + def to_medium_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = { "execution": self.to_short_dict(), "workflowType": self.workflow_type.to_short_dict(), "startTimestamp": 1420066800.123, @@ -127,8 +145,8 @@ class WorkflowExecution(BaseModel): hsh["tagList"] = self.tag_list return hsh - def to_full_dict(self): - hsh = { + def to_full_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = { "executionInfo": self.to_medium_dict(), "executionConfiguration": {"taskList": {"name": self.task_list}}, } @@ -153,8 +171,8 @@ class WorkflowExecution(BaseModel): hsh["latestActivityTaskTimestamp"] = self.latest_activity_task_timestamp return hsh - def to_list_dict(self): - hsh = { + def to_list_dict(self) -> Dict[str, Any]: + hsh: Dict[str, Any] = { "execution": {"workflowId": self.workflow_id, "runId": self.run_id}, "workflowType": self.workflow_type.to_short_dict(), "startTimestamp": self.start_timestamp, @@ -171,7 +189,7 @@ class WorkflowExecution(BaseModel): hsh["closeTimestamp"] = self.close_timestamp return hsh - def _process_timeouts(self): + def _process_timeouts(self) -> None: """ SWF timeouts can happen on different objects (workflow executions, activity tasks, decision tasks) and should be processed in order. @@ -187,26 +205,24 @@ class WorkflowExecution(BaseModel): triggered, process it, then make the workflow state progress and repeat the whole process. """ - timeout_candidates = [] - # workflow execution timeout - timeout_candidates.append(self.first_timeout()) + timeout_candidates_or_none = [self.first_timeout()] # decision tasks timeouts - for task in self.decision_tasks: - timeout_candidates.append(task.first_timeout()) + for d_task in self.decision_tasks: + timeout_candidates_or_none.append(d_task.first_timeout()) # activity tasks timeouts - for task in self.activity_tasks: - timeout_candidates.append(task.first_timeout()) + for a_task in self.activity_tasks: + timeout_candidates_or_none.append(a_task.first_timeout()) # remove blank values (foo.first_timeout() is a Timeout or None) - timeout_candidates = list(filter(None, timeout_candidates)) + timeout_candidates = list(filter(None, timeout_candidates_or_none)) # now find the first timeout to process first_timeout = None if timeout_candidates: - first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) + first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) # type: ignore if first_timeout: should_schedule_decision_next = False @@ -229,17 +245,17 @@ class WorkflowExecution(BaseModel): # timeout should be processed self._process_timeouts() - def events(self, reverse_order=False): + def events(self, reverse_order: bool = False) -> Iterable[HistoryEvent]: if reverse_order: return reversed(self._events) else: return self._events - def next_event_id(self): + def next_event_id(self) -> int: event_ids = [evt.event_id for evt in self._events] return max(event_ids or [0]) + 1 - def _add_event(self, *args, **kwargs): + def _add_event(self, *args: Any, **kwargs: Any) -> HistoryEvent: # lock here because the fire_timer function is called # async, and want to ensure uniqueness in event ids with self.threading_lock: @@ -247,7 +263,7 @@ class WorkflowExecution(BaseModel): self._events.append(evt) return evt - def start(self): + def start(self) -> None: self.start_timestamp = unix_time() self._add_event( "WorkflowExecutionStarted", @@ -262,7 +278,7 @@ class WorkflowExecution(BaseModel): ) self.schedule_decision_task() - def _schedule_decision_task(self): + def _schedule_decision_task(self) -> None: has_scheduled_task = False has_started_task = False for task in self.decision_tasks: @@ -285,30 +301,32 @@ class WorkflowExecution(BaseModel): ) self.open_counts["openDecisionTasks"] += 1 - def schedule_decision_task(self): + def schedule_decision_task(self) -> None: self._schedule_decision_task() # Shortcut for tests: helps having auto-starting decision tasks when needed - def schedule_and_start_decision_task(self, identity=None): + def schedule_and_start_decision_task(self, identity: Optional[str] = None) -> None: self._schedule_decision_task() decision_task = self.decision_tasks[-1] self.start_decision_task(decision_task.task_token, identity=identity) @property - def decision_tasks(self): + def decision_tasks(self) -> List[DecisionTask]: return [t for t in self.domain.decision_tasks if t.workflow_execution == self] @property - def activity_tasks(self): + def activity_tasks(self) -> List[ActivityTask]: return [t for t in self.domain.activity_tasks if t.workflow_execution == self] - def _find_decision_task(self, task_token): + def _find_decision_task(self, task_token: str) -> DecisionTask: for dt in self.decision_tasks: if dt.task_token == task_token: return dt raise ValueError(f"No decision task with token: {task_token}") - def start_decision_task(self, task_token, identity=None): + def start_decision_task( + self, task_token: str, identity: Optional[str] = None + ) -> None: dt = self._find_decision_task(task_token) evt = self._add_event( "DecisionTaskStarted", @@ -319,8 +337,11 @@ class WorkflowExecution(BaseModel): self._previous_started_event_id = evt.event_id def complete_decision_task( - self, task_token, decisions=None, execution_context=None - ): + self, + task_token: str, + decisions: Optional[List[Dict[str, Any]]] = None, + execution_context: Optional[str] = None, + ) -> None: # 'decisions' can be None per boto.swf defaults, so replace it with something iterable if not decisions: decisions = [] @@ -341,7 +362,9 @@ class WorkflowExecution(BaseModel): self.schedule_decision_task() self.latest_execution_context = execution_context - def _check_decision_attributes(self, kind, value, decision_id): + def _check_decision_attributes( + self, kind: str, value: Dict[str, Any], decision_id: int + ) -> List[Dict[str, str]]: problems = [] constraints = DECISIONS_FIELDS.get(kind, {}) for key, constraint in constraints.items(): @@ -354,7 +377,7 @@ class WorkflowExecution(BaseModel): ) return problems - def validate_decisions(self, decisions): + def validate_decisions(self, decisions: List[Dict[str, Any]]) -> None: """ Performs some basic validations on decisions. The real SWF service seems to break early and *not* process any decision if there's a @@ -404,7 +427,7 @@ class WorkflowExecution(BaseModel): if any(problems): raise SWFDecisionValidationException(problems) - def handle_decisions(self, event_id, decisions): + def handle_decisions(self, event_id: int, decisions: List[Dict[str, Any]]) -> None: """ Handles a Decision according to SWF docs. See: http://docs.aws.amazon.com/amazonswf/latest/apireference/API_Decision.html @@ -440,7 +463,7 @@ class WorkflowExecution(BaseModel): # finally decrement counter if and only if everything went well self.open_counts["openDecisionTasks"] -= 1 - def complete(self, event_id, result=None): + def complete(self, event_id: int, result: Any = None) -> None: self.execution_status = "CLOSED" self.close_status = "COMPLETED" self.close_timestamp = unix_time() @@ -450,7 +473,9 @@ class WorkflowExecution(BaseModel): result=result, ) - def fail(self, event_id, details=None, reason=None): + def fail( + self, event_id: int, details: Any = None, reason: Optional[str] = None + ) -> None: # TODO: implement length constraints on details/reason self.execution_status = "CLOSED" self.close_status = "FAILED" @@ -462,7 +487,7 @@ class WorkflowExecution(BaseModel): reason=reason, ) - def cancel(self, event_id, details=None): + def cancel(self, event_id: int, details: Any = None) -> None: # TODO: implement length constraints on details self.cancel_requested = True # Can only cancel if there are no other pending desicion tasks @@ -483,9 +508,9 @@ class WorkflowExecution(BaseModel): details=details, ) - def schedule_activity_task(self, event_id, attributes): + def schedule_activity_task(self, event_id: int, attributes: Dict[str, Any]) -> None: # Helper function to avoid repeating ourselves in the next sections - def fail_schedule_activity_task(_type, _cause): + def fail_schedule_activity_task(_type: "ActivityType", _cause: str) -> None: # TODO: implement other possible failure mode: OPEN_ACTIVITIES_LIMIT_EXCEEDED # NB: some failure modes are not implemented and probably won't be implemented in # the future, such as ACTIVITY_CREATION_RATE_EXCEEDED or @@ -499,7 +524,7 @@ class WorkflowExecution(BaseModel): ) self.should_schedule_decision_next = True - activity_type = self.domain.get_type( + activity_type: ActivityType = self.domain.get_type( # type: ignore[assignment] "activity", attributes["activityType"]["name"], attributes["activityType"]["version"], @@ -576,13 +601,13 @@ class WorkflowExecution(BaseModel): self.open_counts["openActivityTasks"] += 1 self.latest_activity_task_timestamp = unix_time() - def _find_activity_task(self, task_token): + def _find_activity_task(self, task_token: str) -> ActivityTask: for task in self.activity_tasks: if task.task_token == task_token: return task raise ValueError(f"No activity task with token: {task_token}") - def start_activity_task(self, task_token, identity=None): + def start_activity_task(self, task_token: str, identity: Any = None) -> None: task = self._find_activity_task(task_token) evt = self._add_event( "ActivityTaskStarted", @@ -591,7 +616,7 @@ class WorkflowExecution(BaseModel): ) task.start(evt.event_id) - def complete_activity_task(self, task_token, result=None): + def complete_activity_task(self, task_token: str, result: Any = None) -> None: task = self._find_activity_task(task_token) self._add_event( "ActivityTaskCompleted", @@ -604,7 +629,9 @@ class WorkflowExecution(BaseModel): # TODO: ensure we don't schedule multiple decisions at the same time! self.schedule_decision_task() - def fail_activity_task(self, task_token, reason=None, details=None): + def fail_activity_task( + self, task_token: str, reason: Optional[str] = None, details: Any = None + ) -> None: task = self._find_activity_task(task_token) self._add_event( "ActivityTaskFailed", @@ -618,7 +645,12 @@ class WorkflowExecution(BaseModel): # TODO: ensure we don't schedule multiple decisions at the same time! self.schedule_decision_task() - def terminate(self, child_policy=None, details=None, reason=None): + def terminate( + self, + child_policy: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + reason: Optional[str] = None, + ) -> None: # TODO: handle child policy for child workflows here # TODO: handle cause="CHILD_POLICY_APPLIED" # Until this, we set cause manually to "OPERATOR_INITIATED" @@ -636,13 +668,13 @@ class WorkflowExecution(BaseModel): self.close_status = "TERMINATED" self.close_cause = "OPERATOR_INITIATED" - def signal(self, signal_name, workflow_input): + def signal(self, signal_name: str, workflow_input: Dict[str, Any]) -> None: self._add_event( "WorkflowExecutionSignaled", signal_name=signal_name, input=workflow_input ) self.schedule_decision_task() - def first_timeout(self): + def first_timeout(self) -> Optional[Timeout]: if not self.open or not self.start_timestamp: return None start_to_close_at = self.start_timestamp + int( @@ -651,8 +683,9 @@ class WorkflowExecution(BaseModel): _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout + return None - def timeout(self, timeout): + def timeout(self, timeout: Timeout) -> None: # TODO: process child policy on child workflows here or in the # triggering function self.execution_status = "CLOSED" @@ -665,7 +698,7 @@ class WorkflowExecution(BaseModel): timeout_type=self.timeout_type, ) - def timeout_decision_task(self, _timeout): + def timeout_decision_task(self, _timeout: Timeout) -> None: task = _timeout.obj task.timeout(_timeout) self._add_event( @@ -676,7 +709,7 @@ class WorkflowExecution(BaseModel): timeout_type=task.timeout_type, ) - def timeout_activity_task(self, _timeout): + def timeout_activity_task(self, _timeout: Timeout) -> None: task = _timeout.obj task.timeout(_timeout) self._add_event( @@ -688,7 +721,7 @@ class WorkflowExecution(BaseModel): timeout_type=task.timeout_type, ) - def record_marker(self, event_id, attributes): + def record_marker(self, event_id: int, attributes: Dict[str, Any]) -> None: self._add_event( "MarkerRecorded", decision_task_completed_event_id=event_id, @@ -696,7 +729,7 @@ class WorkflowExecution(BaseModel): marker_name=attributes["markerName"], ) - def start_timer(self, event_id, attributes): + def start_timer(self, event_id: int, attributes: Dict[str, Any]) -> None: timer_id = attributes["timerId"] existing_timer = self._timers.get(timer_id) if existing_timer and existing_timer.is_alive(): @@ -725,14 +758,14 @@ class WorkflowExecution(BaseModel): self._timers[timer_id] = workflow_timer workflow_timer.start() - def _fire_timer(self, started_event_id, timer_id): + def _fire_timer(self, started_event_id: int, timer_id: str) -> None: self._add_event( "TimerFired", started_event_id=started_event_id, timer_id=timer_id ) self._timers.pop(timer_id) self._schedule_decision_task() - def cancel_timer(self, event_id, timer_id): + def cancel_timer(self, event_id: int, timer_id: str) -> None: requested_timer = self._timers.get(timer_id) if not requested_timer or not requested_timer.is_alive(): # TODO there are 2 failure states @@ -754,5 +787,5 @@ class WorkflowExecution(BaseModel): ) @property - def open(self): + def open(self) -> bool: return self.execution_status == "OPEN" diff --git a/moto/swf/models/workflow_type.py b/moto/swf/models/workflow_type.py index 137f0e221..4427f7f03 100644 --- a/moto/swf/models/workflow_type.py +++ b/moto/swf/models/workflow_type.py @@ -1,9 +1,10 @@ +from typing import List from .generic_type import GenericType class WorkflowType(GenericType): @property - def _configuration_keys(self): + def _configuration_keys(self) -> List[str]: return [ "defaultChildPolicy", "defaultExecutionStartToCloseTimeout", @@ -13,5 +14,5 @@ class WorkflowType(GenericType): ] @property - def kind(self): + def kind(self) -> str: return "workflow" diff --git a/moto/swf/responses.py b/moto/swf/responses.py index b6512560b..c856d6cbb 100644 --- a/moto/swf/responses.py +++ b/moto/swf/responses.py @@ -1,74 +1,75 @@ import json +from typing import Any, List from moto.core.responses import BaseResponse from .exceptions import SWFSerializationException, SWFValidationException -from .models import swf_backends +from .models import swf_backends, SWFBackend, GenericType class SWFResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="swf") @property - def swf_backend(self): + def swf_backend(self) -> SWFBackend: return swf_backends[self.current_account][self.region] # SWF parameters are passed through a JSON body, so let's ease retrieval @property - def _params(self): + def _params(self) -> Any: # type: ignore[misc] return json.loads(self.body) - def _check_int(self, parameter): + def _check_int(self, parameter: Any) -> None: if not isinstance(parameter, int): raise SWFSerializationException(parameter) - def _check_float_or_int(self, parameter): + def _check_float_or_int(self, parameter: Any) -> None: if not isinstance(parameter, float): if not isinstance(parameter, int): raise SWFSerializationException(parameter) - def _check_none_or_string(self, parameter): + def _check_none_or_string(self, parameter: Any) -> None: if parameter is not None: self._check_string(parameter) - def _check_string(self, parameter): + def _check_string(self, parameter: Any) -> None: if not isinstance(parameter, str): raise SWFSerializationException(parameter) - def _check_none_or_list_of_strings(self, parameter): + def _check_none_or_list_of_strings(self, parameter: Any) -> None: if parameter is not None: self._check_list_of_strings(parameter) - def _check_list_of_strings(self, parameter): + def _check_list_of_strings(self, parameter: Any) -> None: if not isinstance(parameter, list): raise SWFSerializationException(parameter) for i in parameter: if not isinstance(i, str): raise SWFSerializationException(parameter) - def _check_exclusivity(self, **kwargs): + def _check_exclusivity(self, **kwargs: Any) -> None: if list(kwargs.values()).count(None) >= len(kwargs) - 1: return keys = kwargs.keys() if len(keys) == 2: - message = f"Cannot specify both a {keys[0]} and a {keys[1]}" + message = f"Cannot specify both a {keys[0]} and a {keys[1]}" # type: ignore else: message = f"Cannot specify more than one exclusive filters in the same query: {keys}" - raise SWFValidationException(message) + raise SWFValidationException(message) - def _list_types(self, kind): + def _list_types(self, kind: str) -> str: domain_name = self._params["domain"] status = self._params["registrationStatus"] reverse_order = self._params.get("reverseOrder", None) self._check_string(domain_name) self._check_string(status) - types = self.swf_backend.list_types( + types: List[GenericType] = self.swf_backend.list_types( kind, domain_name, status, reverse_order=reverse_order ) return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]}) - def _describe_type(self, kind): + def _describe_type(self, kind: str) -> str: domain = self._params["domain"] _type_args = self._params[f"{kind}Type"] name = _type_args["name"] @@ -80,7 +81,7 @@ class SWFResponse(BaseResponse): return json.dumps(_type.to_full_dict()) - def _deprecate_type(self, kind): + def _deprecate_type(self, kind: str) -> str: domain = self._params["domain"] _type_args = self._params[f"{kind}Type"] name = _type_args["name"] @@ -91,7 +92,7 @@ class SWFResponse(BaseResponse): self.swf_backend.deprecate_type(kind, domain, name, version) return "" - def _undeprecate_type(self, kind): + def _undeprecate_type(self, kind: str) -> str: domain = self._params["domain"] _type_args = self._params[f"{kind}Type"] name = _type_args["name"] @@ -103,7 +104,7 @@ class SWFResponse(BaseResponse): return "" # TODO: implement pagination - def list_domains(self): + def list_domains(self) -> str: status = self._params["registrationStatus"] self._check_string(status) reverse_order = self._params.get("reverseOrder", None) @@ -112,7 +113,7 @@ class SWFResponse(BaseResponse): {"domainInfos": [domain.to_short_dict() for domain in domains]} ) - def list_closed_workflow_executions(self): + def list_closed_workflow_executions(self) -> str: domain = self._params["domain"] start_time_filter = self._params.get("startTimeFilter", None) close_time_filter = self._params.get("closeTimeFilter", None) @@ -166,7 +167,7 @@ class SWFResponse(BaseResponse): {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} ) - def list_open_workflow_executions(self): + def list_open_workflow_executions(self) -> str: domain = self._params["domain"] start_time_filter = self._params["startTimeFilter"] execution_filter = self._params.get("executionFilter", None) @@ -204,7 +205,7 @@ class SWFResponse(BaseResponse): {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} ) - def register_domain(self): + def register_domain(self) -> str: name = self._params["name"] retention = self._params["workflowExecutionRetentionPeriodInDays"] description = self._params.get("description") @@ -214,29 +215,29 @@ class SWFResponse(BaseResponse): self.swf_backend.register_domain(name, retention, description=description) return "" - def deprecate_domain(self): + def deprecate_domain(self) -> str: name = self._params["name"] self._check_string(name) self.swf_backend.deprecate_domain(name) return "" - def undeprecate_domain(self): + def undeprecate_domain(self) -> str: name = self._params["name"] self._check_string(name) self.swf_backend.undeprecate_domain(name) return "" - def describe_domain(self): + def describe_domain(self) -> str: name = self._params["name"] self._check_string(name) domain = self.swf_backend.describe_domain(name) - return json.dumps(domain.to_full_dict()) + return json.dumps(domain.to_full_dict()) # type: ignore[union-attr] # TODO: implement pagination - def list_activity_types(self): + def list_activity_types(self) -> str: return self._list_types("activity") - def register_activity_type(self): + def register_activity_type(self) -> str: domain = self._params["domain"] name = self._params["name"] version = self._params["version"] @@ -282,19 +283,19 @@ class SWFResponse(BaseResponse): ) return "" - def deprecate_activity_type(self): + def deprecate_activity_type(self) -> str: return self._deprecate_type("activity") - def undeprecate_activity_type(self): + def undeprecate_activity_type(self) -> str: return self._undeprecate_type("activity") - def describe_activity_type(self): + def describe_activity_type(self) -> str: return self._describe_type("activity") - def list_workflow_types(self): + def list_workflow_types(self) -> str: return self._list_types("workflow") - def register_workflow_type(self): + def register_workflow_type(self) -> str: domain = self._params["domain"] name = self._params["name"] version = self._params["version"] @@ -340,16 +341,16 @@ class SWFResponse(BaseResponse): ) return "" - def deprecate_workflow_type(self): + def deprecate_workflow_type(self) -> str: return self._deprecate_type("workflow") - def undeprecate_workflow_type(self): + def undeprecate_workflow_type(self) -> str: return self._undeprecate_type("workflow") - def describe_workflow_type(self): + def describe_workflow_type(self) -> str: return self._describe_type("workflow") - def start_workflow_execution(self): + def start_workflow_execution(self) -> str: domain = self._params["domain"] workflow_id = self._params["workflowId"] _workflow_type = self._params["workflowType"] @@ -394,7 +395,7 @@ class SWFResponse(BaseResponse): return json.dumps({"runId": wfe.run_id}) - def describe_workflow_execution(self): + def describe_workflow_execution(self) -> str: domain_name = self._params["domain"] _workflow_execution = self._params["execution"] run_id = _workflow_execution["runId"] @@ -407,9 +408,9 @@ class SWFResponse(BaseResponse): wfe = self.swf_backend.describe_workflow_execution( domain_name, run_id, workflow_id ) - return json.dumps(wfe.to_full_dict()) + return json.dumps(wfe.to_full_dict()) # type: ignore[union-attr] - def get_workflow_execution_history(self): + def get_workflow_execution_history(self) -> str: domain_name = self._params["domain"] _workflow_execution = self._params["execution"] run_id = _workflow_execution["runId"] @@ -418,10 +419,10 @@ class SWFResponse(BaseResponse): wfe = self.swf_backend.describe_workflow_execution( domain_name, run_id, workflow_id ) - events = wfe.events(reverse_order=reverse_order) + events = wfe.events(reverse_order=reverse_order) # type: ignore[union-attr] return json.dumps({"events": [evt.to_dict() for evt in events]}) - def poll_for_decision_task(self): + def poll_for_decision_task(self) -> str: domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] identity = self._params.get("identity") @@ -438,7 +439,7 @@ class SWFResponse(BaseResponse): else: return json.dumps({"previousStartedEventId": 0, "startedEventId": 0}) - def count_pending_decision_tasks(self): + def count_pending_decision_tasks(self) -> str: domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] self._check_string(domain_name) @@ -446,7 +447,7 @@ class SWFResponse(BaseResponse): count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) - def respond_decision_task_completed(self): + def respond_decision_task_completed(self) -> str: task_token = self._params["taskToken"] execution_context = self._params.get("executionContext") decisions = self._params.get("decisions") @@ -457,7 +458,7 @@ class SWFResponse(BaseResponse): ) return "" - def poll_for_activity_task(self): + def poll_for_activity_task(self) -> str: domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] identity = self._params.get("identity") @@ -472,7 +473,7 @@ class SWFResponse(BaseResponse): else: return json.dumps({"startedEventId": 0}) - def count_pending_activity_tasks(self): + def count_pending_activity_tasks(self) -> str: domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] self._check_string(domain_name) @@ -480,7 +481,7 @@ class SWFResponse(BaseResponse): count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) - def respond_activity_task_completed(self): + def respond_activity_task_completed(self) -> str: task_token = self._params["taskToken"] result = self._params.get("result") self._check_string(task_token) @@ -488,7 +489,7 @@ class SWFResponse(BaseResponse): self.swf_backend.respond_activity_task_completed(task_token, result=result) return "" - def respond_activity_task_failed(self): + def respond_activity_task_failed(self) -> str: task_token = self._params["taskToken"] reason = self._params.get("reason") details = self._params.get("details") @@ -502,7 +503,7 @@ class SWFResponse(BaseResponse): ) return "" - def terminate_workflow_execution(self): + def terminate_workflow_execution(self) -> str: domain_name = self._params["domain"] workflow_id = self._params["workflowId"] child_policy = self._params.get("childPolicy") @@ -525,7 +526,7 @@ class SWFResponse(BaseResponse): ) return "" - def record_activity_task_heartbeat(self): + def record_activity_task_heartbeat(self) -> str: task_token = self._params["taskToken"] details = self._params.get("details") self._check_string(task_token) @@ -534,7 +535,7 @@ class SWFResponse(BaseResponse): # TODO: make it dynamic when we implement activity tasks cancellation return json.dumps({"cancelRequested": False}) - def signal_workflow_execution(self): + def signal_workflow_execution(self) -> str: domain_name = self._params["domain"] signal_name = self._params["signalName"] workflow_id = self._params["workflowId"] diff --git a/moto/swf/utils.py b/moto/swf/utils.py index 1b85f4ca9..0ff5429d3 100644 --- a/moto/swf/utils.py +++ b/moto/swf/utils.py @@ -1,2 +1,2 @@ -def decapitalize(key): +def decapitalize(key: str) -> str: return key[0].lower() + key[1:] diff --git a/setup.cfg b/setup.cfg index a7008125c..cfaf0380e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf show_column_numbers=True show_error_codes = True disable_error_code=abstract