Techdebt: MyPy SWF (#6256)

This commit is contained in:
Bert Blommers 2023-04-26 10:56:53 +00:00 committed by GitHub
parent da39d2103c
commit 9b969f7e3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 421 additions and 282 deletions

View File

@ -1,12 +1,16 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from moto.core.exceptions import JsonRESTError
if TYPE_CHECKING:
from .models.generic_type import GenericType
class SWFClientError(JsonRESTError):
code = 400
class SWFUnknownResourceFault(SWFClientError):
def __init__(self, resource_type, resource_name=None):
def __init__(self, resource_type: str, resource_name: Optional[str] = None):
if resource_name:
message = f"Unknown {resource_type}: {resource_name}"
else:
@ -15,21 +19,21 @@ class SWFUnknownResourceFault(SWFClientError):
class SWFDomainAlreadyExistsFault(SWFClientError):
def __init__(self, domain_name):
def __init__(self, domain_name: str):
super().__init__(
"com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name
)
class SWFDomainDeprecatedFault(SWFClientError):
def __init__(self, domain_name):
def __init__(self, domain_name: str):
super().__init__(
"com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name
)
class SWFSerializationException(SWFClientError):
def __init__(self, value):
def __init__(self, value: Any):
message = "class java.lang.Foo can not be converted to an String "
message += f" (not a real SWF exception ; happened on: {value})"
__type = "com.amazonaws.swf.base.model#SerializationException"
@ -37,7 +41,7 @@ class SWFSerializationException(SWFClientError):
class SWFTypeAlreadyExistsFault(SWFClientError):
def __init__(self, _type):
def __init__(self, _type: "GenericType"):
super().__init__(
"com.amazonaws.swf.base.model#TypeAlreadyExistsFault",
f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]",
@ -45,7 +49,7 @@ class SWFTypeAlreadyExistsFault(SWFClientError):
class SWFTypeDeprecatedFault(SWFClientError):
def __init__(self, _type):
def __init__(self, _type: "GenericType"):
super().__init__(
"com.amazonaws.swf.base.model#TypeDeprecatedFault",
f"{_type.__class__.__name__}=[name={_type.name}, version={_type.version}]",
@ -53,7 +57,7 @@ class SWFTypeDeprecatedFault(SWFClientError):
class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault",
"Already Started",
@ -61,7 +65,7 @@ class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError):
class SWFDefaultUndefinedFault(SWFClientError):
def __init__(self, key):
def __init__(self, key: str):
# TODO: move that into moto.core.utils maybe?
words = key.split("_")
key_camel_case = words.pop(0)
@ -73,12 +77,12 @@ class SWFDefaultUndefinedFault(SWFClientError):
class SWFValidationException(SWFClientError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("com.amazon.coral.validate#ValidationException", message)
class SWFDecisionValidationException(SWFClientError):
def __init__(self, problems):
def __init__(self, problems: List[Dict[str, Any]]):
# messages
messages = []
for pb in problems:
@ -106,5 +110,5 @@ class SWFDecisionValidationException(SWFClientError):
class SWFWorkflowExecutionClosedError(Exception):
def __str__(self):
def __str__(self) -> str:
return repr("Cannot change this object because the WorkflowExecution is closed")

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict
from ..exceptions import (
@ -12,7 +13,7 @@ from .activity_task import ActivityTask # noqa
from .activity_type import ActivityType # noqa
from .decision_task import DecisionTask # noqa
from .domain import Domain # noqa
from .generic_type import GenericType # noqa
from .generic_type import GenericType, TGenericType # noqa
from .history_event import HistoryEvent # noqa
from .timeout import Timeout # noqa
from .timer import Timer # noqa
@ -24,33 +25,39 @@ KNOWN_SWF_TYPES = {"activity": ActivityType, "workflow": WorkflowType}
class SWFBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.domains = []
self.domains: List[Domain] = []
def _get_domain(self, name, ignore_empty=False):
def _get_domain(self, name: str, ignore_empty: bool = False) -> Domain:
matching = [domain for domain in self.domains if domain.name == name]
if not matching and not ignore_empty:
raise SWFUnknownResourceFault("domain", name)
if matching:
return matching[0]
return None
return None # type: ignore
def _process_timeouts(self):
def _process_timeouts(self) -> None:
for domain in self.domains:
for wfe in domain.workflow_executions:
wfe._process_timeouts()
def list_domains(self, status, reverse_order=None):
def list_domains(
self, status: str, reverse_order: Optional[bool] = None
) -> List[Domain]:
domains = [domain for domain in self.domains if domain.status == status]
domains = sorted(domains, key=lambda domain: domain.name)
if reverse_order:
domains = reversed(domains)
domains = reversed(domains) # type: ignore[assignment]
return domains
def list_open_workflow_executions(
self, domain_name, maximum_page_size, tag_filter, reverse_order
):
self,
domain_name: str,
maximum_page_size: int,
tag_filter: Dict[str, str],
reverse_order: bool,
) -> List[WorkflowExecution]:
self._process_timeouts()
domain = self._get_domain(domain_name)
if domain.status == "DEPRECATED":
@ -64,17 +71,17 @@ class SWFBackend(BaseBackend):
if tag_filter["tag"] not in open_wfe.tag_list:
open_wfes.remove(open_wfe)
if reverse_order:
open_wfes = reversed(open_wfes)
open_wfes = reversed(open_wfes) # type: ignore[assignment]
return open_wfes[0:maximum_page_size]
def list_closed_workflow_executions(
self,
domain_name,
tag_filter,
close_status_filter,
maximum_page_size,
reverse_order,
):
domain_name: str,
tag_filter: Dict[str, str],
close_status_filter: Dict[str, str],
maximum_page_size: int,
reverse_order: bool,
) -> List[WorkflowExecution]:
self._process_timeouts()
domain = self._get_domain(domain_name)
if domain.status == "DEPRECATED":
@ -90,15 +97,18 @@ class SWFBackend(BaseBackend):
closed_wfes.remove(closed_wfe)
if close_status_filter:
for closed_wfe in closed_wfes:
if close_status_filter != closed_wfe.close_status:
if close_status_filter != closed_wfe.close_status: # type: ignore
closed_wfes.remove(closed_wfe)
if reverse_order:
closed_wfes = reversed(closed_wfes)
closed_wfes = reversed(closed_wfes) # type: ignore[assignment]
return closed_wfes[0:maximum_page_size]
def register_domain(
self, name, workflow_execution_retention_period_in_days, description=None
):
self,
name: str,
workflow_execution_retention_period_in_days: int,
description: Optional[str] = None,
) -> None:
if self._get_domain(name, ignore_empty=True):
raise SWFDomainAlreadyExistsFault(name)
domain = Domain(
@ -110,68 +120,82 @@ class SWFBackend(BaseBackend):
)
self.domains.append(domain)
def deprecate_domain(self, name):
def deprecate_domain(self, name: str) -> None:
domain = self._get_domain(name)
if domain.status == "DEPRECATED":
raise SWFDomainDeprecatedFault(name)
domain.status = "DEPRECATED"
def undeprecate_domain(self, name):
def undeprecate_domain(self, name: str) -> None:
domain = self._get_domain(name)
if domain.status == "REGISTERED":
raise SWFDomainAlreadyExistsFault(name)
domain.status = "REGISTERED"
def describe_domain(self, name):
def describe_domain(self, name: str) -> Optional[Domain]:
return self._get_domain(name)
def list_types(self, kind, domain_name, status, reverse_order=None):
def list_types(
self,
kind: str,
domain_name: str,
status: str,
reverse_order: Optional[bool] = None,
) -> List[GenericType]:
domain = self._get_domain(domain_name)
_types = domain.find_types(kind, status)
_types: List[GenericType] = domain.find_types(kind, status)
_types = sorted(_types, key=lambda domain: domain.name)
if reverse_order:
_types = reversed(_types)
_types = reversed(_types) # type: ignore
return _types
def register_type(self, kind, domain_name, name, version, **kwargs):
def register_type(
self, kind: str, domain_name: str, name: str, version: str, **kwargs: Any
) -> None:
domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version, ignore_empty=True)
_type: GenericType = domain.get_type(kind, name, version, ignore_empty=True)
if _type:
raise SWFTypeAlreadyExistsFault(_type)
_class = KNOWN_SWF_TYPES[kind]
_type = _class(name, version, **kwargs)
domain.add_type(_type)
def deprecate_type(self, kind, domain_name, name, version):
def deprecate_type(
self, kind: str, domain_name: str, name: str, version: str
) -> None:
domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version)
_type: GenericType = domain.get_type(kind, name, version)
if _type.status == "DEPRECATED":
raise SWFTypeDeprecatedFault(_type)
_type.status = "DEPRECATED"
def undeprecate_type(self, kind, domain_name, name, version):
def undeprecate_type(
self, kind: str, domain_name: str, name: str, version: str
) -> None:
domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version)
_type: GenericType = domain.get_type(kind, name, version)
if _type.status == "REGISTERED":
raise SWFTypeAlreadyExistsFault(_type)
_type.status = "REGISTERED"
def describe_type(self, kind, domain_name, name, version):
def describe_type(
self, kind: str, domain_name: str, name: str, version: str
) -> GenericType:
domain = self._get_domain(domain_name)
return domain.get_type(kind, name, version)
def start_workflow_execution(
self,
domain_name,
workflow_id,
workflow_name,
workflow_version,
tag_list=None,
workflow_input=None,
**kwargs,
):
domain_name: str,
workflow_id: str,
workflow_name: str,
workflow_version: str,
tag_list: Optional[Dict[str, str]] = None,
workflow_input: Optional[str] = None,
**kwargs: Any,
) -> WorkflowExecution:
domain = self._get_domain(domain_name)
wf_type = domain.get_type("workflow", workflow_name, workflow_version)
wf_type: WorkflowType = domain.get_type("workflow", workflow_name, workflow_version) # type: ignore
if wf_type.status == "DEPRECATED":
raise SWFTypeDeprecatedFault(wf_type)
wfe = WorkflowExecution(
@ -187,13 +211,17 @@ class SWFBackend(BaseBackend):
return wfe
def describe_workflow_execution(self, domain_name, run_id, workflow_id):
def describe_workflow_execution(
self, domain_name: str, run_id: str, workflow_id: str
) -> Optional[WorkflowExecution]:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
return domain.get_workflow_execution(workflow_id, run_id=run_id)
def poll_for_decision_task(self, domain_name, task_list, identity=None):
def poll_for_decision_task(
self, domain_name: str, task_list: List[str], identity: Optional[str] = None
) -> Optional[DecisionTask]:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
@ -245,7 +273,9 @@ class SWFBackend(BaseBackend):
sleep(1)
return None
def count_pending_decision_tasks(self, domain_name, task_list):
def count_pending_decision_tasks(
self, domain_name: str, task_list: List[str]
) -> int:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
@ -256,8 +286,11 @@ class SWFBackend(BaseBackend):
return count
def respond_decision_task_completed(
self, task_token, decisions=None, execution_context=None
):
self,
task_token: str,
decisions: Optional[List[Dict[str, Any]]] = None,
execution_context: Optional[str] = None,
) -> None:
# process timeouts on all objects
self._process_timeouts()
# let's find decision task
@ -308,7 +341,9 @@ class SWFBackend(BaseBackend):
execution_context=execution_context,
)
def poll_for_activity_task(self, domain_name, task_list, identity=None):
def poll_for_activity_task(
self, domain_name: str, task_list: List[str], identity: Optional[str] = None
) -> Optional[ActivityTask]:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
@ -342,7 +377,9 @@ class SWFBackend(BaseBackend):
sleep(1)
return None
def count_pending_activity_tasks(self, domain_name, task_list):
def count_pending_activity_tasks(
self, domain_name: str, task_list: List[str]
) -> int:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
@ -353,7 +390,7 @@ class SWFBackend(BaseBackend):
count += len(pending)
return count
def _find_activity_task_from_token(self, task_token):
def _find_activity_task_from_token(self, task_token: str) -> ActivityTask:
activity_task = None
for domain in self.domains:
for wfe in domain.workflow_executions:
@ -389,14 +426,18 @@ class SWFBackend(BaseBackend):
# everything's good
return activity_task
def respond_activity_task_completed(self, task_token, result=None):
def respond_activity_task_completed(
self, task_token: str, result: Any = None
) -> None:
# process timeouts on all objects
self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token)
wfe = activity_task.workflow_execution
wfe.complete_activity_task(activity_task.task_token, result=result)
def respond_activity_task_failed(self, task_token, reason=None, details=None):
def respond_activity_task_failed(
self, task_token: str, reason: Optional[str] = None, details: Any = None
) -> None:
# process timeouts on all objects
self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token)
@ -405,22 +446,24 @@ class SWFBackend(BaseBackend):
def terminate_workflow_execution(
self,
domain_name,
workflow_id,
child_policy=None,
details=None,
reason=None,
run_id=None,
):
domain_name: str,
workflow_id: str,
child_policy: Any = None,
details: Any = None,
reason: Optional[str] = None,
run_id: Optional[str] = None,
) -> None:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
wfe = domain.get_workflow_execution(
workflow_id, run_id=run_id, raise_if_closed=True
)
wfe.terminate(child_policy=child_policy, details=details, reason=reason)
wfe.terminate(child_policy=child_policy, details=details, reason=reason) # type: ignore[union-attr]
def record_activity_task_heartbeat(self, task_token, details=None):
def record_activity_task_heartbeat(
self, task_token: str, details: Any = None
) -> None:
# process timeouts on all objects
self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token)
@ -429,15 +472,20 @@ class SWFBackend(BaseBackend):
activity_task.details = details
def signal_workflow_execution(
self, domain_name, signal_name, workflow_id, workflow_input=None, run_id=None
):
self,
domain_name: str,
signal_name: str,
workflow_id: str,
workflow_input: Any = None,
run_id: Optional[str] = None,
) -> None:
# process timeouts on all objects
self._process_timeouts()
domain = self._get_domain(domain_name)
wfe = domain.get_workflow_execution(
workflow_id, run_id=run_id, raise_if_closed=True
)
wfe.signal(signal_name, workflow_input)
wfe.signal(signal_name, workflow_input) # type: ignore[union-attr]
swf_backends = BackendDict(SWFBackend, "swf")

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from moto.core import BaseModel
from moto.core.utils import unix_time
@ -7,16 +8,20 @@ from ..exceptions import SWFWorkflowExecutionClosedError
from .timeout import Timeout
if TYPE_CHECKING:
from .activity_type import ActivityType
from .workflow_execution import WorkflowExecution
class ActivityTask(BaseModel):
def __init__(
self,
activity_id,
activity_type,
scheduled_event_id,
workflow_execution,
timeouts,
workflow_input=None,
activity_id: str,
activity_type: "ActivityType",
scheduled_event_id: int,
workflow_execution: "WorkflowExecution",
timeouts: Dict[str, Any],
workflow_input: Any = None,
):
self.activity_id = activity_id
self.activity_type = activity_type
@ -24,26 +29,26 @@ class ActivityTask(BaseModel):
self.input = workflow_input
self.last_heartbeat_timestamp = unix_time()
self.scheduled_event_id = scheduled_event_id
self.started_event_id = None
self.started_event_id: Optional[int] = None
self.state = "SCHEDULED"
self.task_token = str(mock_random.uuid4())
self.timeouts = timeouts
self.timeout_type = None
self.timeout_type: Optional[str] = None
self.workflow_execution = workflow_execution
# this is *not* necessarily coherent with workflow execution history,
# but that shouldn't be a problem for tests
self.scheduled_at = datetime.utcnow()
def _check_workflow_execution_open(self):
def _check_workflow_execution_open(self) -> None:
if not self.workflow_execution.open:
raise SWFWorkflowExecutionClosedError()
@property
def open(self):
def open(self) -> bool:
return self.state in ["SCHEDULED", "STARTED"]
def to_full_dict(self):
hsh = {
def to_full_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {
"activityId": self.activity_id,
"activityType": self.activity_type.to_short_dict(),
"taskToken": self.task_token,
@ -54,22 +59,22 @@ class ActivityTask(BaseModel):
hsh["input"] = self.input
return hsh
def start(self, started_event_id):
def start(self, started_event_id: int) -> None:
self.state = "STARTED"
self.started_event_id = started_event_id
def complete(self):
def complete(self) -> None:
self._check_workflow_execution_open()
self.state = "COMPLETED"
def fail(self):
def fail(self) -> None:
self._check_workflow_execution_open()
self.state = "FAILED"
def reset_heartbeat_clock(self):
def reset_heartbeat_clock(self) -> None:
self.last_heartbeat_timestamp = unix_time()
def first_timeout(self):
def first_timeout(self) -> Optional[Timeout]:
if not self.open or not self.workflow_execution.open:
return None
@ -82,13 +87,14 @@ class ActivityTask(BaseModel):
_timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT")
if _timeout.reached:
return _timeout
return None
def process_timeouts(self):
def process_timeouts(self) -> None:
_timeout = self.first_timeout()
if _timeout:
self.timeout(_timeout)
def timeout(self, _timeout):
def timeout(self, _timeout: Timeout) -> None:
self._check_workflow_execution_open()
self.state = "TIMED_OUT"
self.timeout_type = _timeout.kind

View File

@ -1,9 +1,10 @@
from typing import List
from .generic_type import GenericType
class ActivityType(GenericType):
@property
def _configuration_keys(self):
def _configuration_keys(self) -> List[str]:
return [
"defaultTaskHeartbeatTimeout",
"defaultTaskScheduleToCloseTimeout",
@ -12,5 +13,5 @@ class ActivityType(GenericType):
]
@property
def kind(self):
def kind(self) -> str:
return "activity"

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from moto.core import BaseModel
from moto.core.utils import unix_time
@ -7,16 +8,21 @@ from ..exceptions import SWFWorkflowExecutionClosedError
from .timeout import Timeout
if TYPE_CHECKING:
from .workflow_execution import WorkflowExecution
class DecisionTask(BaseModel):
def __init__(self, workflow_execution, scheduled_event_id):
def __init__(
self, workflow_execution: "WorkflowExecution", scheduled_event_id: int
):
self.workflow_execution = workflow_execution
self.workflow_type = workflow_execution.workflow_type
self.task_token = str(mock_random.uuid4())
self.scheduled_event_id = scheduled_event_id
self.previous_started_event_id = None
self.started_event_id = None
self.started_timestamp = None
self.previous_started_event_id: Optional[int] = None
self.started_event_id: Optional[int] = None
self.started_timestamp: Optional[float] = None
self.start_to_close_timeout = (
self.workflow_execution.task_start_to_close_timeout
)
@ -24,19 +30,19 @@ class DecisionTask(BaseModel):
# this is *not* necessarily coherent with workflow execution history,
# but that shouldn't be a problem for tests
self.scheduled_at = datetime.utcnow()
self.timeout_type = None
self.timeout_type: Optional[str] = None
@property
def started(self):
def started(self) -> bool:
return self.state == "STARTED"
def _check_workflow_execution_open(self):
def _check_workflow_execution_open(self) -> None:
if not self.workflow_execution.open:
raise SWFWorkflowExecutionClosedError()
def to_full_dict(self, reverse_order=False):
def to_full_dict(self, reverse_order: bool = False) -> Dict[str, Any]:
events = self.workflow_execution.events(reverse_order=reverse_order)
hsh = {
hsh: Dict[str, Any] = {
"events": [evt.to_dict() for evt in events],
"taskToken": self.task_token,
"workflowExecution": self.workflow_execution.to_short_dict(),
@ -48,31 +54,34 @@ class DecisionTask(BaseModel):
hsh["startedEventId"] = self.started_event_id
return hsh
def start(self, started_event_id, previous_started_event_id=None):
def start(
self, started_event_id: int, previous_started_event_id: Optional[int] = None
) -> None:
self.state = "STARTED"
self.started_timestamp = unix_time()
self.started_event_id = started_event_id
self.previous_started_event_id = previous_started_event_id
def complete(self):
def complete(self) -> None:
self._check_workflow_execution_open()
self.state = "COMPLETED"
def first_timeout(self):
def first_timeout(self) -> Optional[Timeout]:
if not self.started or not self.workflow_execution.open:
return None
# TODO: handle the "NONE" case
start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout)
start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) # type: ignore
_timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE")
if _timeout.reached:
return _timeout
return None
def process_timeouts(self):
def process_timeouts(self) -> None:
_timeout = self.first_timeout()
if _timeout:
self.timeout(_timeout)
def timeout(self, _timeout):
def timeout(self, _timeout: Timeout) -> None:
self._check_workflow_execution_open()
self.state = "TIMED_OUT"
self.timeout_type = _timeout.kind

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from moto.core import BaseModel
from ..exceptions import (
@ -6,29 +7,45 @@ from ..exceptions import (
SWFWorkflowExecutionAlreadyStartedFault,
)
if TYPE_CHECKING:
from .activity_task import ActivityTask
from .decision_task import DecisionTask
from .generic_type import GenericType, TGenericType
from .workflow_execution import WorkflowExecution
class Domain(BaseModel):
def __init__(self, name, retention, account_id, region_name, description=None):
def __init__(
self,
name: str,
retention: int,
account_id: str,
region_name: str,
description: Optional[str] = None,
):
self.name = name
self.retention = retention
self.account_id = account_id
self.region_name = region_name
self.description = description
self.status = "REGISTERED"
self.types = {"activity": defaultdict(dict), "workflow": defaultdict(dict)}
self.types: Dict[str, Dict[str, Dict[str, GenericType]]] = {
"activity": defaultdict(dict),
"workflow": defaultdict(dict),
}
# Workflow executions have an id, which unicity is guaranteed
# at domain level (not super clear in the docs, but I checked
# that against SWF API) ; hence the storage method as a dict
# of "workflow_id (client determined)" => WorkflowExecution()
# here.
self.workflow_executions = []
self.activity_task_lists = {}
self.decision_task_lists = {}
self.workflow_executions: List["WorkflowExecution"] = []
self.activity_task_lists: Dict[List[str], List["ActivityTask"]] = {}
self.decision_task_lists: Dict[str, List["DecisionTask"]] = {}
def __repr__(self):
def __repr__(self) -> str:
return f"Domain(name: {self.name}, status: {self.status})"
def to_short_dict(self):
def to_short_dict(self) -> Dict[str, str]:
hsh = {"name": self.name, "status": self.status}
if self.description:
hsh["description"] = self.description
@ -37,13 +54,13 @@ class Domain(BaseModel):
] = f"arn:aws:swf:{self.region_name}:{self.account_id}:/domain/{self.name}"
return hsh
def to_full_dict(self):
def to_full_dict(self) -> Dict[str, Any]:
return {
"domainInfo": self.to_short_dict(),
"configuration": {"workflowExecutionRetentionPeriodInDays": self.retention},
}
def get_type(self, kind, name, version, ignore_empty=False):
def get_type(self, kind: str, name: str, version: str, ignore_empty: bool = False) -> "GenericType": # type: ignore
try:
return self.types[kind][name][version]
except KeyError:
@ -53,26 +70,30 @@ class Domain(BaseModel):
f"{kind.capitalize()}Type=[name={name}, version={version}]",
)
def add_type(self, _type):
def add_type(self, _type: "TGenericType") -> None:
self.types[_type.kind][_type.name][_type.version] = _type
def find_types(self, kind, status):
def find_types(self, kind: str, status: str) -> List["GenericType"]:
_all = []
for family in self.types[kind].values():
for _type in family.values():
if _type.status == status:
if _type.status == status: # type: ignore
_all.append(_type)
return _all
def add_workflow_execution(self, workflow_execution):
def add_workflow_execution(self, workflow_execution: "WorkflowExecution") -> None:
_id = workflow_execution.workflow_id
if self.get_workflow_execution(_id, raise_if_none=False):
raise SWFWorkflowExecutionAlreadyStartedFault()
self.workflow_executions.append(workflow_execution)
def get_workflow_execution(
self, workflow_id, run_id=None, raise_if_none=True, raise_if_closed=False
):
self,
workflow_id: str,
run_id: Optional[str] = None,
raise_if_none: bool = True,
raise_if_closed: bool = False,
) -> Optional["WorkflowExecution"]:
# query
if run_id:
_all = [
@ -103,26 +124,28 @@ class Domain(BaseModel):
# at last return workflow execution
return wfe
def add_to_activity_task_list(self, task_list, obj):
def add_to_activity_task_list(
self, task_list: List[str], obj: "ActivityTask"
) -> None:
if task_list not in self.activity_task_lists:
self.activity_task_lists[task_list] = []
self.activity_task_lists[task_list].append(obj)
@property
def activity_tasks(self):
_all = []
def activity_tasks(self) -> List["ActivityTask"]:
_all: List["ActivityTask"] = []
for tasks in self.activity_task_lists.values():
_all += tasks
return _all
def add_to_decision_task_list(self, task_list, obj):
def add_to_decision_task_list(self, task_list: str, obj: "DecisionTask") -> None:
if task_list not in self.decision_task_lists:
self.decision_task_lists[task_list] = []
self.decision_task_lists[task_list].append(obj)
@property
def decision_tasks(self):
_all = []
def decision_tasks(self) -> List["DecisionTask"]:
_all: List["DecisionTask"] = []
for tasks in self.decision_task_lists.values():
_all += tasks
return _all

View File

@ -1,9 +1,10 @@
from typing import Any, Dict, List, TypeVar
from moto.core import BaseModel
from moto.core.utils import camelcase_to_underscores
class GenericType(BaseModel):
def __init__(self, name, version, **kwargs):
def __init__(self, name: str, version: str, **kwargs: Any):
self.name = name
self.version = version
self.status = "REGISTERED"
@ -19,24 +20,24 @@ class GenericType(BaseModel):
if not hasattr(self, "task_list"):
self.task_list = None
def __repr__(self):
def __repr__(self) -> str:
cls = self.__class__.__name__
attrs = f"name: {self.name}, version: {self.version}, status: {self.status}"
return f"{cls}({attrs})"
@property
def kind(self):
def kind(self) -> str:
raise NotImplementedError()
@property
def _configuration_keys(self):
def _configuration_keys(self) -> List[str]:
raise NotImplementedError()
def to_short_dict(self):
def to_short_dict(self) -> Dict[str, str]:
return {"name": self.name, "version": self.version}
def to_medium_dict(self):
hsh = {
def to_medium_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {
f"{self.kind}Type": self.to_short_dict(),
"creationDate": 1420066800,
"status": self.status,
@ -47,8 +48,8 @@ class GenericType(BaseModel):
hsh["description"] = self.description
return hsh
def to_full_dict(self):
hsh = {"typeInfo": self.to_medium_dict(), "configuration": {}}
def to_full_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {"typeInfo": self.to_medium_dict(), "configuration": {}}
if self.task_list:
hsh["configuration"]["defaultTaskList"] = {"name": self.task_list}
for key in self._configuration_keys:
@ -57,3 +58,6 @@ class GenericType(BaseModel):
continue
hsh["configuration"][key] = getattr(self, attr)
return hsh
TGenericType = TypeVar("TGenericType", bound=GenericType)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, Optional
from moto.core import BaseModel
from moto.core.utils import underscores_to_camelcase, unix_time
@ -36,7 +37,13 @@ SUPPORTED_HISTORY_EVENT_TYPES = (
class HistoryEvent(BaseModel):
def __init__(self, event_id, event_type, event_timestamp=None, **kwargs):
def __init__(
self,
event_id: int,
event_type: str,
event_timestamp: Optional[float] = None,
**kwargs: Any,
):
if event_type not in SUPPORTED_HISTORY_EVENT_TYPES:
raise NotImplementedError(
f"HistoryEvent does not implement attributes for type '{event_type}'"
@ -60,7 +67,7 @@ class HistoryEvent(BaseModel):
value = value.to_short_dict()
self.event_attributes[camel_key] = value
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"eventId": self.event_id,
"eventType": self.event_type,
@ -68,6 +75,6 @@ class HistoryEvent(BaseModel):
self._attributes_key(): self.event_attributes,
}
def _attributes_key(self):
def _attributes_key(self) -> str:
key = f"{self.event_type}EventAttributes"
return decapitalize(key)

View File

@ -1,13 +1,14 @@
from typing import Any
from moto.core import BaseModel
from moto.core.utils import unix_time
class Timeout(BaseModel):
def __init__(self, obj, timestamp, kind):
def __init__(self, obj: Any, timestamp: float, kind: str):
self.obj = obj
self.timestamp = timestamp
self.kind = kind
@property
def reached(self):
def reached(self) -> bool:
return unix_time() >= self.timestamp

View File

@ -1,16 +1,17 @@
from threading import Timer as ThreadingTimer
from moto.core import BaseModel
class Timer(BaseModel):
def __init__(self, background_timer, started_event_id):
def __init__(self, background_timer: ThreadingTimer, started_event_id: int):
self.background_timer = background_timer
self.started_event_id = started_event_id
def start(self):
def start(self) -> None:
return self.background_timer.start()
def is_alive(self):
def is_alive(self) -> bool:
return self.background_timer.is_alive()
def cancel(self):
def cancel(self) -> None:
return self.background_timer.cancel()

View File

@ -1,4 +1,5 @@
from threading import Timer as ThreadingTimer, Lock
from typing import Any, Dict, Iterable, List, Optional
from moto.core import BaseModel
from moto.core.utils import camelcase_to_underscores, unix_time
@ -14,9 +15,11 @@ from ..utils import decapitalize
from .activity_task import ActivityTask
from .activity_type import ActivityType
from .decision_task import DecisionTask
from .domain import Domain
from .history_event import HistoryEvent
from .timeout import Timeout
from .timer import Timer
from .workflow_type import WorkflowType
# TODO: extract decision related logic into a Decision class
@ -40,7 +43,13 @@ class WorkflowExecution(BaseModel):
"CancelWorkflowExecution",
]
def __init__(self, domain, workflow_type, workflow_id, **kwargs):
def __init__(
self,
domain: Domain,
workflow_type: "WorkflowType",
workflow_id: str,
**kwargs: Any,
):
self.domain = domain
self.workflow_id = workflow_id
self.run_id = mock_random.uuid4().hex
@ -49,27 +58,33 @@ class WorkflowExecution(BaseModel):
# TODO: check valid values among:
# COMPLETED | FAILED | CANCELED | TERMINATED | CONTINUED_AS_NEW | TIMED_OUT
# TODO: implement them all
self.close_cause = None
self.close_status = None
self.close_timestamp = None
self.close_cause: Optional[str] = None
self.close_status: Optional[str] = None
self.close_timestamp: Optional[float] = None
self.execution_status = "OPEN"
self.latest_activity_task_timestamp = None
self.latest_execution_context = None
self.latest_activity_task_timestamp: Optional[float] = None
self.latest_execution_context: Optional[str] = None
self.parent = None
self.start_timestamp = None
self.start_timestamp: Optional[float] = None
self.tag_list = kwargs.get("tag_list", None) or []
self.timeout_type = None
self.timeout_type: Optional[str] = None
self.workflow_type = workflow_type
# args processing
# NB: the order follows boto/SWF order of exceptions appearance (if no
# param is set, # SWF will raise DefaultUndefinedFault errors in the
# same order as the few lines that follow)
self._set_from_kwargs_or_workflow_type(
self.execution_start_to_close_timeout = self._get_from_kwargs_or_workflow_type(
kwargs, "execution_start_to_close_timeout"
)
self._set_from_kwargs_or_workflow_type(kwargs, "task_list", "task_list")
self._set_from_kwargs_or_workflow_type(kwargs, "task_start_to_close_timeout")
self._set_from_kwargs_or_workflow_type(kwargs, "child_policy")
self.task_list = self._get_from_kwargs_or_workflow_type(
kwargs, "task_list", "task_list"
)
self.task_start_to_close_timeout = self._get_from_kwargs_or_workflow_type(
kwargs, "task_start_to_close_timeout"
)
self.child_policy = self._get_from_kwargs_or_workflow_type(
kwargs, "child_policy"
)
self.input = kwargs.get("workflow_input")
# counters
self.open_counts = {
@ -80,20 +95,23 @@ class WorkflowExecution(BaseModel):
"openLambdaFunctions": 0,
}
# events
self._events = []
self._events: List[HistoryEvent] = []
# child workflows
self.child_workflow_executions = []
self._previous_started_event_id = None
self.child_workflow_executions: List[WorkflowExecution] = []
self._previous_started_event_id: Optional[int] = None
# timers/thread utils
self.threading_lock = Lock()
self._timers = {}
self._timers: Dict[str, Timer] = {}
def __repr__(self):
def __repr__(self) -> str:
return f"WorkflowExecution(run_id: {self.run_id})"
def _set_from_kwargs_or_workflow_type(
self, kwargs, local_key, workflow_type_key=None
):
def _get_from_kwargs_or_workflow_type(
self,
kwargs: Dict[str, Any],
local_key: str,
workflow_type_key: Optional[str] = None,
) -> Any:
if workflow_type_key is None:
workflow_type_key = "default_" + local_key
value = kwargs.get(local_key)
@ -101,10 +119,10 @@ class WorkflowExecution(BaseModel):
value = getattr(self.workflow_type, workflow_type_key)
if not value:
raise SWFDefaultUndefinedFault(local_key)
setattr(self, local_key, value)
return value
@property
def _configuration_keys(self):
def _configuration_keys(self) -> List[str]:
return [
"executionStartToCloseTimeout",
"childPolicy",
@ -112,11 +130,11 @@ class WorkflowExecution(BaseModel):
"taskStartToCloseTimeout",
]
def to_short_dict(self):
def to_short_dict(self) -> Dict[str, str]:
return {"workflowId": self.workflow_id, "runId": self.run_id}
def to_medium_dict(self):
hsh = {
def to_medium_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {
"execution": self.to_short_dict(),
"workflowType": self.workflow_type.to_short_dict(),
"startTimestamp": 1420066800.123,
@ -127,8 +145,8 @@ class WorkflowExecution(BaseModel):
hsh["tagList"] = self.tag_list
return hsh
def to_full_dict(self):
hsh = {
def to_full_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {
"executionInfo": self.to_medium_dict(),
"executionConfiguration": {"taskList": {"name": self.task_list}},
}
@ -153,8 +171,8 @@ class WorkflowExecution(BaseModel):
hsh["latestActivityTaskTimestamp"] = self.latest_activity_task_timestamp
return hsh
def to_list_dict(self):
hsh = {
def to_list_dict(self) -> Dict[str, Any]:
hsh: Dict[str, Any] = {
"execution": {"workflowId": self.workflow_id, "runId": self.run_id},
"workflowType": self.workflow_type.to_short_dict(),
"startTimestamp": self.start_timestamp,
@ -171,7 +189,7 @@ class WorkflowExecution(BaseModel):
hsh["closeTimestamp"] = self.close_timestamp
return hsh
def _process_timeouts(self):
def _process_timeouts(self) -> None:
"""
SWF timeouts can happen on different objects (workflow executions,
activity tasks, decision tasks) and should be processed in order.
@ -187,26 +205,24 @@ class WorkflowExecution(BaseModel):
triggered, process it, then make the workflow state progress and repeat
the whole process.
"""
timeout_candidates = []
# workflow execution timeout
timeout_candidates.append(self.first_timeout())
timeout_candidates_or_none = [self.first_timeout()]
# decision tasks timeouts
for task in self.decision_tasks:
timeout_candidates.append(task.first_timeout())
for d_task in self.decision_tasks:
timeout_candidates_or_none.append(d_task.first_timeout())
# activity tasks timeouts
for task in self.activity_tasks:
timeout_candidates.append(task.first_timeout())
for a_task in self.activity_tasks:
timeout_candidates_or_none.append(a_task.first_timeout())
# remove blank values (foo.first_timeout() is a Timeout or None)
timeout_candidates = list(filter(None, timeout_candidates))
timeout_candidates = list(filter(None, timeout_candidates_or_none))
# now find the first timeout to process
first_timeout = None
if timeout_candidates:
first_timeout = min(timeout_candidates, key=lambda t: t.timestamp)
first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) # type: ignore
if first_timeout:
should_schedule_decision_next = False
@ -229,17 +245,17 @@ class WorkflowExecution(BaseModel):
# timeout should be processed
self._process_timeouts()
def events(self, reverse_order=False):
def events(self, reverse_order: bool = False) -> Iterable[HistoryEvent]:
if reverse_order:
return reversed(self._events)
else:
return self._events
def next_event_id(self):
def next_event_id(self) -> int:
event_ids = [evt.event_id for evt in self._events]
return max(event_ids or [0]) + 1
def _add_event(self, *args, **kwargs):
def _add_event(self, *args: Any, **kwargs: Any) -> HistoryEvent:
# lock here because the fire_timer function is called
# async, and want to ensure uniqueness in event ids
with self.threading_lock:
@ -247,7 +263,7 @@ class WorkflowExecution(BaseModel):
self._events.append(evt)
return evt
def start(self):
def start(self) -> None:
self.start_timestamp = unix_time()
self._add_event(
"WorkflowExecutionStarted",
@ -262,7 +278,7 @@ class WorkflowExecution(BaseModel):
)
self.schedule_decision_task()
def _schedule_decision_task(self):
def _schedule_decision_task(self) -> None:
has_scheduled_task = False
has_started_task = False
for task in self.decision_tasks:
@ -285,30 +301,32 @@ class WorkflowExecution(BaseModel):
)
self.open_counts["openDecisionTasks"] += 1
def schedule_decision_task(self):
def schedule_decision_task(self) -> None:
self._schedule_decision_task()
# Shortcut for tests: helps having auto-starting decision tasks when needed
def schedule_and_start_decision_task(self, identity=None):
def schedule_and_start_decision_task(self, identity: Optional[str] = None) -> None:
self._schedule_decision_task()
decision_task = self.decision_tasks[-1]
self.start_decision_task(decision_task.task_token, identity=identity)
@property
def decision_tasks(self):
def decision_tasks(self) -> List[DecisionTask]:
return [t for t in self.domain.decision_tasks if t.workflow_execution == self]
@property
def activity_tasks(self):
def activity_tasks(self) -> List[ActivityTask]:
return [t for t in self.domain.activity_tasks if t.workflow_execution == self]
def _find_decision_task(self, task_token):
def _find_decision_task(self, task_token: str) -> DecisionTask:
for dt in self.decision_tasks:
if dt.task_token == task_token:
return dt
raise ValueError(f"No decision task with token: {task_token}")
def start_decision_task(self, task_token, identity=None):
def start_decision_task(
self, task_token: str, identity: Optional[str] = None
) -> None:
dt = self._find_decision_task(task_token)
evt = self._add_event(
"DecisionTaskStarted",
@ -319,8 +337,11 @@ class WorkflowExecution(BaseModel):
self._previous_started_event_id = evt.event_id
def complete_decision_task(
self, task_token, decisions=None, execution_context=None
):
self,
task_token: str,
decisions: Optional[List[Dict[str, Any]]] = None,
execution_context: Optional[str] = None,
) -> None:
# 'decisions' can be None per boto.swf defaults, so replace it with something iterable
if not decisions:
decisions = []
@ -341,7 +362,9 @@ class WorkflowExecution(BaseModel):
self.schedule_decision_task()
self.latest_execution_context = execution_context
def _check_decision_attributes(self, kind, value, decision_id):
def _check_decision_attributes(
self, kind: str, value: Dict[str, Any], decision_id: int
) -> List[Dict[str, str]]:
problems = []
constraints = DECISIONS_FIELDS.get(kind, {})
for key, constraint in constraints.items():
@ -354,7 +377,7 @@ class WorkflowExecution(BaseModel):
)
return problems
def validate_decisions(self, decisions):
def validate_decisions(self, decisions: List[Dict[str, Any]]) -> None:
"""
Performs some basic validations on decisions. The real SWF service
seems to break early and *not* process any decision if there's a
@ -404,7 +427,7 @@ class WorkflowExecution(BaseModel):
if any(problems):
raise SWFDecisionValidationException(problems)
def handle_decisions(self, event_id, decisions):
def handle_decisions(self, event_id: int, decisions: List[Dict[str, Any]]) -> None:
"""
Handles a Decision according to SWF docs.
See: http://docs.aws.amazon.com/amazonswf/latest/apireference/API_Decision.html
@ -440,7 +463,7 @@ class WorkflowExecution(BaseModel):
# finally decrement counter if and only if everything went well
self.open_counts["openDecisionTasks"] -= 1
def complete(self, event_id, result=None):
def complete(self, event_id: int, result: Any = None) -> None:
self.execution_status = "CLOSED"
self.close_status = "COMPLETED"
self.close_timestamp = unix_time()
@ -450,7 +473,9 @@ class WorkflowExecution(BaseModel):
result=result,
)
def fail(self, event_id, details=None, reason=None):
def fail(
self, event_id: int, details: Any = None, reason: Optional[str] = None
) -> None:
# TODO: implement length constraints on details/reason
self.execution_status = "CLOSED"
self.close_status = "FAILED"
@ -462,7 +487,7 @@ class WorkflowExecution(BaseModel):
reason=reason,
)
def cancel(self, event_id, details=None):
def cancel(self, event_id: int, details: Any = None) -> None:
# TODO: implement length constraints on details
self.cancel_requested = True
# Can only cancel if there are no other pending desicion tasks
@ -483,9 +508,9 @@ class WorkflowExecution(BaseModel):
details=details,
)
def schedule_activity_task(self, event_id, attributes):
def schedule_activity_task(self, event_id: int, attributes: Dict[str, Any]) -> None:
# Helper function to avoid repeating ourselves in the next sections
def fail_schedule_activity_task(_type, _cause):
def fail_schedule_activity_task(_type: "ActivityType", _cause: str) -> None:
# TODO: implement other possible failure mode: OPEN_ACTIVITIES_LIMIT_EXCEEDED
# NB: some failure modes are not implemented and probably won't be implemented in
# the future, such as ACTIVITY_CREATION_RATE_EXCEEDED or
@ -499,7 +524,7 @@ class WorkflowExecution(BaseModel):
)
self.should_schedule_decision_next = True
activity_type = self.domain.get_type(
activity_type: ActivityType = self.domain.get_type( # type: ignore[assignment]
"activity",
attributes["activityType"]["name"],
attributes["activityType"]["version"],
@ -576,13 +601,13 @@ class WorkflowExecution(BaseModel):
self.open_counts["openActivityTasks"] += 1
self.latest_activity_task_timestamp = unix_time()
def _find_activity_task(self, task_token):
def _find_activity_task(self, task_token: str) -> ActivityTask:
for task in self.activity_tasks:
if task.task_token == task_token:
return task
raise ValueError(f"No activity task with token: {task_token}")
def start_activity_task(self, task_token, identity=None):
def start_activity_task(self, task_token: str, identity: Any = None) -> None:
task = self._find_activity_task(task_token)
evt = self._add_event(
"ActivityTaskStarted",
@ -591,7 +616,7 @@ class WorkflowExecution(BaseModel):
)
task.start(evt.event_id)
def complete_activity_task(self, task_token, result=None):
def complete_activity_task(self, task_token: str, result: Any = None) -> None:
task = self._find_activity_task(task_token)
self._add_event(
"ActivityTaskCompleted",
@ -604,7 +629,9 @@ class WorkflowExecution(BaseModel):
# TODO: ensure we don't schedule multiple decisions at the same time!
self.schedule_decision_task()
def fail_activity_task(self, task_token, reason=None, details=None):
def fail_activity_task(
self, task_token: str, reason: Optional[str] = None, details: Any = None
) -> None:
task = self._find_activity_task(task_token)
self._add_event(
"ActivityTaskFailed",
@ -618,7 +645,12 @@ class WorkflowExecution(BaseModel):
# TODO: ensure we don't schedule multiple decisions at the same time!
self.schedule_decision_task()
def terminate(self, child_policy=None, details=None, reason=None):
def terminate(
self,
child_policy: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
reason: Optional[str] = None,
) -> None:
# TODO: handle child policy for child workflows here
# TODO: handle cause="CHILD_POLICY_APPLIED"
# Until this, we set cause manually to "OPERATOR_INITIATED"
@ -636,13 +668,13 @@ class WorkflowExecution(BaseModel):
self.close_status = "TERMINATED"
self.close_cause = "OPERATOR_INITIATED"
def signal(self, signal_name, workflow_input):
def signal(self, signal_name: str, workflow_input: Dict[str, Any]) -> None:
self._add_event(
"WorkflowExecutionSignaled", signal_name=signal_name, input=workflow_input
)
self.schedule_decision_task()
def first_timeout(self):
def first_timeout(self) -> Optional[Timeout]:
if not self.open or not self.start_timestamp:
return None
start_to_close_at = self.start_timestamp + int(
@ -651,8 +683,9 @@ class WorkflowExecution(BaseModel):
_timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE")
if _timeout.reached:
return _timeout
return None
def timeout(self, timeout):
def timeout(self, timeout: Timeout) -> None:
# TODO: process child policy on child workflows here or in the
# triggering function
self.execution_status = "CLOSED"
@ -665,7 +698,7 @@ class WorkflowExecution(BaseModel):
timeout_type=self.timeout_type,
)
def timeout_decision_task(self, _timeout):
def timeout_decision_task(self, _timeout: Timeout) -> None:
task = _timeout.obj
task.timeout(_timeout)
self._add_event(
@ -676,7 +709,7 @@ class WorkflowExecution(BaseModel):
timeout_type=task.timeout_type,
)
def timeout_activity_task(self, _timeout):
def timeout_activity_task(self, _timeout: Timeout) -> None:
task = _timeout.obj
task.timeout(_timeout)
self._add_event(
@ -688,7 +721,7 @@ class WorkflowExecution(BaseModel):
timeout_type=task.timeout_type,
)
def record_marker(self, event_id, attributes):
def record_marker(self, event_id: int, attributes: Dict[str, Any]) -> None:
self._add_event(
"MarkerRecorded",
decision_task_completed_event_id=event_id,
@ -696,7 +729,7 @@ class WorkflowExecution(BaseModel):
marker_name=attributes["markerName"],
)
def start_timer(self, event_id, attributes):
def start_timer(self, event_id: int, attributes: Dict[str, Any]) -> None:
timer_id = attributes["timerId"]
existing_timer = self._timers.get(timer_id)
if existing_timer and existing_timer.is_alive():
@ -725,14 +758,14 @@ class WorkflowExecution(BaseModel):
self._timers[timer_id] = workflow_timer
workflow_timer.start()
def _fire_timer(self, started_event_id, timer_id):
def _fire_timer(self, started_event_id: int, timer_id: str) -> None:
self._add_event(
"TimerFired", started_event_id=started_event_id, timer_id=timer_id
)
self._timers.pop(timer_id)
self._schedule_decision_task()
def cancel_timer(self, event_id, timer_id):
def cancel_timer(self, event_id: int, timer_id: str) -> None:
requested_timer = self._timers.get(timer_id)
if not requested_timer or not requested_timer.is_alive():
# TODO there are 2 failure states
@ -754,5 +787,5 @@ class WorkflowExecution(BaseModel):
)
@property
def open(self):
def open(self) -> bool:
return self.execution_status == "OPEN"

View File

@ -1,9 +1,10 @@
from typing import List
from .generic_type import GenericType
class WorkflowType(GenericType):
@property
def _configuration_keys(self):
def _configuration_keys(self) -> List[str]:
return [
"defaultChildPolicy",
"defaultExecutionStartToCloseTimeout",
@ -13,5 +14,5 @@ class WorkflowType(GenericType):
]
@property
def kind(self):
def kind(self) -> str:
return "workflow"

View File

@ -1,74 +1,75 @@
import json
from typing import Any, List
from moto.core.responses import BaseResponse
from .exceptions import SWFSerializationException, SWFValidationException
from .models import swf_backends
from .models import swf_backends, SWFBackend, GenericType
class SWFResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="swf")
@property
def swf_backend(self):
def swf_backend(self) -> SWFBackend:
return swf_backends[self.current_account][self.region]
# SWF parameters are passed through a JSON body, so let's ease retrieval
@property
def _params(self):
def _params(self) -> Any: # type: ignore[misc]
return json.loads(self.body)
def _check_int(self, parameter):
def _check_int(self, parameter: Any) -> None:
if not isinstance(parameter, int):
raise SWFSerializationException(parameter)
def _check_float_or_int(self, parameter):
def _check_float_or_int(self, parameter: Any) -> None:
if not isinstance(parameter, float):
if not isinstance(parameter, int):
raise SWFSerializationException(parameter)
def _check_none_or_string(self, parameter):
def _check_none_or_string(self, parameter: Any) -> None:
if parameter is not None:
self._check_string(parameter)
def _check_string(self, parameter):
def _check_string(self, parameter: Any) -> None:
if not isinstance(parameter, str):
raise SWFSerializationException(parameter)
def _check_none_or_list_of_strings(self, parameter):
def _check_none_or_list_of_strings(self, parameter: Any) -> None:
if parameter is not None:
self._check_list_of_strings(parameter)
def _check_list_of_strings(self, parameter):
def _check_list_of_strings(self, parameter: Any) -> None:
if not isinstance(parameter, list):
raise SWFSerializationException(parameter)
for i in parameter:
if not isinstance(i, str):
raise SWFSerializationException(parameter)
def _check_exclusivity(self, **kwargs):
def _check_exclusivity(self, **kwargs: Any) -> None:
if list(kwargs.values()).count(None) >= len(kwargs) - 1:
return
keys = kwargs.keys()
if len(keys) == 2:
message = f"Cannot specify both a {keys[0]} and a {keys[1]}"
message = f"Cannot specify both a {keys[0]} and a {keys[1]}" # type: ignore
else:
message = f"Cannot specify more than one exclusive filters in the same query: {keys}"
raise SWFValidationException(message)
def _list_types(self, kind):
def _list_types(self, kind: str) -> str:
domain_name = self._params["domain"]
status = self._params["registrationStatus"]
reverse_order = self._params.get("reverseOrder", None)
self._check_string(domain_name)
self._check_string(status)
types = self.swf_backend.list_types(
types: List[GenericType] = self.swf_backend.list_types(
kind, domain_name, status, reverse_order=reverse_order
)
return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]})
def _describe_type(self, kind):
def _describe_type(self, kind: str) -> str:
domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"]
name = _type_args["name"]
@ -80,7 +81,7 @@ class SWFResponse(BaseResponse):
return json.dumps(_type.to_full_dict())
def _deprecate_type(self, kind):
def _deprecate_type(self, kind: str) -> str:
domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"]
name = _type_args["name"]
@ -91,7 +92,7 @@ class SWFResponse(BaseResponse):
self.swf_backend.deprecate_type(kind, domain, name, version)
return ""
def _undeprecate_type(self, kind):
def _undeprecate_type(self, kind: str) -> str:
domain = self._params["domain"]
_type_args = self._params[f"{kind}Type"]
name = _type_args["name"]
@ -103,7 +104,7 @@ class SWFResponse(BaseResponse):
return ""
# TODO: implement pagination
def list_domains(self):
def list_domains(self) -> str:
status = self._params["registrationStatus"]
self._check_string(status)
reverse_order = self._params.get("reverseOrder", None)
@ -112,7 +113,7 @@ class SWFResponse(BaseResponse):
{"domainInfos": [domain.to_short_dict() for domain in domains]}
)
def list_closed_workflow_executions(self):
def list_closed_workflow_executions(self) -> str:
domain = self._params["domain"]
start_time_filter = self._params.get("startTimeFilter", None)
close_time_filter = self._params.get("closeTimeFilter", None)
@ -166,7 +167,7 @@ class SWFResponse(BaseResponse):
{"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]}
)
def list_open_workflow_executions(self):
def list_open_workflow_executions(self) -> str:
domain = self._params["domain"]
start_time_filter = self._params["startTimeFilter"]
execution_filter = self._params.get("executionFilter", None)
@ -204,7 +205,7 @@ class SWFResponse(BaseResponse):
{"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]}
)
def register_domain(self):
def register_domain(self) -> str:
name = self._params["name"]
retention = self._params["workflowExecutionRetentionPeriodInDays"]
description = self._params.get("description")
@ -214,29 +215,29 @@ class SWFResponse(BaseResponse):
self.swf_backend.register_domain(name, retention, description=description)
return ""
def deprecate_domain(self):
def deprecate_domain(self) -> str:
name = self._params["name"]
self._check_string(name)
self.swf_backend.deprecate_domain(name)
return ""
def undeprecate_domain(self):
def undeprecate_domain(self) -> str:
name = self._params["name"]
self._check_string(name)
self.swf_backend.undeprecate_domain(name)
return ""
def describe_domain(self):
def describe_domain(self) -> str:
name = self._params["name"]
self._check_string(name)
domain = self.swf_backend.describe_domain(name)
return json.dumps(domain.to_full_dict())
return json.dumps(domain.to_full_dict()) # type: ignore[union-attr]
# TODO: implement pagination
def list_activity_types(self):
def list_activity_types(self) -> str:
return self._list_types("activity")
def register_activity_type(self):
def register_activity_type(self) -> str:
domain = self._params["domain"]
name = self._params["name"]
version = self._params["version"]
@ -282,19 +283,19 @@ class SWFResponse(BaseResponse):
)
return ""
def deprecate_activity_type(self):
def deprecate_activity_type(self) -> str:
return self._deprecate_type("activity")
def undeprecate_activity_type(self):
def undeprecate_activity_type(self) -> str:
return self._undeprecate_type("activity")
def describe_activity_type(self):
def describe_activity_type(self) -> str:
return self._describe_type("activity")
def list_workflow_types(self):
def list_workflow_types(self) -> str:
return self._list_types("workflow")
def register_workflow_type(self):
def register_workflow_type(self) -> str:
domain = self._params["domain"]
name = self._params["name"]
version = self._params["version"]
@ -340,16 +341,16 @@ class SWFResponse(BaseResponse):
)
return ""
def deprecate_workflow_type(self):
def deprecate_workflow_type(self) -> str:
return self._deprecate_type("workflow")
def undeprecate_workflow_type(self):
def undeprecate_workflow_type(self) -> str:
return self._undeprecate_type("workflow")
def describe_workflow_type(self):
def describe_workflow_type(self) -> str:
return self._describe_type("workflow")
def start_workflow_execution(self):
def start_workflow_execution(self) -> str:
domain = self._params["domain"]
workflow_id = self._params["workflowId"]
_workflow_type = self._params["workflowType"]
@ -394,7 +395,7 @@ class SWFResponse(BaseResponse):
return json.dumps({"runId": wfe.run_id})
def describe_workflow_execution(self):
def describe_workflow_execution(self) -> str:
domain_name = self._params["domain"]
_workflow_execution = self._params["execution"]
run_id = _workflow_execution["runId"]
@ -407,9 +408,9 @@ class SWFResponse(BaseResponse):
wfe = self.swf_backend.describe_workflow_execution(
domain_name, run_id, workflow_id
)
return json.dumps(wfe.to_full_dict())
return json.dumps(wfe.to_full_dict()) # type: ignore[union-attr]
def get_workflow_execution_history(self):
def get_workflow_execution_history(self) -> str:
domain_name = self._params["domain"]
_workflow_execution = self._params["execution"]
run_id = _workflow_execution["runId"]
@ -418,10 +419,10 @@ class SWFResponse(BaseResponse):
wfe = self.swf_backend.describe_workflow_execution(
domain_name, run_id, workflow_id
)
events = wfe.events(reverse_order=reverse_order)
events = wfe.events(reverse_order=reverse_order) # type: ignore[union-attr]
return json.dumps({"events": [evt.to_dict() for evt in events]})
def poll_for_decision_task(self):
def poll_for_decision_task(self) -> str:
domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"]
identity = self._params.get("identity")
@ -438,7 +439,7 @@ class SWFResponse(BaseResponse):
else:
return json.dumps({"previousStartedEventId": 0, "startedEventId": 0})
def count_pending_decision_tasks(self):
def count_pending_decision_tasks(self) -> str:
domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"]
self._check_string(domain_name)
@ -446,7 +447,7 @@ class SWFResponse(BaseResponse):
count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False})
def respond_decision_task_completed(self):
def respond_decision_task_completed(self) -> str:
task_token = self._params["taskToken"]
execution_context = self._params.get("executionContext")
decisions = self._params.get("decisions")
@ -457,7 +458,7 @@ class SWFResponse(BaseResponse):
)
return ""
def poll_for_activity_task(self):
def poll_for_activity_task(self) -> str:
domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"]
identity = self._params.get("identity")
@ -472,7 +473,7 @@ class SWFResponse(BaseResponse):
else:
return json.dumps({"startedEventId": 0})
def count_pending_activity_tasks(self):
def count_pending_activity_tasks(self) -> str:
domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"]
self._check_string(domain_name)
@ -480,7 +481,7 @@ class SWFResponse(BaseResponse):
count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False})
def respond_activity_task_completed(self):
def respond_activity_task_completed(self) -> str:
task_token = self._params["taskToken"]
result = self._params.get("result")
self._check_string(task_token)
@ -488,7 +489,7 @@ class SWFResponse(BaseResponse):
self.swf_backend.respond_activity_task_completed(task_token, result=result)
return ""
def respond_activity_task_failed(self):
def respond_activity_task_failed(self) -> str:
task_token = self._params["taskToken"]
reason = self._params.get("reason")
details = self._params.get("details")
@ -502,7 +503,7 @@ class SWFResponse(BaseResponse):
)
return ""
def terminate_workflow_execution(self):
def terminate_workflow_execution(self) -> str:
domain_name = self._params["domain"]
workflow_id = self._params["workflowId"]
child_policy = self._params.get("childPolicy")
@ -525,7 +526,7 @@ class SWFResponse(BaseResponse):
)
return ""
def record_activity_task_heartbeat(self):
def record_activity_task_heartbeat(self) -> str:
task_token = self._params["taskToken"]
details = self._params.get("details")
self._check_string(task_token)
@ -534,7 +535,7 @@ class SWFResponse(BaseResponse):
# TODO: make it dynamic when we implement activity tasks cancellation
return json.dumps({"cancelRequested": False})
def signal_workflow_execution(self):
def signal_workflow_execution(self) -> str:
domain_name = self._params["domain"]
signal_name = self._params["signalName"]
workflow_id = self._params["workflowId"]

View File

@ -1,2 +1,2 @@
def decapitalize(key):
def decapitalize(key: str) -> str:
return key[0].lower() + key[1:]

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract