Improved support for Kinesis (#4795)

This commit is contained in:
Bert Blommers 2022-01-26 18:41:04 -01:00 committed by GitHub
parent a2467b7c3f
commit 76605e30a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 903 additions and 121 deletions

View File

@ -3161,36 +3161,36 @@
## kinesis
<details>
<summary>58% implemented</summary>
<summary>89% implemented</summary>
- [X] add_tags_to_stream
- [X] create_stream
- [X] decrease_stream_retention_period
- [X] delete_stream
- [ ] deregister_stream_consumer
- [X] deregister_stream_consumer
- [ ] describe_limits
- [X] describe_stream
- [ ] describe_stream_consumer
- [X] describe_stream_consumer
- [X] describe_stream_summary
- [ ] disable_enhanced_monitoring
- [ ] enable_enhanced_monitoring
- [X] disable_enhanced_monitoring
- [X] enable_enhanced_monitoring
- [X] get_records
- [X] get_shard_iterator
- [X] increase_stream_retention_period
- [X] list_shards
- [ ] list_stream_consumers
- [X] list_stream_consumers
- [X] list_streams
- [X] list_tags_for_stream
- [X] merge_shards
- [X] put_record
- [X] put_records
- [ ] register_stream_consumer
- [X] register_stream_consumer
- [X] remove_tags_from_stream
- [X] split_shard
- [ ] start_stream_encryption
- [ ] stop_stream_encryption
- [X] start_stream_encryption
- [X] stop_stream_encryption
- [ ] subscribe_to_shard
- [ ] update_shard_count
- [X] update_shard_count
- [ ] update_stream_mode
</details>

View File

@ -29,29 +29,33 @@ kinesis
- [X] create_stream
- [X] decrease_stream_retention_period
- [X] delete_stream
- [ ] deregister_stream_consumer
- [X] deregister_stream_consumer
- [ ] describe_limits
- [X] describe_stream
- [ ] describe_stream_consumer
- [X] describe_stream_consumer
- [X] describe_stream_summary
- [ ] disable_enhanced_monitoring
- [ ] enable_enhanced_monitoring
- [X] disable_enhanced_monitoring
- [X] enable_enhanced_monitoring
- [X] get_records
- [X] get_shard_iterator
- [X] increase_stream_retention_period
- [X] list_shards
- [ ] list_stream_consumers
- [X] list_stream_consumers
Pagination is not yet implemented
- [X] list_streams
- [X] list_tags_for_stream
- [X] merge_shards
- [X] put_record
- [X] put_records
- [ ] register_stream_consumer
- [X] register_stream_consumer
- [X] remove_tags_from_stream
- [X] split_shard
- [ ] start_stream_encryption
- [ ] stop_stream_encryption
- [X] start_stream_encryption
- [X] stop_stream_encryption
- [ ] subscribe_to_shard
- [ ] update_shard_count
- [X] update_shard_count
- [ ] update_stream_mode

View File

@ -33,6 +33,11 @@ class ShardNotFoundError(ResourceNotFoundError):
)
class ConsumerNotFound(ResourceNotFoundError):
def __init__(self, consumer):
super().__init__(f"Consumer {consumer}, account {ACCOUNT_ID} not found.")
class InvalidArgumentError(BadRequest):
def __init__(self, message):
super().__init__()
@ -41,6 +46,27 @@ class InvalidArgumentError(BadRequest):
)
class InvalidRetentionPeriod(InvalidArgumentError):
def __init__(self, hours, too_short):
if too_short:
msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short."
else:
msg = f"Maximum allowed retention period is 8760 hours. Requested retention period ({hours} hours) is too long."
super().__init__(msg)
class InvalidDecreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API."
super().__init__(msg)
class InvalidIncreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API."
super().__init__(msg)
class ValidationException(BadRequest):
def __init__(self, value, position, regex_to_match):
super().__init__()

View File

