From 76605e30a0c3440ed28e5cef43c21d39198ec826 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 26 Jan 2022 18:41:04 -0100 Subject: [PATCH] Improved support for Kinesis (#4795) --- IMPLEMENTATION_COVERAGE.md | 20 +- docs/docs/services/kinesis.rst | 22 +- moto/kinesis/exceptions.py | 26 ++ moto/kinesis/models.py | 383 ++++++++++++++---- moto/kinesis/responses.py | 95 ++++- tests/terraform-tests.success.txt | 1 + tests/test_kinesis/test_kinesis.py | 114 +++++- tests/test_kinesis/test_kinesis_boto3.py | 44 +- tests/test_kinesis/test_kinesis_encryption.py | 46 +++ tests/test_kinesis/test_kinesis_monitoring.py | 127 ++++++ .../test_kinesis_stream_consumers.py | 146 +++++++ 11 files changed, 903 insertions(+), 121 deletions(-) create mode 100644 tests/test_kinesis/test_kinesis_encryption.py create mode 100644 tests/test_kinesis/test_kinesis_monitoring.py create mode 100644 tests/test_kinesis/test_kinesis_stream_consumers.py diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index c0cb820eb..4762a4cbf 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -3161,36 +3161,36 @@ ## kinesis
-58% implemented +89% implemented - [X] add_tags_to_stream - [X] create_stream - [X] decrease_stream_retention_period - [X] delete_stream -- [ ] deregister_stream_consumer +- [X] deregister_stream_consumer - [ ] describe_limits - [X] describe_stream -- [ ] describe_stream_consumer +- [X] describe_stream_consumer - [X] describe_stream_summary -- [ ] disable_enhanced_monitoring -- [ ] enable_enhanced_monitoring +- [X] disable_enhanced_monitoring +- [X] enable_enhanced_monitoring - [X] get_records - [X] get_shard_iterator - [X] increase_stream_retention_period - [X] list_shards -- [ ] list_stream_consumers +- [X] list_stream_consumers - [X] list_streams - [X] list_tags_for_stream - [X] merge_shards - [X] put_record - [X] put_records -- [ ] register_stream_consumer +- [X] register_stream_consumer - [X] remove_tags_from_stream - [X] split_shard -- [ ] start_stream_encryption -- [ ] stop_stream_encryption +- [X] start_stream_encryption +- [X] stop_stream_encryption - [ ] subscribe_to_shard -- [ ] update_shard_count +- [X] update_shard_count - [ ] update_stream_mode
diff --git a/docs/docs/services/kinesis.rst b/docs/docs/services/kinesis.rst index 3f2fc1c28..4905e4169 100644 --- a/docs/docs/services/kinesis.rst +++ b/docs/docs/services/kinesis.rst @@ -29,29 +29,33 @@ kinesis - [X] create_stream - [X] decrease_stream_retention_period - [X] delete_stream -- [ ] deregister_stream_consumer +- [X] deregister_stream_consumer - [ ] describe_limits - [X] describe_stream -- [ ] describe_stream_consumer +- [X] describe_stream_consumer - [X] describe_stream_summary -- [ ] disable_enhanced_monitoring -- [ ] enable_enhanced_monitoring +- [X] disable_enhanced_monitoring +- [X] enable_enhanced_monitoring - [X] get_records - [X] get_shard_iterator - [X] increase_stream_retention_period - [X] list_shards -- [ ] list_stream_consumers +- [X] list_stream_consumers + + Pagination is not yet implemented + + - [X] list_streams - [X] list_tags_for_stream - [X] merge_shards - [X] put_record - [X] put_records -- [ ] register_stream_consumer +- [X] register_stream_consumer - [X] remove_tags_from_stream - [X] split_shard -- [ ] start_stream_encryption -- [ ] stop_stream_encryption +- [X] start_stream_encryption +- [X] stop_stream_encryption - [ ] subscribe_to_shard -- [ ] update_shard_count +- [X] update_shard_count - [ ] update_stream_mode diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index f57523111..89653d9b2 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -33,6 +33,11 @@ class ShardNotFoundError(ResourceNotFoundError): ) +class ConsumerNotFound(ResourceNotFoundError): + def __init__(self, consumer): + super().__init__(f"Consumer {consumer}, account {ACCOUNT_ID} not found.") + + class InvalidArgumentError(BadRequest): def __init__(self, message): super().__init__() @@ -41,6 +46,27 @@ class InvalidArgumentError(BadRequest): ) +class InvalidRetentionPeriod(InvalidArgumentError): + def __init__(self, hours, too_short): + if too_short: + msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short." + else: + msg = f"Maximum allowed retention period is 8760 hours. Requested retention period ({hours} hours) is too long." + super().__init__(msg) + + +class InvalidDecreaseRetention(InvalidArgumentError): + def __init__(self, name, requested, existing): + msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API." + super().__init__(msg) + + +class InvalidIncreaseRetention(InvalidArgumentError): + def __init__(self, name, requested, existing): + msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API." + super().__init__(msg) + + class ValidationException(BadRequest): def __init__(self, value, position, regex_to_match): super().__init__() diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index 1b2b3173c..7c3378448 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -11,11 +11,15 @@ from moto.core.utils import unix_time, BackendDict from moto.core import ACCOUNT_ID from moto.utilities.paginator import paginate from .exceptions import ( + ConsumerNotFound, StreamNotFoundError, ShardNotFoundError, ResourceInUseError, ResourceNotFoundError, InvalidArgumentError, + InvalidRetentionPeriod, + InvalidDecreaseRetention, + InvalidIncreaseRetention, ValidationException, ) from .utils import ( @@ -26,6 +30,26 @@ from .utils import ( ) +class Consumer(BaseModel): + def __init__(self, consumer_name, region_name, stream_arn): + self.consumer_name = consumer_name + self.created = unix_time() + self.stream_arn = stream_arn + stream_name = stream_arn.split("/")[-1] + self.consumer_arn = f"arn:aws:kinesis:{region_name}:{ACCOUNT_ID}:stream/{stream_name}/consumer/{consumer_name}" + + def to_json(self, include_stream_arn=False): + resp = { + "ConsumerName": self.consumer_name, + "ConsumerARN": self.consumer_arn, + "ConsumerStatus": "ACTIVE", + "ConsumerCreationTimestamp": self.created, + } + if include_stream_arn: + resp["StreamARN"] = self.stream_arn + return resp + + class Record(BaseModel): def __init__(self, partition_key, data, sequence_number, explicit_hash_key): self.partition_key = partition_key @@ -45,13 +69,16 @@ class Record(BaseModel): class Shard(BaseModel): - def __init__(self, shard_id, starting_hash, ending_hash, parent=None): + def __init__( + self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None + ): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash self.records = OrderedDict() self.is_open = True self.parent = parent + self.adjacent_parent = adjacent_parent @property def shard_id(self): @@ -127,6 +154,8 @@ class Shard(BaseModel): } if self.parent: response["ParentShardId"] = self.parent + if self.adjacent_parent: + response["AdjacentParentShardId"] = self.adjacent_parent if not self.is_open: response["SequenceNumberRange"]["EndingSequenceNumber"] = str( self.get_max_sequence_number() @@ -137,36 +166,175 @@ class Shard(BaseModel): class Stream(CloudFormationModel): def __init__(self, stream_name, shard_count, retention_period_hours, region_name): self.stream_name = stream_name - self.creation_datetime = datetime.datetime.now() + self.creation_datetime = datetime.datetime.now().strftime( + "%Y-%m-%dT%H:%M:%S.%f000" + ) self.region = region_name self.account_number = ACCOUNT_ID self.shards = {} self.tags = {} self.status = "ACTIVE" self.shard_count = None - self.update_shard_count(shard_count) + self.init_shards(shard_count) self.retention_period_hours = ( retention_period_hours if retention_period_hours else 24 ) - self.enhanced_monitoring = [{"ShardLevelMetrics": []}] + self.shard_level_metrics = [] self.encryption_type = "NONE" + self.key_id = None + self.consumers = [] - def update_shard_count(self, shard_count): - # ToDo: This was extracted from init. It's only accurate for new streams. - # It doesn't (yet) try to accurately mimic the more complex re-sharding behavior. - # It makes the stream as if it had been created with this number of shards. - # Logically consistent, but not what AWS does. + def delete_consumer(self, consumer_arn): + self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn] + + def get_consumer_by_arn(self, consumer_arn): + return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None) + + def init_shards(self, shard_count): self.shard_count = shard_count step = 2 ** 128 // shard_count hash_ranges = itertools.chain( - map(lambda i: (i, i * step, (i + 1) * step), range(shard_count - 1)), + map(lambda i: (i, i * step, (i + 1) * step - 1), range(shard_count - 1)), [(shard_count - 1, (shard_count - 1) * step, 2 ** 128)], ) for index, start, end in hash_ranges: shard = Shard(index, start, end) self.shards[shard.shard_id] = shard + def split_shard(self, shard_to_split, new_starting_hash_key): + new_starting_hash_key = int(new_starting_hash_key) + + shard = self.shards[shard_to_split] + + if shard.starting_hash < new_starting_hash_key < shard.ending_hash: + pass + else: + raise InvalidArgumentError( + message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {ACCOUNT_ID} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." + ) + + if not shard.is_open: + raise InvalidArgumentError( + message=f"Shard {shard.shard_id} in stream {self.stream_name} under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting." + ) + + last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[ + -1 + ]._shard_id + + # Create two new shards + new_shard_1 = Shard( + last_id + 1, + starting_hash=shard.starting_hash, + ending_hash=new_starting_hash_key - 1, + parent=shard.shard_id, + ) + new_shard_2 = Shard( + last_id + 2, + starting_hash=new_starting_hash_key, + ending_hash=shard.ending_hash, + parent=shard.shard_id, + ) + self.shards[new_shard_1.shard_id] = new_shard_1 + self.shards[new_shard_2.shard_id] = new_shard_2 + shard.is_open = False + + records = shard.records + shard.records = OrderedDict() + + for index in records: + record = records[index] + self.put_record( + record.partition_key, record.explicit_hash_key, None, record.data + ) + + def merge_shards(self, shard_to_merge, adjacent_shard_to_merge): + shard1 = self.shards[shard_to_merge] + shard2 = self.shards[adjacent_shard_to_merge] + + # Validate the two shards are adjacent + if shard1.ending_hash == (shard2.starting_hash - 1): + pass + elif shard2.ending_hash == (shard1.starting_hash + 1): + pass + else: + raise InvalidArgumentError(adjacent_shard_to_merge) + + # Create a new shard + last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[ + -1 + ]._shard_id + new_shard = Shard( + last_id + 1, + starting_hash=shard1.starting_hash, + ending_hash=shard2.ending_hash, + parent=shard1.shard_id, + adjacent_parent=shard2.shard_id, + ) + self.shards[new_shard.shard_id] = new_shard + + # Close the merged shards + shard1.is_open = False + shard2.is_open = False + + # Move all data across + for record in shard1.records.values(): + new_shard.put_record( + record.partition_key, record.data, record.explicit_hash_key + ) + for record in shard2.records.values(): + new_shard.put_record( + record.partition_key, record.data, record.explicit_hash_key + ) + + def update_shard_count(self, target_shard_count): + current_shard_count = len([s for s in self.shards.values() if s.is_open]) + if current_shard_count == target_shard_count: + return + + # Split shards until we have enough shards + # AWS seems to split until we have (current * 2) shards, and then merge until we reach the target + # That's what observable at least - the actual algorithm is probably more advanced + # + if current_shard_count < target_shard_count: + open_shards = [ + (shard_id, shard) + for shard_id, shard in self.shards.items() + if shard.is_open + ] + for shard_id, shard in open_shards: + # Split the current shard + new_starting_hash_key = str( + int((shard.ending_hash + shard.starting_hash) / 2) + ) + self.split_shard(shard_id, new_starting_hash_key) + + current_shard_count = len([s for s in self.shards.values() if s.is_open]) + + # If we need to reduce the shard count, merge shards until we get there + while current_shard_count > target_shard_count: + # Keep track of how often we need to merge to get to the target shard count + required_shard_merges = current_shard_count - target_shard_count + # Get a list of pairs of adjacent shards + shard_list = sorted( + [s for s in self.shards.values() if s.is_open], + key=lambda x: x.starting_hash, + ) + adjacent_shards = zip( + [s for s in shard_list[0:-1:2]], [s for s in shard_list[1::2]] + ) + + for (shard, adjacent) in adjacent_shards: + self.merge_shards(shard.shard_id, adjacent.shard_id) + required_shard_merges -= 1 + if required_shard_merges == 0: + break + + current_shard_count = len([s for s in self.shards.values() if s.is_open]) + + self.shard_count = target_shard_count + @property def arn(self): return "arn:aws:kinesis:{region}:{account_number}:stream/{stream_name}".format( @@ -218,12 +386,13 @@ class Stream(CloudFormationModel): "StreamDescription": { "StreamARN": self.arn, "StreamName": self.stream_name, - "StreamCreationTimestamp": str(self.creation_datetime), + "StreamCreationTimestamp": self.creation_datetime, "StreamStatus": self.status, "HasMoreShards": len(requested_shards) != len(all_shards), "RetentionPeriodHours": self.retention_period_hours, - "EnhancedMonitoring": self.enhanced_monitoring, + "EnhancedMonitoring": [{"ShardLevelMetrics": self.shard_level_metrics}], "EncryptionType": self.encryption_type, + "KeyId": self.key_id, "Shards": [shard.to_json() for shard in requested_shards], } } @@ -234,7 +403,7 @@ class Stream(CloudFormationModel): "StreamARN": self.arn, "StreamName": self.stream_name, "StreamStatus": self.status, - "StreamCreationTimestamp": str(self.creation_datetime), + "StreamCreationTimestamp": self.creation_datetime, "OpenShardCount": self.shard_count, } } @@ -262,7 +431,7 @@ class Stream(CloudFormationModel): backend = kinesis_backends[region_name] stream = backend.create_stream( - resource_name, shard_count, retention_period_hours, region_name + resource_name, shard_count, retention_period_hours ) if any(tags): backend.add_tags_to_stream(stream.stream_name, tags) @@ -335,8 +504,14 @@ class Stream(CloudFormationModel): class KinesisBackend(BaseBackend): - def __init__(self, region=None): + def __init__(self, region): self.streams = OrderedDict() + self.region_name = region + + def reset(self): + region = self.region_name + self.__dict__ = {} + self.__init__(region) @staticmethod def default_vpc_endpoint_service(service_region, zones): @@ -345,12 +520,12 @@ class KinesisBackend(BaseBackend): service_region, zones, "kinesis", special_service_name="kinesis-streams" ) - def create_stream( - self, stream_name, shard_count, retention_period_hours, region_name - ): + def create_stream(self, stream_name, shard_count, retention_period_hours): if stream_name in self.streams: raise ResourceInUseError(stream_name) - stream = Stream(stream_name, shard_count, retention_period_hours, region_name) + stream = Stream( + stream_name, shard_count, retention_period_hours, self.region_name + ) self.streams[stream_name] = stream return stream @@ -468,51 +643,8 @@ class KinesisBackend(BaseBackend): position="newStartingHashKey", regex_to_match=r"0|([1-9]\d{0,38})", ) - new_starting_hash_key = int(new_starting_hash_key) - shard = stream.shards[shard_to_split] - - if shard.starting_hash < new_starting_hash_key < shard.ending_hash: - pass - else: - raise InvalidArgumentError( - message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {stream_name} under account {ACCOUNT_ID} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash-1)}." - ) - - if not shard.is_open: - raise InvalidArgumentError( - message=f"Shard {shard.shard_id} in stream {stream_name} under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting." - ) - - last_id = sorted(stream.shards.values(), key=attrgetter("_shard_id"))[ - -1 - ]._shard_id - - # Create two new shards - new_shard_1 = Shard( - last_id + 1, - starting_hash=shard.starting_hash, - ending_hash=new_starting_hash_key - 1, - parent=shard.shard_id, - ) - new_shard_2 = Shard( - last_id + 2, - starting_hash=new_starting_hash_key, - ending_hash=shard.ending_hash, - parent=shard.shard_id, - ) - stream.shards[new_shard_1.shard_id] = new_shard_1 - stream.shards[new_shard_2.shard_id] = new_shard_2 - shard.is_open = False - - records = shard.records - shard.records = OrderedDict() - - for index in records: - record = records[index] - stream.put_record( - record.partition_key, record.explicit_hash_key, None, record.data - ) + stream.split_shard(shard_to_split, new_starting_hash_key) def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge): stream = self.describe_stream(stream_name) @@ -523,22 +655,15 @@ class KinesisBackend(BaseBackend): if adjacent_shard_to_merge not in stream.shards: raise ShardNotFoundError(adjacent_shard_to_merge, stream=stream_name) - shard1 = stream.shards[shard_to_merge] - shard2 = stream.shards[adjacent_shard_to_merge] + stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) - if shard1.ending_hash == shard2.starting_hash: - shard1.ending_hash = shard2.ending_hash - elif shard2.ending_hash == shard1.starting_hash: - shard1.starting_hash = shard2.starting_hash - else: - raise InvalidArgumentError(adjacent_shard_to_merge) + def update_shard_count(self, stream_name, target_shard_count): + stream = self.describe_stream(stream_name) + current_shard_count = len([s for s in stream.shards.values() if s.is_open]) - del stream.shards[shard2.shard_id] - for index in shard2.records: - record = shard2.records[index] - shard1.put_record( - record.partition_key, record.data, record.explicit_hash_key - ) + stream.update_shard_count(target_shard_count) + + return current_shard_count @paginate(pagination_model=PAGINATION_MODEL) def list_shards(self, stream_name): @@ -548,22 +673,30 @@ class KinesisBackend(BaseBackend): def increase_stream_retention_period(self, stream_name, retention_period_hours): stream = self.describe_stream(stream_name) - if ( - retention_period_hours <= stream.retention_period_hours - or retention_period_hours < 24 - or retention_period_hours > 8760 - ): - raise InvalidArgumentError(retention_period_hours) + if retention_period_hours < 24: + raise InvalidRetentionPeriod(retention_period_hours, too_short=True) + if retention_period_hours > 8760: + raise InvalidRetentionPeriod(retention_period_hours, too_short=False) + if retention_period_hours < stream.retention_period_hours: + raise InvalidIncreaseRetention( + name=stream_name, + requested=retention_period_hours, + existing=stream.retention_period_hours, + ) stream.retention_period_hours = retention_period_hours def decrease_stream_retention_period(self, stream_name, retention_period_hours): stream = self.describe_stream(stream_name) - if ( - retention_period_hours >= stream.retention_period_hours - or retention_period_hours < 24 - or retention_period_hours > 8760 - ): - raise InvalidArgumentError(retention_period_hours) + if retention_period_hours < 24: + raise InvalidRetentionPeriod(retention_period_hours, too_short=True) + if retention_period_hours > 8760: + raise InvalidRetentionPeriod(retention_period_hours, too_short=False) + if retention_period_hours > stream.retention_period_hours: + raise InvalidDecreaseRetention( + name=stream_name, + requested=retention_period_hours, + existing=stream.retention_period_hours, + ) stream.retention_period_hours = retention_period_hours def list_tags_for_stream( @@ -594,5 +727,77 @@ class KinesisBackend(BaseBackend): if key in stream.tags: del stream.tags[key] + def enable_enhanced_monitoring(self, stream_name, shard_level_metrics): + stream = self.describe_stream(stream_name) + current_shard_level_metrics = stream.shard_level_metrics + desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) + stream.shard_level_metrics = desired_metrics + return current_shard_level_metrics, desired_metrics + + def disable_enhanced_monitoring(self, stream_name, to_be_disabled): + stream = self.describe_stream(stream_name) + current_metrics = stream.shard_level_metrics + if "ALL" in to_be_disabled: + desired_metrics = [] + else: + desired_metrics = [ + metric for metric in current_metrics if metric not in to_be_disabled + ] + stream.shard_level_metrics = desired_metrics + return current_metrics, desired_metrics + + def _find_stream_by_arn(self, stream_arn): + for stream in self.streams.values(): + if stream.arn == stream_arn: + return stream + + def list_stream_consumers(self, stream_arn): + """ + Pagination is not yet implemented + """ + stream = self._find_stream_by_arn(stream_arn) + return stream.consumers + + def register_stream_consumer(self, stream_arn, consumer_name): + consumer = Consumer(consumer_name, self.region_name, stream_arn) + stream = self._find_stream_by_arn(stream_arn) + stream.consumers.append(consumer) + return consumer + + def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn): + if stream_arn: + stream = self._find_stream_by_arn(stream_arn) + for consumer in stream.consumers: + if consumer_name and consumer.consumer_name == consumer_name: + return consumer + if consumer_arn: + for stream in self.streams.values(): + consumer = stream.get_consumer_by_arn(consumer_arn) + if consumer: + return consumer + raise ConsumerNotFound(consumer=consumer_name or consumer_arn) + + def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn): + if stream_arn: + stream = self._find_stream_by_arn(stream_arn) + stream.consumers = [ + c for c in stream.consumers if c.consumer_name == consumer_name + ] + if consumer_arn: + for stream in self.streams.values(): + # Only one stream will actually have this consumer + # It will be a noop for other streams + stream.delete_consumer(consumer_arn) + + def start_stream_encryption(self, stream_name, encryption_type, key_id): + stream = self.describe_stream(stream_name) + stream.encryption_type = encryption_type + stream.key_id = key_id + + def stop_stream_encryption(self, stream_name): + stream = self.describe_stream(stream_name) + stream.encryption_type = "NONE" + stream.key_id = None + kinesis_backends = BackendDict(KinesisBackend, "kinesis") diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index a29bc4630..c150e80ad 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -18,7 +18,7 @@ class KinesisResponse(BaseResponse): shard_count = self.parameters.get("ShardCount") retention_period_hours = self.parameters.get("RetentionPeriodHours") self.kinesis_backend.create_stream( - stream_name, shard_count, retention_period_hours, self.region + stream_name, shard_count, retention_period_hours ) return "" @@ -149,6 +149,20 @@ class KinesisResponse(BaseResponse): res["NextToken"] = token return json.dumps(res) + def update_shard_count(self): + stream_name = self.parameters.get("StreamName") + target_shard_count = self.parameters.get("TargetShardCount") + current_shard_count = self.kinesis_backend.update_shard_count( + stream_name=stream_name, target_shard_count=target_shard_count, + ) + return json.dumps( + dict( + StreamName=stream_name, + CurrentShardCount=current_shard_count, + TargetShardCount=target_shard_count, + ) + ) + def increase_stream_retention_period(self): stream_name = self.parameters.get("StreamName") retention_period_hours = self.parameters.get("RetentionPeriodHours") @@ -185,3 +199,82 @@ class KinesisResponse(BaseResponse): tag_keys = self.parameters.get("TagKeys") self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys) return json.dumps({}) + + def enable_enhanced_monitoring(self): + stream_name = self.parameters.get("StreamName") + shard_level_metrics = self.parameters.get("ShardLevelMetrics") + current, desired = self.kinesis_backend.enable_enhanced_monitoring( + stream_name=stream_name, shard_level_metrics=shard_level_metrics, + ) + return json.dumps( + dict( + StreamName=stream_name, + CurrentShardLevelMetrics=current, + DesiredShardLevelMetrics=desired, + ) + ) + + def disable_enhanced_monitoring(self): + stream_name = self.parameters.get("StreamName") + shard_level_metrics = self.parameters.get("ShardLevelMetrics") + current, desired = self.kinesis_backend.disable_enhanced_monitoring( + stream_name=stream_name, to_be_disabled=shard_level_metrics, + ) + return json.dumps( + dict( + StreamName=stream_name, + CurrentShardLevelMetrics=current, + DesiredShardLevelMetrics=desired, + ) + ) + + def list_stream_consumers(self): + stream_arn = self.parameters.get("StreamARN") + consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn) + return json.dumps(dict(Consumers=[c.to_json() for c in consumers])) + + def register_stream_consumer(self): + stream_arn = self.parameters.get("StreamARN") + consumer_name = self.parameters.get("ConsumerName") + consumer = self.kinesis_backend.register_stream_consumer( + stream_arn=stream_arn, consumer_name=consumer_name, + ) + return json.dumps(dict(Consumer=consumer.to_json())) + + def describe_stream_consumer(self): + stream_arn = self.parameters.get("StreamARN") + consumer_name = self.parameters.get("ConsumerName") + consumer_arn = self.parameters.get("ConsumerARN") + consumer = self.kinesis_backend.describe_stream_consumer( + stream_arn=stream_arn, + consumer_name=consumer_name, + consumer_arn=consumer_arn, + ) + return json.dumps( + dict(ConsumerDescription=consumer.to_json(include_stream_arn=True)) + ) + + def deregister_stream_consumer(self): + stream_arn = self.parameters.get("StreamARN") + consumer_name = self.parameters.get("ConsumerName") + consumer_arn = self.parameters.get("ConsumerARN") + self.kinesis_backend.deregister_stream_consumer( + stream_arn=stream_arn, + consumer_name=consumer_name, + consumer_arn=consumer_arn, + ) + return json.dumps(dict()) + + def start_stream_encryption(self): + stream_name = self.parameters.get("StreamName") + encryption_type = self.parameters.get("EncryptionType") + key_id = self.parameters.get("KeyId") + self.kinesis_backend.start_stream_encryption( + stream_name=stream_name, encryption_type=encryption_type, key_id=key_id + ) + return json.dumps(dict()) + + def stop_stream_encryption(self): + stream_name = self.parameters.get("StreamName") + self.kinesis_backend.stop_stream_encryption(stream_name=stream_name,) + return json.dumps(dict()) diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index a9b884435..62bd1aef2 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -66,6 +66,7 @@ TestAccAWSIAMGroupPolicyAttachment TestAccAWSIAMRole TestAccAWSIAMUserPolicy TestAccAWSIPRanges +TestAccAWSKinesisStream TestAccAWSKmsAlias TestAccAWSKmsSecretDataSource TestAccAWSMq diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index f361611f7..7e93b8a55 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -518,10 +518,46 @@ def test_invalid_increase_stream_retention_period(): ) with pytest.raises(ClientError) as ex: conn.increase_stream_retention_period( - StreamName=stream_name, RetentionPeriodHours=20 + StreamName=stream_name, RetentionPeriodHours=25 ) ex.value.response["Error"]["Code"].should.equal("InvalidArgumentException") - ex.value.response["Error"]["Message"].should.equal(20) + ex.value.response["Error"]["Message"].should.equal( + "Requested retention period (25 hours) for stream my_stream can not be shorter than existing retention period (30 hours). Use DecreaseRetentionPeriod API." + ) + + +@mock_kinesis +def test_invalid_increase_stream_retention_too_low(): + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "my_stream" + conn.create_stream(StreamName=stream_name, ShardCount=1) + + with pytest.raises(ClientError) as ex: + conn.increase_stream_retention_period( + StreamName=stream_name, RetentionPeriodHours=20 + ) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidArgumentException") + err["Message"].should.equal( + "Minimum allowed retention period is 24 hours. Requested retention period (20 hours) is too short." + ) + + +@mock_kinesis +def test_invalid_increase_stream_retention_too_high(): + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "my_stream" + conn.create_stream(StreamName=stream_name, ShardCount=1) + + with pytest.raises(ClientError) as ex: + conn.increase_stream_retention_period( + StreamName=stream_name, RetentionPeriodHours=9999 + ) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidArgumentException") + err["Message"].should.equal( + "Maximum allowed retention period is 8760 hours. Requested retention period (9999 hours) is too long." + ) @mock_kinesis @@ -542,20 +578,54 @@ def test_valid_decrease_stream_retention_period(): @mock_kinesis -def test_invalid_decrease_stream_retention_period(): +def test_decrease_stream_retention_period_upwards(): conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "decrease_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) - conn.increase_stream_retention_period( - StreamName=stream_name, RetentionPeriodHours=30 - ) with pytest.raises(ClientError) as ex: conn.decrease_stream_retention_period( - StreamName=stream_name, RetentionPeriodHours=20 + StreamName=stream_name, RetentionPeriodHours=40 ) - ex.value.response["Error"]["Code"].should.equal("InvalidArgumentException") - ex.value.response["Error"]["Message"].should.equal(20) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidArgumentException") + err["Message"].should.equal( + "Requested retention period (40 hours) for stream decrease_stream can not be longer than existing retention period (24 hours). Use IncreaseRetentionPeriod API." + ) + + +@mock_kinesis +def test_decrease_stream_retention_period_too_low(): + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "decrease_stream" + conn.create_stream(StreamName=stream_name, ShardCount=1) + + with pytest.raises(ClientError) as ex: + conn.decrease_stream_retention_period( + StreamName=stream_name, RetentionPeriodHours=4 + ) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidArgumentException") + err["Message"].should.equal( + "Minimum allowed retention period is 24 hours. Requested retention period (4 hours) is too short." + ) + + +@mock_kinesis +def test_decrease_stream_retention_period_too_high(): + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "decrease_stream" + conn.create_stream(StreamName=stream_name, ShardCount=1) + + with pytest.raises(ClientError) as ex: + conn.decrease_stream_retention_period( + StreamName=stream_name, RetentionPeriodHours=9999 + ) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidArgumentException") + err["Message"].should.equal( + "Maximum allowed retention period is 8760 hours. Requested retention period (9999 hours) is too long." + ) @mock_kinesis @@ -632,6 +702,11 @@ def test_merge_shards_boto3(): stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] shards = stream["Shards"] + + # Old shards still exist, but are closed. A new shard is created out of the old one + shards.should.have.length_of(5) + + # Only three shards are active - the two merged shards are closed active_shards = [ shard for shard in shards @@ -641,12 +716,13 @@ def test_merge_shards_boto3(): client.merge_shards( StreamName=stream_name, - ShardToMerge="shardId-000000000002", - AdjacentShardToMerge="shardId-000000000000", + ShardToMerge="shardId-000000000004", + AdjacentShardToMerge="shardId-000000000002", ) stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] shards = stream["Shards"] + active_shards = [ shard for shard in shards @@ -654,6 +730,22 @@ def test_merge_shards_boto3(): ] active_shards.should.have.length_of(2) + for shard in active_shards: + del shard["HashKeyRange"] + del shard["SequenceNumberRange"] + + # Original shard #3 is still active (0,1,2 have been merged and closed + active_shards.should.contain({"ShardId": "shardId-000000000003"}) + # Shard #4 was the child of #0 and #1 + # Shard #5 is the child of #4 (parent) and #2 (adjacent-parent) + active_shards.should.contain( + { + "ShardId": "shardId-000000000005", + "ParentShardId": "shardId-000000000004", + "AdjacentParentShardId": "shardId-000000000002", + } + ) + @mock_kinesis def test_merge_shards_invalid_arg(): diff --git a/tests/test_kinesis/test_kinesis_boto3.py b/tests/test_kinesis/test_kinesis_boto3.py index fbc3b65dd..e8ea13753 100644 --- a/tests/test_kinesis/test_kinesis_boto3.py +++ b/tests/test_kinesis/test_kinesis_boto3.py @@ -63,7 +63,7 @@ def test_list_shards(): shard["HashKeyRange"].should.have.key("StartingHashKey") shard["HashKeyRange"].should.have.key("EndingHashKey") shard_list[0]["HashKeyRange"]["EndingHashKey"].should.equal( - shard_list[1]["HashKeyRange"]["StartingHashKey"] + str(int(shard_list[1]["HashKeyRange"]["StartingHashKey"]) - 1) ) # Verify sequence numbers for shard in shard_list: @@ -297,3 +297,45 @@ def test_split_shard_that_was_split_before(): err["Message"].should.equal( f"Shard shardId-000000000001 in stream my-stream under account {ACCOUNT_ID} has already been merged or split, and thus is not eligible for merging or splitting." ) + + +@mock_kinesis +@pytest.mark.parametrize( + "initial,target,expected_total", + [(2, 4, 6), (4, 5, 15), (10, 13, 37), (4, 2, 6), (5, 3, 7), (10, 3, 17)], +) +def test_update_shard_count(initial, target, expected_total): + """ + Test that we update the shard count in a similar manner to AWS + Assert on: initial_shard_count, target_shard_count and total_shard_count + + total_shard_count gives an idea of the number of splits/merges required to reach the target + + These numbers have been verified against AWS + """ + client = boto3.client("kinesis", region_name="eu-west-1") + client.create_stream(StreamName="my-stream", ShardCount=initial) + + resp = client.update_shard_count( + StreamName="my-stream", TargetShardCount=target, ScalingType="UNIFORM_SCALING" + ) + + resp.should.have.key("StreamName").equals("my-stream") + resp.should.have.key("CurrentShardCount").equals(initial) + resp.should.have.key("TargetShardCount").equals(target) + + stream = client.describe_stream(StreamName="my-stream")["StreamDescription"] + stream["StreamStatus"].should.equal("ACTIVE") + stream["Shards"].should.have.length_of(expected_total) + + active_shards = [ + shard + for shard in stream["Shards"] + if "EndingSequenceNumber" not in shard["SequenceNumberRange"] + ] + active_shards.should.have.length_of(target) + + resp = client.describe_stream_summary(StreamName="my-stream") + stream = resp["StreamDescriptionSummary"] + + stream["OpenShardCount"].should.equal(target) diff --git a/tests/test_kinesis/test_kinesis_encryption.py b/tests/test_kinesis/test_kinesis_encryption.py new file mode 100644 index 000000000..2d3d5e4d7 --- /dev/null +++ b/tests/test_kinesis/test_kinesis_encryption.py @@ -0,0 +1,46 @@ +import boto3 + +from moto import mock_kinesis + + +@mock_kinesis +def test_enable_encryption(): + client = boto3.client("kinesis", region_name="us-west-2") + client.create_stream(StreamName="my-stream", ShardCount=2) + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("NONE") + desc.shouldnt.have.key("KeyId") + + client.start_stream_encryption( + StreamName="my-stream", EncryptionType="KMS", KeyId="n/a" + ) + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("KMS") + desc.should.have.key("KeyId").equals("n/a") + + +@mock_kinesis +def test_disable_encryption(): + client = boto3.client("kinesis", region_name="us-west-2") + client.create_stream(StreamName="my-stream", ShardCount=2) + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("NONE") + + client.start_stream_encryption( + StreamName="my-stream", EncryptionType="KMS", KeyId="n/a" + ) + + client.stop_stream_encryption( + StreamName="my-stream", EncryptionType="KMS", KeyId="n/a" + ) + + resp = client.describe_stream(StreamName="my-stream") + desc = resp["StreamDescription"] + desc.should.have.key("EncryptionType").should.equal("NONE") + desc.shouldnt.have.key("KeyId") diff --git a/tests/test_kinesis/test_kinesis_monitoring.py b/tests/test_kinesis/test_kinesis_monitoring.py new file mode 100644 index 000000000..21e0eaff3 --- /dev/null +++ b/tests/test_kinesis/test_kinesis_monitoring.py @@ -0,0 +1,127 @@ +import boto3 + +from moto import mock_kinesis + + +@mock_kinesis +def test_enable_enhanced_monitoring_all(): + client = boto3.client("kinesis", region_name="us-east-1") + stream_name = "my_stream_summary" + client.create_stream(StreamName=stream_name, ShardCount=4) + + resp = client.enable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["ALL"] + ) + + resp.should.have.key("StreamName").equals(stream_name) + resp.should.have.key("CurrentShardLevelMetrics").equals([]) + resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"]) + + +@mock_kinesis +def test_enable_enhanced_monitoring_is_persisted(): + client = boto3.client("kinesis", region_name="us-east-1") + stream_name = "my_stream_summary" + client.create_stream(StreamName=stream_name, ShardCount=4) + + client.enable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["IncomingBytes", "OutgoingBytes"] + ) + + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"] + set(metrics).should.equal({"IncomingBytes", "OutgoingBytes"}) + + +@mock_kinesis +def test_enable_enhanced_monitoring_in_steps(): + client = boto3.client("kinesis", region_name="us-east-1") + stream_name = "my_stream_summary" + client.create_stream(StreamName=stream_name, ShardCount=4) + + client.enable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["IncomingBytes", "OutgoingBytes"] + ) + + resp = client.enable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["WriteProvisionedThroughputExceeded"] + ) + + resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(2) + resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes") + resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes") + resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(3) + resp["DesiredShardLevelMetrics"].should.contain("IncomingBytes") + resp["DesiredShardLevelMetrics"].should.contain("OutgoingBytes") + resp["DesiredShardLevelMetrics"].should.contain( + "WriteProvisionedThroughputExceeded" + ) + + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"] + metrics.should.have.length_of(3) + metrics.should.contain("IncomingBytes") + metrics.should.contain("OutgoingBytes") + metrics.should.contain("WriteProvisionedThroughputExceeded") + + +@mock_kinesis +def test_disable_enhanced_monitoring(): + client = boto3.client("kinesis", region_name="us-east-1") + stream_name = "my_stream_summary" + client.create_stream(StreamName=stream_name, ShardCount=4) + + client.enable_enhanced_monitoring( + StreamName=stream_name, + ShardLevelMetrics=[ + "IncomingBytes", + "OutgoingBytes", + "WriteProvisionedThroughputExceeded", + ], + ) + + resp = client.disable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"] + ) + + resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3) + resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes") + resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes") + resp["CurrentShardLevelMetrics"].should.contain( + "WriteProvisionedThroughputExceeded" + ) + resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(2) + resp["DesiredShardLevelMetrics"].should.contain("IncomingBytes") + resp["DesiredShardLevelMetrics"].should.contain( + "WriteProvisionedThroughputExceeded" + ) + + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"] + metrics.should.have.length_of(2) + metrics.should.contain("IncomingBytes") + metrics.should.contain("WriteProvisionedThroughputExceeded") + + +@mock_kinesis +def test_disable_enhanced_monitoring_all(): + client = boto3.client("kinesis", region_name="us-east-1") + stream_name = "my_stream_summary" + client.create_stream(StreamName=stream_name, ShardCount=4) + + client.enable_enhanced_monitoring( + StreamName=stream_name, + ShardLevelMetrics=[ + "IncomingBytes", + "OutgoingBytes", + "WriteProvisionedThroughputExceeded", + ], + ) + + client.disable_enhanced_monitoring( + StreamName=stream_name, ShardLevelMetrics=["ALL"] + ) + + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + metrics = stream["EnhancedMonitoring"][0]["ShardLevelMetrics"] + metrics.should.equal([]) diff --git a/tests/test_kinesis/test_kinesis_stream_consumers.py b/tests/test_kinesis/test_kinesis_stream_consumers.py new file mode 100644 index 000000000..9b8ca2962 --- /dev/null +++ b/tests/test_kinesis/test_kinesis_stream_consumers.py @@ -0,0 +1,146 @@ +import boto3 +import pytest + +from botocore.exceptions import ClientError +from moto import mock_kinesis +from moto.core import ACCOUNT_ID + + +def create_stream(client): + stream_name = "my-stream" + client.create_stream(StreamName=stream_name, ShardCount=4) + stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] + return stream["StreamARN"] + + +@mock_kinesis +def test_list_stream_consumers(): + client = boto3.client("kinesis", region_name="eu-west-1") + stream_arn = create_stream(client) + + resp = client.list_stream_consumers(StreamARN=stream_arn) + + resp.should.have.key("Consumers").equals([]) + + +@mock_kinesis +def test_register_stream_consumer(): + client = boto3.client("kinesis", region_name="eu-west-1") + stream_arn = create_stream(client) + + resp = client.register_stream_consumer( + StreamARN=stream_arn, ConsumerName="newconsumer" + ) + resp.should.have.key("Consumer") + + consumer = resp["Consumer"] + + consumer.should.have.key("ConsumerName").equals("newconsumer") + consumer.should.have.key("ConsumerARN").equals( + f"arn:aws:kinesis:eu-west-1:{ACCOUNT_ID}:stream/my-stream/consumer/newconsumer" + ) + consumer.should.have.key("ConsumerStatus").equals("ACTIVE") + consumer.should.have.key("ConsumerCreationTimestamp") + + resp = client.list_stream_consumers(StreamARN=stream_arn) + + resp.should.have.key("Consumers").length_of(1) + consumer = resp["Consumers"][0] + consumer.should.have.key("ConsumerName").equals("newconsumer") + consumer.should.have.key("ConsumerARN").equals( + f"arn:aws:kinesis:eu-west-1:{ACCOUNT_ID}:stream/my-stream/consumer/newconsumer" + ) + consumer.should.have.key("ConsumerStatus").equals("ACTIVE") + consumer.should.have.key("ConsumerCreationTimestamp") + + +@mock_kinesis +def test_describe_stream_consumer_by_name(): + client = boto3.client("kinesis", region_name="us-east-2") + stream_arn = create_stream(client) + client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="newconsumer") + + resp = client.describe_stream_consumer( + StreamARN=stream_arn, ConsumerName="newconsumer" + ) + resp.should.have.key("ConsumerDescription") + + consumer = resp["ConsumerDescription"] + consumer.should.have.key("ConsumerName").equals("newconsumer") + consumer.should.have.key("ConsumerARN") + consumer.should.have.key("ConsumerStatus").equals("ACTIVE") + consumer.should.have.key("ConsumerCreationTimestamp") + consumer.should.have.key("StreamARN").equals(stream_arn) + + +@mock_kinesis +def test_describe_stream_consumer_by_arn(): + client = boto3.client("kinesis", region_name="us-east-2") + stream_arn = create_stream(client) + resp = client.register_stream_consumer( + StreamARN=stream_arn, ConsumerName="newconsumer" + ) + consumer_arn = resp["Consumer"]["ConsumerARN"] + + resp = client.describe_stream_consumer(ConsumerARN=consumer_arn) + resp.should.have.key("ConsumerDescription") + + consumer = resp["ConsumerDescription"] + consumer.should.have.key("ConsumerName").equals("newconsumer") + consumer.should.have.key("ConsumerARN") + consumer.should.have.key("ConsumerStatus").equals("ACTIVE") + consumer.should.have.key("ConsumerCreationTimestamp") + consumer.should.have.key("StreamARN").equals(stream_arn) + + +@mock_kinesis +def test_describe_stream_consumer_unknown(): + client = boto3.client("kinesis", region_name="us-east-2") + create_stream(client) + + with pytest.raises(ClientError) as exc: + client.describe_stream_consumer(ConsumerARN="unknown") + err = exc.value.response["Error"] + err["Code"].should.equal("ResourceNotFoundException") + err["Message"].should.equal(f"Consumer unknown, account {ACCOUNT_ID} not found.") + + +@mock_kinesis +def test_deregister_stream_consumer_by_name(): + client = boto3.client("kinesis", region_name="ap-southeast-1") + stream_arn = create_stream(client) + + client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer1") + client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer2") + + client.list_stream_consumers(StreamARN=stream_arn)[ + "Consumers" + ].should.have.length_of(2) + + client.deregister_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer1") + + client.list_stream_consumers(StreamARN=stream_arn)[ + "Consumers" + ].should.have.length_of(1) + + +@mock_kinesis +def test_deregister_stream_consumer_by_arn(): + client = boto3.client("kinesis", region_name="ap-southeast-1") + stream_arn = create_stream(client) + + resp = client.register_stream_consumer( + StreamARN=stream_arn, ConsumerName="consumer1" + ) + consumer1_arn = resp["Consumer"]["ConsumerARN"] + client.register_stream_consumer(StreamARN=stream_arn, ConsumerName="consumer2") + + client.list_stream_consumers(StreamARN=stream_arn)[ + "Consumers" + ].should.have.length_of(2) + + client.deregister_stream_consumer(ConsumerARN=consumer1_arn) + + client.list_stream_consumers(StreamARN=stream_arn)[ + "Consumers" + ].should.have.length_of(1)