Stepfunctions improvements (#3427)

* Implement filtering for stepfunctions:ListExecutions

* Add pagination to Step Functions endpoints

Implements a generalized approach to pagination via a decorator method for the
Step Functions endpoints.  Modeled on the real AWS backend behavior, `nextToken`
is a dictionary of pagination information encoded in an opaque string.

With just a bit of metadata hard-coded (`utils.PAGINATION_MODEL`), backend `list`
methods need only be decorated with `@paginate` and ensure that their returned
entities are sorted to get full pagination support without any duplicated code
polluting the model.

Closes #3137
This commit is contained in:
Brian Pandola 2020-11-01 02:16:41 -08:00 committed by GitHub
parent a3880c4c35
commit 68e3d394ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 284 additions and 26 deletions

View File

@ -46,3 +46,11 @@ class InvalidExecutionInput(AWSError):
class StateMachineDoesNotExist(AWSError):
TYPE = "StateMachineDoesNotExist"
STATUS = 400
class InvalidToken(AWSError):
TYPE = "InvalidToken"
STATUS = 400
def __init__(self, message="Invalid token"):
super(InvalidToken, self).__init__("Invalid Token: {}".format(message))

View File

@ -5,7 +5,7 @@ from datetime import datetime
from boto3 import Session
from moto.core import ACCOUNT_ID, BaseBackend
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.core.utils import iso_8601_datetime_with_milliseconds
from uuid import uuid4
from .exceptions import (
ExecutionAlreadyExists,
@ -15,11 +15,12 @@ from .exceptions import (
InvalidName,
StateMachineDoesNotExist,
)
from .utils import paginate
class StateMachine:
def __init__(self, arn, name, definition, roleArn, tags=None):
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now())
self.arn = arn
self.name = name
self.definition = definition
@ -43,7 +44,7 @@ class Execution:
)
self.execution_arn = execution_arn
self.name = execution_name
self.start_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.start_date = iso_8601_datetime_with_milliseconds(datetime.now())
self.state_machine_arn = state_machine_arn
self.execution_input = execution_input
self.status = "RUNNING"
@ -51,7 +52,7 @@ class Execution:
def stop(self):
self.status = "ABORTED"
self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now())
class StepFunctionBackend(BaseBackend):
@ -189,8 +190,10 @@ class StepFunctionBackend(BaseBackend):
self.state_machines.append(state_machine)
return state_machine
@paginate
def list_state_machines(self):
return self.state_machines
state_machines = sorted(self.state_machines, key=lambda x: x.creation_date)
return state_machines
def describe_state_machine(self, arn):
self._validate_machine_arn(arn)
@ -233,13 +236,20 @@ class StepFunctionBackend(BaseBackend):
execution.stop()
return execution
def list_executions(self, state_machine_arn):
return [
@paginate
def list_executions(self, state_machine_arn, status_filter=None):
executions = [
execution
for execution in self.executions
if execution.state_machine_arn == state_machine_arn
]
if status_filter:
executions = list(filter(lambda e: e.status == status_filter, executions))
executions = sorted(executions, key=lambda x: x.start_date, reverse=True)
return executions
def describe_execution(self, arn):
self._validate_execution_arn(arn)
exctn = next((x for x in self.executions if x.execution_arn == arn), None)

View File

@ -33,19 +33,22 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id
def list_state_machines(self):
list_all = self.stepfunction_backend.list_state_machines()
list_all = sorted(
[
{
"creationDate": sm.creation_date,
"name": sm.name,
"stateMachineArn": sm.arn,
}
for sm in list_all
],
key=lambda x: x["name"],
max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken")
results, next_token = self.stepfunction_backend.list_state_machines(
max_results=max_results, next_token=next_token
)
response = {"stateMachines": list_all}
state_machines = [
{
"creationDate": sm.creation_date,
"name": sm.name,
"stateMachineArn": sm.arn,
}
for sm in results
]
response = {"stateMachines": state_machines}
if next_token:
response["nextToken"] = next_token
return 200, {}, json.dumps(response)
@amzn_request_id
@ -110,9 +113,20 @@ class StepFunctionResponse(BaseResponse):
@amzn_request_id
def list_executions(self):
max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken")
arn = self._get_param("stateMachineArn")
state_machine = self.stepfunction_backend.describe_state_machine(arn)
executions = self.stepfunction_backend.list_executions(arn)
status_filter = self._get_param("statusFilter")
try:
state_machine = self.stepfunction_backend.describe_state_machine(arn)
results, next_token = self.stepfunction_backend.list_executions(
arn,
status_filter=status_filter,
max_results=max_results,
next_token=next_token,
)
except AWSError as err:
return err.response()
executions = [
{
"executionArn": execution.execution_arn,
@ -121,9 +135,12 @@ class StepFunctionResponse(BaseResponse):
"stateMachineArn": state_machine.arn,
"status": execution.status,
}
for execution in executions
for execution in results
]
return 200, {}, json.dumps({"executions": executions})
response = {"executions": executions}
if next_token:
response["nextToken"] = next_token
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_execution(self):

138
moto/stepfunctions/utils.py Normal file
View File

@ -0,0 +1,138 @@
from functools import wraps
from botocore.paginate import TokenDecoder, TokenEncoder
from six.moves import reduce
from .exceptions import InvalidToken
PAGINATION_MODEL = {
"list_executions": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"page_ending_range_keys": ["start_date", "execution_arn"],
},
"list_state_machines": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"page_ending_range_keys": ["creation_date", "arn"],
},
}
def paginate(original_function=None, pagination_model=None):
def pagination_decorator(func):
@wraps(func)
def pagination_wrapper(*args, **kwargs):
method = func.__name__
model = pagination_model or PAGINATION_MODEL
pagination_config = model.get(method)
if not pagination_config:
raise ValueError(
"No pagination config for backend method: {}".format(method)
)
# We pop the pagination arguments, so the remaining kwargs (if any)
# can be used to compute the optional parameters checksum.
input_token = kwargs.pop(pagination_config.get("input_token"), None)
limit = kwargs.pop(pagination_config.get("limit_key"), None)
paginator = Paginator(
max_results=limit,
max_results_default=pagination_config.get("limit_default"),
starting_token=input_token,
page_ending_range_keys=pagination_config.get("page_ending_range_keys"),
param_values_to_check=kwargs,
)
results = func(*args, **kwargs)
return paginator.paginate(results)
return pagination_wrapper
if original_function:
return pagination_decorator(original_function)
return pagination_decorator
class Paginator(object):
def __init__(
self,
max_results=None,
max_results_default=None,
starting_token=None,
page_ending_range_keys=None,
param_values_to_check=None,
):
self._max_results = max_results if max_results else max_results_default
self._starting_token = starting_token
self._page_ending_range_keys = page_ending_range_keys
self._param_values_to_check = param_values_to_check
self._token_encoder = TokenEncoder()
self._token_decoder = TokenDecoder()
self._param_checksum = self._calculate_parameter_checksum()
self._parsed_token = self._parse_starting_token()
def _parse_starting_token(self):
if self._starting_token is None:
return None
# The starting token is a dict passed as a base64 encoded string.
next_token = self._starting_token
try:
next_token = self._token_decoder.decode(next_token)
except (ValueError, TypeError):
raise InvalidToken("Invalid token")
if next_token.get("parameterChecksum") != self._param_checksum:
raise InvalidToken(
"Input inconsistent with page token: {}".format(str(next_token))
)
return next_token
def _calculate_parameter_checksum(self):
if not self._param_values_to_check:
return None
return reduce(
lambda x, y: x ^ y,
[hash(item) for item in self._param_values_to_check.items()],
)
def _check_predicate(self, item):
page_ending_range_key = self._parsed_token["pageEndingRangeKey"]
predicate_values = page_ending_range_key.split("|")
for (index, attr) in enumerate(self._page_ending_range_keys):
if not getattr(item, attr, None) == predicate_values[index]:
return False
return True
def _build_next_token(self, next_item):
token_dict = {}
if self._param_checksum:
token_dict["parameterChecksum"] = self._param_checksum
range_keys = []
for (index, attr) in enumerate(self._page_ending_range_keys):
range_keys.append(getattr(next_item, attr))
token_dict["pageEndingRangeKey"] = "|".join(range_keys)
return TokenEncoder().encode(token_dict)
def paginate(self, results):
index_start = 0
if self._starting_token:
try:
index_start = next(
index
for (index, result) in enumerate(results)
if self._check_predicate(result)
)
except StopIteration:
raise InvalidToken("Resource not found!")
index_end = index_start + self._max_results
if index_end > len(results):
index_end = len(results)
results_page = results[index_start:index_end]
next_token = None
if results_page and index_end < len(results):
page_ending_result = results[index_end]
next_token = self._build_next_token(page_ending_result)
return results_page, next_token

