Move SWF type checks to response object

(suggested in @spulec review)
This commit is contained in:
Jean-Baptiste Barth 2015-11-23 12:41:31 +01:00
parent e3fff8759b
commit 45437368b2
2 changed files with 97 additions and 74 deletions

View File

@ -9,7 +9,6 @@ from ..exceptions import (
SWFUnknownResourceFault, SWFUnknownResourceFault,
SWFDomainAlreadyExistsFault, SWFDomainAlreadyExistsFault,
SWFDomainDeprecatedFault, SWFDomainDeprecatedFault,
SWFSerializationException,
SWFTypeAlreadyExistsFault, SWFTypeAlreadyExistsFault,
SWFTypeDeprecatedFault, SWFTypeDeprecatedFault,
SWFValidationException, SWFValidationException,
@ -50,32 +49,12 @@ class SWFBackend(BaseBackend):
return matching[0] return matching[0]
return None return None
def _check_none_or_string(self, parameter):
if parameter is not None:
self._check_string(parameter)
def _check_string(self, parameter):
if not isinstance(parameter, six.string_types):
raise SWFSerializationException(parameter)
def _check_none_or_list_of_strings(self, parameter):
if parameter is not None:
self._check_list_of_strings(parameter)
def _check_list_of_strings(self, parameter):
if not isinstance(parameter, list):
raise SWFSerializationException(parameter)
for i in parameter:
if not isinstance(i, six.string_types):
raise SWFSerializationException(parameter)
def _process_timeouts(self): def _process_timeouts(self):
for domain in self.domains: for domain in self.domains:
for wfe in domain.workflow_executions: for wfe in domain.workflow_executions:
wfe._process_timeouts() wfe._process_timeouts()
def list_domains(self, status, reverse_order=None): def list_domains(self, status, reverse_order=None):
self._check_string(status)
domains = [domain for domain in self.domains domains = [domain for domain in self.domains
if domain.status == status] if domain.status == status]
domains = sorted(domains, key=lambda domain: domain.name) domains = sorted(domains, key=lambda domain: domain.name)
@ -85,9 +64,6 @@ class SWFBackend(BaseBackend):
def register_domain(self, name, workflow_execution_retention_period_in_days, def register_domain(self, name, workflow_execution_retention_period_in_days,
description=None): description=None):
self._check_string(name)
self._check_string(workflow_execution_retention_period_in_days)
self._check_none_or_string(description)
if self._get_domain(name, ignore_empty=True): if self._get_domain(name, ignore_empty=True):
raise SWFDomainAlreadyExistsFault(name) raise SWFDomainAlreadyExistsFault(name)
domain = Domain(name, workflow_execution_retention_period_in_days, domain = Domain(name, workflow_execution_retention_period_in_days,
@ -95,19 +71,15 @@ class SWFBackend(BaseBackend):
self.domains.append(domain) self.domains.append(domain)
def deprecate_domain(self, name): def deprecate_domain(self, name):
self._check_string(name)
domain = self._get_domain(name) domain = self._get_domain(name)
if domain.status == "DEPRECATED": if domain.status == "DEPRECATED":
raise SWFDomainDeprecatedFault(name) raise SWFDomainDeprecatedFault(name)
domain.status = "DEPRECATED" domain.status = "DEPRECATED"
def describe_domain(self, name): def describe_domain(self, name):
self._check_string(name)
return self._get_domain(name) return self._get_domain(name)
def list_types(self, kind, domain_name, status, reverse_order=None): def list_types(self, kind, domain_name, status, reverse_order=None):
self._check_string(domain_name)
self._check_string(status)
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_types = domain.find_types(kind, status) _types = domain.find_types(kind, status)
_types = sorted(_types, key=lambda domain: domain.name) _types = sorted(_types, key=lambda domain: domain.name)
@ -116,11 +88,6 @@ class SWFBackend(BaseBackend):
return _types return _types
def register_type(self, kind, domain_name, name, version, **kwargs): def register_type(self, kind, domain_name, name, version, **kwargs):
self._check_string(domain_name)
self._check_string(name)
self._check_string(version)
for value in kwargs.values():
self._check_none_or_string(value)
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version, ignore_empty=True) _type = domain.get_type(kind, name, version, ignore_empty=True)
if _type: if _type:
@ -130,9 +97,6 @@ class SWFBackend(BaseBackend):
domain.add_type(_type) domain.add_type(_type)
def deprecate_type(self, kind, domain_name, name, version): def deprecate_type(self, kind, domain_name, name, version):
self._check_string(domain_name)
self._check_string(name)
self._check_string(version)
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
_type = domain.get_type(kind, name, version) _type = domain.get_type(kind, name, version)
if _type.status == "DEPRECATED": if _type.status == "DEPRECATED":
@ -140,23 +104,12 @@ class SWFBackend(BaseBackend):
_type.status = "DEPRECATED" _type.status = "DEPRECATED"
def describe_type(self, kind, domain_name, name, version): def describe_type(self, kind, domain_name, name, version):
self._check_string(domain_name)
self._check_string(name)
self._check_string(version)
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
return domain.get_type(kind, name, version) return domain.get_type(kind, name, version)
def start_workflow_execution(self, domain_name, workflow_id, def start_workflow_execution(self, domain_name, workflow_id,
workflow_name, workflow_version, workflow_name, workflow_version,
tag_list=None, **kwargs): tag_list=None, **kwargs):
self._check_string(domain_name)
self._check_string(workflow_id)
self._check_string(workflow_name)
self._check_string(workflow_version)
self._check_none_or_list_of_strings(tag_list)
for value in kwargs.values():
self._check_none_or_string(value)
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
wf_type = domain.get_type("workflow", workflow_name, workflow_version) wf_type = domain.get_type("workflow", workflow_name, workflow_version)
if wf_type.status == "DEPRECATED": if wf_type.status == "DEPRECATED":
@ -169,17 +122,12 @@ class SWFBackend(BaseBackend):
return wfe return wfe
def describe_workflow_execution(self, domain_name, run_id, workflow_id): def describe_workflow_execution(self, domain_name, run_id, workflow_id):
self._check_string(domain_name)
self._check_string(run_id)
self._check_string(workflow_id)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
return domain.get_workflow_execution(workflow_id, run_id=run_id) return domain.get_workflow_execution(workflow_id, run_id=run_id)
def poll_for_decision_task(self, domain_name, task_list, identity=None): def poll_for_decision_task(self, domain_name, task_list, identity=None):
self._check_string(domain_name)
self._check_string(task_list)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -211,8 +159,6 @@ class SWFBackend(BaseBackend):
return None return None
def count_pending_decision_tasks(self, domain_name, task_list): def count_pending_decision_tasks(self, domain_name, task_list):
self._check_string(domain_name)
self._check_string(task_list)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -225,8 +171,6 @@ class SWFBackend(BaseBackend):
def respond_decision_task_completed(self, task_token, def respond_decision_task_completed(self, task_token,
decisions=None, decisions=None,
execution_context=None): execution_context=None):
self._check_string(task_token)
self._check_none_or_string(execution_context)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
# let's find decision task # let's find decision task
@ -278,8 +222,6 @@ class SWFBackend(BaseBackend):
execution_context=execution_context) execution_context=execution_context)
def poll_for_activity_task(self, domain_name, task_list, identity=None): def poll_for_activity_task(self, domain_name, task_list, identity=None):
self._check_string(domain_name)
self._check_string(task_list)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -311,8 +253,6 @@ class SWFBackend(BaseBackend):
return None return None
def count_pending_activity_tasks(self, domain_name, task_list): def count_pending_activity_tasks(self, domain_name, task_list):
self._check_string(domain_name)
self._check_string(task_list)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -362,8 +302,6 @@ class SWFBackend(BaseBackend):
return activity_task return activity_task
def respond_activity_task_completed(self, task_token, result=None): def respond_activity_task_completed(self, task_token, result=None):
self._check_string(task_token)
self._check_none_or_string(result)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)
@ -371,10 +309,6 @@ class SWFBackend(BaseBackend):
wfe.complete_activity_task(activity_task.task_token, result=result) wfe.complete_activity_task(activity_task.task_token, result=result)
def respond_activity_task_failed(self, task_token, reason=None, details=None): def respond_activity_task_failed(self, task_token, reason=None, details=None):
self._check_string(task_token)
# TODO: implement length limits on reason and details (common pb with client libs)
self._check_none_or_string(reason)
self._check_none_or_string(details)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)
@ -383,12 +317,6 @@ class SWFBackend(BaseBackend):
def terminate_workflow_execution(self, domain_name, workflow_id, child_policy=None, def terminate_workflow_execution(self, domain_name, workflow_id, child_policy=None,
details=None, reason=None, run_id=None): details=None, reason=None, run_id=None):
self._check_string(domain_name)
self._check_string(workflow_id)
self._check_none_or_string(child_policy)
self._check_none_or_string(details)
self._check_none_or_string(reason)
self._check_none_or_string(run_id)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
@ -396,8 +324,6 @@ class SWFBackend(BaseBackend):
wfe.terminate(child_policy=child_policy, details=details, reason=reason) wfe.terminate(child_policy=child_policy, details=details, reason=reason)
def record_activity_task_heartbeat(self, task_token, details=None): def record_activity_task_heartbeat(self, task_token, details=None):
self._check_string(task_token)
self._check_none_or_string(details)
# process timeouts on all objects # process timeouts on all objects
self._process_timeouts() self._process_timeouts()
activity_task = self._find_activity_task_from_token(task_token) activity_task = self._find_activity_task_from_token(task_token)