@ -11,11 +11,15 @@ from moto.core.utils import unix_time, BackendDict
from moto.core import ACCOUNT_ID
from moto.utilities.paginator import paginate
from .exceptions import (
ConsumerNotFound,
StreamNotFoundError,
ShardNotFoundError,
ResourceInUseError,
ResourceNotFoundError,
InvalidArgumentError,
InvalidRetentionPeriod,
InvalidDecreaseRetention,
InvalidIncreaseRetention,
ValidationException,
)
from .utils import (
@ -26,6 +30,26 @@ from .utils import (
)
class Consumer(BaseModel):
def __init__(self, consumer_name, region_name, stream_arn):
self.consumer_name = consumer_name
self.created = unix_time()
self.stream_arn = stream_arn
stream_name = stream_arn.split("/")[-1]
self.consumer_arn = f"arn:aws:kinesis:{region_name}:{ACCOUNT_ID}:stream/{stream_name}/consumer/{consumer_name}"
def to_json(self, include_stream_arn=False):
resp = {
"ConsumerName": self.consumer_name,
"ConsumerARN": self.consumer_arn,
"ConsumerStatus": "ACTIVE",
"ConsumerCreationTimestamp": self.created,
}
if include_stream_arn:
resp["StreamARN"] = self.stream_arn
return resp
class Record(BaseModel):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
self.partition_key = partition_key
@ -45,13 +69,16 @@ class Record(BaseModel):
class Shard(BaseModel):
def __init__(self, shard_id, starting_hash, ending_hash, parent=None):
def __init__(
self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None
):
self._shard_id = shard_id
self.starting_hash = starting_hash
self.ending_hash = ending_hash
self.records = OrderedDict()
self.is_open = True
self.parent = parent
self.adjacent_parent = adjacent_parent
@property
def shard_id(self):
@ -127,6 +154,8 @@ class Shard(BaseModel):
}
if self.parent:
response["ParentShardId"] = self.parent
if self.adjacent_parent:
response["AdjacentParentShardId"] = self.adjacent_parent
if not self.is_open:
response["SequenceNumberRange"]["EndingSequenceNumber"] = str(
self.get_max_sequence_number()
@ -137,36 +166,175 @@ class Shard(BaseModel):
class Stream(CloudFormationModel):
def __init__(self, stream_name, shard_count, retention_period_hours, region_name):
self.stream_name = stream_name
self.creation_datetime = datetime.datetime.now()
self.creation_datetime = datetime.datetime.now().strftime(
"%Y-%m-%dT%H:%M:%S.%f000"
)
self.region = region_name
self.account_number = ACCOUNT_ID
self.shards = {}
self.tags = {}
self.status = "ACTIVE"
self.shard_count = None
self.update_shard_count(shard_count)
self.init_shards(shard_count)
self.retention_period_hours = (
retention_period_hours if retention_period_hours else 24
)
self.enhanced_monitoring = [{"ShardLevelMetrics": []}]
self.shard_level_metrics = []
self.encryption_type = "NONE"
self.key_id = None
self.consumers = []
def update_shard_count(self, shard_count):
# ToDo: This was extracted from init. It's only accurate for new streams.
# It doesn't (yet) try to accurately mimic the more complex re-sharding behavior.
# It makes the stream as if it had been created with this number of shards.
# Logically consistent, but not what AWS does.
def delete_consumer(self, consumer_arn):
self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn]
def get_consumer_by_arn(self, consumer_arn):
return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None)
def init_shards(self, shard_count):
self.shard_count = shard_count
step = 2 ** 128 // shard_count
hash_ranges = itertools.chain(
map(lambda i: (i, i * step, (i + 1) * step), range(shard_count - 1)),
map(lambda i: (i, i * step, (i + 1) * step - 1), range(shard_count - 1)),
[(shard_count - 1, (shard_count - 1) * step, 2 ** 128)],
)
for index, start, end in hash_ranges:
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard
def split_shard(self, shard_to_split, new_starting_hash_key):
new_starting_hash_key = int(new_starting_hash_key)
shard = self.shards[shard_to_split]
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
pass
else:
raise InvalidArgumentError(
message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {ACCOUNT_ID} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
)
if not shard.is_open:
raise InvalidArgumentError(
message=f"Shard {shard.shard_id} in stream {self.stream_name} under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting."
)
last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[
-1
]._shard_id
# Create two new shards
new_shard_1 = Shard(
last_id + 1,
starting_hash=shard.starting_hash,
ending_hash=new_starting_hash_key - 1,
parent=shard.shard_id,
)
new_shard_2 = Shard(
last_id + 2,
starting_hash=new_starting_hash_key,
ending_hash=shard.ending_hash,
parent=shard.shard_id,
)
self.shards[new_shard_1.shard_id] = new_shard_1
self.shards[new_shard_2.shard_id] = new_shard_2
shard.is_open = False
records = shard.records
shard.records = OrderedDict()
for index in records:
record = records[index]
self.put_record(
record.partition_key, record.explicit_hash_key, None, record.data
)
def merge_shards(self, shard_to_merge, adjacent_shard_to_merge):
shard1 = self.shards[shard_to_merge]
shard2 = self.shards[adjacent_shard_to_merge]
# Validate the two shards are adjacent
if shard1.ending_hash == (shard2.starting_hash - 1):
pass
elif shard2.ending_hash == (shard1.starting_hash + 1):
pass
else:
raise InvalidArgumentError(adjacent_shard_to_merge)
# Create a new shard
last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[
-1
]._shard_id
new_shard = Shard(
last_id + 1,
starting_hash=shard1.starting_hash,
ending_hash=shard2.ending_hash,
parent=shard1.shard_id,
adjacent_parent=shard2.shard_id,
)
self.shards[new_shard.shard_id] = new_shard
# Close the merged shards
shard1.is_open = False
shard2.is_open = False
# Move all data across
for record in shard1.records.values():
new_shard.put_record(
record.partition_key, record.data, record.explicit_hash_key
)
for record in shard2.records.values():
new_shard.put_record(
record.partition_key, record.data, record.explicit_hash_key
)
def update_shard_count(self, target_shard_count):
current_shard_count = len([s for s in self.shards.values() if s.is_open])
if current_shard_count == target_shard_count:
return
# Split shards until we have enough shards
# AWS seems to split until we have (current * 2) shards, and then merge until we reach the target
# That's what observable at least - the actual algorithm is probably more advanced
#
if current_shard_count < target_shard_count:
open_shards = [
(shard_id, shard)
for shard_id, shard in self.shards.items()
if shard.is_open
]
for shard_id, shard in open_shards:
# Split the current shard
new_starting_hash_key = str(
int((shard.ending_hash + shard.starting_hash) / 2)
)
self.split_shard(shard_id, new_starting_hash_key)
current_shard_count = len([s for s in self.shards.values() if s.is_open])
# If we need to reduce the shard count, merge shards until we get there
while current_shard_count > target_shard_count:
# Keep track of how often we need to merge to get to the target shard count
required_shard_merges = current_shard_count - target_shard_count
# Get a list of pairs of adjacent shards
shard_list = sorted(
[s for s in self.shards.values() if s.is_open],
key=lambda x: x.starting_hash,
)
adjacent_shards = zip(
[s for s in shard_list[0:-1:2]], [s for s in shard_list[1::2]]
)
for (shard, adjacent) in adjacent_shards:
self.merge_shards(shard.shard_id, adjacent.shard_id)
required_shard_merges -= 1
if required_shard_merges == 0:
break
current_shard_count = len([s for s in self.shards.values() if s.is_open])
self.shard_count = target_shard_count
@property
def arn(self):
return "arn:aws:kinesis:{region}:{account_number}:stream/{stream_name}".format(
@ -218,12 +386,13 @@ class Stream(CloudFormationModel):
"StreamDescription": {
"StreamARN": self.arn,
"StreamName": self.stream_name,
"StreamCreationTimestamp": str(self.creation_datetime),
"StreamCreationTimestamp": self.creation_datetime,
"StreamStatus": self.status,
"HasMoreShards": len(requested_shards) != len(all_shards),
"RetentionPeriodHours": self.retention_period_hours,
"EnhancedMonitoring": self.enhanced_monitoring,
"EnhancedMonitoring": [{"ShardLevelMetrics": self.shard_level_metrics}],
"EncryptionType": self.encryption_type,
"KeyId": self.key_id,
"Shards": [shard.to_json() for shard in requested_shards],
}
}
@ -234,7 +403,7 @@ class Stream(CloudFormationModel):
"StreamARN": self.arn,
"StreamName": self.stream_name,
"StreamStatus": self.status,
"StreamCreationTimestamp": str(self.creation_datetime),
"StreamCreationTimestamp": self.creation_datetime,
"OpenShardCount": self.shard_count,
}
}
@ -262,7 +431,7 @@ class Stream(CloudFormationModel):
backend = kinesis_backends[region_name]
stream = backend.create_stream(
resource_name, shard_count, retention_period_hours, region_name
resource_name, shard_count, retention_period_hours
)
if any(tags):
backend.add_tags_to_stream(stream.stream_name, tags)
@ -335,8 +504,14 @@ class Stream(CloudFormationModel):
class KinesisBackend(BaseBackend):
def __init__(self, region=None):
def __init__(self, region):
self.streams = OrderedDict()
self.region_name = region
def reset(self):
region = self.region_name
self.__dict__ = {}
self.__init__(region)
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
@ -345,12 +520,12 @@ class KinesisBackend(BaseBackend):
service_region, zones, "kinesis", special_service_name="kinesis-streams"
)
def create_stream(
self, stream_name, shard_count, retention_period_hours, region_name
):
def create_stream(self, stream_name, shard_count, retention_period_hours):
if stream_name in self.streams:
raise ResourceInUseError(stream_name)
stream = Stream(stream_name, shard_count, retention_period_hours, region_name)
stream = Stream(
stream_name, shard_count, retention_period_hours, self.region_name
)
self.streams[stream_name] = stream
return stream
@ -468,51 +643,8 @@ class KinesisBackend(BaseBackend):
position="newStartingHashKey",
regex_to_match=r"0|([1-9]\d{0,38})",
)
new_starting_hash_key = int(new_starting_hash_key)
shard = stream.shards[shard_to_split]
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
pass
else:
raise InvalidArgumentError(
message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {stream_name} under account {ACCOUNT_ID} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash-1)}."
)
if not shard.is_open:
raise InvalidArgumentError(
message=f"Shard {shard.shard_id} in stream {stream_name} under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting."
)
last_id = sorted(stream.shards.values(), key=attrgetter("_shard_id"))[
-1
]._shard_id
# Create two new shards
new_shard_1 = Shard(
last_id + 1,
starting_hash=shard.starting_hash,
ending_hash=new_starting_hash_key - 1,
parent=shard.shard_id,
)
new_shard_2 = Shard(
last_id + 2,
starting_hash=new_starting_hash_key,
ending_hash=shard.ending_hash,
parent=shard.shard_id,
)
stream.shards[new_shard_1.shard_id] = new_shard_1
stream.shards[new_shard_2.shard_id] = new_shard_2
shard.is_open = False
records = shard.records
shard.records = OrderedDict()
for index in records:
record = records[index]
stream.put_record(
record.partition_key, record.explicit_hash_key, None, record.data
)
stream.split_shard(shard_to_split, new_starting_hash_key)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
@ -523,22 +655,15 @@ class KinesisBackend(BaseBackend):
if adjacent_shard_to_merge not in stream.shards:
raise ShardNotFoundError(adjacent_shard_to_merge, stream=stream_name)
shard1 = stream.shards[shard_to_merge]
shard2 = stream.shards[adjacent_shard_to_merge]
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
if shard1.ending_hash == shard2.starting_hash:
shard1.ending_hash = shard2.ending_hash
elif shard2.ending_hash == shard1.starting_hash:
shard1.starting_hash = shard2.starting_hash
else:
raise InvalidArgumentError(adjacent_shard_to_merge)
def update_shard_count(self, stream_name, target_shard_count):
stream = self.describe_stream(stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open])
del stream.shards[shard2.shard_id]
for index in shard2.records:
record = shard2.records[index]
shard1.put_record(
record.partition_key, record.data, record.explicit_hash_key
)
stream.update_shard_count(target_shard_count)
return current_shard_count
@paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_name):
@ -548,22 +673,30 @@ class KinesisBackend(BaseBackend):
def increase_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
if (
retention_period_hours <= stream.retention_period_hours
or retention_period_hours < 24
or retention_period_hours > 8760
):
raise InvalidArgumentError(retention_period_hours)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
raise InvalidRetentionPeriod(retention_period_hours, too_short=False)
if retention_period_hours < stream.retention_period_hours:
raise InvalidIncreaseRetention(
name=stream_name,
requested=retention_period_hours,
existing=stream.retention_period_hours,
)
stream.retention_period_hours = retention_period_hours
def decrease_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
if (
retention_period_hours >= stream.retention_period_hours
or retention_period_hours < 24
or retention_period_hours > 8760
):
raise InvalidArgumentError(retention_period_hours)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
raise InvalidRetentionPeriod(retention_period_hours, too_short=False)
if retention_period_hours > stream.retention_period_hours:
raise InvalidDecreaseRetention(
name=stream_name,
requested=retention_period_hours,
existing=stream.retention_period_hours,
)
stream.retention_period_hours = retention_period_hours
def list_tags_for_stream(
@ -594,5 +727,77 @@ class KinesisBackend(BaseBackend):
if key in stream.tags:
del stream.tags[key]
def enable_enhanced_monitoring(self, stream_name, shard_level_metrics):
stream = self.describe_stream(stream_name)
current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
stream.shard_level_metrics = desired_metrics
return current_shard_level_metrics, desired_metrics
def disable_enhanced_monitoring(self, stream_name, to_be_disabled):
stream = self.describe_stream(stream_name)
current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled:
desired_metrics = []
else:
desired_metrics = [
metric for metric in current_metrics if metric not in to_be_disabled
]
stream.shard_level_metrics = desired_metrics
return current_metrics, desired_metrics
def _find_stream_by_arn(self, stream_arn):
for stream in self.streams.values():
if stream.arn == stream_arn:
return stream
def list_stream_consumers(self, stream_arn):
"""
Pagination is not yet implemented
"""
stream = self._find_stream_by_arn(stream_arn)
return stream.consumers
def register_stream_consumer(self, stream_arn, consumer_name):
consumer = Consumer(consumer_name, self.region_name, stream_arn)
stream = self._find_stream_by_arn(stream_arn)
stream.consumers.append(consumer)
return consumer
def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
if stream_arn:
stream = self._find_stream_by_arn(stream_arn)
for consumer in stream.consumers:
if consumer_name and consumer.consumer_name == consumer_name:
return consumer
if consumer_arn:
for stream in self.streams.values():
consumer = stream.get_consumer_by_arn(consumer_arn)
if consumer:
return consumer
raise ConsumerNotFound(consumer=consumer_name or consumer_arn)
def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
if stream_arn:
stream = self._find_stream_by_arn(stream_arn)
stream.consumers = [
c for c in stream.consumers if c.consumer_name == consumer_name
]
if consumer_arn:
for stream in self.streams.values():
# Only one stream will actually have this consumer
# It will be a noop for other streams
stream.delete_consumer(consumer_arn)
def start_stream_encryption(self, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_name)
stream.encryption_type = encryption_type
stream.key_id = key_id
def stop_stream_encryption(self, stream_name):
stream = self.describe_stream(stream_name)
stream.encryption_type = "NONE"
stream.key_id = None
kinesis_backends = BackendDict(KinesisBackend, "kinesis")

View File

@ -18,7 +18,7 @@ class KinesisResponse(BaseResponse):
shard_count = self.parameters.get("ShardCount")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
self.kinesis_backend.create_stream(
stream_name, shard_count, retention_period_hours, self.region
stream_name, shard_count, retention_period_hours
)
return ""
@ -149,6 +149,20 @@ class KinesisResponse(BaseResponse):
res["NextToken"] = token
return json.dumps(res)
def update_shard_count(self):
stream_name = self.parameters.get("StreamName")
target_shard_count = self.parameters.get("TargetShardCount")
current_shard_count = self.kinesis_backend.update_shard_count(
stream_name=stream_name, target_shard_count=target_shard_count,
)
return json.dumps(
dict(
StreamName=stream_name,
CurrentShardCount=current_shard_count,
TargetShardCount=target_shard_count,
)
)
def increase_stream_retention_period(self):
stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
@ -185,3 +199,82 @@ class KinesisResponse(BaseResponse):
tag_keys = self.parameters.get("TagKeys")
self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys)
return json.dumps({})
def enable_enhanced_monitoring(self):
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_name=stream_name, shard_level_metrics=shard_level_metrics,
)
return json.dumps(
dict(
StreamName=stream_name,
CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired,
)
)
def disable_enhanced_monitoring(self):
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_name=stream_name, to_be_disabled=shard_level_metrics,
)
return json.dumps(
dict(
StreamName=stream_name,
CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired,
)
)
def list_stream_consumers(self):
stream_arn = self.parameters.get("StreamARN")
consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn)
return json.dumps(dict(Consumers=[c.to_json() for c in consumers]))
def register_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
consumer = self.kinesis_backend.register_stream_consumer(
stream_arn=stream_arn, consumer_name=consumer_name,
)
return json.dumps(dict(Consumer=consumer.to_json()))
def describe_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN")
consumer = self.kinesis_backend.describe_stream_consumer(
stream_arn=stream_arn,
consumer_name=consumer_name,
consumer_arn=consumer_arn,
)
return json.dumps(
dict(ConsumerDescription=consumer.to_json(include_stream_arn=True))
)
def deregister_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN")
self.kinesis_backend.deregister_stream_consumer(
stream_arn=stream_arn,
consumer_name=consumer_name,
consumer_arn=consumer_arn,
)
return json.dumps(dict())
def start_stream_encryption(self):
stream_name = self.parameters.get("StreamName")
encryption_type = self.parameters.get("EncryptionType")
key_id = self.parameters.get("KeyId")
self.kinesis_backend.start_stream_encryption(
stream_name=stream_name, encryption_type=encryption_type, key_id=key_id
)
return json.dumps(dict())
def stop_stream_encryption(self):
stream_name = self.parameters.get("StreamName")
self.kinesis_backend.stop_stream_encryption(stream_name=stream_name,)
return json.dumps(dict())

