Stepfunctions improvements (#3427)
* Implement filtering for stepfunctions:ListExecutions * Add pagination to Step Functions endpoints Implements a generalized approach to pagination via a decorator method for the Step Functions endpoints. Modeled on the real AWS backend behavior, `nextToken` is a dictionary of pagination information encoded in an opaque string. With just a bit of metadata hard-coded (`utils.PAGINATION_MODEL`), backend `list` methods need only be decorated with `@paginate` and ensure that their returned entities are sorted to get full pagination support without any duplicated code polluting the model. Closes #3137
This commit is contained in:
parent
a3880c4c35
commit
68e3d394ab
@ -46,3 +46,11 @@ class InvalidExecutionInput(AWSError):
|
|||||||
class StateMachineDoesNotExist(AWSError):
|
class StateMachineDoesNotExist(AWSError):
|
||||||
TYPE = "StateMachineDoesNotExist"
|
TYPE = "StateMachineDoesNotExist"
|
||||||
STATUS = 400
|
STATUS = 400
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidToken(AWSError):
|
||||||
|
TYPE = "InvalidToken"
|
||||||
|
STATUS = 400
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid token"):
|
||||||
|
super(InvalidToken, self).__init__("Invalid Token: {}".format(message))
|
||||||
|
@ -5,7 +5,7 @@ from datetime import datetime
|
|||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
|
|
||||||
from moto.core import ACCOUNT_ID, BaseBackend
|
from moto.core import ACCOUNT_ID, BaseBackend
|
||||||
from moto.core.utils import iso_8601_datetime_without_milliseconds
|
from moto.core.utils import iso_8601_datetime_with_milliseconds
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
ExecutionAlreadyExists,
|
ExecutionAlreadyExists,
|
||||||
@ -15,11 +15,12 @@ from .exceptions import (
|
|||||||
InvalidName,
|
InvalidName,
|
||||||
StateMachineDoesNotExist,
|
StateMachineDoesNotExist,
|
||||||
)
|
)
|
||||||
|
from .utils import paginate
|
||||||
|
|
||||||
|
|
||||||
class StateMachine:
|
class StateMachine:
|
||||||
def __init__(self, arn, name, definition, roleArn, tags=None):
|
def __init__(self, arn, name, definition, roleArn, tags=None):
|
||||||
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||||
self.arn = arn
|
self.arn = arn
|
||||||
self.name = name
|
self.name = name
|
||||||
self.definition = definition
|
self.definition = definition
|
||||||
@ -43,7 +44,7 @@ class Execution:
|
|||||||
)
|
)
|
||||||
self.execution_arn = execution_arn
|
self.execution_arn = execution_arn
|
||||||
self.name = execution_name
|
self.name = execution_name
|
||||||
self.start_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
self.start_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||||
self.state_machine_arn = state_machine_arn
|
self.state_machine_arn = state_machine_arn
|
||||||
self.execution_input = execution_input
|
self.execution_input = execution_input
|
||||||
self.status = "RUNNING"
|
self.status = "RUNNING"
|
||||||
@ -51,7 +52,7 @@ class Execution:
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.status = "ABORTED"
|
self.status = "ABORTED"
|
||||||
self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||||
|
|
||||||
|
|
||||||
class StepFunctionBackend(BaseBackend):
|
class StepFunctionBackend(BaseBackend):
|
||||||
@ -189,8 +190,10 @@ class StepFunctionBackend(BaseBackend):
|
|||||||
self.state_machines.append(state_machine)
|
self.state_machines.append(state_machine)
|
||||||
return state_machine
|
return state_machine
|
||||||
|
|
||||||
|
@paginate
|
||||||
def list_state_machines(self):
|
def list_state_machines(self):
|
||||||
return self.state_machines
|
state_machines = sorted(self.state_machines, key=lambda x: x.creation_date)
|
||||||
|
return state_machines
|
||||||
|
|
||||||
def describe_state_machine(self, arn):
|
def describe_state_machine(self, arn):
|
||||||
self._validate_machine_arn(arn)
|
self._validate_machine_arn(arn)
|
||||||
@ -233,13 +236,20 @@ class StepFunctionBackend(BaseBackend):
|
|||||||
execution.stop()
|
execution.stop()
|
||||||
return execution
|
return execution
|
||||||
|
|
||||||
def list_executions(self, state_machine_arn):
|
@paginate
|
||||||
return [
|
def list_executions(self, state_machine_arn, status_filter=None):
|
||||||
|
executions = [
|
||||||
execution
|
execution
|
||||||
for execution in self.executions
|
for execution in self.executions
|
||||||
if execution.state_machine_arn == state_machine_arn
|
if execution.state_machine_arn == state_machine_arn
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if status_filter:
|
||||||
|
executions = list(filter(lambda e: e.status == status_filter, executions))
|
||||||
|
|
||||||
|
executions = sorted(executions, key=lambda x: x.start_date, reverse=True)
|
||||||
|
return executions
|
||||||
|
|
||||||
def describe_execution(self, arn):
|
def describe_execution(self, arn):
|
||||||
self._validate_execution_arn(arn)
|
self._validate_execution_arn(arn)
|
||||||
exctn = next((x for x in self.executions if x.execution_arn == arn), None)
|
exctn = next((x for x in self.executions if x.execution_arn == arn), None)
|
||||||
|
@ -33,19 +33,22 @@ class StepFunctionResponse(BaseResponse):
|
|||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_state_machines(self):
|
def list_state_machines(self):
|
||||||
list_all = self.stepfunction_backend.list_state_machines()
|
max_results = self._get_int_param("maxResults")
|
||||||
list_all = sorted(
|
next_token = self._get_param("nextToken")
|
||||||
[
|
results, next_token = self.stepfunction_backend.list_state_machines(
|
||||||
|
max_results=max_results, next_token=next_token
|
||||||
|
)
|
||||||
|
state_machines = [
|
||||||
{
|
{
|
||||||
"creationDate": sm.creation_date,
|
"creationDate": sm.creation_date,
|
||||||
"name": sm.name,
|
"name": sm.name,
|
||||||
"stateMachineArn": sm.arn,
|
"stateMachineArn": sm.arn,
|
||||||
}
|
}
|
||||||
for sm in list_all
|
for sm in results
|
||||||
],
|
]
|
||||||
key=lambda x: x["name"],
|
response = {"stateMachines": state_machines}
|
||||||
)
|
if next_token:
|
||||||
response = {"stateMachines": list_all}
|
response["nextToken"] = next_token
|
||||||
return 200, {}, json.dumps(response)
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
@ -110,9 +113,20 @@ class StepFunctionResponse(BaseResponse):
|
|||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def list_executions(self):
|
def list_executions(self):
|
||||||
|
max_results = self._get_int_param("maxResults")
|
||||||
|
next_token = self._get_param("nextToken")
|
||||||
arn = self._get_param("stateMachineArn")
|
arn = self._get_param("stateMachineArn")
|
||||||
|
status_filter = self._get_param("statusFilter")
|
||||||
|
try:
|
||||||
state_machine = self.stepfunction_backend.describe_state_machine(arn)
|
state_machine = self.stepfunction_backend.describe_state_machine(arn)
|
||||||
executions = self.stepfunction_backend.list_executions(arn)
|
results, next_token = self.stepfunction_backend.list_executions(
|
||||||
|
arn,
|
||||||
|
status_filter=status_filter,
|
||||||
|
max_results=max_results,
|
||||||
|
next_token=next_token,
|
||||||
|
)
|
||||||
|
except AWSError as err:
|
||||||
|
return err.response()
|
||||||
executions = [
|
executions = [
|
||||||
{
|
{
|
||||||
"executionArn": execution.execution_arn,
|
"executionArn": execution.execution_arn,
|
||||||
@ -121,9 +135,12 @@ class StepFunctionResponse(BaseResponse):
|
|||||||
"stateMachineArn": state_machine.arn,
|
"stateMachineArn": state_machine.arn,
|
||||||
"status": execution.status,
|
"status": execution.status,
|
||||||
}
|
}
|
||||||
for execution in executions
|
for execution in results
|
||||||
]
|
]
|
||||||
return 200, {}, json.dumps({"executions": executions})
|
response = {"executions": executions}
|
||||||
|
if next_token:
|
||||||
|
response["nextToken"] = next_token
|
||||||
|
return 200, {}, json.dumps(response)
|
||||||
|
|
||||||
@amzn_request_id
|
@amzn_request_id
|
||||||
def describe_execution(self):
|
def describe_execution(self):
|
||||||
|
138
moto/stepfunctions/utils.py
Normal file
138
moto/stepfunctions/utils.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from botocore.paginate import TokenDecoder, TokenEncoder
|
||||||
|
from six.moves import reduce
|
||||||
|
|
||||||
|
from .exceptions import InvalidToken
|
||||||
|
|
||||||
|
PAGINATION_MODEL = {
|
||||||
|
"list_executions": {
|
||||||
|
"input_token": "next_token",
|
||||||
|
"limit_key": "max_results",
|
||||||
|
"limit_default": 100,
|
||||||
|
"page_ending_range_keys": ["start_date", "execution_arn"],
|
||||||
|
},
|
||||||
|
"list_state_machines": {
|
||||||
|
"input_token": "next_token",
|
||||||
|
"limit_key": "max_results",
|
||||||
|
"limit_default": 100,
|
||||||
|
"page_ending_range_keys": ["creation_date", "arn"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def paginate(original_function=None, pagination_model=None):
|
||||||
|
def pagination_decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def pagination_wrapper(*args, **kwargs):
|
||||||
|
method = func.__name__
|
||||||
|
model = pagination_model or PAGINATION_MODEL
|
||||||
|
pagination_config = model.get(method)
|
||||||
|
if not pagination_config:
|
||||||
|
raise ValueError(
|
||||||
|
"No pagination config for backend method: {}".format(method)
|
||||||
|
)
|
||||||
|
# We pop the pagination arguments, so the remaining kwargs (if any)
|
||||||
|
# can be used to compute the optional parameters checksum.
|
||||||
|
input_token = kwargs.pop(pagination_config.get("input_token"), None)
|
||||||
|
limit = kwargs.pop(pagination_config.get("limit_key"), None)
|
||||||
|
paginator = Paginator(
|
||||||
|
max_results=limit,
|
||||||
|
max_results_default=pagination_config.get("limit_default"),
|
||||||
|
starting_token=input_token,
|
||||||
|
page_ending_range_keys=pagination_config.get("page_ending_range_keys"),
|
||||||
|
param_values_to_check=kwargs,
|
||||||
|
)
|
||||||
|
results = func(*args, **kwargs)
|
||||||
|
return paginator.paginate(results)
|
||||||
|
|
||||||
|
return pagination_wrapper
|
||||||
|
|
||||||
|
if original_function:
|
||||||
|
return pagination_decorator(original_function)
|
||||||
|
|
||||||
|
return pagination_decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Paginator(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_results=None,
|
||||||
|
max_results_default=None,
|
||||||
|
starting_token=None,
|
||||||
|
page_ending_range_keys=None,
|
||||||
|
param_values_to_check=None,
|
||||||
|
):
|
||||||
|
self._max_results = max_results if max_results else max_results_default
|
||||||
|
self._starting_token = starting_token
|
||||||
|
self._page_ending_range_keys = page_ending_range_keys
|
||||||
|
self._param_values_to_check = param_values_to_check
|
||||||
|
self._token_encoder = TokenEncoder()
|
||||||
|
self._token_decoder = TokenDecoder()
|
||||||
|
self._param_checksum = self._calculate_parameter_checksum()
|
||||||
|
self._parsed_token = self._parse_starting_token()
|
||||||
|
|
||||||
|
def _parse_starting_token(self):
|
||||||
|
if self._starting_token is None:
|
||||||
|
return None
|
||||||
|
# The starting token is a dict passed as a base64 encoded string.
|
||||||
|
next_token = self._starting_token
|
||||||
|
try:
|
||||||
|
next_token = self._token_decoder.decode(next_token)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise InvalidToken("Invalid token")
|
||||||
|
if next_token.get("parameterChecksum") != self._param_checksum:
|
||||||
|
raise InvalidToken(
|
||||||
|
"Input inconsistent with page token: {}".format(str(next_token))
|
||||||
|
)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
def _calculate_parameter_checksum(self):
|
||||||
|
if not self._param_values_to_check:
|
||||||
|
return None
|
||||||
|
return reduce(
|
||||||
|
lambda x, y: x ^ y,
|
||||||
|
[hash(item) for item in self._param_values_to_check.items()],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_predicate(self, item):
|
||||||
|
page_ending_range_key = self._parsed_token["pageEndingRangeKey"]
|
||||||
|
predicate_values = page_ending_range_key.split("|")
|
||||||
|
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||||
|
if not getattr(item, attr, None) == predicate_values[index]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _build_next_token(self, next_item):
|
||||||
|
token_dict = {}
|
||||||
|
if self._param_checksum:
|
||||||
|
token_dict["parameterChecksum"] = self._param_checksum
|
||||||
|
range_keys = []
|
||||||
|
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||||
|
range_keys.append(getattr(next_item, attr))
|
||||||
|
token_dict["pageEndingRangeKey"] = "|".join(range_keys)
|
||||||
|
return TokenEncoder().encode(token_dict)
|
||||||
|
|
||||||
|
def paginate(self, results):
|
||||||
|
index_start = 0
|
||||||
|
if self._starting_token:
|
||||||
|
try:
|
||||||
|
index_start = next(
|
||||||
|
index
|
||||||
|
for (index, result) in enumerate(results)
|
||||||
|
if self._check_predicate(result)
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
raise InvalidToken("Resource not found!")
|
||||||
|
|
||||||
|
index_end = index_start + self._max_results
|
||||||
|
if index_end > len(results):
|
||||||
|
index_end = len(results)
|
||||||
|
|
||||||
|
results_page = results[index_start:index_end]
|
||||||
|
|
||||||
|
next_token = None
|
||||||
|
if results_page and index_end < len(results):
|
||||||
|
page_ending_result = results[index_end]
|
||||||
|
next_token = self._build_next_token(page_ending_result)
|
||||||
|
return results_page, next_token
|
@ -168,15 +168,15 @@ def test_state_machine_list_returns_empty_list_by_default():
|
|||||||
def test_state_machine_list_returns_created_state_machines():
|
def test_state_machine_list_returns_created_state_machines():
|
||||||
client = boto3.client("stepfunctions", region_name=region)
|
client = boto3.client("stepfunctions", region_name=region)
|
||||||
#
|
#
|
||||||
machine2 = client.create_state_machine(
|
|
||||||
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
|
|
||||||
)
|
|
||||||
machine1 = client.create_state_machine(
|
machine1 = client.create_state_machine(
|
||||||
name="name1",
|
name="name1",
|
||||||
definition=str(simple_definition),
|
definition=str(simple_definition),
|
||||||
roleArn=_get_default_role(),
|
roleArn=_get_default_role(),
|
||||||
tags=[{"key": "tag_key", "value": "tag_value"}],
|
tags=[{"key": "tag_key", "value": "tag_value"}],
|
||||||
)
|
)
|
||||||
|
machine2 = client.create_state_machine(
|
||||||
|
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
|
||||||
|
)
|
||||||
list = client.list_state_machines()
|
list = client.list_state_machines()
|
||||||
#
|
#
|
||||||
list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
||||||
@ -195,6 +195,28 @@ def test_state_machine_list_returns_created_state_machines():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_stepfunctions
|
||||||
|
def test_state_machine_list_pagination():
|
||||||
|
client = boto3.client("stepfunctions", region_name=region)
|
||||||
|
for i in range(25):
|
||||||
|
machine_name = "StateMachine-{}".format(i)
|
||||||
|
client.create_state_machine(
|
||||||
|
name=machine_name,
|
||||||
|
definition=str(simple_definition),
|
||||||
|
roleArn=_get_default_role(),
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = client.list_state_machines()
|
||||||
|
resp.should_not.have.key("nextToken")
|
||||||
|
resp["stateMachines"].should.have.length_of(25)
|
||||||
|
|
||||||
|
paginator = client.get_paginator("list_state_machines")
|
||||||
|
page_iterator = paginator.paginate(maxResults=5)
|
||||||
|
for page in page_iterator:
|
||||||
|
page["stateMachines"].should.have.length_of(5)
|
||||||
|
page["stateMachines"][-1]["name"].should.contain("24")
|
||||||
|
|
||||||
|
|
||||||
@mock_stepfunctions
|
@mock_stepfunctions
|
||||||
@mock_sts
|
@mock_sts
|
||||||
def test_state_machine_creation_is_idempotent_by_name():
|
def test_state_machine_creation_is_idempotent_by_name():
|
||||||
@ -489,6 +511,69 @@ def test_state_machine_list_executions():
|
|||||||
executions["executions"][0].shouldnt.have("stopDate")
|
executions["executions"][0].shouldnt.have("stopDate")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_stepfunctions
|
||||||
|
def test_state_machine_list_executions_with_filter():
|
||||||
|
client = boto3.client("stepfunctions", region_name=region)
|
||||||
|
sm = client.create_state_machine(
|
||||||
|
name="name", definition=str(simple_definition), roleArn=_get_default_role()
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
execution = client.start_execution(stateMachineArn=sm["stateMachineArn"])
|
||||||
|
if not i % 4:
|
||||||
|
client.stop_execution(executionArn=execution["executionArn"])
|
||||||
|
|
||||||
|
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
|
||||||
|
resp["executions"].should.have.length_of(20)
|
||||||
|
|
||||||
|
resp = client.list_executions(
|
||||||
|
stateMachineArn=sm["stateMachineArn"], statusFilter="ABORTED"
|
||||||
|
)
|
||||||
|
resp["executions"].should.have.length_of(5)
|
||||||
|
all([e["status"] == "ABORTED" for e in resp["executions"]]).should.be.true
|
||||||
|
|
||||||
|
|
||||||
|
@mock_stepfunctions
|
||||||
|
def test_state_machine_list_executions_with_pagination():
|
||||||
|
client = boto3.client("stepfunctions", region_name=region)
|
||||||
|
sm = client.create_state_machine(
|
||||||
|
name="name", definition=str(simple_definition), roleArn=_get_default_role()
|
||||||
|
)
|
||||||
|
for _ in range(100):
|
||||||
|
client.start_execution(stateMachineArn=sm["stateMachineArn"])
|
||||||
|
|
||||||
|
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
|
||||||
|
resp.should_not.have.key("nextToken")
|
||||||
|
resp["executions"].should.have.length_of(100)
|
||||||
|
|
||||||
|
paginator = client.get_paginator("list_executions")
|
||||||
|
page_iterator = paginator.paginate(
|
||||||
|
stateMachineArn=sm["stateMachineArn"], maxResults=25
|
||||||
|
)
|
||||||
|
for page in page_iterator:
|
||||||
|
page["executions"].should.have.length_of(25)
|
||||||
|
|
||||||
|
with assert_raises(ClientError) as ex:
|
||||||
|
resp = client.list_executions(
|
||||||
|
stateMachineArn=sm["stateMachineArn"], maxResults=10
|
||||||
|
)
|
||||||
|
client.list_executions(
|
||||||
|
stateMachineArn=sm["stateMachineArn"],
|
||||||
|
maxResults=10,
|
||||||
|
statusFilter="ABORTED",
|
||||||
|
nextToken=resp["nextToken"],
|
||||||
|
)
|
||||||
|
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
|
||||||
|
ex.exception.response["Error"]["Message"].should.contain(
|
||||||
|
"Input inconsistent with page token"
|
||||||
|
)
|
||||||
|
|
||||||
|
with assert_raises(ClientError) as ex:
|
||||||
|
client.list_executions(
|
||||||
|
stateMachineArn=sm["stateMachineArn"], nextToken="invalid"
|
||||||
|
)
|
||||||
|
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
|
||||||
|
|
||||||
|
|
||||||
@mock_stepfunctions
|
@mock_stepfunctions
|
||||||
@mock_sts
|
@mock_sts
|
||||||
def test_state_machine_list_executions_when_none_exist():
|
def test_state_machine_list_executions_when_none_exist():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user