View File

@ -5,6 +5,7 @@ from moto.core.responses import BaseResponse
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from moto.core.utils import camelcase_to_underscores, method_names_from_class from moto.core.utils import camelcase_to_underscores, method_names_from_class
from .exceptions import SWFSerializationException
from .models import swf_backends from .models import swf_backends
@ -19,10 +20,31 @@ class SWFResponse(BaseResponse):
def _params(self): def _params(self):
return json.loads(self.body.decode("utf-8")) return json.loads(self.body.decode("utf-8"))
def _check_none_or_string(self, parameter):
if parameter is not None:
self._check_string(parameter)
def _check_string(self, parameter):
if not isinstance(parameter, six.string_types):
raise SWFSerializationException(parameter)
def _check_none_or_list_of_strings(self, parameter):
if parameter is not None:
self._check_list_of_strings(parameter)
def _check_list_of_strings(self, parameter):
if not isinstance(parameter, list):
raise SWFSerializationException(parameter)
for i in parameter:
if not isinstance(i, six.string_types):
raise SWFSerializationException(parameter)
def _list_types(self, kind): def _list_types(self, kind):
domain_name = self._params["domain"] domain_name = self._params["domain"]
status = self._params["registrationStatus"] status = self._params["registrationStatus"]
reverse_order = self._params.get("reverseOrder", None) reverse_order = self._params.get("reverseOrder", None)
self._check_string(domain_name)
self._check_string(status)
types = self.swf_backend.list_types(kind, domain_name, status, reverse_order=reverse_order) types = self.swf_backend.list_types(kind, domain_name, status, reverse_order=reverse_order)
return json.dumps({ return json.dumps({
"typeInfos": [_type.to_medium_dict() for _type in types] "typeInfos": [_type.to_medium_dict() for _type in types]
@ -33,6 +55,9 @@ class SWFResponse(BaseResponse):
_type_args = self._params["{0}Type".format(kind)] _type_args = self._params["{0}Type".format(kind)]
name = _type_args["name"] name = _type_args["name"]
version = _type_args["version"] version = _type_args["version"]
self._check_string(domain)
self._check_string(name)
self._check_string(version)
_type = self.swf_backend.describe_type(kind, domain, name, version) _type = self.swf_backend.describe_type(kind, domain, name, version)
return json.dumps(_type.to_full_dict()) return json.dumps(_type.to_full_dict())
@ -42,12 +67,16 @@ class SWFResponse(BaseResponse):
_type_args = self._params["{0}Type".format(kind)] _type_args = self._params["{0}Type".format(kind)]
name = _type_args["name"] name = _type_args["name"]
version = _type_args["version"] version = _type_args["version"]
self._check_string(domain)
self._check_string(name)
self._check_string(version)
self.swf_backend.deprecate_type(kind, domain, name, version) self.swf_backend.deprecate_type(kind, domain, name, version)
return "" return ""
# TODO: implement pagination # TODO: implement pagination
def list_domains(self): def list_domains(self):
status = self._params["registrationStatus"] status = self._params["registrationStatus"]
self._check_string(status)
reverse_order = self._params.get("reverseOrder", None) reverse_order = self._params.get("reverseOrder", None)
domains = self.swf_backend.list_domains(status, reverse_order=reverse_order) domains = self.swf_backend.list_domains(status, reverse_order=reverse_order)
return json.dumps({ return json.dumps({
@ -58,17 +87,22 @@ class SWFResponse(BaseResponse):
name = self._params["name"] name = self._params["name"]
retention = self._params["workflowExecutionRetentionPeriodInDays"] retention = self._params["workflowExecutionRetentionPeriodInDays"]
description = self._params.get("description") description = self._params.get("description")
self._check_string(retention)
self._check_string(name)
self._check_none_or_string(description)
domain = self.swf_backend.register_domain(name, retention, domain = self.swf_backend.register_domain(name, retention,
description=description) description=description)
return "" return ""
def deprecate_domain(self): def deprecate_domain(self):
name = self._params["name"] name = self._params["name"]
self._check_string(name)
domain = self.swf_backend.deprecate_domain(name) domain = self.swf_backend.deprecate_domain(name)
return "" return ""
def describe_domain(self): def describe_domain(self):
name = self._params["name"] name = self._params["name"]
self._check_string(name)
domain = self.swf_backend.describe_domain(name) domain = self.swf_backend.describe_domain(name)
return json.dumps(domain.to_full_dict()) return json.dumps(domain.to_full_dict())
@ -90,6 +124,17 @@ class SWFResponse(BaseResponse):
default_task_schedule_to_start_timeout = self._params.get("defaultTaskScheduleToStartTimeout") default_task_schedule_to_start_timeout = self._params.get("defaultTaskScheduleToStartTimeout")
default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout") default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout")
description = self._params.get("description") description = self._params.get("description")
self._check_string(domain)
self._check_string(name)
self._check_string(version)
self._check_none_or_string(task_list)
self._check_none_or_string(default_task_heartbeat_timeout)
self._check_none_or_string(default_task_schedule_to_close_timeout)
self._check_none_or_string(default_task_schedule_to_start_timeout)
self._check_none_or_string(default_task_start_to_close_timeout)
self._check_none_or_string(description)
# TODO: add defaultTaskPriority when boto gets to support it # TODO: add defaultTaskPriority when boto gets to support it
activity_type = self.swf_backend.register_type( activity_type = self.swf_backend.register_type(
"activity", domain, name, version, task_list=task_list, "activity", domain, name, version, task_list=task_list,
@ -123,6 +168,16 @@ class SWFResponse(BaseResponse):
default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout") default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout")
default_execution_start_to_close_timeout = self._params.get("defaultExecutionStartToCloseTimeout") default_execution_start_to_close_timeout = self._params.get("defaultExecutionStartToCloseTimeout")
description = self._params.get("description") description = self._params.get("description")
self._check_string(domain)
self._check_string(name)
self._check_string(version)
self._check_none_or_string(task_list)
self._check_none_or_string(default_child_policy)
self._check_none_or_string(default_task_start_to_close_timeout)
self._check_none_or_string(default_execution_start_to_close_timeout)
self._check_none_or_string(description)
# TODO: add defaultTaskPriority when boto gets to support it # TODO: add defaultTaskPriority when boto gets to support it
# TODO: add defaultLambdaRole when boto gets to support it # TODO: add defaultLambdaRole when boto gets to support it
workflow_type = self.swf_backend.register_type( workflow_type = self.swf_backend.register_type(
@ -157,6 +212,17 @@ class SWFResponse(BaseResponse):
tag_list = self._params.get("tagList") tag_list = self._params.get("tagList")
task_start_to_close_timeout = self._params.get("taskStartToCloseTimeout") task_start_to_close_timeout = self._params.get("taskStartToCloseTimeout")
self._check_string(domain)
self._check_string(workflow_id)
self._check_string(workflow_name)
self._check_string(workflow_version)
self._check_none_or_string(task_list)
self._check_none_or_string(child_policy)
self._check_none_or_string(execution_start_to_close_timeout)
self._check_none_or_string(input_)
self._check_none_or_list_of_strings(tag_list)
self._check_none_or_string(task_start_to_close_timeout)
wfe = self.swf_backend.start_workflow_execution( wfe = self.swf_backend.start_workflow_execution(
domain, workflow_id, workflow_name, workflow_version, domain, workflow_id, workflow_name, workflow_version,
task_list=task_list, child_policy=child_policy, task_list=task_list, child_policy=child_policy,
@ -175,6 +241,10 @@ class SWFResponse(BaseResponse):
run_id = _workflow_execution["runId"] run_id = _workflow_execution["runId"]
workflow_id = _workflow_execution["workflowId"] workflow_id = _workflow_execution["workflowId"]
self._check_string(domain_name)
self._check_string(run_id)
self._check_string(workflow_id)
wfe = self.swf_backend.describe_workflow_execution(domain_name, run_id, workflow_id) wfe = self.swf_backend.describe_workflow_execution(domain_name, run_id, workflow_id)
return json.dumps(wfe.to_full_dict()) return json.dumps(wfe.to_full_dict())
@ -195,6 +265,10 @@ class SWFResponse(BaseResponse):
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
identity = self._params.get("identity") identity = self._params.get("identity")
reverse_order = self._params.get("reverseOrder", None) reverse_order = self._params.get("reverseOrder", None)
self._check_string(domain_name)
self._check_string(task_list)
decision = self.swf_backend.poll_for_decision_task( decision = self.swf_backend.poll_for_decision_task(
domain_name, task_list, identity=identity domain_name, task_list, identity=identity
) )
@ -208,6 +282,8 @@ class SWFResponse(BaseResponse):
def count_pending_decision_tasks(self): def count_pending_decision_tasks(self):
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
self._check_string(domain_name)
self._check_string(task_list)
count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False}) return json.dumps({"count": count, "truncated": False})
@ -215,6 +291,8 @@ class SWFResponse(BaseResponse):
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
execution_context = self._params.get("executionContext") execution_context = self._params.get("executionContext")
decisions = self._params.get("decisions") decisions = self._params.get("decisions")
self._check_string(task_token)
self._check_none_or_string(execution_context)
self.swf_backend.respond_decision_task_completed( self.swf_backend.respond_decision_task_completed(
task_token, decisions=decisions, execution_context=execution_context task_token, decisions=decisions, execution_context=execution_context
) )
@ -224,6 +302,9 @@ class SWFResponse(BaseResponse):
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
identity = self._params.get("identity") identity = self._params.get("identity")
self._check_string(domain_name)
self._check_string(task_list)
self._check_none_or_string(identity)
activity_task = self.swf_backend.poll_for_activity_task( activity_task = self.swf_backend.poll_for_activity_task(
domain_name, task_list, identity=identity domain_name, task_list, identity=identity
) )
@ -237,12 +318,16 @@ class SWFResponse(BaseResponse):
def count_pending_activity_tasks(self): def count_pending_activity_tasks(self):
domain_name = self._params["domain"] domain_name = self._params["domain"]
task_list = self._params["taskList"]["name"] task_list = self._params["taskList"]["name"]
self._check_string(domain_name)
self._check_string(task_list)
count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list)
return json.dumps({"count": count, "truncated": False}) return json.dumps({"count": count, "truncated": False})
def respond_activity_task_completed(self): def respond_activity_task_completed(self):
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
result = self._params.get("result") result = self._params.get("result")
self._check_string(task_token)
self._check_none_or_string(result)
self.swf_backend.respond_activity_task_completed( self.swf_backend.respond_activity_task_completed(
task_token, result=result task_token, result=result
) )
@ -252,6 +337,10 @@ class SWFResponse(BaseResponse):
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
reason = self._params.get("reason") reason = self._params.get("reason")
details = self._params.get("details") details = self._params.get("details")
self._check_string(task_token)
# TODO: implement length limits on reason and details (common pb with client libs)
self._check_none_or_string(reason)
self._check_none_or_string(details)
self.swf_backend.respond_activity_task_failed( self.swf_backend.respond_activity_task_failed(
task_token, reason=reason, details=details task_token, reason=reason, details=details
) )
@ -264,6 +353,12 @@ class SWFResponse(BaseResponse):
details = self._params.get("details") details = self._params.get("details")
reason = self._params.get("reason") reason = self._params.get("reason")
run_id = self._params.get("runId") run_id = self._params.get("runId")
self._check_string(domain_name)
self._check_string(workflow_id)
self._check_none_or_string(child_policy)
self._check_none_or_string(details)
self._check_none_or_string(reason)
self._check_none_or_string(run_id)
self.swf_backend.terminate_workflow_execution( self.swf_backend.terminate_workflow_execution(
domain_name, workflow_id, child_policy=child_policy, domain_name, workflow_id, child_policy=child_policy,
details=details, reason=reason, run_id=run_id details=details, reason=reason, run_id=run_id
@ -273,6 +368,8 @@ class SWFResponse(BaseResponse):
def record_activity_task_heartbeat(self): def record_activity_task_heartbeat(self):
task_token = self._params["taskToken"] task_token = self._params["taskToken"]
details = self._params.get("details") details = self._params.get("details")
self._check_string(task_token)
self._check_none_or_string(details)
self.swf_backend.record_activity_task_heartbeat( self.swf_backend.record_activity_task_heartbeat(
task_token, details=details task_token, details=details
) )