Feature: Kinesis - list_shards() (#3752)
This commit is contained in:
parent
df05b608b0
commit
c642e8b4a7
@ -1,9 +1,7 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -22,7 +20,8 @@ from .exceptions import (
|
||||
UserNotConfirmedException,
|
||||
InvalidParameterException,
|
||||
)
|
||||
from .utils import create_id, check_secret_hash
|
||||
from .utils import create_id, check_secret_hash, PAGINATION_MODEL
|
||||
from moto.utilities.paginator import paginate
|
||||
|
||||
UserStatus = {
|
||||
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
|
||||
@ -31,43 +30,6 @@ UserStatus = {
|
||||
}
|
||||
|
||||
|
||||
def paginate(limit, start_arg="next_token", limit_arg="max_results"):
|
||||
"""Returns a limited result list, and an offset into list of remaining items
|
||||
|
||||
Takes the next_token, and max_results kwargs given to a function and handles
|
||||
the slicing of the results. The kwarg `next_token` is the offset into the
|
||||
list to begin slicing from. `max_results` is the size of the result required
|
||||
|
||||
If the max_results is not supplied then the `limit` parameter is used as a
|
||||
default
|
||||
|
||||
:param limit_arg: the name of argument in the decorated function that
|
||||
controls amount of items returned
|
||||
:param start_arg: the name of the argument in the decorated that provides
|
||||
the starting offset
|
||||
:param limit: A default maximum items to return
|
||||
:return: a tuple containing a list of items, and the offset into the list
|
||||
"""
|
||||
default_start = 0
|
||||
|
||||
def outer_wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start = int(
|
||||
default_start if kwargs.get(start_arg) is None else kwargs[start_arg]
|
||||
)
|
||||
lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg])
|
||||
stop = start + lim
|
||||
result = func(*args, **kwargs)
|
||||
limited_results = list(itertools.islice(result, start, stop))
|
||||
next_token = stop if stop < len(result) else None
|
||||
return limited_results, next_token
|
||||
|
||||
return wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
class CognitoIdpUserPool(BaseModel):
|
||||
def __init__(self, region, name, extended_config):
|
||||
self.region = region
|
||||
@ -414,9 +376,9 @@ class CognitoIdpBackend(BaseBackend):
|
||||
"MfaConfiguration": user_pool.mfa_config,
|
||||
}
|
||||
|
||||
@paginate(60)
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_user_pools(self, max_results=None, next_token=None):
|
||||
return self.user_pools.values()
|
||||
return list(self.user_pools.values())
|
||||
|
||||
def describe_user_pool(self, user_pool_id):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
@ -474,13 +436,13 @@ class CognitoIdpBackend(BaseBackend):
|
||||
user_pool.clients[user_pool_client.id] = user_pool_client
|
||||
return user_pool_client
|
||||
|
||||
@paginate(60)
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_user_pool_clients(self, user_pool_id, max_results=None, next_token=None):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
if not user_pool:
|
||||
raise ResourceNotFoundError(user_pool_id)
|
||||
|
||||
return user_pool.clients.values()
|
||||
return list(user_pool.clients.values())
|
||||
|
||||
def describe_user_pool_client(self, user_pool_id, client_id):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
@ -525,13 +487,13 @@ class CognitoIdpBackend(BaseBackend):
|
||||
user_pool.identity_providers[name] = identity_provider
|
||||
return identity_provider
|
||||
|
||||
@paginate(60)
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_identity_providers(self, user_pool_id, max_results=None, next_token=None):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
if not user_pool:
|
||||
raise ResourceNotFoundError(user_pool_id)
|
||||
|
||||
return user_pool.identity_providers.values()
|
||||
return list(user_pool.identity_providers.values())
|
||||
|
||||
def describe_identity_provider(self, user_pool_id, name):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
@ -684,13 +646,13 @@ class CognitoIdpBackend(BaseBackend):
|
||||
return user
|
||||
raise NotAuthorizedError("Invalid token")
|
||||
|
||||
@paginate(60, "pagination_token", "limit")
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_users(self, user_pool_id, pagination_token=None, limit=None):
|
||||
user_pool = self.user_pools.get(user_pool_id)
|
||||
if not user_pool:
|
||||
raise ResourceNotFoundError(user_pool_id)
|
||||
|
||||
return user_pool.users.values()
|
||||
return list(user_pool.users.values())
|
||||
|
||||
def admin_disable_user(self, user_pool_id, username):
|
||||
user = self.admin_get_user(user_pool_id, username)
|
||||
|
@ -57,7 +57,7 @@ class CognitoIdpResponse(BaseResponse):
|
||||
|
||||
def list_user_pools(self):
|
||||
max_results = self._get_param("MaxResults")
|
||||
next_token = self._get_param("NextToken", "0")
|
||||
next_token = self._get_param("NextToken")
|
||||
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
|
||||
max_results=max_results, next_token=next_token
|
||||
)
|
||||
@ -128,7 +128,7 @@ class CognitoIdpResponse(BaseResponse):
|
||||
def list_user_pool_clients(self):
|
||||
user_pool_id = self._get_param("UserPoolId")
|
||||
max_results = self._get_param("MaxResults")
|
||||
next_token = self._get_param("NextToken", "0")
|
||||
next_token = self._get_param("NextToken")
|
||||
user_pool_clients, next_token = cognitoidp_backends[
|
||||
self.region
|
||||
].list_user_pool_clients(
|
||||
@ -181,7 +181,7 @@ class CognitoIdpResponse(BaseResponse):
|
||||
def list_identity_providers(self):
|
||||
user_pool_id = self._get_param("UserPoolId")
|
||||
max_results = self._get_param("MaxResults")
|
||||
next_token = self._get_param("NextToken", "0")
|
||||
next_token = self._get_param("NextToken")
|
||||
identity_providers, next_token = cognitoidp_backends[
|
||||
self.region
|
||||
].list_identity_providers(
|
||||
|
@ -6,6 +6,34 @@ import hmac
|
||||
import base64
|
||||
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_user_pools": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 60,
|
||||
"page_ending_range_keys": ["arn"],
|
||||
},
|
||||
"list_user_pool_clients": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 60,
|
||||
"page_ending_range_keys": ["id"],
|
||||
},
|
||||
"list_identity_providers": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 60,
|
||||
"page_ending_range_keys": ["name"],
|
||||
},
|
||||
"list_users": {
|
||||
"input_token": "pagination_token",
|
||||
"limit_key": "limit",
|
||||
"limit_default": 60,
|
||||
"page_ending_range_keys": ["id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_id():
|
||||
size = 26
|
||||
chars = list(range(10)) + list(string.ascii_lowercase)
|
||||
|
@ -162,3 +162,11 @@ class InvalidNextTokenException(JsonRESTError):
|
||||
super(InvalidNextTokenException, self).__init__(
|
||||
"InvalidNextTokenException", "The nextToken provided is invalid"
|
||||
)
|
||||
|
||||
|
||||
class InvalidToken(AWSError):
|
||||
TYPE = "InvalidToken"
|
||||
STATUS = 400
|
||||
|
||||
def __init__(self, message="Invalid token"):
|
||||
super(InvalidToken, self).__init__("Invalid Token: {}".format(message))
|
||||
|
@ -13,6 +13,7 @@ from boto3 import Session
|
||||
from moto.core import BaseBackend, BaseModel, CloudFormationModel
|
||||
from moto.core.utils import unix_time
|
||||
from moto.core import ACCOUNT_ID
|
||||
from moto.utilities.paginator import paginate
|
||||
from .exceptions import (
|
||||
StreamNotFoundError,
|
||||
ShardNotFoundError,
|
||||
@ -24,6 +25,7 @@ from .utils import (
|
||||
compose_shard_iterator,
|
||||
compose_new_shard_iterator,
|
||||
decompose_shard_iterator,
|
||||
PAGINATION_MODEL,
|
||||
)
|
||||
|
||||
|
||||
@ -489,6 +491,12 @@ class KinesisBackend(BaseBackend):
|
||||
record.partition_key, record.data, record.explicit_hash_key
|
||||
)
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_shards(self, stream_name, limit=None, next_token=None):
|
||||
stream = self.describe_stream(stream_name)
|
||||
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
|
||||
return [shard.to_json() for shard in shards]
|
||||
|
||||
def increase_stream_retention_period(self, stream_name, retention_period_hours):
|
||||
stream = self.describe_stream(stream_name)
|
||||
if (
|
||||
|
@ -138,6 +138,20 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
return ""
|
||||
|
||||
def list_shards(self):
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
next_token = self.parameters.get("NextToken")
|
||||
start_id = self.parameters.get("ExclusiveStartShardId") # noqa
|
||||
max = self.parameters.get("MaxResults", 10000)
|
||||
start_timestamp = self.parameters.get("StreamCreationTimestamp") # noqa
|
||||
shards, token = self.kinesis_backend.list_shards(
|
||||
stream_name=stream_name, limit=max, next_token=next_token
|
||||
)
|
||||
res = {"Shards": shards}
|
||||
if token:
|
||||
res["NextToken"] = token
|
||||
return json.dumps(res)
|
||||
|
||||
def increase_stream_retention_period(self):
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
retention_period_hours = self.parameters.get("RetentionPeriodHours")
|
||||
|
@ -1,17 +1,21 @@
|
||||
import sys
|
||||
import base64
|
||||
|
||||
from .exceptions import InvalidArgumentError
|
||||
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
encode_method = base64.encodestring
|
||||
decode_method = base64.decodestring
|
||||
elif sys.version_info[0] == 3:
|
||||
encode_method = base64.encodebytes
|
||||
decode_method = base64.decodebytes
|
||||
else:
|
||||
raise Exception("Python version is not supported")
|
||||
encode_method = base64.encodebytes
|
||||
decode_method = base64.decodebytes
|
||||
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_shards": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "limit",
|
||||
"limit_default": 10000,
|
||||
"page_ending_range_keys": ["ShardId"],
|
||||
"fail_on_invalid_token": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def compose_new_shard_iterator(
|
||||
|
@ -5,6 +5,7 @@ from boto3 import Session
|
||||
from moto import core as moto_core
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.core.utils import unix_time_millis
|
||||
from moto.utilities.paginator import paginate
|
||||
from moto.logs.metric_filters import MetricFilters
|
||||
from moto.logs.exceptions import (
|
||||
ResourceNotFoundException,
|
||||
@ -12,6 +13,7 @@ from moto.logs.exceptions import (
|
||||
InvalidParameterException,
|
||||
LimitExceededException,
|
||||
)
|
||||
from .utils import PAGINATION_MODEL
|
||||
|
||||
MAX_RESOURCE_POLICIES_PER_REGION = 10
|
||||
|
||||
@ -306,11 +308,11 @@ class LogGroup(BaseModel):
|
||||
def describe_log_streams(
|
||||
self,
|
||||
descending,
|
||||
limit,
|
||||
log_group_name,
|
||||
log_stream_name_prefix,
|
||||
next_token,
|
||||
order_by,
|
||||
next_token=None,
|
||||
limit=None,
|
||||
):
|
||||
# responses only log_stream_name, creation_time, arn, stored_bytes when no events are stored.
|
||||
|
||||
@ -323,7 +325,7 @@ class LogGroup(BaseModel):
|
||||
def sorter(item):
|
||||
return (
|
||||
item[0]
|
||||
if order_by == "logStreamName"
|
||||
if order_by == "LogStreamName"
|
||||
else item[1].get("lastEventTimestamp", 0)
|
||||
)
|
||||
|
||||
@ -582,13 +584,8 @@ class LogsBackend(BaseBackend):
|
||||
raise ResourceNotFoundException()
|
||||
del self.groups[log_group_name]
|
||||
|
||||
def describe_log_groups(self, limit, log_group_name_prefix, next_token):
|
||||
if limit > 50:
|
||||
raise InvalidParameterException(
|
||||
constraint="Member must have value less than or equal to 50",
|
||||
parameter="limit",
|
||||
value=limit,
|
||||
)
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def describe_log_groups(self, log_group_name_prefix, limit=None, next_token=None):
|
||||
if log_group_name_prefix is None:
|
||||
log_group_name_prefix = ""
|
||||
|
||||
@ -599,33 +596,7 @@ class LogsBackend(BaseBackend):
|
||||
]
|
||||
groups = sorted(groups, key=lambda x: x["logGroupName"])
|
||||
|
||||
index_start = 0
|
||||
if next_token:
|
||||
try:
|
||||
index_start = (
|
||||
next(
|
||||
index
|
||||
for (index, d) in enumerate(groups)
|
||||
if d["logGroupName"] == next_token
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
except StopIteration:
|
||||
index_start = 0
|
||||
# AWS returns an empty list if it receives an invalid token.
|
||||
groups = []
|
||||
|
||||
index_end = index_start + limit
|
||||
if index_end > len(groups):
|
||||
index_end = len(groups)
|
||||
|
||||
groups_page = groups[index_start:index_end]
|
||||
|
||||
next_token = None
|
||||
if groups_page and index_end < len(groups):
|
||||
next_token = groups_page[-1]["logGroupName"]
|
||||
|
||||
return groups_page, next_token
|
||||
return groups
|
||||
|
||||
def create_log_stream(self, log_group_name, log_stream_name):
|
||||
if log_group_name not in self.groups:
|
||||
@ -668,12 +639,12 @@ class LogsBackend(BaseBackend):
|
||||
)
|
||||
log_group = self.groups[log_group_name]
|
||||
return log_group.describe_log_streams(
|
||||
descending,
|
||||
limit,
|
||||
log_group_name,
|
||||
log_stream_name_prefix,
|
||||
next_token,
|
||||
order_by,
|
||||
descending=descending,
|
||||
limit=limit,
|
||||
log_group_name=log_group_name,
|
||||
log_stream_name_prefix=log_stream_name_prefix,
|
||||
next_token=next_token,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
def put_log_events(
|
||||
|
@ -162,8 +162,16 @@ class LogsResponse(BaseResponse):
|
||||
log_group_name_prefix = self._get_param("logGroupNamePrefix")
|
||||
next_token = self._get_param("nextToken")
|
||||
limit = self._get_param("limit", 50)
|
||||
if limit > 50:
|
||||
raise InvalidParameterException(
|
||||
constraint="Member must have value less than or equal to 50",
|
||||
parameter="limit",
|
||||
value=limit,
|
||||
)
|
||||
groups, next_token = self.logs_backend.describe_log_groups(
|
||||
limit, log_group_name_prefix, next_token
|
||||
limit=limit,
|
||||
log_group_name_prefix=log_group_name_prefix,
|
||||
next_token=next_token,
|
||||
)
|
||||
result = {"logGroups": groups}
|
||||
if next_token:
|
||||
|
15
moto/logs/utils.py
Normal file
15
moto/logs/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
PAGINATION_MODEL = {
|
||||
"describe_log_groups": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "limit",
|
||||
"limit_default": 50,
|
||||
"page_ending_range_keys": ["arn"],
|
||||
"fail_on_invalid_token": False,
|
||||
},
|
||||
"describe_log_streams": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "limit",
|
||||
"limit_default": 50,
|
||||
"page_ending_range_keys": ["arn"],
|
||||
},
|
||||
}
|
@ -17,8 +17,9 @@ from .exceptions import (
|
||||
ResourceNotFound,
|
||||
StateMachineDoesNotExist,
|
||||
)
|
||||
from .utils import paginate, api_to_cfn_tags, cfn_to_api_tags
|
||||
from .utils import api_to_cfn_tags, cfn_to_api_tags, PAGINATION_MODEL
|
||||
from moto import settings
|
||||
from moto.utilities.paginator import paginate
|
||||
|
||||
|
||||
class StateMachine(CloudFormationModel):
|
||||
@ -457,7 +458,7 @@ class StepFunctionBackend(BaseBackend):
|
||||
self.state_machines.append(state_machine)
|
||||
return state_machine
|
||||
|
||||
@paginate
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_state_machines(self):
|
||||
state_machines = sorted(self.state_machines, key=lambda x: x.creation_date)
|
||||
return state_machines
|
||||
@ -501,7 +502,7 @@ class StepFunctionBackend(BaseBackend):
|
||||
state_machine = self._get_state_machine_for_execution(execution_arn)
|
||||
return state_machine.stop_execution(execution_arn)
|
||||
|
||||
@paginate
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_executions(self, state_machine_arn, status_filter=None):
|
||||
executions = self.describe_state_machine(state_machine_arn).executions
|
||||
|
||||
|
@ -1,10 +1,3 @@
|
||||
from functools import wraps
|
||||
|
||||
from botocore.paginate import TokenDecoder, TokenEncoder
|
||||
from functools import reduce
|
||||
|
||||
from .exceptions import InvalidToken
|
||||
|
||||
PAGINATION_MODEL = {
|
||||
"list_executions": {
|
||||
"input_token": "next_token",
|
||||
@ -21,123 +14,6 @@ PAGINATION_MODEL = {
|
||||
}
|
||||
|
||||
|
||||
def paginate(original_function=None, pagination_model=None):
|
||||
def pagination_decorator(func):
|
||||
@wraps(func)
|
||||
def pagination_wrapper(*args, **kwargs):
|
||||
method = func.__name__
|
||||
model = pagination_model or PAGINATION_MODEL
|
||||
pagination_config = model.get(method)
|
||||
if not pagination_config:
|
||||
raise ValueError(
|
||||
"No pagination config for backend method: {}".format(method)
|
||||
)
|
||||
# We pop the pagination arguments, so the remaining kwargs (if any)
|
||||
# can be used to compute the optional parameters checksum.
|
||||
input_token = kwargs.pop(pagination_config.get("input_token"), None)
|
||||
limit = kwargs.pop(pagination_config.get("limit_key"), None)
|
||||
paginator = Paginator(
|
||||
max_results=limit,
|
||||
max_results_default=pagination_config.get("limit_default"),
|
||||
starting_token=input_token,
|
||||
page_ending_range_keys=pagination_config.get("page_ending_range_keys"),
|
||||
param_values_to_check=kwargs,
|
||||
)
|
||||
results = func(*args, **kwargs)
|
||||
return paginator.paginate(results)
|
||||
|
||||
return pagination_wrapper
|
||||
|
||||
if original_function:
|
||||
return pagination_decorator(original_function)
|
||||
|
||||
return pagination_decorator
|
||||
|
||||
|
||||
class Paginator(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_results=None,
|
||||
max_results_default=None,
|
||||
starting_token=None,
|
||||
page_ending_range_keys=None,
|
||||
param_values_to_check=None,
|
||||
):
|
||||
self._max_results = max_results if max_results else max_results_default
|
||||
self._starting_token = starting_token
|
||||
self._page_ending_range_keys = page_ending_range_keys
|
||||
self._param_values_to_check = param_values_to_check
|
||||
self._token_encoder = TokenEncoder()
|
||||
self._token_decoder = TokenDecoder()
|
||||
self._param_checksum = self._calculate_parameter_checksum()
|
||||
self._parsed_token = self._parse_starting_token()
|
||||
|
||||
def _parse_starting_token(self):
|
||||
if self._starting_token is None:
|
||||
return None
|
||||
# The starting token is a dict passed as a base64 encoded string.
|
||||
next_token = self._starting_token
|
||||
try:
|
||||
next_token = self._token_decoder.decode(next_token)
|
||||
except (ValueError, TypeError):
|
||||
raise InvalidToken("Invalid token")
|
||||
if next_token.get("parameterChecksum") != self._param_checksum:
|
||||
raise InvalidToken(
|
||||
"Input inconsistent with page token: {}".format(str(next_token))
|
||||
)
|
||||
return next_token
|
||||
|
||||
def _calculate_parameter_checksum(self):
|
||||
if not self._param_values_to_check:
|
||||
return None
|
||||
return reduce(
|
||||
lambda x, y: x ^ y,
|
||||
[hash(item) for item in self._param_values_to_check.items()],
|
||||
)
|
||||
|
||||
def _check_predicate(self, item):
|
||||
page_ending_range_key = self._parsed_token["pageEndingRangeKey"]
|
||||
predicate_values = page_ending_range_key.split("|")
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
if not getattr(item, attr, None) == predicate_values[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _build_next_token(self, next_item):
|
||||
token_dict = {}
|
||||
if self._param_checksum:
|
||||
token_dict["parameterChecksum"] = self._param_checksum
|
||||
range_keys = []
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
range_keys.append(getattr(next_item, attr))
|
||||
token_dict["pageEndingRangeKey"] = "|".join(range_keys)
|
||||
return TokenEncoder().encode(token_dict)
|
||||
|
||||
def paginate(self, results):
|
||||
index_start = 0
|
||||
if self._starting_token:
|
||||
try:
|
||||
index_start = next(
|
||||
index
|
||||
for (index, result) in enumerate(results)
|
||||
if self._check_predicate(result)
|
||||
)
|
||||
except StopIteration:
|
||||
raise InvalidToken("Resource not found!")
|
||||
|
||||
index_end = index_start + self._max_results
|
||||
if index_end > len(results):
|
||||
index_end = len(results)
|
||||
|
||||
results_page = results[index_start:index_end]
|
||||
|
||||
next_token = None
|
||||
if results_page and index_end < len(results):
|
||||
page_ending_result = results[index_end]
|
||||
next_token = self._build_next_token(page_ending_result)
|
||||
return results_page, next_token
|
||||
|
||||
|
||||
def cfn_to_api_tags(cfn_tags_entry):
|
||||
api_tags = [{k.lower(): v for k, v in d.items()} for d in cfn_tags_entry]
|
||||
return api_tags
|
||||
|
139
moto/utilities/paginator.py
Normal file
139
moto/utilities/paginator.py
Normal file
@ -0,0 +1,139 @@
|
||||
from functools import wraps
|
||||
|
||||
from botocore.paginate import TokenDecoder, TokenEncoder
|
||||
from six.moves import reduce
|
||||
|
||||
from moto.core.exceptions import InvalidToken
|
||||
|
||||
|
||||
def paginate(pagination_model, original_function=None):
|
||||
def pagination_decorator(func):
|
||||
@wraps(func)
|
||||
def pagination_wrapper(*args, **kwargs):
|
||||
|
||||
method = func.__name__
|
||||
model = pagination_model
|
||||
pagination_config = model.get(method)
|
||||
if not pagination_config:
|
||||
raise ValueError(
|
||||
"No pagination config for backend method: {}".format(method)
|
||||
)
|
||||
# We pop the pagination arguments, so the remaining kwargs (if any)
|
||||
# can be used to compute the optional parameters checksum.
|
||||
input_token = kwargs.pop(pagination_config.get("input_token"), None)
|
||||
limit = kwargs.pop(pagination_config.get("limit_key"), None)
|
||||
fail_on_invalid_token = pagination_config.get("fail_on_invalid_token", True)
|
||||
paginator = Paginator(
|
||||
max_results=limit,
|
||||
max_results_default=pagination_config.get("limit_default"),
|
||||
starting_token=input_token,
|
||||
page_ending_range_keys=pagination_config.get("page_ending_range_keys"),
|
||||
param_values_to_check=kwargs,
|
||||
fail_on_invalid_token=fail_on_invalid_token,
|
||||
)
|
||||
results = func(*args, **kwargs)
|
||||
return paginator.paginate(results)
|
||||
|
||||
return pagination_wrapper
|
||||
|
||||
if original_function:
|
||||
return pagination_decorator(original_function)
|
||||
|
||||
return pagination_decorator
|
||||
|
||||
|
||||
class Paginator(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_results=None,
|
||||
max_results_default=None,
|
||||
starting_token=None,
|
||||
page_ending_range_keys=None,
|
||||
param_values_to_check=None,
|
||||
fail_on_invalid_token=True,
|
||||
):
|
||||
self._max_results = max_results if max_results else max_results_default
|
||||
self._starting_token = starting_token
|
||||
self._page_ending_range_keys = page_ending_range_keys
|
||||
self._param_values_to_check = param_values_to_check
|
||||
self._fail_on_invalid_token = fail_on_invalid_token
|
||||
self._token_encoder = TokenEncoder()
|
||||
self._token_decoder = TokenDecoder()
|
||||
self._param_checksum = self._calculate_parameter_checksum()
|
||||
self._parsed_token = self._parse_starting_token()
|
||||
|
||||
def _parse_starting_token(self):
|
||||
if self._starting_token is None:
|
||||
return None
|
||||
# The starting token is a dict passed as a base64 encoded string.
|
||||
next_token = self._starting_token
|
||||
try:
|
||||
next_token = self._token_decoder.decode(next_token)
|
||||
except (ValueError, TypeError, UnicodeDecodeError):
|
||||
if self._fail_on_invalid_token:
|
||||
raise InvalidToken("Invalid token")
|
||||
return None
|
||||
if next_token.get("parameterChecksum") != self._param_checksum:
|
||||
raise InvalidToken(
|
||||
"Input inconsistent with page token: {}".format(str(next_token))
|
||||
)
|
||||
return next_token
|
||||
|
||||
def _calculate_parameter_checksum(self):
|
||||
if not self._param_values_to_check:
|
||||
return None
|
||||
return reduce(
|
||||
lambda x, y: x ^ y,
|
||||
[hash(item) for item in self._param_values_to_check.items()],
|
||||
)
|
||||
|
||||
def _check_predicate(self, item):
|
||||
if self._parsed_token is None:
|
||||
return False
|
||||
page_ending_range_key = self._parsed_token["pageEndingRangeKey"]
|
||||
predicate_values = page_ending_range_key.split("|")
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
curr_val = item[attr] if type(item) == dict else getattr(item, attr, None)
|
||||
if not curr_val == predicate_values[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _build_next_token(self, next_item):
|
||||
token_dict = {}
|
||||
if self._param_checksum:
|
||||
token_dict["parameterChecksum"] = self._param_checksum
|
||||
range_keys = []
|
||||
for (index, attr) in enumerate(self._page_ending_range_keys):
|
||||
if type(next_item) == dict:
|
||||
range_keys.append(next_item[attr])
|
||||
else:
|
||||
range_keys.append(getattr(next_item, attr))
|
||||
token_dict["pageEndingRangeKey"] = "|".join(range_keys)
|
||||
return self._token_encoder.encode(token_dict)
|
||||
|
||||
def paginate(self, results):
|
||||
index_start = 0
|
||||
if self._starting_token:
|
||||
try:
|
||||
index_start = next(
|
||||
index
|
||||
for (index, result) in enumerate(results)
|
||||
if self._check_predicate(result)
|
||||
)
|
||||
except StopIteration:
|
||||
if self._fail_on_invalid_token:
|
||||
raise InvalidToken("Resource not found!")
|
||||
else:
|
||||
return [], None
|
||||
|
||||
index_end = index_start + self._max_results
|
||||
if index_end > len(results):
|
||||
index_end = len(results)
|
||||
|
||||
results_page = results[index_start:index_end]
|
||||
|
||||
next_token = None
|
||||
if results_page and index_end < len(results):
|
||||
page_ending_result = results[index_end]
|
||||
next_token = self._build_next_token(page_ending_result)
|
||||
return results_page, next_token
|
140
tests/test_kinesis/test_kinesis_boto3.py
Normal file
140
tests/test_kinesis/test_kinesis_boto3.py
Normal file
@ -0,0 +1,140 @@
|
||||
import boto3
|
||||
|
||||
from moto import mock_kinesis
|
||||
|
||||
import sure # noqa
|
||||
|
||||
|
||||
@mock_kinesis
|
||||
def test_split_shard():
|
||||
conn = boto3.client("kinesis", region_name="us-west-2")
|
||||
stream_name = "my_stream"
|
||||
|
||||
conn.create_stream(StreamName=stream_name, ShardCount=2)
|
||||
|
||||
# Create some data
|
||||
for index in range(1, 100):
|
||||
conn.put_record(
|
||||
StreamName=stream_name, Data="data:" + str(index), PartitionKey=str(index)
|
||||
)
|
||||
|
||||
stream_response = conn.describe_stream(StreamName=stream_name)
|
||||
|
||||
stream = stream_response["StreamDescription"]
|
||||
shards = stream["Shards"]
|
||||
shards.should.have.length_of(2)
|
||||
|
||||
shard_range = shards[0]["HashKeyRange"]
|
||||
new_starting_hash = (
|
||||
int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"])
|
||||
) // 2
|
||||
conn.split_shard(
|
||||
StreamName=stream_name,
|
||||
ShardToSplit=shards[0]["ShardId"],
|
||||
NewStartingHashKey=str(new_starting_hash),
|
||||
)
|
||||
|
||||
stream_response = conn.describe_stream(StreamName=stream_name)
|
||||
|
||||
stream = stream_response["StreamDescription"]
|
||||
shards = stream["Shards"]
|
||||
shards.should.have.length_of(3)
|
||||
|
||||
shard_range = shards[2]["HashKeyRange"]
|
||||
new_starting_hash = (
|
||||
int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"])
|
||||
) // 2
|
||||
conn.split_shard(
|
||||
StreamName=stream_name,
|
||||
ShardToSplit=shards[2]["ShardId"],
|
||||
NewStartingHashKey=str(new_starting_hash),
|
||||
)
|
||||
|
||||
stream_response = conn.describe_stream(StreamName=stream_name)
|
||||
|
||||
stream = stream_response["StreamDescription"]
|
||||
shards = stream["Shards"]
|
||||
shards.should.have.length_of(4)
|
||||
|
||||
|
||||
@mock_kinesis
|
||||
def test_list_shards():
|
||||
conn = boto3.client("kinesis", region_name="us-west-2")
|
||||
stream_name = "my_stream"
|
||||
|
||||
conn.create_stream(StreamName=stream_name, ShardCount=2)
|
||||
|
||||
# Create some data
|
||||
for index in range(1, 100):
|
||||
conn.put_record(
|
||||
StreamName=stream_name, Data="data:" + str(index), PartitionKey=str(index)
|
||||
)
|
||||
|
||||
shard_list = conn.list_shards(StreamName=stream_name)["Shards"]
|
||||
shard_list.should.have.length_of(2)
|
||||
# Verify IDs
|
||||
[s["ShardId"] for s in shard_list].should.equal(
|
||||
["shardId-000000000000", "shardId-000000000001"]
|
||||
)
|
||||
# Verify hash range
|
||||
for shard in shard_list:
|
||||
shard.should.have.key("HashKeyRange")
|
||||
shard["HashKeyRange"].should.have.key("StartingHashKey")
|
||||
shard["HashKeyRange"].should.have.key("EndingHashKey")
|
||||
shard_list[0]["HashKeyRange"]["EndingHashKey"].should.equal(
|
||||
shard_list[1]["HashKeyRange"]["StartingHashKey"]
|
||||
)
|
||||
# Verify sequence numbers
|
||||
for shard in shard_list:
|
||||
shard.should.have.key("SequenceNumberRange")
|
||||
shard["SequenceNumberRange"].should.have.key("StartingSequenceNumber")
|
||||
|
||||
|
||||
@mock_kinesis
|
||||
def test_list_shards_paging():
|
||||
client = boto3.client("kinesis", region_name="us-west-2")
|
||||
stream_name = "my_stream"
|
||||
client.create_stream(StreamName=stream_name, ShardCount=10)
|
||||
|
||||
# Get shard 1-10
|
||||
shard_list = client.list_shards(StreamName=stream_name)
|
||||
shard_list["Shards"].should.have.length_of(10)
|
||||
shard_list.should_not.have.key("NextToken")
|
||||
|
||||
# Get shard 1-4
|
||||
resp = client.list_shards(StreamName=stream_name, MaxResults=4)
|
||||
resp["Shards"].should.have.length_of(4)
|
||||
[s["ShardId"] for s in resp["Shards"]].should.equal(
|
||||
[
|
||||
"shardId-000000000000",
|
||||
"shardId-000000000001",
|
||||
"shardId-000000000002",
|
||||
"shardId-000000000003",
|
||||
]
|
||||
)
|
||||
resp.should.have.key("NextToken")
|
||||
|
||||
# Get shard 4-8
|
||||
resp = client.list_shards(
|
||||
StreamName=stream_name, MaxResults=4, NextToken=str(resp["NextToken"])
|
||||
)
|
||||
resp["Shards"].should.have.length_of(4)
|
||||
[s["ShardId"] for s in resp["Shards"]].should.equal(
|
||||
[
|
||||
"shardId-000000000004",
|
||||
"shardId-000000000005",
|
||||
"shardId-000000000006",
|
||||
"shardId-000000000007",
|
||||
]
|
||||
)
|
||||
resp.should.have.key("NextToken")
|
||||
|
||||
# Get shard 8-10
|
||||
resp = client.list_shards(
|
||||
StreamName=stream_name, MaxResults=4, NextToken=str(resp["NextToken"])
|
||||
)
|
||||
resp["Shards"].should.have.length_of(2)
|
||||
[s["ShardId"] for s in resp["Shards"]].should.equal(
|
||||
["shardId-000000000008", "shardId-000000000009"]
|
||||
)
|
||||
resp.should_not.have.key("NextToken")
|
@ -368,8 +368,6 @@ def test_exceptions():
|
||||
with pytest.raises(ClientError):
|
||||
conn.create_log_group(logGroupName=log_group_name)
|
||||
|
||||
# descrine_log_groups is not implemented yet
|
||||
|
||||
conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name)
|
||||
with pytest.raises(ClientError):
|
||||
conn.create_log_stream(
|
||||
@ -995,11 +993,11 @@ def test_describe_log_groups_paging():
|
||||
|
||||
resp = client.describe_log_groups(limit=2)
|
||||
resp["logGroups"].should.have.length_of(2)
|
||||
resp["nextToken"].should.equal("/aws/lambda/FileMonitoring")
|
||||
resp.should.have.key("nextToken")
|
||||
|
||||
resp = client.describe_log_groups(nextToken=resp["nextToken"], limit=1)
|
||||
resp["logGroups"].should.have.length_of(1)
|
||||
resp["nextToken"].should.equal("/aws/lambda/fileAvailable")
|
||||
resp.should.have.key("nextToken")
|
||||
|
||||
resp = client.describe_log_groups(nextToken=resp["nextToken"])
|
||||
resp["logGroups"].should.have.length_of(1)
|
||||
@ -1011,6 +1009,51 @@ def test_describe_log_groups_paging():
|
||||
resp.should_not.have.key("nextToken")
|
||||
|
||||
|
||||
@mock_logs
|
||||
def test_describe_log_streams_simple_paging():
|
||||
client = boto3.client("logs", "us-east-1")
|
||||
|
||||
group_name = "/aws/lambda/lowercase-dev"
|
||||
|
||||
client.create_log_group(logGroupName=group_name)
|
||||
stream_names = ["stream" + str(i) for i in range(0, 10)]
|
||||
for name in stream_names:
|
||||
client.create_log_stream(logGroupName=group_name, logStreamName=name)
|
||||
|
||||
# Get stream 1-10
|
||||
resp = client.describe_log_streams(logGroupName=group_name)
|
||||
resp["logStreams"].should.have.length_of(10)
|
||||
resp.should_not.have.key("nextToken")
|
||||
|
||||
# Get stream 1-4
|
||||
resp = client.describe_log_streams(logGroupName=group_name, limit=4)
|
||||
resp["logStreams"].should.have.length_of(4)
|
||||
[l["logStreamName"] for l in resp["logStreams"]].should.equal(
|
||||
["stream0", "stream1", "stream2", "stream3"]
|
||||
)
|
||||
resp.should.have.key("nextToken")
|
||||
|
||||
# Get stream 4-8
|
||||
resp = client.describe_log_streams(
|
||||
logGroupName=group_name, limit=4, nextToken=str(resp["nextToken"])
|
||||
)
|
||||
resp["logStreams"].should.have.length_of(4)
|
||||
[l["logStreamName"] for l in resp["logStreams"]].should.equal(
|
||||
["stream4", "stream5", "stream6", "stream7"]
|
||||
)
|
||||
resp.should.have.key("nextToken")
|
||||
|
||||
# Get stream 8-10
|
||||
resp = client.describe_log_streams(
|
||||
logGroupName=group_name, limit=4, nextToken=str(resp["nextToken"])
|
||||
)
|
||||
resp["logStreams"].should.have.length_of(2)
|
||||
[l["logStreamName"] for l in resp["logStreams"]].should.equal(
|
||||
["stream8", "stream9"]
|
||||
)
|
||||
resp.should_not.have.key("nextToken")
|
||||
|
||||
|
||||
@mock_logs
|
||||
def test_describe_log_streams_paging():
|
||||
client = boto3.client("logs", "us-east-1")
|
||||
|
@ -226,6 +226,7 @@ def test_state_machine_list_returns_created_state_machines():
|
||||
|
||||
|
||||
@mock_stepfunctions
|
||||
@mock_sts
|
||||
def test_state_machine_list_pagination():
|
||||
client = boto3.client("stepfunctions", region_name=region)
|
||||
for i in range(25):
|
||||
|
Loading…
Reference in New Issue
Block a user