diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index 4abb6a8af..b5fd2ddb9 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -46,3 +46,11 @@ class InvalidExecutionInput(AWSError): class StateMachineDoesNotExist(AWSError): TYPE = "StateMachineDoesNotExist" STATUS = 400 + + +class InvalidToken(AWSError): + TYPE = "InvalidToken" + STATUS = 400 + + def __init__(self, message="Invalid token"): + super(InvalidToken, self).__init__("Invalid Token: {}".format(message)) diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 03cbcf320..3731539f8 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -5,7 +5,7 @@ from datetime import datetime from boto3 import Session from moto.core import ACCOUNT_ID, BaseBackend -from moto.core.utils import iso_8601_datetime_without_milliseconds +from moto.core.utils import iso_8601_datetime_with_milliseconds from uuid import uuid4 from .exceptions import ( ExecutionAlreadyExists, @@ -15,11 +15,12 @@ from .exceptions import ( InvalidName, StateMachineDoesNotExist, ) +from .utils import paginate class StateMachine: def __init__(self, arn, name, definition, roleArn, tags=None): - self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.arn = arn self.name = name self.definition = definition @@ -43,7 +44,7 @@ class Execution: ) self.execution_arn = execution_arn self.name = execution_name - self.start_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.start_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.state_machine_arn = state_machine_arn self.execution_input = execution_input self.status = "RUNNING" @@ -51,7 +52,7 @@ class Execution: def stop(self): self.status = "ABORTED" - self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now()) class StepFunctionBackend(BaseBackend): @@ -189,8 +190,10 @@ class StepFunctionBackend(BaseBackend): self.state_machines.append(state_machine) return state_machine + @paginate def list_state_machines(self): - return self.state_machines + state_machines = sorted(self.state_machines, key=lambda x: x.creation_date) + return state_machines def describe_state_machine(self, arn): self._validate_machine_arn(arn) @@ -233,13 +236,20 @@ class StepFunctionBackend(BaseBackend): execution.stop() return execution - def list_executions(self, state_machine_arn): - return [ + @paginate + def list_executions(self, state_machine_arn, status_filter=None): + executions = [ execution for execution in self.executions if execution.state_machine_arn == state_machine_arn ] + if status_filter: + executions = list(filter(lambda e: e.status == status_filter, executions)) + + executions = sorted(executions, key=lambda x: x.start_date, reverse=True) + return executions + def describe_execution(self, arn): self._validate_execution_arn(arn) exctn = next((x for x in self.executions if x.execution_arn == arn), None) diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index d9e438892..7106d81d0 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -33,19 +33,22 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def list_state_machines(self): - list_all = self.stepfunction_backend.list_state_machines() - list_all = sorted( - [ - { - "creationDate": sm.creation_date, - "name": sm.name, - "stateMachineArn": sm.arn, - } - for sm in list_all - ], - key=lambda x: x["name"], + max_results = self._get_int_param("maxResults") + next_token = self._get_param("nextToken") + results, next_token = self.stepfunction_backend.list_state_machines( + max_results=max_results, next_token=next_token ) - response = {"stateMachines": list_all} + state_machines = [ + { + "creationDate": sm.creation_date, + "name": sm.name, + "stateMachineArn": sm.arn, + } + for sm in results + ] + response = {"stateMachines": state_machines} + if next_token: + response["nextToken"] = next_token return 200, {}, json.dumps(response) @amzn_request_id @@ -110,9 +113,20 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def list_executions(self): + max_results = self._get_int_param("maxResults") + next_token = self._get_param("nextToken") arn = self._get_param("stateMachineArn") - state_machine = self.stepfunction_backend.describe_state_machine(arn) - executions = self.stepfunction_backend.list_executions(arn) + status_filter = self._get_param("statusFilter") + try: + state_machine = self.stepfunction_backend.describe_state_machine(arn) + results, next_token = self.stepfunction_backend.list_executions( + arn, + status_filter=status_filter, + max_results=max_results, + next_token=next_token, + ) + except AWSError as err: + return err.response() executions = [ { "executionArn": execution.execution_arn, @@ -121,9 +135,12 @@ class StepFunctionResponse(BaseResponse): "stateMachineArn": state_machine.arn, "status": execution.status, } - for execution in executions + for execution in results ] - return 200, {}, json.dumps({"executions": executions}) + response = {"executions": executions} + if next_token: + response["nextToken"] = next_token + return 200, {}, json.dumps(response) @amzn_request_id def describe_execution(self): diff --git a/moto/stepfunctions/utils.py b/moto/stepfunctions/utils.py new file mode 100644 index 000000000..cf6b58c8a --- /dev/null +++ b/moto/stepfunctions/utils.py @@ -0,0 +1,138 @@ +from functools import wraps + +from botocore.paginate import TokenDecoder, TokenEncoder +from six.moves import reduce + +from .exceptions import InvalidToken + +PAGINATION_MODEL = { + "list_executions": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "page_ending_range_keys": ["start_date", "execution_arn"], + }, + "list_state_machines": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "page_ending_range_keys": ["creation_date", "arn"], + }, +} + + +def paginate(original_function=None, pagination_model=None): + def pagination_decorator(func): + @wraps(func) + def pagination_wrapper(*args, **kwargs): + method = func.__name__ + model = pagination_model or PAGINATION_MODEL + pagination_config = model.get(method) + if not pagination_config: + raise ValueError( + "No pagination config for backend method: {}".format(method) + ) + # We pop the pagination arguments, so the remaining kwargs (if any) + # can be used to compute the optional parameters checksum. + input_token = kwargs.pop(pagination_config.get("input_token"), None) + limit = kwargs.pop(pagination_config.get("limit_key"), None) + paginator = Paginator( + max_results=limit, + max_results_default=pagination_config.get("limit_default"), + starting_token=input_token, + page_ending_range_keys=pagination_config.get("page_ending_range_keys"), + param_values_to_check=kwargs, + ) + results = func(*args, **kwargs) + return paginator.paginate(results) + + return pagination_wrapper + + if original_function: + return pagination_decorator(original_function) + + return pagination_decorator + + +class Paginator(object): + def __init__( + self, + max_results=None, + max_results_default=None, + starting_token=None, + page_ending_range_keys=None, + param_values_to_check=None, + ): + self._max_results = max_results if max_results else max_results_default + self._starting_token = starting_token + self._page_ending_range_keys = page_ending_range_keys + self._param_values_to_check = param_values_to_check + self._token_encoder = TokenEncoder() + self._token_decoder = TokenDecoder() + self._param_checksum = self._calculate_parameter_checksum() + self._parsed_token = self._parse_starting_token() + + def _parse_starting_token(self): + if self._starting_token is None: + return None + # The starting token is a dict passed as a base64 encoded string. + next_token = self._starting_token + try: + next_token = self._token_decoder.decode(next_token) + except (ValueError, TypeError): + raise InvalidToken("Invalid token") + if next_token.get("parameterChecksum") != self._param_checksum: + raise InvalidToken( + "Input inconsistent with page token: {}".format(str(next_token)) + ) + return next_token + + def _calculate_parameter_checksum(self): + if not self._param_values_to_check: + return None + return reduce( + lambda x, y: x ^ y, + [hash(item) for item in self._param_values_to_check.items()], + ) + + def _check_predicate(self, item): + page_ending_range_key = self._parsed_token["pageEndingRangeKey"] + predicate_values = page_ending_range_key.split("|") + for (index, attr) in enumerate(self._page_ending_range_keys): + if not getattr(item, attr, None) == predicate_values[index]: + return False + return True + + def _build_next_token(self, next_item): + token_dict = {} + if self._param_checksum: + token_dict["parameterChecksum"] = self._param_checksum + range_keys = [] + for (index, attr) in enumerate(self._page_ending_range_keys): + range_keys.append(getattr(next_item, attr)) + token_dict["pageEndingRangeKey"] = "|".join(range_keys) + return TokenEncoder().encode(token_dict) + + def paginate(self, results): + index_start = 0 + if self._starting_token: + try: + index_start = next( + index + for (index, result) in enumerate(results) + if self._check_predicate(result) + ) + except StopIteration: + raise InvalidToken("Resource not found!") + + index_end = index_start + self._max_results + if index_end > len(results): + index_end = len(results) + + results_page = results[index_start:index_end] + + next_token = None + if results_page and index_end < len(results): + page_ending_result = results[index_end] + next_token = self._build_next_token(page_ending_result) + return results_page, next_token diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 36b08487c..e6592c2ff 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -168,15 +168,15 @@ def test_state_machine_list_returns_empty_list_by_default(): def test_state_machine_list_returns_created_state_machines(): client = boto3.client("stepfunctions", region_name=region) # - machine2 = client.create_state_machine( - name="name2", definition=str(simple_definition), roleArn=_get_default_role() - ) machine1 = client.create_state_machine( name="name1", definition=str(simple_definition), roleArn=_get_default_role(), tags=[{"key": "tag_key", "value": "tag_value"}], ) + machine2 = client.create_state_machine( + name="name2", definition=str(simple_definition), roleArn=_get_default_role() + ) list = client.list_state_machines() # list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @@ -195,6 +195,28 @@ def test_state_machine_list_returns_created_state_machines(): ) +@mock_stepfunctions +def test_state_machine_list_pagination(): + client = boto3.client("stepfunctions", region_name=region) + for i in range(25): + machine_name = "StateMachine-{}".format(i) + client.create_state_machine( + name=machine_name, + definition=str(simple_definition), + roleArn=_get_default_role(), + ) + + resp = client.list_state_machines() + resp.should_not.have.key("nextToken") + resp["stateMachines"].should.have.length_of(25) + + paginator = client.get_paginator("list_state_machines") + page_iterator = paginator.paginate(maxResults=5) + for page in page_iterator: + page["stateMachines"].should.have.length_of(5) + page["stateMachines"][-1]["name"].should.contain("24") + + @mock_stepfunctions @mock_sts def test_state_machine_creation_is_idempotent_by_name(): @@ -489,6 +511,69 @@ def test_state_machine_list_executions(): executions["executions"][0].shouldnt.have("stopDate") +@mock_stepfunctions +def test_state_machine_list_executions_with_filter(): + client = boto3.client("stepfunctions", region_name=region) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + for i in range(20): + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + if not i % 4: + client.stop_execution(executionArn=execution["executionArn"]) + + resp = client.list_executions(stateMachineArn=sm["stateMachineArn"]) + resp["executions"].should.have.length_of(20) + + resp = client.list_executions( + stateMachineArn=sm["stateMachineArn"], statusFilter="ABORTED" + ) + resp["executions"].should.have.length_of(5) + all([e["status"] == "ABORTED" for e in resp["executions"]]).should.be.true + + +@mock_stepfunctions +def test_state_machine_list_executions_with_pagination(): + client = boto3.client("stepfunctions", region_name=region) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + for _ in range(100): + client.start_execution(stateMachineArn=sm["stateMachineArn"]) + + resp = client.list_executions(stateMachineArn=sm["stateMachineArn"]) + resp.should_not.have.key("nextToken") + resp["executions"].should.have.length_of(100) + + paginator = client.get_paginator("list_executions") + page_iterator = paginator.paginate( + stateMachineArn=sm["stateMachineArn"], maxResults=25 + ) + for page in page_iterator: + page["executions"].should.have.length_of(25) + + with assert_raises(ClientError) as ex: + resp = client.list_executions( + stateMachineArn=sm["stateMachineArn"], maxResults=10 + ) + client.list_executions( + stateMachineArn=sm["stateMachineArn"], + maxResults=10, + statusFilter="ABORTED", + nextToken=resp["nextToken"], + ) + ex.exception.response["Error"]["Code"].should.equal("InvalidToken") + ex.exception.response["Error"]["Message"].should.contain( + "Input inconsistent with page token" + ) + + with assert_raises(ClientError) as ex: + client.list_executions( + stateMachineArn=sm["stateMachineArn"], nextToken="invalid" + ) + ex.exception.response["Error"]["Code"].should.equal("InvalidToken") + + @mock_stepfunctions @mock_sts def test_state_machine_list_executions_when_none_exist():