View File

@ -66,6 +66,7 @@ TestAccAWSIAMGroupPolicyAttachment
TestAccAWSIAMRole
TestAccAWSIAMUserPolicy
TestAccAWSIPRanges
TestAccAWSKinesisStream
TestAccAWSKmsAlias
TestAccAWSKmsSecretDataSource
TestAccAWSMq

View File

@ -518,10 +518,46 @@ def test_invalid_increase_stream_retention_period():
)
with pytest.raises(ClientError) as ex:
conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=20
StreamName=stream_name, RetentionPeriodHours=25
)
ex.value.response["Error"]["Code"].should.equal("InvalidArgumentException")
ex.value.response["Error"]["Message"].should.equal(20)
ex.value.response["Error"]["Message"].should.equal(
"Requested retention period (25 hours) for stream my_stream can not be shorter than existing retention period (30 hours). Use DecreaseRetentionPeriod API."
)
@mock_kinesis
def test_invalid_increase_stream_retention_too_low():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "my_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
with pytest.raises(ClientError) as ex:
conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=20
)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal(
"Minimum allowed retention period is 24 hours. Requested retention period (20 hours) is too short."
)
@mock_kinesis
def test_invalid_increase_stream_retention_too_high():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "my_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
with pytest.raises(ClientError) as ex:
conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=9999
)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal(
"Maximum allowed retention period is 8760 hours. Requested retention period (9999 hours) is too long."
)
@mock_kinesis
@ -542,20 +578,54 @@ def test_valid_decrease_stream_retention_period():
@mock_kinesis
def test_invalid_decrease_stream_retention_period():
def test_decrease_stream_retention_period_upwards():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "decrease_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=30
)
with pytest.raises(ClientError) as ex:
conn.decrease_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=20
StreamName=stream_name, RetentionPeriodHours=40
)
ex.value.response["Error"]["Code"].should.equal("InvalidArgumentException")
ex.value.response["Error"]["Message"].should.equal(20)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal(
"Requested retention period (40 hours) for stream decrease_stream can not be longer than existing retention period (24 hours). Use IncreaseRetentionPeriod API."
)
@mock_kinesis
def test_decrease_stream_retention_period_too_low():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "decrease_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
with pytest.raises(ClientError) as ex:
conn.decrease_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=4
)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal(
"Minimum allowed retention period is 24 hours. Requested retention period (4 hours) is too short."
)
@mock_kinesis
def test_decrease_stream_retention_period_too_high():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "decrease_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
with pytest.raises(ClientError) as ex:
conn.decrease_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=9999
)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal(
"Maximum allowed retention period is 8760 hours. Requested retention period (9999 hours) is too long."
)
@mock_kinesis
@ -632,6 +702,11 @@ def test_merge_shards_boto3():
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shards = stream["Shards"]
# Old shards still exist, but are closed. A new shard is created out of the old one
shards.should.have.length_of(5)
# Only three shards are active - the two merged shards are closed
active_shards = [
shard
for shard in shards
@ -641,12 +716,13 @@ def test_merge_shards_boto3():
client.merge_shards(
StreamName=stream_name,
ShardToMerge="shardId-000000000002",
AdjacentShardToMerge="shardId-000000000000",
ShardToMerge="shardId-000000000004",
AdjacentShardToMerge="shardId-000000000002",
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shards = stream["Shards"]
active_shards = [
shard
for shard in shards
@ -654,6 +730,22 @@ def test_merge_shards_boto3():
]
active_shards.should.have.length_of(2)
for shard in active_shards:
del shard["HashKeyRange"]
del shard["SequenceNumberRange"]
# Original shard #3 is still active (0,1,2 have been merged and closed
active_shards.should.contain({"ShardId": "shardId-000000000003"})
# Shard #4 was the child of #0 and #1
# Shard #5 is the child of #4 (parent) and #2 (adjacent-parent)
active_shards.should.contain(
{
"ShardId": "shardId-000000000005",
"ParentShardId": "shardId-000000000004",
"AdjacentParentShardId": "shardId-000000000002",
}
)
@mock_kinesis
def test_merge_shards_invalid_arg():

View File

@ -63,7 +63,7 @@ def test_list_shards():
shard["HashKeyRange"].should.have.key("StartingHashKey")
shard["HashKeyRange"].should.have.key("EndingHashKey")
shard_list[0]["HashKeyRange"]["EndingHashKey"].should.equal(
shard_list[1]["HashKeyRange"]["StartingHashKey"]
str(int(shard_list[1]["HashKeyRange"]["StartingHashKey"]) - 1)
)
# Verify sequence numbers
for shard in shard_list:
@ -297,3 +297,45 @@ def test_split_shard_that_was_split_before():
err["Message"].should.equal(
f"Shard shardId-000000000001 in stream my-stream under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting."
)
@mock_kinesis
@pytest.mark.parametrize(
"initial,target,expected_total",
[(2, 4, 6), (4, 5, 15), (10, 13, 37), (4, 2, 6), (5, 3, 7), (10, 3, 17)],
)
def test_update_shard_count(initial, target, expected_total):
"""
Test that we update the shard count in a similar manner to AWS
Assert on: initial_shard_count, target_shard_count and total_shard_count
total_shard_count gives an idea of the number of splits/merges required to reach the target
These numbers have been verified against AWS
"""
client = boto3.client("kinesis", region_name="eu-west-1")
client.create_stream(StreamName="my-stream", ShardCount=initial)
resp = client.update_shard_count(
StreamName="my-stream", TargetShardCount=target, ScalingType="UNIFORM_SCALING"
)
resp.should.have.key("StreamName").equals("my-stream")
resp.should.have.key("CurrentShardCount").equals(initial)
resp.should.have.key("TargetShardCount").equals(target)
stream = client.describe_stream(StreamName="my-stream")["StreamDescription"]
stream["StreamStatus"].should.equal("ACTIVE")
stream["Shards"].should.have.length_of(expected_total)
active_shards = [
shard
for shard in stream["Shards"]
if "EndingSequenceNumber" not in shard["SequenceNumberRange"]
]
active_shards.should.have.length_of(target)
resp = client.describe_stream_summary(StreamName="my-stream")
stream = resp["StreamDescriptionSummary"]
stream["OpenShardCount"].should.equal(target)

View File

@ -0,0 +1,46 @@
import boto3
from moto import mock_kinesis
@mock_kinesis
def test_enable_encryption():
client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId")
client.start_stream_encryption(
StreamName="my-stream", EncryptionType="KMS", KeyId="n/a"
)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("KMS")
desc.should.have.key("KeyId").equals("n/a")
@mock_kinesis
def test_disable_encryption():
client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
client.start_stream_encryption(
StreamName="my-stream", EncryptionType="KMS", KeyId="n/a"
)
client.stop_stream_encryption(
StreamName="my-stream", EncryptionType="KMS", KeyId="n/a"
)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId")

View File

@ -0,0 +1,127 @@
import boto3
from moto import mock_kinesis
@mock_kinesis
def test_enable_enhanced_monitoring_all():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
resp = client.enable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["ALL"]
)
resp.should.have.key("StreamName").equals(stream_name)
resp.should.have.key("CurrentShardLevelMetrics").equals([])
resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"])
@mock_kinesis
def test_enable_enhanced_monitoring_is_persisted():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
client.enable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["IncomingBytes", "OutgoingBytes"]
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"]
set(metrics).should.equal({"IncomingBytes", "OutgoingBytes"})
@mock_kinesis
def test_enable_enhanced_monitoring_in_steps():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
client.enable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["IncomingBytes", "OutgoingBytes"]
)
resp = client.enable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["WriteProvisionedThroughputExceeded"]
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(2)
resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes")
resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes")
resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(3)
resp["DesiredShardLevelMetrics"].should.contain("IncomingBytes")
resp["DesiredShardLevelMetrics"].should.contain("OutgoingBytes")
resp["DesiredShardLevelMetrics"].should.contain(
"WriteProvisionedThroughputExceeded"
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"]
metrics.should.have.length_of(3)
metrics.should.contain("IncomingBytes")
metrics.should.contain("OutgoingBytes")
metrics.should.contain("WriteProvisionedThroughputExceeded")
@mock_kinesis
def test_disable_enhanced_monitoring():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
client.enable_enhanced_monitoring(
StreamName=stream_name,
ShardLevelMetrics=[
"IncomingBytes",
"OutgoingBytes",
"WriteProvisionedThroughputExceeded",
],
)
resp = client.disable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"]
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3)
resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes")
resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes")
resp["CurrentShardLevelMetrics"].should.contain(
"WriteProvisionedThroughputExceeded"
)
resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(2)
resp["DesiredShardLevelMetrics"].should.contain("IncomingBytes")
resp["DesiredShardLevelMetrics"].should.contain(
"WriteProvisionedThroughputExceeded"
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"]
metrics.should.have.length_of(2)
metrics.should.contain("IncomingBytes")
metrics.should.contain("WriteProvisionedThroughputExceeded")
@mock_kinesis
def test_disable_enhanced_monitoring_all():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
client.enable_enhanced_monitoring(
StreamName=stream_name,
ShardLevelMetrics=[
"IncomingBytes",
"OutgoingBytes",
"WriteProvisionedThroughputExceeded",
],
)
client.disable_enhanced_monitoring(
StreamName=stream_name, ShardLevelMetrics=["ALL"]
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"]
metrics.should.equal([])

View File

@ -0,0 +1,146 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_kinesis
from moto.core import ACCOUNT_ID
def create_stream(client):
stream_name = "my-stream"
client.create_stream(StreamName=stream_name, ShardCount=4)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
return stream["StreamARN"]
@mock_kinesis
def test_list_stream_consumers():
client = boto3.client("kinesis", region_name="eu-west-1")
stream_arn = create_stream(client)
resp = client.list_stream_consumers(StreamARN=stream_arn)
resp.should.have.key("Consumers").equals([])
@mock_kinesis
def test_register_stream_consumer():
client = boto3.client("kinesis", region_name="eu-west-1")
stream_arn = create_stream(client)
resp = client.register_stream_consumer(
StreamARN=stream_arn, ConsumerName="newconsumer"
)
resp.should.have.key("Consumer")
consumer = resp["Consumer"]
consumer.should.have.key("ConsumerName").equals("newconsumer")
consumer.should.have.key("ConsumerARN").equals(
f"arn:aws:kinesis:eu-west-1:{ACCOUNT_ID}:stream/my-stream/consumer/newconsumer"
)
consumer.should.have.key("ConsumerStatus").equals("ACTIVE")
consumer.should.have.key("ConsumerCreationTimestamp")
resp = client.list_stream_consumers(StreamARN=stream_arn)
resp.should.have.key("Consumers").length_of(1)
consumer = resp["Consumers"][0]
consumer.should.have.key("ConsumerName").equals("newconsumer")
consumer.should.have.key("ConsumerARN").equals(
f"arn:aws:kinesis:eu-west-1:{ACCOUNT_ID}:stream/my-stream/consumer/newconsumer"
)
consumer.should.have.key("ConsumerStatus").equals("ACTIVE")
consumer.should.have.key("ConsumerCreationTimestamp")
@mock_kinesis
def test_describe_stream_consumer_by_name():
client = boto3.client("kinesis", region_name="us-east-2")
stream_arn = create_stream(client)
client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="newconsumer")
resp = client.describe_stream_consumer(
StreamARN=stream_arn, ConsumerName="newconsumer"
)
resp.should.have.key("ConsumerDescription")
consumer = resp["ConsumerDescription"]
consumer.should.have.key("ConsumerName").equals("newconsumer")
consumer.should.have.key("ConsumerARN")
consumer.should.have.key("ConsumerStatus").equals("ACTIVE")
consumer.should.have.key("ConsumerCreationTimestamp")
consumer.should.have.key("StreamARN").equals(stream_arn)
@mock_kinesis
def test_describe_stream_consumer_by_arn():
client = boto3.client("kinesis", region_name="us-east-2")
stream_arn = create_stream(client)
resp = client.register_stream_consumer(
StreamARN=stream_arn, ConsumerName="newconsumer"
)
consumer_arn = resp["Consumer"]["ConsumerARN"]
resp = client.describe_stream_consumer(ConsumerARN=consumer_arn)
resp.should.have.key("ConsumerDescription")
consumer = resp["ConsumerDescription"]
consumer.should.have.key("ConsumerName").equals("newconsumer")
consumer.should.have.key("ConsumerARN")
consumer.should.have.key("ConsumerStatus").equals("ACTIVE")
consumer.should.have.key("ConsumerCreationTimestamp")
consumer.should.have.key("StreamARN").equals(stream_arn)
@mock_kinesis
def test_describe_stream_consumer_unknown():
client = boto3.client("kinesis", region_name="us-east-2")
create_stream(client)
with pytest.raises(ClientError) as exc:
client.describe_stream_consumer(ConsumerARN="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ResourceNotFoundException")
err["Message"].should.equal(f"Consumer unknown, account {ACCOUNT_ID} not found.")
@mock_kinesis
def test_deregister_stream_consumer_by_name():
client = boto3.client("kinesis", region_name="ap-southeast-1")
stream_arn = create_stream(client)
client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer1")
client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer2")
client.list_stream_consumers(StreamARN=stream_arn)[
"Consumers"
].should.have.length_of(2)
client.deregister_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer1")
client.list_stream_consumers(StreamARN=stream_arn)[
"Consumers"
].should.have.length_of(1)
@mock_kinesis
def test_deregister_stream_consumer_by_arn():
client = boto3.client("kinesis", region_name="ap-southeast-1")
stream_arn = create_stream(client)
resp = client.register_stream_consumer(
StreamARN=stream_arn, ConsumerName="consumer1"
)
consumer1_arn = resp["Consumer"]["ConsumerARN"]
client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer2")
client.list_stream_consumers(StreamARN=stream_arn)[
"Consumers"
].should.have.length_of(2)
client.deregister_stream_consumer(ConsumerARN=consumer1_arn)
client.list_stream_consumers(StreamARN=stream_arn)[
"Consumers"
].should.have.length_of(1)