Techdebt: MyPy SWF (#6256)

This commit is contained in:
Bert Blommers 2023-04-26 10:56:53 +00:00 committed by GitHub
parent da39d2103c
commit 9b969f7e3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 421 additions and 282 deletions

View File

@ -1,12 +1,16 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
if TYPE_CHECKING:
from .models.generic_type import GenericType
class SWFClientError(JsonRESTError): class SWFClientError(JsonRESTError):
code = 400 code = 400
class SWFUnknownResourceFault(SWFClientError): class SWFUnknownResourceFault(SWFClientError):
def __init__(self, resource_type, resource_name=None): def __init__(self, resource_type: str, resource_name: Optional[str] = None):
if resource_name: if resource_name:
message = f"Unknown {resource_type}: {resource_name}" message = f"Unknown {resource_type}: {resource_name}"
else: else:
@ -15,21 +19,21 @@ class SWFUnknownResourceFault(SWFClientError):
class SWFDomainAlreadyExistsFault(SWFClientError): class SWFDomainAlreadyExistsFault(SWFClientError):
def __init__(self, domain_name): def __init__(self, domain_name: str):
super().__init__( super().__init__(
"com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name
) )
class SWFDomainDeprecatedFault(SWFClientError): class SWFDomainDeprecatedFault(SWFClientError):
def __init__(self, domain_name): def __init__(self, domain_name: str):
super().__init__( super().__init__(
"com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name "com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name
) )
class SWFSerializationException(SWFClientError): class SWFSerializationException(SWFClientError):
def __init__(self, value): def __init__(self, value: Any):
message = "class java.lang.Foo can not be converted to an String " message = "class java.lang.Foo can not be converted to an String "
message += f" (not a real SWF exception ; happened on: {value})" message += f" (not a real SWF exception ; happened on: {value})"
__type = "com.amazonaws.swf.base.model#SerializationException" __type = "com.amazonaws.swf.base.model#SerializationException"
@ -37,7 +41,7 @@ class SWFSerializationException(SWFClientError):
class SWFTypeAlreadyExistsFault(SWFClientError): class SWFTypeAlreadyExistsFault(SWFClientError):
def __init__(self, _type): def __init__(self, _type: "GenericType"):
super().__init__( super().__init__(
"com.amazonaws.swf.base.model#TypeAlreadyExistsFault", "com.amazonaws.swf.base.model#TypeAlreadyExistsFault",
f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]", f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]",
@ -45,7 +49,7 @@ class SWFTypeAlreadyExistsFault(SWFClientError):
class SWFTypeDeprecatedFault(SWFClientError): class SWFTypeDeprecatedFault(SWFClientError):
def __init__(self, _type): def __init__(self, _type: "GenericType"):
super().__init__( super().__init__(
"com.amazonaws.swf.base.model#TypeDeprecatedFault", "com.amazonaws.swf.base.model#TypeDeprecatedFault",
f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]", f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]",
@ -53,7 +57,7 @@ class SWFTypeDeprecatedFault(SWFClientError):
class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault",
"Already Started", "Already Started",
@ -61,7 +65,7 @@ class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError):
class SWFDefaultUndefinedFault(SWFClientError): class SWFDefaultUndefinedFault(SWFClientError):
def __init__(self, key): def __init__(self, key: str):
# TODO: move that into moto.core.utils maybe? # TODO: move that into moto.core.utils maybe?
words = key.split("_") words = key.split("_")
key_camel_case = words.pop(0) key_camel_case = words.pop(0)
@ -73,12 +77,12 @@ class SWFDefaultUndefinedFault(SWFClientError):
class SWFValidationException(SWFClientError): class SWFValidationException(SWFClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("com.amazon.coral.validate#ValidationException", message) super().__init__("com.amazon.coral.validate#ValidationException", message)
class SWFDecisionValidationException(SWFClientError): class SWFDecisionValidationException(SWFClientError):
def __init__(self, problems): def __init__(self, problems: List[Dict[str, Any]]):
# messages # messages
messages = [] messages = []
for pb in problems: for pb in problems:
@ -106,5 +110,5 @@ class SWFDecisionValidationException(SWFClientError):
class SWFWorkflowExecutionClosedError(Exception): class SWFWorkflowExecutionClosedError(Exception):
def __str__(self): def __str__(self) -> str:
return repr("Cannot change this object because the WorkflowExecution is closed") return repr("Cannot change this object because the WorkflowExecution is closed")

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict from moto.core import BaseBackend, BackendDict
from ..exceptions import ( from ..exceptions import (
@ -12,7 +13,7 @@ from .activity_task import ActivityTask # noqa
from .activity_type import ActivityType # noqa from .activity_type import ActivityType # noqa
from .decision_task import DecisionTask # noqa from .decision_task import DecisionTask # noqa
from .domain import Domain # noqa from .domain import Domain # noqa
from .generic_type import GenericType # noqa from .generic_type import GenericType, TGenericType # noqa
from .history_event import HistoryEvent # noqa from .history_event import HistoryEvent # noqa
from .timeout import Timeout # noqa from .timeout import Timeout # noqa
from .timer import Timer # noqa from .timer import Timer # noqa
@ -24,33 +25,39 @@ KNOWN_SWF_TYPES = {"activity": ActivityType, "workflow": WorkflowType}
class SWFBackend(BaseBackend): class SWFBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.domains = [] self.domains: List[Domain] = []
def _get_domain(self, name, ignore_empty=False): def _get_domain(self, name: str, ignore_empty: bool = False) -> Domain:
matching = [domain for domain in self.domains if domain.name == name] matching = [domain for domain in self.domains if domain.name == name]
if not matching and not ignore_empty: if not matching and not ignore_empty:
raise SWFUnknownResourceFault("domain", name) raise SWFUnknownResourceFault("domain", name)
if matching: if matching:
return matching[0] return matching[0]
return None return None # type: ignore
def _process_timeouts(self): def _process_timeouts(self) -> None:
for domain in self.domains: for domain in self.domains:
for wfe in domain.workflow_executions: for wfe in domain.workflow_executions:
wfe._process_timeouts() wfe._process_timeouts()
def list_domains(self, status, reverse_order=None): def list_domains(
self, status: str, reverse_order: Optional[bool] = None
) -> List[Domain]:
domains = [domain for domain in self.domains if domain.status == status] domains = [domain for domain in self.domains if domain.status == status]
domains = sorted(domains, key=lambda domain: domain.name) domains = sorted(domains, key=lambda domain: domain.name)
if reverse_order: if reverse_order:
domains = reversed(domains) domains = reversed(domains) # type: ignore[assignment]
return domains return domains
def list_open_workflow_executions( def list_open_workflow_executions(
self, domain_name, maximum_page_size, tag_filter, reverse_order self,
): domain_name: str,
maximum_page_size: int,
tag_filter: Dict[str, str],
reverse_order: bool,
) -> List[WorkflowExecution]:
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
if domain.status == "DEPRECATED": if domain.status == "DEPRECATED":
@ -64,17 +71,17 @@ class SWFBackend(BaseBackend):
if tag_filter["tag"] not in open_wfe.tag_list: if tag_filter["tag"] not in open_wfe.tag_list:
open_wfes.remove(open_wfe) open_wfes.remove(open_wfe)
if reverse_order: if reverse_order:
open_wfes = reversed(open_wfes) open_wfes = reversed(open_wfes) # type: ignore[assignment]
return open_wfes[0:maximum_page_size] return open_wfes[0:maximum_page_size]
def list_closed_workflow_executions( def list_closed_workflow_executions(
self, self,
domain_name, domain_name: str,
tag_filter, tag_filter: Dict[str, str],
close_status_filter, close_status_filter: Dict[str, str],
maximum_page_size, maximum_page_size: int,
reverse_order, reverse_order: bool,
): ) -> List[WorkflowExecution]:
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
if domain.status == "DEPRECATED": if domain.status == "DEPRECATED":
@ -90,15 +97,18 @@ class SWFBackend(BaseBackend):
closed_wfes.remove(closed_wfe) closed_wfes.remove(closed_wfe)
if close_status_filter: if close_status_filter:
for closed_wfe in closed_wfes: for closed_wfe in closed_wfes:
if close_status_filter != closed_wfe.close_status: if close_status_filter != closed_wfe.close_status: # type: ignore
closed_wfes.remove(closed_wfe) closed_wfes.remove(closed_wfe)
if reverse_order: if reverse_order:
closed_wfes = reversed(closed_wfes) closed_wfes = reversed(closed_wfes) # type: ignore[assignment]
return closed_wfes[0:maximum_page_size] return closed_wfes[0:maximum_page_size]
def register_domain( def register_domain(
self, name, workflow_execution_retention_period_in_days, description=None self,
): name: str,
workflow_execution_retention_period_in_days: int,
description: Optional[str] = None,
) -> None:
if self._get_domain(name, ignore_empty=True): if self._get_domain(name, ignore_empty=True):
raise SWFDomainAlreadyExistsFault(name) raise SWFDomainAlreadyExistsFault(name)
domain = Domain( domain = Domain(
@ -110,68 +120,82 @@ class SWFBackend(BaseBackend):
) )
self.domains.append(domain) self.domains.append(domain)
def deprecate_domain(self, name): def deprecate_domain(self, name: str) -> None:
domain = self._get_domain(name) domain = self._get_domain(name)
if domain.status == "DEPRECATED": if domain.status == "DEPRECATED":
raise SWFDomainDeprecatedFault(name) raise SWFDomainDeprecatedFault(name)
domain.status = "DEPRECATED" domain.status = "DEPRECATED"
def undeprecate_domain(self, name): def undeprecate_domain(self, name: str) -> None:
domain = self._get_domain(name) domain = self._get_domain(name)
if domain.status == "REGISTERED": if domain.status == "REGISTERED":
raise SWFDomainAlreadyExistsFault(name) raise SWFDomainAlreadyExistsFault(name)
domain.status = "REGISTERED" domain.status = "REGISTERED"
def describe_domain(self, name): def describe_domain(self, name: str) -> Optional[Domain]:
return self._get_domain(name) return self._get_domain(name)
def list_types(self, kind, domain_name, status, reverse_order=None): def list_types(
self,
kind: str,
domain_name: str,
status: str,
reverse_order: Optional[bool] = None,
) -> List[GenericType]:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_types = domain.find_types(kind, status) _types: List[GenericType] = domain.find_types(kind, status)
_types = sorted(_types, key=lambda domain: domain.name) _types = sorted(_types, key=lambda domain: domain.name)
if reverse_order: if reverse_order:
_types = reversed(_types) _types = reversed(_types) # type: ignore
return _types return _types
def register_type(self, kind, domain_name, name, version, **kwargs): def register_type(
self, kind: str, domain_name: str, name: str, version: str, **kwargs: Any
) -> None:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version, ignore_empty=True) _type: GenericType = domain.get_type(kind, name, version, ignore_empty=True)
if _type: if _type:
raise SWFTypeAlreadyExistsFault(_type) raise SWFTypeAlreadyExistsFault(_type)
_class = KNOWN_SWF_TYPES[kind] _class = KNOWN_SWF_TYPES[kind]
_type = _class(name, version, **kwargs) _type = _class(name, version, **kwargs)
domain.add_type(_type) domain.add_type(_type)
def deprecate_type(self, kind, domain_name, name, version): def deprecate_type(
self, kind: str, domain_name: str, name: str, version: str
) -> None:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version) _type: GenericType = domain.get_type(kind, name, version)
if _type.status == "DEPRECATED": if _type.status == "DEPRECATED":
raise SWFTypeDeprecatedFault(_type) raise SWFTypeDeprecatedFault(_type)
_type.status = "DEPRECATED" _type.status = "DEPRECATED"
def undeprecate_type(self, kind, domain_name, name, version): def undeprecate_type(
self, kind: str, domain_name: str, name: str, version: str
) -> None:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version) _type: GenericType = domain.get_type(kind, name, version)
if _type.status == "REGISTERED": if _type.status == "REGISTERED":
raise SWFTypeAlreadyExistsFault(_type) raise SWFTypeAlreadyExistsFault(_type)
_type.status = "REGISTERED" _type.status = "REGISTERED"
def describe_type(self, kind, domain_name, name, version): def describe_type(
self, kind: str, domain_name: str, name: str, version: str
) -> GenericType:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
return domain.get_type(kind, name, version) return domain.get_type(kind, name, version)
def start_workflow_execution( def start_workflow_execution(
self, self,
domain_name, domain_name: str,
workflow_id, workflow_id: str,
workflow_name, workflow_name: str,
workflow_version, workflow_version: str,
tag_list=None, tag_list: Optional[Dict[str, str]] = None,
workflow_input=None, workflow_input: Optional[str] = None,
**kwargs, **kwargs: Any,
): ) -> WorkflowExecution:
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
wf_type = domain.get_type("workflow", workflow_name, workflow_version) wf_type: WorkflowType = domain.get_type("workflow", workflow_name, workflow_version) # type: ignore
if wf_type.status == "DEPRECATED": if wf_type.status == "DEPRECATED":
raise SWFTypeDeprecatedFault(wf_type) raise SWFTypeDeprecatedFault(wf_type)
wfe = WorkflowExecution( wfe = WorkflowExecution(
@ -187,13 +211,17 @@ class SWFBackend(BaseBackend):
return wfe return wfe
def describe_workflow_execution(self, domain_name, run_id, workflow_id): def describe_workflow_execution(
self, domain_name: str, run_id: str, workflow_id: str
) -> Optional[WorkflowExecution]:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
return domain.get_workflow_execution(workflow_id, run_id=run_id) return domain.get_workflow_execution(workflow_id, run_id=run_id)
def poll_for_decision_task(self, domain_name, task_list, identity=None): def poll_for_decision_task(
self, domain_name: str, task_list: List[str], identity: Optional[str] = None
) -> Optional[DecisionTask]:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -245,7 +273,9 @@ class SWFBackend(BaseBackend):
sleep(1) sleep(1)
return None return None
def count_pending_decision_tasks(self, domain_name, task_list): def count_pending_decision_tasks(
self, domain_name: str, task_list: List[str]
) -> int:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -256,8 +286,11 @@ class SWFBackend(BaseBackend):
return count return count
def respond_decision_task_completed( def respond_decision_task_completed(
self, task_token, decisions=None, execution_context=None self,
): task_token: str,
decisions: Optional[List[Dict[str, Any]]] = None,
execution_context: Optional[str] = None,
) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
# let's find decision task # let's find decision task
@ -308,7 +341,9 @@ class SWFBackend(BaseBackend):
execution_context=execution_context, execution_context=execution_context,
) )
def poll_for_activity_task(self, domain_name, task_list, identity=None): def poll_for_activity_task(
self, domain_name: str, task_list: List[str], identity: Optional[str] = None
) -> Optional[ActivityTask]:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -342,7 +377,9 @@ class SWFBackend(BaseBackend):
sleep(1) sleep(1)
return None return None
def count_pending_activity_tasks(self, domain_name, task_list): def count_pending_activity_tasks(
self, domain_name: str, task_list: List[str]
) -> int:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -353,7 +390,7 @@ class SWFBackend(BaseBackend):
count += len(pending) count += len(pending)
return count return count
def _find_activity_task_from_token(self, task_token): def _find_activity_task_from_token(self, task_token: str) -> ActivityTask:
activity_task = None activity_task = None
for domain in self.domains: for domain in self.domains:
for wfe in domain.workflow_executions: for wfe in domain.workflow_executions:
@ -389,14 +426,18 @@ class SWFBackend(BaseBackend):
# everything's good # everything's good
return activity_task return activity_task
def respond_activity_task_completed(self, task_token, result=None): def respond_activity_task_completed(
self, task_token: str, result: Any = None
) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)
wfe = activity_task.workflow_execution wfe = activity_task.workflow_execution
wfe.complete_activity_task(activity_task.task_token, result=result) wfe.complete_activity_task(activity_task.task_token, result=result)
def respond_activity_task_failed(self, task_token, reason=None, details=None): def respond_activity_task_failed(
self, task_token: str, reason: Optional[str] = None, details: Any = None
) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)
@ -405,22 +446,24 @@ class SWFBackend(BaseBackend):
def terminate_workflow_execution( def terminate_workflow_execution(
self, self,
domain_name, domain_name: str,
workflow_id, workflow_id: str,
child_policy=None, child_policy: Any = None,
details=None, details: Any = None,
reason=None, reason: Optional[str] = None,
run_id=None, run_id: Optional[str] = None,
): ) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
wfe = domain.get_workflow_execution( wfe = domain.get_workflow_execution(
workflow_id, run_id=run_id, raise_if_closed=True workflow_id, run_id=run_id, raise_if_closed=True
) )
wfe.terminate(child_policy=child_policy, details=details, reason=reason) wfe.terminate(child_policy=child_policy, details=details, reason=reason) # type: ignore[union-attr]
def record_activity_task_heartbeat(self, task_token, details=None): def record_activity_task_heartbeat(
self, task_token: str, details: Any = None
) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)
@ -429,15 +472,20 @@ class SWFBackend(BaseBackend):
activity_task.details = details activity_task.details = details
def signal_workflow_execution( def signal_workflow_execution(
self, domain_name, signal_name, workflow_id, workflow_input=None, run_id=None self,
): domain_name: str,
signal_name: str,
workflow_id: str,
workflow_input: Any = None,
run_id: Optional[str] = None,
) -> None:
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
wfe = domain.get_workflow_execution( wfe = domain.get_workflow_execution(
workflow_id, run_id=run_id, raise_if_closed=True workflow_id, run_id=run_id, raise_if_closed=True
) )
wfe.signal(signal_name, workflow_input) wfe.signal(signal_name, workflow_input) # type: ignore[union-attr]
swf_backends = BackendDict(SWFBackend, "swf") swf_backends = BackendDict(SWFBackend, "swf")

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -7,16 +8,20 @@ from ..exceptions import SWFWorkflowExecutionClosedError
from .timeout import Timeout from .timeout import Timeout
if TYPE_CHECKING:
from .activity_type import ActivityType
from .workflow_execution import WorkflowExecution
class ActivityTask(BaseModel): class ActivityTask(BaseModel):
def __init__( def __init__(
self, self,
activity_id, activity_id: str,
activity_type, activity_type: "ActivityType",
scheduled_event_id, scheduled_event_id: int,
workflow_execution, workflow_execution: "WorkflowExecution",
timeouts, timeouts: Dict[str, Any],
workflow_input=None, workflow_input: Any = None,
): ):
self.activity_id = activity_id self.activity_id = activity_id
self.activity_type = activity_type self.activity_type = activity_type
@ -24,26 +29,26 @@ class ActivityTask(BaseModel):
self.input = workflow_input self.input = workflow_input
self.last_heartbeat_timestamp = unix_time() self.last_heartbeat_timestamp = unix_time()
self.scheduled_event_id = scheduled_event_id self.scheduled_event_id = scheduled_event_id
self.started_event_id = None self.started_event_id: Optional[int] = None
self.state = "SCHEDULED" self.state = "SCHEDULED"
self.task_token = str(mock_random.uuid4()) self.task_token = str(mock_random.uuid4())
self.timeouts = timeouts self.timeouts = timeouts
self.timeout_type = None self.timeout_type: Optional[str] = None
self.workflow_execution = workflow_execution self.workflow_execution = workflow_execution
# this is *not* necessarily coherent with workflow execution history, # this is *not* necessarily coherent with workflow execution history,
# but that shouldn't be a problem for tests # but that shouldn't be a problem for tests
self.scheduled_at = datetime.utcnow() self.scheduled_at = datetime.utcnow()
def _check_workflow_execution_open(self): def _check_workflow_execution_open(self) -> None:
if not self.workflow_execution.open: if not self.workflow_execution.open:
raise SWFWorkflowExecutionClosedError() raise SWFWorkflowExecutionClosedError()
@property @property
def open(self): def open(self) -> bool:
return self.state in ["SCHEDULED", "STARTED"] return self.state in ["SCHEDULED", "STARTED"]
def to_full_dict(self): def to_full_dict(self) -> Dict[str, Any]:
hsh = { hsh: Dict[str, Any] = {
"activityId": self.activity_id, "activityId": self.activity_id,
"activityType": self.activity_type.to_short_dict(), "activityType": self.activity_type.to_short_dict(),
"taskToken": self.task_token, "taskToken": self.task_token,
@ -54,22 +59,22 @@ class ActivityTask(BaseModel):
hsh["input"] = self.input hsh["input"] = self.input
return hsh return hsh
def start(self, started_event_id): def start(self, started_event_id: int) -> None:
self.state = "STARTED" self.state = "STARTED"
self.started_event_id = started_event_id self.started_event_id = started_event_id
def complete(self): def complete(self) -> None:
self._check_workflow_execution_open() self._check_workflow_execution_open()
self.state = "COMPLETED" self.state = "COMPLETED"
def fail(self): def fail(self) -> None:
self._check_workflow_execution_open() self._check_workflow_execution_open()
self.state = "FAILED" self.state = "FAILED"
def reset_heartbeat_clock(self): def reset_heartbeat_clock(self) -> None:
self.last_heartbeat_timestamp = unix_time() self.last_heartbeat_timestamp = unix_time()
def first_timeout(self): def first_timeout(self) -> Optional[Timeout]:
if not self.open or not self.workflow_execution.open: if not self.open or not self.workflow_execution.open:
return None return None
@ -82,13 +87,14 @@ class ActivityTask(BaseModel):
_timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT") _timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT")
if _timeout.reached: if _timeout.reached:
return _timeout return _timeout
return None
def process_timeouts(self): def process_timeouts(self) -> None:
_timeout = self.first_timeout() _timeout = self.first_timeout()
if _timeout: if _timeout:
self.timeout(_timeout) self.timeout(_timeout)
def timeout(self, _timeout): def timeout(self, _timeout: Timeout) -> None:
self._check_workflow_execution_open() self._check_workflow_execution_open()
self.state = "TIMED_OUT" self.state = "TIMED_OUT"
self.timeout_type = _timeout.kind self.timeout_type = _timeout.kind

View File

@ -1,9 +1,10 @@
from typing import List
from .generic_type import GenericType from .generic_type import GenericType
class ActivityType(GenericType): class ActivityType(GenericType):
@property @property
def _configuration_keys(self): def _configuration_keys(self) -> List[str]:
return [ return [
"defaultTaskHeartbeatTimeout", "defaultTaskHeartbeatTimeout",
"defaultTaskScheduleToCloseTimeout", "defaultTaskScheduleToCloseTimeout",
@ -12,5 +13,5 @@ class ActivityType(GenericType):
] ]
@property @property
def kind(self): def kind(self) -> str:
return "activity" return "activity"

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -7,16 +8,21 @@ from ..exceptions import SWFWorkflowExecutionClosedError
from .timeout import Timeout from .timeout import Timeout
if TYPE_CHECKING:
from .workflow_execution import WorkflowExecution
class DecisionTask(BaseModel): class DecisionTask(BaseModel):
def __init__(self, workflow_execution, scheduled_event_id): def __init__(
self, workflow_execution: "WorkflowExecution", scheduled_event_id: int
):
self.workflow_execution = workflow_execution self.workflow_execution = workflow_execution
self.workflow_type = workflow_execution.workflow_type self.workflow_type = workflow_execution.workflow_type
self.task_token = str(mock_random.uuid4()) self.task_token = str(mock_random.uuid4())
self.scheduled_event_id = scheduled_event_id self.scheduled_event_id = scheduled_event_id
self.previous_started_event_id = None self.previous_started_event_id: Optional[int] = None
self.started_event_id = None self.started_event_id: Optional[int] = None
self.started_timestamp = None self.started_timestamp: Optional[float] = None
self.start_to_close_timeout = ( self.start_to_close_timeout = (
self.workflow_execution.task_start_to_close_timeout self.workflow_execution.task_start_to_close_timeout
) )
@ -24,19 +30,19 @@ class DecisionTask(BaseModel):
# this is *not* necessarily coherent with workflow execution history, # this is *not* necessarily coherent with workflow execution history,
# but that shouldn't be a problem for tests # but that shouldn't be a problem for tests
self.scheduled_at = datetime.utcnow() self.scheduled_at = datetime.utcnow()
self.timeout_type = None self.timeout_type: Optional[str] = None
@property @property
def started(self): def started(self) -> bool:
return self.state == "STARTED" return self.state == "STARTED"
def _check_workflow_execution_open(self): def _check_workflow_execution_open(self) -> None:
if not self.workflow_execution.open: if not self.workflow_execution.open:
raise SWFWorkflowExecutionClosedError() raise SWFWorkflowExecutionClosedError()
def to_full_dict(self, reverse_order=False): def to_full_dict(self, reverse_order: bool = False) -> Dict[str, Any]:
events = self.workflow_execution.events(reverse_order=reverse_order) events = self.workflow_execution.events(reverse_order=reverse_order)
hsh = { hsh: Dict[str, Any] = {
"events": [evt.to_dict() for evt in events], "events": [evt.to_dict() for evt in events],
"taskToken": self.task_token, "taskToken": self.task_token,
"workflowExecution": self.workflow_execution.to_short_dict(), "workflowExecution": self.workflow_execution.to_short_dict(),
@ -48,31 +54,34 @@ class DecisionTask(BaseModel):
hsh["startedEventId"] = self.started_event_id hsh["startedEventId"] = self.started_event_id
return hsh return hsh
def start(self, started_event_id, previous_started_event_id=None): def start(
self, started_event_id: int, previous_started_event_id: Optional[int] = None
) -> None:
self.state = "STARTED" self.state = "STARTED"
self.started_timestamp = unix_time() self.started_timestamp = unix_time()
self.started_event_id = started_event_id self.started_event_id = started_event_id
self.previous_started_event_id = previous_started_event_id self.previous_started_event_id = previous_started_event_id
def complete(self): def complete(self) -> None:
self._check_workflow_execution_open() self._check_workflow_execution_open()
self.state = "COMPLETED" self.state = "COMPLETED"
def first_timeout(self): def first_timeout(self) -> Optional[Timeout]:
if not self.started or not self.workflow_execution.open: if not self.started or not self.workflow_execution.open:
return None return None
# TODO: handle the "NONE" case # TODO: handle the "NONE" case
start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) # type: ignore
_timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE")
if _timeout.reached: if _timeout.reached:
return _timeout return _timeout
return None
def process_timeouts(self): def process_timeouts(self) -> None:
_timeout = self.first_timeout() _timeout = self.first_timeout()
if _timeout: if _timeout:
self.timeout(_timeout) self.timeout(_timeout)
def timeout(self, _timeout): def timeout(self, _timeout: Timeout) -> None:
self._check_workflow_execution_open() self._check_workflow_execution_open()
self.state = "TIMED_OUT" self.state = "TIMED_OUT"
self.timeout_type = _timeout.kind self.timeout_type = _timeout.kind

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from moto.core import BaseModel from moto.core import BaseModel
from ..exceptions import ( from ..exceptions import (
@ -6,29 +7,45 @@ from ..exceptions import (
SWFWorkflowExecutionAlreadyStartedFault, SWFWorkflowExecutionAlreadyStartedFault,
) )
if TYPE_CHECKING:
from .activity_task import ActivityTask
from .decision_task import DecisionTask
from .generic_type import GenericType, TGenericType
from .workflow_execution import WorkflowExecution
class Domain(BaseModel): class Domain(BaseModel):
def __init__(self, name, retention, account_id, region_name, description=None): def __init__(
self,
name: str,
retention: int,
account_id: str,
region_name: str,
description: Optional[str] = None,
):
self.name = name self.name = name
self.retention = retention self.retention = retention
self.account_id = account_id self.account_id = account_id
self.region_name = region_name self.region_name = region_name
self.description = description self.description = description
self.status = "REGISTERED" self.status = "REGISTERED"
self.types = {"activity": defaultdict(dict), "workflow": defaultdict(dict)} self.types: Dict[str, Dict[str, Dict[str, GenericType]]] = {
"activity": defaultdict(dict),
"workflow": defaultdict(dict),
}
# Workflow executions have an id, which unicity is guaranteed # Workflow executions have an id, which unicity is guaranteed
# at domain level (not super clear in the docs, but I checked # at domain level (not super clear in the docs, but I checked
# that against SWF API) ; hence the storage method as a dict # that against SWF API) ; hence the storage method as a dict
# of "workflow_id (client determined)" => WorkflowExecution() # of "workflow_id (client determined)" => WorkflowExecution()
# here. # here.
self.workflow_executions = [] self.workflow_executions: List["WorkflowExecution"] = []
self.activity_task_lists = {} self.activity_task_lists: Dict[List[str], List["ActivityTask"]] = {}
self.decision_task_lists = {} self.decision_task_lists: Dict[str, List["DecisionTask"]] = {}
def __repr__(self): def __repr__(self) -> str:
return f"Domain(name: {self.name}, status: {self.status})" return f"Domain(name: {self.name}, status: {self.status})"
def to_short_dict(self): def to_short_dict(self) -> Dict[str, str]:
hsh = {"name": self.name, "status": self.status} hsh = {"name": self.name, "status": self.status}
if self.description: if self.description:
hsh["description"] = self.description hsh["description"] = self.description
@ -37,13 +54,13 @@ class Domain(BaseModel):
] = f"arn:aws:swf:{self.region_name}:{self.account_id}:/domain/{self.name}" ] = f"arn:aws:swf:{self.region_name}:{self.account_id}:/domain/{self.name}"
return hsh return hsh
def to_full_dict(self): def to_full_dict(self) -> Dict[str, Any]:
return { return {
"domainInfo": self.to_short_dict(), "domainInfo": self.to_short_dict(),
"configuration": {"workflowExecutionRetentionPeriodInDays": self.retention}, "configuration": {"workflowExecutionRetentionPeriodInDays": self.retention},
} }
def get_type(self, kind, name, version, ignore_empty=False): def get_type(self, kind: str, name: str, version: str, ignore_empty: bool = False) -> "GenericType": # type: ignore
try: try:
return self.types[kind][name][version] return self.types[kind][name][version]
except KeyError: except KeyError:
@ -53,26 +70,30 @@ class Domain(BaseModel):
f"{kind.capitalize()}Type=[name={name}, version={version}]", f"{kind.capitalize()}Type=[name={name}, version={version}]",
) )
def add_type(self, _type): def add_type(self, _type: "TGenericType") -> None:
self.types[_type.kind][_type.name][_type.version] = _type self.types[_type.kind][_type.name][_type.version] = _type
def find_types(self, kind, status): def find_types(self, kind: str, status: str) -> List["GenericType"]:
_all = [] _all = []
for family in self.types[kind].values(): for family in self.types[kind].values():
for _type in family.values(): for _type in family.values():
if _type.status == status: if _type.status == status: # type: ignore
_all.append(_type) _all.append(_type)
return _all return _all
def add_workflow_execution(self, workflow_execution): def add_workflow_execution(self, workflow_execution: "WorkflowExecution") -> None:
_id = workflow_execution.workflow_id _id = workflow_execution.workflow_id
if self.get_workflow_execution(_id, raise_if_none=False): if self.get_workflow_execution(_id, raise_if_none=False):
raise SWFWorkflowExecutionAlreadyStartedFault() raise SWFWorkflowExecutionAlreadyStartedFault()
self.workflow_executions.append(workflow_execution) self.workflow_executions.append(workflow_execution)
def get_workflow_execution( def get_workflow_execution(
self, workflow_id, run_id=None, raise_if_none=True, raise_if_closed=False self,
): workflow_id: str,
run_id: Optional[str] = None,
raise_if_none: bool = True,
raise_if_closed: bool = False,
) -> Optional["WorkflowExecution"]:
# query # query
if run_id: if run_id:
_all = [ _all = [
@ -103,26 +124,28 @@ class Domain(BaseModel):
# at last return workflow execution # at last return workflow execution
return wfe return wfe
def add_to_activity_task_list(self, task_list, obj): def add_to_activity_task_list(
self, task_list: List[str], obj: "ActivityTask"
) -> None:
if task_list not in self.activity_task_lists: if task_list not in self.activity_task_lists:
self.activity_task_lists[task_list] = [] self.activity_task_lists[task_list] = []
self.activity_task_lists[task_list].append(obj) self.activity_task_lists[task_list].append(obj)
@property @property
def activity_tasks(self): def activity_tasks(self) -> List["ActivityTask"]:
_all = [] _all: List["ActivityTask"] = []
for tasks in self.activity_task_lists.values(): for tasks in self.activity_task_lists.values():
_all += tasks _all += tasks
return _all return _all
def add_to_decision_task_list(self, task_list, obj): def add_to_decision_task_list(self, task_list: str, obj: "DecisionTask") -> None:
if task_list not in self.decision_task_lists: if task_list not in self.decision_task_lists:
self.decision_task_lists[task_list] = [] self.decision_task_lists[task_list] = []
self.decision_task_lists[task_list].append(obj) self.decision_task_lists[task_list].append(obj)
@property @property
def decision_tasks(self): def decision_tasks(self) -> List["DecisionTask"]:
_all = [] _all: List["DecisionTask"] = []
for tasks in self.decision_task_lists.values(): for tasks in self.decision_task_lists.values():
_all += tasks _all += tasks
return _all return _all

View File

@ -1,9 +1,10 @@
from typing import Any, Dict, List, TypeVar
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
class GenericType(BaseModel): class GenericType(BaseModel):
def __init__(self, name, version, **kwargs): def __init__(self, name: str, version: str, **kwargs: Any):
self.name = name self.name = name
self.version = version self.version = version
self.status = "REGISTERED" self.status = "REGISTERED"
@ -19,24 +20,24 @@ class GenericType(BaseModel):
if not hasattr(self, "task_list"): if not hasattr(self, "task_list"):
self.task_list = None self.task_list = None
def __repr__(self): def __repr__(self) -> str:
cls = self.__class__.__name__ cls = self.__class__.__name__
attrs = f"name: {self.name}, version: {self.version}, status: {self.status}" attrs = f"name: {self.name}, version: {self.version}, status: {self.status}"
return f"{cls}({attrs})" return f"{cls}({attrs})"
@property @property
def kind(self): def kind(self) -> str:
raise NotImplementedError() raise NotImplementedError()
@property @property
def _configuration_keys(self): def _configuration_keys(self) -> List[str]:
raise NotImplementedError() raise NotImplementedError()
def to_short_dict(self): def to_short_dict(self) -> Dict[str, str]:
return {"name": self.name, "version": self.version} return {"name": self.name, "version": self.version}
def to_medium_dict(self): def to_medium_dict(self) -> Dict[str, Any]:
hsh = { hsh: Dict[str, Any] = {
f"{self.kind}Type": self.to_short_dict(), f"{self.kind}Type": self.to_short_dict(),
"creationDate": 1420066800, "creationDate": 1420066800,
"status": self.status, "status": self.status,
@ -47,8 +48,8 @@ class GenericType(BaseModel):
hsh["description"] = self.description hsh["description"] = self.description
return hsh return hsh
def to_full_dict(self): def to_full_dict(self) -> Dict[str, Any]:
hsh = {"typeInfo": self.to_medium_dict(), "configuration": {}} hsh: Dict[str, Any] = {"typeInfo": self.to_medium_dict(), "configuration": {}}
if self.task_list: if self.task_list:
hsh["configuration"]["defaultTaskList"] = {"name": self.task_list} hsh["configuration"]["defaultTaskList"] = {"name": self.task_list}
for key in self._configuration_keys: for key in self._configuration_keys:
@ -57,3 +58,6 @@ class GenericType(BaseModel):
continue continue
hsh["configuration"][key] = getattr(self, attr) hsh["configuration"][key] = getattr(self, attr)
return hsh return hsh
TGenericType = TypeVar("TGenericType", bound=GenericType)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, Optional
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import underscores_to_camelcase, unix_time from moto.core.utils import underscores_to_camelcase, unix_time
@ -36,7 +37,13 @@ SUPPORTED_HISTORY_EVENT_TYPES = (
class HistoryEvent(BaseModel): class HistoryEvent(BaseModel):
def __init__(self, event_id, event_type, event_timestamp=None, **kwargs): def __init__(
self,
event_id: int,
event_type: str,
event_timestamp: Optional[float] = None,
**kwargs: Any,
):
if event_type not in SUPPORTED_HISTORY_EVENT_TYPES: if event_type not in SUPPORTED_HISTORY_EVENT_TYPES:
raise NotImplementedError( raise NotImplementedError(
f"HistoryEvent does not implement attributes for type '{event_type}'" f"HistoryEvent does not implement attributes for type '{event_type}'"
@ -60,7 +67,7 @@ class HistoryEvent(BaseModel):
value = value.to_short_dict() value = value.to_short_dict()
self.event_attributes[camel_key] = value self.event_attributes[camel_key] = value
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"eventId": self.event_id, "eventId": self.event_id,
"eventType": self.event_type, "eventType": self.event_type,
@ -68,6 +75,6 @@ class HistoryEvent(BaseModel):
self._attributes_key(): self.event_attributes, self._attributes_key(): self.event_attributes,
} }
def _attributes_key(self): def _attributes_key(self) -> str:
key = f"{self.event_type}EventAttributes" key = f"{self.event_type}EventAttributes"
return decapitalize(key) return decapitalize(key)

View File

@ -1,13 +1,14 @@
from typing import Any
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
class Timeout(BaseModel): class Timeout(BaseModel):
def __init__(self, obj, timestamp, kind): def __init__(self, obj: Any, timestamp: float, kind: str):
self.obj = obj self.obj = obj
self.timestamp = timestamp self.timestamp = timestamp
self.kind = kind self.kind = kind
@property @property
def reached(self): def reached(self) -> bool:
return unix_time() >= self.timestamp return unix_time() >= self.timestamp

View File

@ -1,16 +1,17 @@
from threading import Timer as ThreadingTimer
from moto.core import BaseModel from moto.core import BaseModel
class Timer(BaseModel): class Timer(BaseModel):
def __init__(self, background_timer, started_event_id): def __init__(self, background_timer: ThreadingTimer, started_event_id: int):
self.background_timer = background_timer self.background_timer = background_timer
self.started_event_id = started_event_id self.started_event_id = started_event_id
def start(self): def start(self) -> None:
return self.background_timer.start() return self.background_timer.start()
def is_alive(self): def is_alive(self) -> bool:
return self.background_timer.is_alive() return self.background_timer.is_alive()
def cancel(self): def cancel(self) -> None:
return self.background_timer.cancel() return self.background_timer.cancel()

View File

@ -1,4 +1,5 @@
from threading import Timer as ThreadingTimer, Lock from threading import Timer as ThreadingTimer, Lock
from typing import Any, Dict, Iterable, List, Optional
from moto.core import BaseModel from moto.core import BaseModel
from moto.core.utils import camelcase_to_underscores, unix_time from moto.core.utils import camelcase_to_underscores, unix_time
@ -14,9 +15,11 @@ from ..utils import decapitalize
from .activity_task import ActivityTask from .activity_task import ActivityTask
from .activity_type import ActivityType from .activity_type import ActivityType
from .decision_task import DecisionTask from .decision_task import DecisionTask
from .domain import Domain
from .history_event import HistoryEvent from .history_event import HistoryEvent
from .timeout import Timeout from .timeout import Timeout
from .timer import Timer from .timer import Timer
from .workflow_type import WorkflowType
# TODO: extract decision related logic into a Decision class # TODO: extract decision related logic into a Decision class
@ -40,7 +43,13 @@ class WorkflowExecution(BaseModel):
"CancelWorkflowExecution", "CancelWorkflowExecution",
] ]
def __init__(self, domain, workflow_type, workflow_id, **kwargs): def __init__(
self,
domain: Domain,
workflow_type: "WorkflowType",
workflow_id: str,
**kwargs: Any,
):
self.domain = domain self.domain = domain
self.workflow_id = workflow_id self.workflow_id = workflow_id
self.run_id = mock_random.uuid4().hex self.run_id = mock_random.uuid4().hex
@ -49,27 +58,33 @@ class WorkflowExecution(BaseModel):
# TODO: check valid values among: # TODO: check valid values among:
# COMPLETED | FAILED | CANCELED | TERMINATED | CONTINUED_AS_NEW | TIMED_OUT # COMPLETED | FAILED | CANCELED | TERMINATED | CONTINUED_AS_NEW | TIMED_OUT
# TODO: implement them all # TODO: implement them all
self.close_cause = None self.close_cause: Optional[str] = None
self.close_status = None self.close_status: Optional[str] = None
self.close_timestamp = None self.close_timestamp: Optional[float] = None
self.execution_status = "OPEN" self.execution_status = "OPEN"
self.latest_activity_task_timestamp = None self.latest_activity_task_timestamp: Optional[float] = None
self.latest_execution_context = None self.latest_execution_context: Optional[str] = None
self.parent = None self.parent = None
self.start_timestamp = None self.start_timestamp: Optional[float] = None
self.tag_list = kwargs.get("tag_list", None) or [] self.tag_list = kwargs.get("tag_list", None) or []
self.timeout_type = None self.timeout_type: Optional[str] = None
self.workflow_type = workflow_type self.workflow_type = workflow_type
# args processing # args processing
# NB: the order follows boto/SWF order of exceptions appearance (if no # NB: the order follows boto/SWF order of exceptions appearance (if no
# param is set, # SWF will raise DefaultUndefinedFault errors in the # param is set, # SWF will raise DefaultUndefinedFault errors in the
# same order as the few lines that follow) # same order as the few lines that follow)
self._set_from_kwargs_or_workflow_type( self.execution_start_to_close_timeout = self._get_from_kwargs_or_workflow_type(
kwargs, "execution_start_to_close_timeout" kwargs, "execution_start_to_close_timeout"
) )
self._set_from_kwargs_or_workflow_type(kwargs, "task_list", "task_list") self.task_list = self._get_from_kwargs_or_workflow_type(
self._set_from_kwargs_or_workflow_type(kwargs, "task_start_to_close_timeout") kwargs, "task_list", "task_list"
self._set_from_kwargs_or_workflow_type(kwargs, "child_policy") )
self.task_start_to_close_timeout = self._get_from_kwargs_or_workflow_type(
kwargs, "task_start_to_close_timeout"
)
self.child_policy = self._get_from_kwargs_or_workflow_type(
kwargs, "child_policy"
)
self.input = kwargs.get("workflow_input") self.input = kwargs.get("workflow_input")
# counters # counters
self.open_counts = { self.open_counts = {
@ -80,20 +95,23 @@ class WorkflowExecution(BaseModel):
"openLambdaFunctions": 0, "openLambdaFunctions": 0,
} }
# events # events
self._events = [] self._events: List[HistoryEvent] = []
# child workflows # child workflows
self.child_workflow_executions = [] self.child_workflow_executions: List[WorkflowExecution] = []
self._previous_started_event_id = None self._previous_started_event_id: Optional[int] = None
# timers/thread utils # timers/thread utils
self.threading_lock = Lock() self.threading_lock = Lock()
self._timers = {} self._timers: Dict[str, Timer] = {}
def __repr__(self): def __repr__(self) -> str:
return f"WorkflowExecution(run_id: {self.run_id})" return f"WorkflowExecution(run_id: {self.run_id})"
def _set_from_kwargs_or_workflow_type( def _get_from_kwargs_or_workflow_type(
self, kwargs, local_key, workflow_type_key=None self,
): kwargs: Dict[str, Any],
local_key: str,
workflow_type_key: Optional[str] = None,
) -> Any:
if workflow_type_key is None: if workflow_type_key is None:
workflow_type_key = "default_" + local_key workflow_type_key = "default_" + local_key
value = kwargs.get(local_key) value = kwargs.get(local_key)
@ -101,10 +119,10 @@ class WorkflowExecution(BaseModel):
value = getattr(self.workflow_type, workflow_type_key) value = getattr(self.workflow_type, workflow_type_key)
if not value: if not value:
raise SWFDefaultUndefinedFault(local_key) raise SWFDefaultUndefinedFault(local_key)
setattr(self, local_key, value) return value
@property @property
def _configuration_keys(self): def _configuration_keys(self) -> List[str]:
return [ return [
"executionStartToCloseTimeout", "executionStartToCloseTimeout",
"childPolicy", "childPolicy",
@ -112,11 +130,11 @@ class WorkflowExecution(BaseModel):
"taskStartToCloseTimeout", "taskStartToCloseTimeout",
] ]
def to_short_dict(self): def to_short_dict(self) -> Dict[str, str]:
return {"workflowId": self.workflow_id, "runId": self.run_id} return {"workflowId": self.workflow_id, "runId": self.run_id}
def to_medium_dict(self): def to_medium_dict(self) -> Dict[str, Any]:
hsh = { hsh: Dict[str, Any] = {
"execution": self.to_short_dict(), "execution": self.to_short_dict(),
"workflowType": self.workflow_type.to_short_dict(), "workflowType": self.workflow_type.to_short_dict(),
"startTimestamp": 1420066800.123, "startTimestamp": 1420066800.123,
@ -127,8 +145,8 @@ class WorkflowExecution(BaseModel):
hsh["tagList"] = self.tag_list hsh["tagList"] = self.tag_list
return hsh return hsh
def to_full_dict(self): def to_full_dict(self) -> Dict[str, Any]:
hsh = { hsh: Dict[str, Any] = {
"executionInfo": self.to_medium_dict(), "executionInfo": self.to_medium_dict(),
"executionConfiguration": {"taskList": {"name": self.task_list}}, "executionConfiguration": {"taskList": {"name": self.task_list}},
} }
@ -153,8 +171,8 @@ class WorkflowExecution(BaseModel):
hsh["latestActivityTaskTimestamp"] = self.latest_activity_task_timestamp hsh["latestActivityTaskTimestamp"] = self.latest_activity_task_timestamp
return hsh return hsh
def to_list_dict(self): def to_list_dict(self) -> Dict[str, Any]:
hsh = { hsh: Dict[str, Any] = {
"execution": {"workflowId": self.workflow_id, "runId": self.run_id}, "execution": {"workflowId": self.workflow_id, "runId": self.run_id},
"workflowType": self.workflow_type.to_short_dict(), "workflowType": self.workflow_type.to_short_dict(),
"startTimestamp": self.start_timestamp, "startTimestamp": self.start_timestamp,
@ -171,7 +189,7 @@ class WorkflowExecution(BaseModel):
hsh["closeTimestamp"] = self.close_timestamp hsh["closeTimestamp"] = self.close_timestamp
return hsh return hsh
def _process_timeouts(self): def _process_timeouts(self) -> None:
""" """
SWF timeouts can happen on different objects (workflow executions, SWF timeouts can happen on different objects (workflow executions,
activity tasks, decision tasks) and should be processed in order. activity tasks, decision tasks) and should be processed in order.
@ -187,26 +205,24 @@ class WorkflowExecution(BaseModel):
triggered, process it, then make the workflow state progress and repeat triggered, process it, then make the workflow state progress and repeat
the whole process. the whole process.
""" """
timeout_candidates = []
# workflow execution timeout # workflow execution timeout
timeout_candidates.append(self.first_timeout()) timeout_candidates_or_none = [self.first_timeout()]
# decision tasks timeouts # decision tasks timeouts
for task in self.decision_tasks: for d_task in self.decision_tasks:
timeout_candidates.append(task.first_timeout()) timeout_candidates_or_none.append(d_task.first_timeout())
# activity tasks timeouts # activity tasks timeouts
for task in self.activity_tasks: for a_task in self.activity_tasks:
timeout_candidates.append(task.first_timeout()) timeout_candidates_or_none.append(a_task.first_timeout())
# remove blank values (foo.first_timeout() is a Timeout or None) # remove blank values (foo.first_timeout() is a Timeout or None)
timeout_candidates = list(filter(None, timeout_candidates)) timeout_candidates = list(filter(None, timeout_candidates_or_none))
# now find the first timeout to process # now find the first timeout to process
first_timeout = None first_timeout = None
if timeout_candidates: if timeout_candidates:
first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) # type: ignore
if first_timeout: if first_timeout:
should_schedule_decision_next = False should_schedule_decision_next = False
@ -229,17 +245,17 @@ class WorkflowExecution(BaseModel):
# timeout should be processed # timeout should be processed
self._process_timeouts() self._process_timeouts()
def events(self, reverse_order=False): def events(self, reverse_order: bool = False) -> Iterable[HistoryEvent]:
if reverse_order: if reverse_order:
return reversed(self._events) return reversed(self._events)
else: else:
return self._events return self._events
def next_event_id(self): def next_event_id(self) -> int:
event_ids = [evt.event_id for evt in self._events] event_ids = [evt.event_id for evt in self._events]
return max(event_ids or [0]) + 1 return max(event_ids or [0]) + 1
def _add_event(self, *args, **kwargs): def _add_event(self, *args: Any, **kwargs: Any) -> HistoryEvent:
# lock here because the fire_timer function is called # lock here because the fire_timer function is called
# async, and want to ensure uniqueness in event ids # async, and want to ensure uniqueness in event ids
with self.threading_lock: with self.threading_lock:
@ -247,7 +263,7 @@ class WorkflowExecution(BaseModel):
self._events.append(evt) self._events.append(evt)
return evt return evt
def start(self): def start(self) -> None:
self.start_timestamp = unix_time() self.start_timestamp = unix_time()
self._add_event( self._add_event(
"WorkflowExecutionStarted", "WorkflowExecutionStarted",
@ -262,7 +278,7 @@ class WorkflowExecution(BaseModel):
) )
self.schedule_decision_task() self.schedule_decision_task()
def _schedule_decision_task(self): def _schedule_decision_task(self) -> None:
has_scheduled_task = False has_scheduled_task = False
has_started_task = False has_started_task = False
for task in self.decision_tasks: for task in self.decision_tasks:
@ -285,30 +301,32 @@ class WorkflowExecution(BaseModel):
) )
self.open_counts["openDecisionTasks"] += 1 self.open_counts["openDecisionTasks"] += 1
def schedule_decision_task(self): def schedule_decision_task(self) -> None:
self._schedule_decision_task() self._schedule_decision_task()
# Shortcut for tests: helps having auto-starting decision tasks when needed # Shortcut for tests: helps having auto-starting decision tasks when needed
def schedule_and_start_decision_task(self, identity=None): def schedule_and_start_decision_task(self, identity: Optional[str] = None) -> None:
self._schedule_decision_task() self._schedule_decision_task()
decision_task = self.decision_tasks[-1] decision_task = self.decision_tasks[-1]
self.start_decision_task(decision_task.task_token, identity=identity) self.start_decision_task(decision_task.task_token, identity=identity)
@property @property
def decision_tasks(self): def decision_tasks(self) -> List[DecisionTask]:
return [t for t in self.domain.decision_tasks if t.workflow_execution == self] return [t for t in self.domain.decision_tasks if t.workflow_execution == self]
@property @property
def activity_tasks(self): def activity_tasks(self) -> List[ActivityTask]:
return [t for t in self.domain.activity_tasks if t.workflow_execution == self] return [t for t in self.domain.activity_tasks if t.workflow_execution == self]
def _find_decision_task(self, task_token): def _find_decision_task(self, task_token: str) -> DecisionTask:
for dt in self.decision_tasks: for dt in self.decision_tasks:
if dt.task_token == task_token: if dt.task_token == task_token:
return dt return dt
raise ValueError(f"No decision task with token: {task_token}") raise ValueError(f"No decision task with token: {task_token}")
def start_decision_task(self, task_token, identity=None): def start_decision_task(
self, task_token: str, identity: Optional[str] = None
) -> None:
dt = self._find_decision_task(task_token) dt = self._find_decision_task(task_token)
evt = self._add_event( evt = self._add_event(
"DecisionTaskStarted", "DecisionTaskStarted",
@ -319,8 +337,11 @@ class WorkflowExecution(BaseModel):
self._previous_started_event_id = evt.event_id self._previous_started_event_id = evt.event_id
def complete_decision_task( def complete_decision_task(
self, task_token, decisions=None, execution_context=None self,
): task_token: str,
decisions: Optional[List[Dict[str, Any]]] = None,
execution_context: Optional[str] = None,
) -> None:
# 'decisions' can be None per boto.swf defaults, so replace it with something iterable # 'decisions' can be None per boto.swf defaults, so replace it with something iterable
if not decisions: if not decisions:
decisions = [] decisions = []
@ -341,7 +362,9 @@ class WorkflowExecution(BaseModel):
self.schedule_decision_task() self.schedule_decision_task()
self.latest_execution_context = execution_context self.latest_execution_context = execution_context
def _check_decision_attributes(self, kind, value, decision_id): def _check_decision_attributes(
self, kind: str, value: Dict[str, Any], decision_id: int
) -> List[Dict[str, str]]:
problems = [] problems = []
constraints = DECISIONS_FIELDS.get(kind, {}) constraints = DECISIONS_FIELDS.get(kind, {})
for key, constraint in constraints.items(): for key, constraint in constraints.items():
@ -354,7 +377,7 @@ class WorkflowExecution(BaseModel):
) )
return problems return problems
def validate_decisions(self, decisions): def validate_decisions(self, decisions: List[Dict[str, Any]]) -> None:
""" """
Performs some basic validations on decisions. The real SWF service Performs some basic validations on decisions. The real SWF service
seems to break early and *not* process any decision if there's a seems to break early and *not* process any decision if there's a
@ -404,7 +427,7 @@ class WorkflowExecution(BaseModel):
if any(problems): if any(problems):
raise SWFDecisionValidationException(problems) raise SWFDecisionValidationException(problems)
def handle_decisions(self, event_id, decisions): def handle_decisions(self, event_id: int, decisions: List[Dict[str, Any]]) -> None:
""" """
Handles a Decision according to SWF docs. Handles a Decision according to SWF docs.
See: http://docs.aws.amazon.com/amazonswf/latest/apireference/API_Decision.html See: http://docs.aws.amazon.com/amazonswf/latest/apireference/API_Decision.html
@ -440,7 +463,7 @@ class WorkflowExecution(BaseModel):
# finally decrement counter if and only if everything went well # finally decrement counter if and only if everything went well
self.open_counts["openDecisionTasks"] -= 1 self.open_counts["openDecisionTasks"] -= 1
def complete(self, event_id, result=None): def complete(self, event_id: int, result: Any = None) -> None:
self.execution_status = "CLOSED" self.execution_status = "CLOSED"
self.close_status = "COMPLETED" self.close_status = "COMPLETED"
self.close_timestamp = unix_time() self.close_timestamp = unix_time()
@ -450,7 +473,9 @@ class WorkflowExecution(BaseModel):
result=result, result=result,
) )
def fail(self, event_id, details=None, reason=None): def fail(
self, event_id: int, details: Any = None, reason: Optional[str] = None
) -> None:
# TODO: implement length constraints on details/reason # TODO: implement length constraints on details/reason
self.execution_status = "CLOSED" self.execution_status = "CLOSED"
self.close_status = "FAILED" self.close_status = "FAILED"
@ -462,7 +487,7 @@ class WorkflowExecution(BaseModel):
reason=reason, reason=reason,
) )
def cancel(self, event_id, details=None): def cancel(self, event_id: int, details: Any = None) -> None:
# TODO: implement length constraints on details # TODO: implement length constraints on details
self.cancel_requested = True self.cancel_requested = True
# Can only cancel if there are no other pending desicion tasks # Can only cancel if there are no other pending desicion tasks
@ -483,9 +508,9 @@ class WorkflowExecution(BaseModel):
details=details, details=details,
) )
def schedule_activity_task(self, event_id, attributes): def schedule_activity_task(self, event_id: int, attributes: Dict[str, Any]) -> None:
# Helper function to avoid repeating ourselves in the next sections # Helper function to avoid repeating ourselves in the next sections
def fail_schedule_activity_task(_type, _cause): def fail_schedule_activity_task(_type: "ActivityType", _cause: str) -> None:
# TODO: implement other possible failure mode: OPEN_ACTIVITIES_LIMIT_EXCEEDED # TODO: implement other possible failure mode: OPEN_ACTIVITIES_LIMIT_EXCEEDED
# NB: some failure modes are not implemented and probably won't be implemented in # NB: some failure modes are not implemented and probably won't be implemented in
# the future, such as ACTIVITY_CREATION_RATE_EXCEEDED or # the future, such as ACTIVITY_CREATION_RATE_EXCEEDED or
@ -499,7 +524,7 @@ class WorkflowExecution(BaseModel):
) )
self.should_schedule_decision_next = True self.should_schedule_decision_next = True
activity_type = self.domain.get_type( activity_type: ActivityType = self.domain.get_type( # type: ignore[assignment]
"activity", "activity",
attributes["activityType"]["name"], attributes["activityType"]["name"],
attributes["activityType"]["version"], attributes["activityType"]["version"],
@ -576,13 +601,13 @@ class WorkflowExecution(BaseModel):
self.open_counts["openActivityTasks"] += 1 self.open_counts["openActivityTasks"] += 1
self.latest_activity_task_timestamp = unix_time() self.latest_activity_task_timestamp = unix_time()
def _find_activity_task(self, task_token): def _find_activity_task(self, task_token: str) -> ActivityTask:
for task in self.activity_tasks: for task in self.activity_tasks:
if task.task_token == task_token: if task.task_token == task_token:
return task return task
raise ValueError(f"No activity task with token: {task_token}") raise ValueError(f"No activity task with token: {task_token}")
def start_activity_task(self, task_token, identity=None): def start_activity_task(self, task_token: str, identity: Any = None) -> None:
task = self._find_activity_task(task_token) task = self._find_activity_task(task_token)
evt = self._add_event( evt = self._add_event(
"ActivityTaskStarted", "ActivityTaskStarted",
@ -591,7 +616,7 @@ class WorkflowExecution(BaseModel):
) )
task.start(evt.event_id) task.start(evt.event_id)
def complete_activity_task(self, task_token, result=None): def complete_activity_task(self, task_token: str, result: Any = None) -> None:
task = self._find_activity_task(task_token) task = self._find_activity_task(task_token)
self._add_event( self._add_event(
"ActivityTaskCompleted", "ActivityTaskCompleted",
@ -604,7 +629,9 @@ class WorkflowExecution(BaseModel):
# TODO: ensure we don't schedule multiple decisions at the same time! # TODO: ensure we don't schedule multiple decisions at the same time!
self.schedule_decision_task() self.schedule_decision_task()
def fail_activity_task(self, task_token, reason=None, details=None): def fail_activity_task(
self, task_token: str, reason: Optional[str] = None, details: Any = None
) -> None:
task = self._find_activity_task(task_token) task = self._find_activity_task(task_token)
self._add_event( self._add_event(
"ActivityTaskFailed", "ActivityTaskFailed",
@ -618,7 +645,12 @@ class WorkflowExecution(BaseModel):
# TODO: ensure we don't schedule multiple decisions at the same time! # TODO: ensure we don't schedule multiple decisions at the same time!
self.schedule_decision_task() self.schedule_decision_task()
def terminate(self, child_policy=None, details=None, reason=None): def terminate(
self,
child_policy: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
reason: Optional[str] = None,
) -> None:
# TODO: handle child policy for child workflows here # TODO: handle child policy for child workflows here
# TODO: handle cause="CHILD_POLICY_APPLIED" # TODO: handle cause="CHILD_POLICY_APPLIED"
# Until this, we set cause manually to "OPERATOR_INITIATED" # Until this, we set cause manually to "OPERATOR_INITIATED"
@ -636,13 +668,13 @@ class WorkflowExecution(BaseModel):
self.close_status = "TERMINATED" self.close_status = "TERMINATED"
self.close_cause = "OPERATOR_INITIATED" self.close_cause = "OPERATOR_INITIATED"
def signal(self, signal_name, workflow_input): def signal(self, signal_name: str, workflow_input: Dict[str, Any]) -> None:
self._add_event( self._add_event(
"WorkflowExecutionSignaled", signal_name=signal_name, input=workflow_input "WorkflowExecutionSignaled", signal_name=signal_name, input=workflow_input
) )
self.schedule_decision_task() self.schedule_decision_task()
def first_timeout(self): def first_timeout(self) -> Optional[Timeout]:
if not self.open or not self.start_timestamp: if not self.open or not self.start_timestamp:
return None return None
start_to_close_at = self.start_timestamp + int( start_to_close_at = self.start_timestamp + int(
@ -651,8 +683,9 @@ class WorkflowExecution(BaseModel):
_timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE")
if _timeout.reached: if _timeout.reached:
return _timeout return _timeout
return None
def timeout(self, timeout): def timeout(self, timeout: Timeout) -> None:
# TODO: process child policy on child workflows here or in the # TODO: process child policy on child workflows here or in the
# triggering function # triggering function
self.execution_status = "CLOSED" self.execution_status = "CLOSED"
@ -665,7 +698,7 @@ class WorkflowExecution(BaseModel):
timeout_type=self.timeout_type, timeout_type=self.timeout_type,
) )
def timeout_decision_task(self, _timeout): def timeout_decision_task(self, _timeout: Timeout) -> None:
task = _timeout.obj task = _timeout.obj
task.timeout(_timeout) task.timeout(_timeout)
self._add_event( self._add_event(
@ -676,7 +709,7 @@ class WorkflowExecution(BaseModel):
timeout_type=task.timeout_type, timeout_type=task.timeout_type,
) )
def timeout_activity_task(self, _timeout): def timeout_activity_task(self, _timeout: Timeout) -> None:
task = _timeout.obj task = _timeout.obj
task.timeout(_timeout) task.timeout(_timeout)
self._add_event( self._add_event(
@ -688,7 +721,7 @@ class WorkflowExecution(BaseModel):
timeout_type=task.timeout_type, timeout_type=task.timeout_type,
) )
def record_marker(self, event_id, attributes): def record_marker(self, event_id: int, attributes: Dict[str, Any]) -> None:
self._add_event( self._add_event(
"MarkerRecorded", "MarkerRecorded",
decision_task_completed_event_id=event_id, decision_task_completed_event_id=event_id,
@ -696,7 +729,7 @@ class WorkflowExecution(BaseModel):
marker_name=attributes["markerName"], marker_name=attributes["markerName"],
) )
def start_timer(self, event_id, attributes): def start_timer(self, event_id: int, attributes: Dict[str, Any]) -> None:
timer_id = attributes["timerId"] timer_id = attributes["timerId"]
existing_timer = self._timers.get(timer_id) existing_timer = self._timers.get(timer_id)
if existing_timer and existing_timer.is_alive(): if existing_timer and existing_timer.is_alive():
@ -725,14 +758,14 @@ class WorkflowExecution(BaseModel):
self._timers[timer_id] = workflow_timer self._timers[timer_id] = workflow_timer
workflow_timer.start() workflow_timer.start()
def _fire_timer(self, started_event_id, timer_id): def _fire_timer(self, started_event_id: int, timer_id: str) -> None:
self._add_event( self._add_event(
"TimerFired", started_event_id=started_event_id, timer_id=timer_id "TimerFired", started_event_id=started_event_id, timer_id=timer_id
) )
self._timers.pop(timer_id) self._timers.pop(timer_id)
self._schedule_decision_task() self._schedule_decision_task()
def cancel_timer(self, event_id, timer_id): def cancel_timer(self, event_id: int, timer_id: str) -> None:
requested_timer = self._timers.get(timer_id) requested_timer = self._timers.get(timer_id)
if not requested_timer or not requested_timer.is_alive(): if not requested_timer or not requested_timer.is_alive():
# TODO there are 2 failure states # TODO there are 2 failure states
@ -754,5 +787,5 @@ class WorkflowExecution(BaseModel):
) )
@property @property
def open(self): def open(self) -> bool:
return self.execution_status == "OPEN" return self.execution_status == "OPEN"

View File

@ -1,9 +1,10 @@
from typing import List
from .generic_type import GenericType from .generic_type import GenericType
class WorkflowType(GenericType): class WorkflowType(GenericType):
@property @property
def _configuration_keys(self): def _configuration_keys(self) -> List[str]:
return [ return [
"defaultChildPolicy", "defaultChildPolicy",
"defaultExecutionStartToCloseTimeout", "defaultExecutionStartToCloseTimeout",
@ -13,5 +14,5 @@ class WorkflowType(GenericType):
] ]
@property @property
def kind(self): def kind(self) -> str:
return "workflow" return "workflow"

View File

@ -1,74 +1,75 @@
import json import json
from typing import Any, List
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import SWFSerializationException, SWFValidationException from .exceptions import SWFSerializationException, SWFValidationException
from .models import swf_backends from .models import swf_backends, SWFBackend, GenericType
class SWFResponse(BaseResponse): class SWFResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="swf") super().__init__(service_name="swf")
@property @property
def swf_backend(self): def swf_backend(self) -> SWFBackend:
return swf_backends[self.current_account][self.region] return swf_backends[self.current_account][self.region]
# SWF parameters are passed through a JSON body, so let's ease retrieval # SWF parameters are passed through a JSON body, so let's ease retrieval
@property @property
def _params(self): def _params(self) -> Any: # type: ignore[misc]
return json.loads(self.body) return json.loads(self.body)
def _check_int(self, parameter): def _check_int(self, parameter: Any) -> None:
if not isinstance(parameter, int): if not isinstance(parameter, int):
raise SWFSerializationException(parameter) raise SWFSerializationException(parameter)
def _check_float_or_int(self, parameter): def _check_float_or_int(self, parameter: Any) -> None:
if not isinstance(parameter, float): if not isinstance(parameter, float):
if not isinstance(parameter, int): if not isinstance(parameter, int):
raise SWFSerializationException(parameter) raise SWFSerializationException(parameter)
def _check_none_or_string(self, parameter): def _check_none_or_string(self, parameter: Any) -> None:
if parameter is not None: if parameter is not None:
self._check_string(parameter) self._check_string(parameter)
def _check_string(self, parameter): def _check_string(self, parameter: Any) -> None:
if not isinstance(parameter, str): if not isinstance(parameter, str):
raise SWFSerializationException(parameter) raise SWFSerializationException(parameter)
def _check_none_or_list_of_strings(self, parameter): def _check_none_or_list_of_strings(self, parameter: Any) -> None:
if parameter is not None: if parameter is not None:
self._check_list_of_strings(parameter) self._check_list_of_strings(parameter)
def _check_list_of_strings(self, parameter): def _check_list_of_strings(self, parameter: Any) -> None:
if not isinstance(parameter, list): if not isinstance(parameter, list):
raise SWFSerializationException(parameter) raise SWFSerializationException(parameter)
for i in parameter: for i in parameter:
if not isinstance(i, str): if not isinstance(i, str):
raise SWFSerializationException(parameter) raise SWFSerializationException(parameter)
def _check_exclusivity(self, **kwargs): def _check_exclusivity(self, **kwargs: Any) -> None:
if list(kwargs.values()).count(None) >= len(kwargs) - 1: if list(kwargs.values()).count(None) >= len(kwargs) - 1:
return return
keys = kwargs.keys() keys = kwargs.keys()
if len(keys) == 2: if len(keys) == 2:
message = f"Cannot specify both a {keys[0]} and a {keys[1]}" message = f"Cannot specify both a {keys[0]} and a {keys[1]}" # type: ignore
else: else:
message = f"Cannot specify more than one exclusive filters in the same query: {keys}" message = f"Cannot specify more than one exclusive filters in the same query: {keys}"
raise SWFValidationException(message) raise SWFValidationException(message)
def _list_types(self, kind): def _list_types(self, kind: str) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
status = self._params["registrationStatus"] status = self._params["registrationStatus"]
reverse_order = self._params.get("reverseOrder", None) reverse_order = self._params.get("reverseOrder", None)
self._check_string(domain_name) self._check_string(domain_name)
self._check_string(status) self._check_string(status)
types = self.swf_backend.list_types( types: List[GenericType] = self.swf_backend.list_types(
kind, domain_name, status, reverse_order=reverse_order kind, domain_name, status, reverse_order=reverse_order
) )
return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]}) return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]})
def _describe_type(self, kind): def _describe_type(self, kind: str) -> str:
domain = self._params["domain"] domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"] _type_args = self._params[f"{kind}Type"]
name = _type_args["name"] name = _type_args["name"]
@ -80,7 +81,7 @@ class SWFResponse(BaseResponse):
return json.dumps(_type.to_full_dict()) return json.dumps(_type.to_full_dict())
def _deprecate_type(self, kind): def _deprecate_type(self, kind: str) -> str:
domain = self._params["domain"] domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"] _type_args = self._params[f"{kind}Type"]
name = _type_args["name"] name = _type_args["name"]
@ -91,7 +92,7 @@ class SWFResponse(BaseResponse):
self.swf_backend.deprecate_type(kind, domain, name, version) self.swf_backend.deprecate_type(kind, domain, name, version)
return "" return ""
def _undeprecate_type(self, kind): def _undeprecate_type(self, kind: str) -> str:
domain = self._params["domain"] domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"] _type_args = self._params[f"{kind}Type"]
name = _type_args["name"] name = _type_args["name"]
@ -103,7 +104,7 @@ class SWFResponse(BaseResponse):
return "" return ""
# TODO: implement pagination # TODO: implement pagination
def list_domains(self): def list_domains(self) -> str:
status = self._params["registrationStatus"] status = self._params["registrationStatus"]
self._check_string(status) self._check_string(status)
reverse_order = self._params.get("reverseOrder", None) reverse_order = self._params.get("reverseOrder", None)
@ -112,7 +113,7 @@ class SWFResponse(BaseResponse):
{"domainInfos": [domain.to_short_dict() for domain in domains]} {"domainInfos": [domain.to_short_dict() for domain in domains]}
) )
def list_closed_workflow_executions(self): def list_closed_workflow_executions(self) -> str:
domain = self._params["domain"] domain = self._params["domain"]
start_time_filter = self._params.get("startTimeFilter", None) start_time_filter = self._params.get("startTimeFilter", None)
close_time_filter = self._params.get("closeTimeFilter", None) close_time_filter = self._params.get("closeTimeFilter", None)
@ -166,7 +167,7 @@ class SWFResponse(BaseResponse):
{"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]}
) )
def list_open_workflow_executions(self): def list_open_workflow_executions(self) -> str:
domain = self._params["domain"] domain = self._params["domain"]
start_time_filter = self._params["startTimeFilter"] start_time_filter = self._params["startTimeFilter"]
execution_filter = self._params.get("executionFilter", None) execution_filter = self._params.get("executionFilter", None)
@ -204,7 +205,7 @@ class SWFResponse(BaseResponse):
{"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]}
) )
def register_domain(self): def register_domain(self) -> str:
name = self._params["name"] name = self._params["name"]
retention = self._params["workflowExecutionRetentionPeriodInDays"] retention = self._params["workflowExecutionRetentionPeriodInDays"]
description = self._params.get("description") description = self._params.get("description")
@ -214,29 +215,29 @@ class SWFResponse(BaseResponse):
self.swf_backend.register_domain(name, retention, description=description) self.swf_backend.register_domain(name, retention, description=description)
return "" return ""
def deprecate_domain(self): def deprecate_domain(self) -> str:
name = self._params["name"] name = self._params["name"]
self._check_string(name) self._check_string(name)
self.swf_backend.deprecate_domain(name) self.swf_backend.deprecate_domain(name)
return "" return ""
def undeprecate_domain(self): def undeprecate_domain(self) -> str:
name = self._params["name"] name = self._params["name"]
self._check_string(name) self._check_string(name)
self.swf_backend.undeprecate_domain(name) self.swf_backend.undeprecate_domain(name)
return "" return ""
def describe_domain(self): def describe_domain(self) -> str:
name = self._params["name"] name = self._params["name"]
self._check_string(name) self._check_string(name)
domain = self.swf_backend.describe_domain(name) domain = self.swf_backend.describe_domain(name)
return json.dumps(domain.to_full_dict()) return json.dumps(domain.to_full_dict()) # type: ignore[union-attr]
# TODO: implement pagination # TODO: implement pagination
def list_activity_types(self): def list_activity_types(self) -> str:
return self._list_types("activity") return self._list_types("activity")
def register_activity_type(self): def register_activity_type(self) -> str:
domain = self._params["domain"] domain = self._params["domain"]
name = self._params["name"] name = self._params["name"]
version = self._params["version"] version = self._params["version"]
@ -282,19 +283,19 @@ class SWFResponse(BaseResponse):
) )
return "" return ""
def deprecate_activity_type(self): def deprecate_activity_type(self) -> str:
return self._deprecate_type("activity") return self._deprecate_type("activity")
def undeprecate_activity_type(self): def undeprecate_activity_type(self) -> str:
return self._undeprecate_type("activity") return self._undeprecate_type("activity")
def describe_activity_type(self): def describe_activity_type(self) -> str:
return self._describe_type("activity") return self._describe_type("activity")
def list_workflow_types(self): def list_workflow_types(self) -> str:
return self._list_types("workflow") return self._list_types("workflow")
def register_workflow_type(self): def register_workflow_type(self) -> str:
domain = self._params["domain"] domain = self._params["domain"]
name = self._params["name"] name = self._params["name"]
version = self._params["version"] version = self._params["version"]
@ -340,16 +341,16 @@ class SWFResponse(BaseResponse):
) )
return "" return ""
def deprecate_workflow_type(self): def deprecate_workflow_type(self) -> str:
return self._deprecate_type("workflow") return self._deprecate_type("workflow")
def undeprecate_workflow_type(self): def undeprecate_workflow_type(self) -> str:
return self._undeprecate_type("workflow") return self._undeprecate_type("workflow")
def describe_workflow_type(self): def describe_workflow_type(self) -> str:
return self._describe_type("workflow") return self._describe_type("workflow")
def start_workflow_execution(self): def start_workflow_execution(self) -> str:
domain = self._params["domain"] domain = self._params["domain"]
workflow_id = self._params["workflowId"] workflow_id = self._params["workflowId"]
_workflow_type = self._params["workflowType"] _workflow_type = self._params["workflowType"]
@ -394,7 +395,7 @@ class SWFResponse(BaseResponse):
return json.dumps({"runId": wfe.run_id}) return json.dumps({"runId": wfe.run_id})
def describe_workflow_execution(self): def describe_workflow_execution(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
_workflow_execution = self._params["execution"] _workflow_execution = self._params["execution"]
run_id = _workflow_execution["runId"] run_id = _workflow_execution["runId"]
@ -407,9 +408,9 @@ class SWFResponse(BaseResponse):
wfe = self.swf_backend.describe_workflow_execution( wfe = self.swf_backend.describe_workflow_execution(
domain_name, run_id, workflow_id domain_name, run_id, workflow_id
) )
return json.dumps(wfe.to_full_dict()) return json.dumps(wfe.to_full_dict()) # type: ignore[union-attr]
def get_workflow_execution_history(self): def get_workflow_execution_history(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
_workflow_execution = self._params["execution"] _workflow_execution = self._params["execution"]
run_id = _workflow_execution["runId"] run_id = _workflow_execution["runId"]
@ -418,10 +419,10 @@ class SWFResponse(BaseResponse):
wfe = self.swf_backend.describe_workflow_execution( wfe = self.swf_backend.describe_workflow_execution(
domain_name, run_id, workflow_id domain_name, run_id, workflow_id
) )
events = wfe.events(reverse_order=reverse_order) events = wfe.events(reverse_order=reverse_order) # type: ignore[union-attr]
return json.dumps({"events": [evt.to_dict() for evt in events]}) return json.dumps({"events": [evt.to_dict() for evt in events]})
def poll_for_decision_task(self): def poll_for_decision_task(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
identity = self._params.get("identity") identity = self._params.get("identity")
@ -438,7 +439,7 @@ class SWFResponse(BaseResponse):
else: else:
return json.dumps({"previousStartedEventId": 0, "startedEventId": 0}) return json.dumps({"previousStartedEventId": 0, "startedEventId": 0})
def count_pending_decision_tasks(self): def count_pending_decision_tasks(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
self._check_string(domain_name) self._check_string(domain_name)
@ -446,7 +447,7 @@ class SWFResponse(BaseResponse):
count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False}) return json.dumps({"count": count, "truncated": False})
def respond_decision_task_completed(self): def respond_decision_task_completed(self) -> str:
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
execution_context = self._params.get("executionContext") execution_context = self._params.get("executionContext")
decisions = self._params.get("decisions") decisions = self._params.get("decisions")
@ -457,7 +458,7 @@ class SWFResponse(BaseResponse):
) )
return "" return ""
def poll_for_activity_task(self): def poll_for_activity_task(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
identity = self._params.get("identity") identity = self._params.get("identity")
@ -472,7 +473,7 @@ class SWFResponse(BaseResponse):
else: else:
return json.dumps({"startedEventId": 0}) return json.dumps({"startedEventId": 0})
def count_pending_activity_tasks(self): def count_pending_activity_tasks(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
self._check_string(domain_name) self._check_string(domain_name)
@ -480,7 +481,7 @@ class SWFResponse(BaseResponse):
count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False}) return json.dumps({"count": count, "truncated": False})
def respond_activity_task_completed(self): def respond_activity_task_completed(self) -> str:
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
result = self._params.get("result") result = self._params.get("result")
self._check_string(task_token) self._check_string(task_token)
@ -488,7 +489,7 @@ class SWFResponse(BaseResponse):
self.swf_backend.respond_activity_task_completed(task_token, result=result) self.swf_backend.respond_activity_task_completed(task_token, result=result)
return "" return ""
def respond_activity_task_failed(self): def respond_activity_task_failed(self) -> str:
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
reason = self._params.get("reason") reason = self._params.get("reason")
details = self._params.get("details") details = self._params.get("details")
@ -502,7 +503,7 @@ class SWFResponse(BaseResponse):
) )
return "" return ""
def terminate_workflow_execution(self): def terminate_workflow_execution(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
workflow_id = self._params["workflowId"] workflow_id = self._params["workflowId"]
child_policy = self._params.get("childPolicy") child_policy = self._params.get("childPolicy")
@ -525,7 +526,7 @@ class SWFResponse(BaseResponse):
) )
return "" return ""
def record_activity_task_heartbeat(self): def record_activity_task_heartbeat(self) -> str:
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
details = self._params.get("details") details = self._params.get("details")
self._check_string(task_token) self._check_string(task_token)
@ -534,7 +535,7 @@ class SWFResponse(BaseResponse):
# TODO: make it dynamic when we implement activity tasks cancellation # TODO: make it dynamic when we implement activity tasks cancellation
return json.dumps({"cancelRequested": False}) return json.dumps({"cancelRequested": False})
def signal_workflow_execution(self): def signal_workflow_execution(self) -> str:
domain_name = self._params["domain"] domain_name = self._params["domain"]
signal_name = self._params["signalName"] signal_name = self._params["signalName"]
workflow_id = self._params["workflowId"] workflow_id = self._params["workflowId"]

View File

@ -1,2 +1,2 @@
def decapitalize(key): def decapitalize(key: str) -> str:
return key[0].lower() + key[1:] return key[0].lower() + key[1:]

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract