Stepfunctions improvements (#3427)
* Implement filtering for stepfunctions:ListExecutions * Add pagination to Step Functions endpoints Implements a generalized approach to pagination via a decorator method for the Step Functions endpoints. Modeled on the real AWS backend behavior, `nextToken` is a dictionary of pagination information encoded in an opaque string. With just a bit of metadata hard-coded (`utils.PAGINATION_MODEL`), backend `list` methods need only be decorated with `@paginate` and ensure that their returned entities are sorted to get full pagination support without any duplicated code polluting the model. Closes #3137
This commit is contained in:
parent
a3880c4c35
commit
68e3d394ab
@ -46,3 +46,11 @@ class InvalidExecutionInput(AWSError):
|
||||
class StateMachineDoesNotExist(AWSError):
|
||||
TYPE = "StateMachineDoesNotExist"
|
||||
STATUS = 400
|
||||
|
||||
|
||||
class InvalidToken(AWSError):
|
||||
TYPE = "InvalidToken"
|
||||
STATUS = 400
|
||||
|
||||
def __init__(self, message="Invalid token"):
|
||||
super(InvalidToken, self).__init__("Invalid Token: {}".format(message))
|
||||
|
@ -5,7 +5,7 @@ from datetime import datetime
|
||||
from boto3 import Session
|
||||
|
||||
from moto.core import ACCOUNT_ID, BaseBackend
|
||||
from moto.core.utils import iso_8601_datetime_without_milliseconds
|
||||
from moto.core.utils import iso_8601_datetime_with_milliseconds
|
||||
from uuid import uuid4
|
||||
from .exceptions import (
|
||||
ExecutionAlreadyExists,
|
||||
@ -15,11 +15,12 @@ from .exceptions import (
|
||||
InvalidName,
|
||||
StateMachineDoesNotExist,
|
||||
)
|
||||
from .utils import paginate
|
||||
|
||||
|
||||
class StateMachine:
|
||||
def __init__(self, arn, name, definition, roleArn, tags=None):
|
||||
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
||||
self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||
self.arn = arn
|
||||
self.name = name
|
||||
self.definition = definition
|
||||
@ -43,7 +44,7 @@ class Execution:
|
||||
)
|
||||
self.execution_arn = execution_arn
|
||||
self.name = execution_name
|
||||
self.start_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
||||
self.start_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||
self.state_machine_arn = state_machine_arn
|
||||
self.execution_input = execution_input
|
||||
self.status = "RUNNING"
|
||||
@ -51,7 +52,7 @@ class Execution:
|
||||
|
||||
def stop(self):
|
||||
self.status = "ABORTED"
|
||||
self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now())
|
||||
self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now())
|
||||
|
||||
|
||||
class StepFunctionBackend(BaseBackend):
|
||||
@ -189,8 +190,10 @@ class StepFunctionBackend(BaseBackend):
|
||||
self.state_machines.append(state_machine)
|
||||
return state_machine
|
||||
|
||||
@paginate
|
||||
def list_state_machines(self):
|
||||
return self.state_machines
|
||||
state_machines = sorted(self.state_machines, key=lambda x: x.creation_date)
|
||||
return state_machines
|
||||
|
||||
def describe_state_machine(self, arn):
|
||||
self._validate_machine_arn(arn)
|
||||
@ -233,13 +236,20 @@ class StepFunctionBackend(BaseBackend):
|
||||
execution.stop()
|
||||
return execution
|
||||
|
||||
def list_executions(self, state_machine_arn):
|
||||
return [
|
||||
@paginate
|
||||
def list_executions(self, state_machine_arn, status_filter=None):
|
||||
executions = [
|
||||
execution
|
||||
for execution in self.executions
|
||||
if execution.state_machine_arn == state_machine_arn
|
||||
]
|
||||
|
||||
if status_filter:
|
||||
executions = list(filter(lambda e: e.status == status_filter, executions))
|
||||
|
||||
executions = sorted(executions, key=lambda x: x.start_date, reverse=True)
|
||||
return executions
|
||||
|
||||
def describe_execution(self, arn):
|
||||
self._validate_execution_arn(arn)
|
||||
exctn = next((x for x in self.executions if x.execution_arn == arn), None)
|
||||
|
@ -33,19 +33,22 @@ class StepFunctionResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def list_state_machines(self):
|
||||
list_all = self.stepfunction_backend.list_state_machines()
|
||||
list_all = sorted(
|
||||
[
|
||||
{
|
||||
"creationDate": sm.creation_date,
|
||||
"name": sm.name,
|
||||
"stateMachineArn": sm.arn,
|
||||
}
|
||||
for sm in list_all
|
||||
],
|
||||
key=lambda x: x["name"],
|
||||
max_results = self._get_int_param("maxResults")
|
||||
next_token = self._get_param("nextToken")
|
||||
results, next_token = self.stepfunction_backend.list_state_machines(
|
||||
max_results=max_results, next_token=next_token
|
||||
)
|
||||
response = {"stateMachines": list_all}
|
||||
state_machines = [
|
||||
{
|
||||
"creationDate": sm.creation_date,
|
||||
"name": sm.name,
|
||||
"stateMachineArn": sm.arn,
|
||||
}
|
||||
for sm in results
|
||||
]
|
||||
response = {"stateMachines": state_machines}
|
||||
if next_token:
|
||||
response["nextToken"] = next_token
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
@ -110,9 +113,20 @@ class StepFunctionResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def list_executions(self):
|
||||
max_results = self._get_int_param("maxResults")
|
||||
next_token = self._get_param("nextToken")
|
||||
arn = self._get_param("stateMachineArn")
|
||||
state_machine = self.stepfunction_backend.describe_state_machine(arn)
|
||||
executions = self.stepfunction_backend.list_executions(arn)
|
||||
status_filter = self._get_param("statusFilter")
|
||||
try:
|
||||
state_machine = self.stepfunction_backend.describe_state_machine(arn)
|
||||
results, next_token = self.stepfunction_backend.list_executions(
|
||||
arn,
|
||||
status_filter=status_filter,
|
||||
max_results=max_results,
|
||||
next_token=next_token,
|
||||
)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
executions = [
|
||||
{
|
||||
"executionArn": execution.execution_arn,
|
||||
@ -121,9 +135,12 @@ class StepFunctionResponse(BaseResponse):
|
||||
"stateMachineArn": state_machine.arn,
|
||||
"status": execution.status,
|
||||
}
|
||||
for execution in executions
|
||||
for execution in results
|
||||
]
|
||||
return 200, {}, json.dumps({"executions": executions})
|
||||
response = {"executions": executions}
|
||||
if next_token:
|
||||
response["nextToken"] = next_token
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def describe_execution(self):
|
||||
|
138
moto/stepfunctions/utils.py
Normal file
138
moto/stepfunctions/utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
from functools import wraps
|
||||
|
||||
from botocore.paginate import TokenDecoder, TokenEncoder
|
||||
from six.moves import reduce
|
||||
|
||||
from .exceptions import InvalidToken
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_executions": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 100,
|
||||
"page_ending_range_keys": ["start_date", "execution_arn"],
|
||||
},
|
||||
"list_state_machines": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 100,
|
||||
"page_ending_range_keys": ["creation_date", "arn"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def paginate(original_function=None, pagination_model=None):
|
||||
def pagination_decorator(func):
|
||||
@wraps(func)
|
||||
def pagination_wrapper(*args, **kwargs):
|
||||
method = func.__name__
|
||||
model = pagination_model or PAGINATION_MODEL
|
||||
pagination_config = model.get(method)
|
||||
if not pagination_config:
|
||||
raise ValueError(
|
||||
"No pagination config for backend method: {}".format(method)
|
||||
)
|
||||
# We pop the pagination arguments, so the remaining kwargs (if any)
|
||||
# can be used to compute the optional parameters checksum.
|
||||
input_token = kwargs.pop(pagination_config.get("input_token"), None)
|
||||
limit = kwargs.pop(pagination_config.get("limit_key"), None)
|
||||
paginator = Paginator(
|
||||
max_results=limit,
|
||||
max_results_default=pagination_config.get("limit_default"),
|
||||
starting_token=input_token,
|
||||
page_ending_range_keys=pagination_config.get("page_ending_range_keys"),
|
||||
param_values_to_check=kwargs,
|
||||
)
|
||||
results = func(*args, **kwargs)
|
||||
return paginator.paginate(results)
|
||||
|
||||
return pagination_wrapper
|
||||
|
||||
if original_function:
|
||||
return pagination_decorator(original_function)
|
||||
|
||||
return pagination_decorator
|
||||
|
||||
|
||||
class Paginator(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_results=None,
|
||||
max_results_default=None,
|
||||
starting_token=None,
|
||||
page_ending_range_keys=None,
|
||||
param_values_to_check=None,
|
||||
):
|
||||
self._max_results = max_results if max_results else max_results_default
|
||||
self._starting_token = starting_token
|
||||
self._page_ending_range_keys = page_ending_range_keys
|
||||
self._param_values_to_check = param_values_to_check
|
||||
self._token_encoder = TokenEncoder()
|
||||
self._token_decoder = TokenDecoder()
|
||||
self._param_checksum = self._calculate_parameter_checksum()
|
||||
self._parsed_token = self._parse_starting_token()
|
||||
|
||||
def _parse_starting_token(self):
|
||||
if self._starting_token is None:
|
||||
return None
|
||||
# The starting token is a dict passed as a base64 encoded string.
|
||||
next_token = self._starting_token
|
||||
try:
|
||||
next_token = self._token_decoder.decode(next_token)
|
||||
except (ValueError, TypeError):
|
||||
raise InvalidToken("Invalid token")
|
||||
if next_token.get("parameterChecksum") != self._param_checksum:
|
||||
raise InvalidToken(
|
||||
"Input inconsistent with page token: {}".format(str(next_token))
|
||||
)
|
||||
return next_token
|
||||
|
||||
def _calculate_parameter_checksum(self):
|
||||
if not self._param_values_to_check:
|
||||
return None
|
||||
return reduce(
|
||||
lambda x, y: x ^ y,
|
||||
[hash(item) for item in self._param_values_to_check.items()],
|
||||
)
|
||||
|
||||
def _check_predicate(self, item):
|
||||
page_ending_range_key = self._parsed_token["pageEndingRangeKey"]
|
||||
predicate_values = page_ending_range_key.split("|")
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
if not getattr(item, attr, None) == predicate_values[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _build_next_token(self, next_item):
|
||||
token_dict = {}
|
||||
if self._param_checksum:
|
||||
token_dict["parameterChecksum"] = self._param_checksum
|
||||
range_keys = []
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
range_keys.append(getattr(next_item, attr))
|
||||
token_dict["pageEndingRangeKey"] = "|".join(range_keys)
|
||||
return TokenEncoder().encode(token_dict)
|
||||
|
||||
def paginate(self, results):
|
||||
index_start = 0
|
||||
if self._starting_token:
|
||||
try:
|
||||
index_start = next(
|
||||
index
|
||||
for (index, result) in enumerate(results)
|
||||
if self._check_predicate(result)
|
||||
)
|
||||
except StopIteration:
|
||||
raise InvalidToken("Resource not found!")
|
||||
|
||||
index_end = index_start + self._max_results
|
||||
if index_end > len(results):
|
||||
index_end = len(results)
|
||||
|
||||
results_page = results[index_start:index_end]
|
||||
|
||||
next_token = None
|
||||
if results_page and index_end < len(results):
|
||||
page_ending_result = results[index_end]
|
||||
next_token = self._build_next_token(page_ending_result)
|
||||
return results_page, next_token
|
@ -168,15 +168,15 @@ def test_state_machine_list_returns_empty_list_by_default():
|
||||
def test_state_machine_list_returns_created_state_machines():
|
||||
client = boto3.client("stepfunctions", region_name=region)
|
||||
#
|
||||
machine2 = client.create_state_machine(
|
||||
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
|
||||
)
|
||||
machine1 = client.create_state_machine(
|
||||
name="name1",
|
||||
definition=str(simple_definition),
|
||||
roleArn=_get_default_role(),
|
||||
tags=[{"key": "tag_key", "value": "tag_value"}],
|
||||
)
|
||||
machine2 = client.create_state_machine(
|
||||
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
|
||||
)
|
||||
list = client.list_state_machines()
|
||||
#
|
||||
list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
||||
@ -195,6 +195,28 @@ def test_state_machine_list_returns_created_state_machines():
|
||||
)
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
def test_state_machine_list_pagination():
|
||||
client = boto3.client("stepfunctions", region_name=region)
|
||||
for i in range(25):
|
||||
machine_name = "StateMachine-{}".format(i)
|
||||
client.create_state_machine(
|
||||
name=machine_name,
|
||||
definition=str(simple_definition),
|
||||
roleArn=_get_default_role(),
|
||||
)
|
||||
|
||||
resp = client.list_state_machines()
|
||||
resp.should_not.have.key("nextToken")
|
||||
resp["stateMachines"].should.have.length_of(25)
|
||||
|
||||
paginator = client.get_paginator("list_state_machines")
|
||||
page_iterator = paginator.paginate(maxResults=5)
|
||||
for page in page_iterator:
|
||||
page["stateMachines"].should.have.length_of(5)
|
||||
page["stateMachines"][-1]["name"].should.contain("24")
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
@mock_sts
|
||||
def test_state_machine_creation_is_idempotent_by_name():
|
||||
@ -489,6 +511,69 @@ def test_state_machine_list_executions():
|
||||
executions["executions"][0].shouldnt.have("stopDate")
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
def test_state_machine_list_executions_with_filter():
|
||||
client = boto3.client("stepfunctions", region_name=region)
|
||||
sm = client.create_state_machine(
|
||||
name="name", definition=str(simple_definition), roleArn=_get_default_role()
|
||||
)
|
||||
for i in range(20):
|
||||
execution = client.start_execution(stateMachineArn=sm["stateMachineArn"])
|
||||
if not i % 4:
|
||||
client.stop_execution(executionArn=execution["executionArn"])
|
||||
|
||||
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
|
||||
resp["executions"].should.have.length_of(20)
|
||||
|
||||
resp = client.list_executions(
|
||||
stateMachineArn=sm["stateMachineArn"], statusFilter="ABORTED"
|
||||
)
|
||||
resp["executions"].should.have.length_of(5)
|
||||
all([e["status"] == "ABORTED" for e in resp["executions"]]).should.be.true
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
def test_state_machine_list_executions_with_pagination():
|
||||
client = boto3.client("stepfunctions", region_name=region)
|
||||
sm = client.create_state_machine(
|
||||
name="name", definition=str(simple_definition), roleArn=_get_default_role()
|
||||
)
|
||||
for _ in range(100):
|
||||
client.start_execution(stateMachineArn=sm["stateMachineArn"])
|
||||
|
||||
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
|
||||
resp.should_not.have.key("nextToken")
|
||||
resp["executions"].should.have.length_of(100)
|
||||
|
||||
paginator = client.get_paginator("list_executions")
|
||||
page_iterator = paginator.paginate(
|
||||
stateMachineArn=sm["stateMachineArn"], maxResults=25
|
||||
)
|
||||
for page in page_iterator:
|
||||
page["executions"].should.have.length_of(25)
|
||||
|
||||
with assert_raises(ClientError) as ex:
|
||||
resp = client.list_executions(
|
||||
stateMachineArn=sm["stateMachineArn"], maxResults=10
|
||||
)
|
||||
client.list_executions(
|
||||
stateMachineArn=sm["stateMachineArn"],
|
||||
maxResults=10,
|
||||
statusFilter="ABORTED",
|
||||
nextToken=resp["nextToken"],
|
||||
)
|
||||
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
|
||||
ex.exception.response["Error"]["Message"].should.contain(
|
||||
"Input inconsistent with page token"
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as ex:
|
||||
client.list_executions(
|
||||
stateMachineArn=sm["stateMachineArn"], nextToken="invalid"
|
||||
)
|
||||
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
@mock_sts
|
||||
def test_state_machine_list_executions_when_none_exist():
|
||||
|
Loading…
Reference in New Issue
Block a user