View File

@ -168,15 +168,15 @@ def test_state_machine_list_returns_empty_list_by_default():
def test_state_machine_list_returns_created_state_machines():
client = boto3.client("stepfunctions", region_name=region)
#
machine2 = client.create_state_machine(
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
)
machine1 = client.create_state_machine(
name="name1",
definition=str(simple_definition),
roleArn=_get_default_role(),
tags=[{"key": "tag_key", "value": "tag_value"}],
)
machine2 = client.create_state_machine(
name="name2", definition=str(simple_definition), roleArn=_get_default_role()
)
list = client.list_state_machines()
#
list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
@ -195,6 +195,28 @@ def test_state_machine_list_returns_created_state_machines():
)
@mock_stepfunctions
def test_state_machine_list_pagination():
client = boto3.client("stepfunctions", region_name=region)
for i in range(25):
machine_name = "StateMachine-{}".format(i)
client.create_state_machine(
name=machine_name,
definition=str(simple_definition),
roleArn=_get_default_role(),
)
resp = client.list_state_machines()
resp.should_not.have.key("nextToken")
resp["stateMachines"].should.have.length_of(25)
paginator = client.get_paginator("list_state_machines")
page_iterator = paginator.paginate(maxResults=5)
for page in page_iterator:
page["stateMachines"].should.have.length_of(5)
page["stateMachines"][-1]["name"].should.contain("24")
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_is_idempotent_by_name():
@ -489,6 +511,69 @@ def test_state_machine_list_executions():
executions["executions"][0].shouldnt.have("stopDate")
@mock_stepfunctions
def test_state_machine_list_executions_with_filter():
client = boto3.client("stepfunctions", region_name=region)
sm = client.create_state_machine(
name="name", definition=str(simple_definition), roleArn=_get_default_role()
)
for i in range(20):
execution = client.start_execution(stateMachineArn=sm["stateMachineArn"])
if not i % 4:
client.stop_execution(executionArn=execution["executionArn"])
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
resp["executions"].should.have.length_of(20)
resp = client.list_executions(
stateMachineArn=sm["stateMachineArn"], statusFilter="ABORTED"
)
resp["executions"].should.have.length_of(5)
all([e["status"] == "ABORTED" for e in resp["executions"]]).should.be.true
@mock_stepfunctions
def test_state_machine_list_executions_with_pagination():
client = boto3.client("stepfunctions", region_name=region)
sm = client.create_state_machine(
name="name", definition=str(simple_definition), roleArn=_get_default_role()
)
for _ in range(100):
client.start_execution(stateMachineArn=sm["stateMachineArn"])
resp = client.list_executions(stateMachineArn=sm["stateMachineArn"])
resp.should_not.have.key("nextToken")
resp["executions"].should.have.length_of(100)
paginator = client.get_paginator("list_executions")
page_iterator = paginator.paginate(
stateMachineArn=sm["stateMachineArn"], maxResults=25
)
for page in page_iterator:
page["executions"].should.have.length_of(25)
with assert_raises(ClientError) as ex:
resp = client.list_executions(
stateMachineArn=sm["stateMachineArn"], maxResults=10
)
client.list_executions(
stateMachineArn=sm["stateMachineArn"],
maxResults=10,
statusFilter="ABORTED",
nextToken=resp["nextToken"],
)
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
ex.exception.response["Error"]["Message"].should.contain(
"Input inconsistent with page token"
)
with assert_raises(ClientError) as ex:
client.list_executions(
stateMachineArn=sm["stateMachineArn"], nextToken="invalid"
)
ex.exception.response["Error"]["Code"].should.equal("InvalidToken")
@mock_stepfunctions
@mock_sts
def test_state_machine_list_executions_when_none_exist():