from base64 import b64encode, b64decode from collections import OrderedDict from gzip import GzipFile import datetime import io import json import re import itertools from operator import attrgetter from typing import Any, Dict, List, Optional, Tuple, Iterable from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import unix_time from moto.moto_api._internal import mock_random as random from moto.utilities.paginator import paginate from moto.utilities.utils import md5_hash from .exceptions import ( ConsumerNotFound, StreamNotFoundError, StreamCannotBeUpdatedError, ShardNotFoundError, ResourceInUseError, ResourceNotFoundError, InvalidArgumentError, InvalidRetentionPeriod, InvalidDecreaseRetention, InvalidIncreaseRetention, ValidationException, RecordSizeExceedsLimit, TotalRecordsSizeExceedsLimit, TooManyRecords, ) from .utils import ( compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator, PAGINATION_MODEL, ) class Consumer(BaseModel): def __init__( self, consumer_name: str, account_id: str, region_name: str, stream_arn: str ): self.consumer_name = consumer_name self.created = unix_time() self.stream_arn = stream_arn stream_name = stream_arn.split("/")[-1] self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}" def to_json(self, include_stream_arn: bool = False) -> Dict[str, Any]: resp = { "ConsumerName": self.consumer_name, "ConsumerARN": self.consumer_arn, "ConsumerStatus": "ACTIVE", "ConsumerCreationTimestamp": self.created, } if include_stream_arn: resp["StreamARN"] = self.stream_arn return resp class Record(BaseModel): def __init__( self, partition_key: str, data: str, sequence_number: int, explicit_hash_key: str, ): self.partition_key = partition_key self.data = data self.sequence_number = sequence_number self.explicit_hash_key = explicit_hash_key self.created_at_datetime = datetime.datetime.utcnow() self.created_at = unix_time(self.created_at_datetime) def to_json(self) -> Dict[str, Any]: return { "Data": self.data, "PartitionKey": self.partition_key, "SequenceNumber": str(self.sequence_number), "ApproximateArrivalTimestamp": self.created_at, } class Shard(BaseModel): def __init__( self, shard_id: int, starting_hash: int, ending_hash: int, parent: Optional[str] = None, adjacent_parent: Optional[str] = None, ): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash self.records: Dict[int, Record] = OrderedDict() self.is_open = True self.parent = parent self.adjacent_parent = adjacent_parent @property def shard_id(self) -> str: return f"shardId-{str(self._shard_id).zfill(12)}" def get_records( self, last_sequence_id: str, limit: Optional[int] ) -> Tuple[List[Record], int, int]: last_sequence_int = int(last_sequence_id) results = [] secs_behind_latest = 0.0 for sequence_number, record in self.records.items(): if sequence_number > last_sequence_int: results.append(record) last_sequence_int = sequence_number very_last_record = self.records[next(reversed(self.records))] secs_behind_latest = very_last_record.created_at - record.created_at if len(results) == limit: break millis_behind_latest = int(secs_behind_latest * 1000) return results, last_sequence_int, millis_behind_latest def put_record(self, partition_key: str, data: str, explicit_hash_key: str) -> str: # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() else: last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record( partition_key, data, sequence_number, explicit_hash_key ) return str(sequence_number) def get_min_sequence_number(self) -> int: if self.records: return list(self.records.keys())[0] return 0 def get_max_sequence_number(self) -> int: if self.records: return list(self.records.keys())[-1] return 0 def get_sequence_number_at(self, at_timestamp: float) -> int: if not self.records or at_timestamp < list(self.records.values())[0].created_at: return 0 else: # find the last item in the list that was created before # at_timestamp r = next( ( r for r in reversed(self.records.values()) if r.created_at < at_timestamp ), None, ) return r.sequence_number # type: ignore def to_json(self) -> Dict[str, Any]: response: Dict[str, Any] = { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), "StartingHashKey": str(self.starting_hash), }, "SequenceNumberRange": { "StartingSequenceNumber": str(self.get_min_sequence_number()), }, "ShardId": self.shard_id, } if self.parent: response["ParentShardId"] = self.parent if self.adjacent_parent: response["AdjacentParentShardId"] = self.adjacent_parent if not self.is_open: response["SequenceNumberRange"]["EndingSequenceNumber"] = str( self.get_max_sequence_number() ) return response class Stream(CloudFormationModel): def __init__( self, stream_name: str, shard_count: int, stream_mode: Optional[Dict[str, str]], retention_period_hours: Optional[int], account_id: str, region_name: str, ): self.stream_name = stream_name self.creation_datetime = datetime.datetime.now().strftime( "%Y-%m-%dT%H:%M:%S.%f000" ) self.region = region_name self.account_id = account_id self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}" self.shards: Dict[str, Shard] = {} self.tags: Dict[str, str] = {} self.status = "ACTIVE" self.shard_count: Optional[int] = None self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"} if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": shard_count = 4 self.init_shards(shard_count) self.retention_period_hours = retention_period_hours or 24 self.shard_level_metrics: List[str] = [] self.encryption_type = "NONE" self.key_id: Optional[str] = None self.consumers: List[Consumer] = [] def delete_consumer(self, consumer_arn: str) -> None: self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn] def get_consumer_by_arn(self, consumer_arn: str) -> Optional[Consumer]: return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None) def init_shards(self, shard_count: int) -> None: self.shard_count = shard_count step = 2**128 // shard_count hash_ranges = itertools.chain( map(lambda i: (i, i * step, (i + 1) * step - 1), range(shard_count - 1)), [(shard_count - 1, (shard_count - 1) * step, 2**128)], ) for index, start, end in hash_ranges: shard = Shard(index, start, end) self.shards[shard.shard_id] = shard def split_shard(self, shard_to_split: str, new_starting_hash_key: str) -> None: new_starting_hash_int = int(new_starting_hash_key) shard = self.shards[shard_to_split] if shard.starting_hash < new_starting_hash_int < shard.ending_hash: pass else: raise InvalidArgumentError( message=f"NewStartingHashKey {new_starting_hash_int} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." ) if not shard.is_open: raise InvalidArgumentError( message=f"Shard {shard.shard_id} in stream {self.stream_name} under account {self.account_id} has already been merged or split, and thus is not eligible for merging or splitting." ) last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[ -1 ]._shard_id # Create two new shards new_shard_1 = Shard( last_id + 1, starting_hash=shard.starting_hash, ending_hash=new_starting_hash_int - 1, parent=shard.shard_id, ) new_shard_2 = Shard( last_id + 2, starting_hash=new_starting_hash_int, ending_hash=shard.ending_hash, parent=shard.shard_id, ) self.shards[new_shard_1.shard_id] = new_shard_1 self.shards[new_shard_2.shard_id] = new_shard_2 shard.is_open = False records = shard.records shard.records = OrderedDict() for index in records: record = records[index] self.put_record(record.partition_key, record.explicit_hash_key, record.data) def merge_shards(self, shard_to_merge: str, adjacent_shard_to_merge: str) -> None: shard1 = self.shards[shard_to_merge] shard2 = self.shards[adjacent_shard_to_merge] # Validate the two shards are adjacent if shard1.ending_hash == (shard2.starting_hash - 1): pass elif shard2.ending_hash == (shard1.starting_hash + 1): pass else: raise InvalidArgumentError(adjacent_shard_to_merge) # Create a new shard last_id = sorted(self.shards.values(), key=attrgetter("_shard_id"))[ -1 ]._shard_id new_shard = Shard( last_id + 1, starting_hash=shard1.starting_hash, ending_hash=shard2.ending_hash, parent=shard1.shard_id, adjacent_parent=shard2.shard_id, ) self.shards[new_shard.shard_id] = new_shard # Close the merged shards shard1.is_open = False shard2.is_open = False # Move all data across for record in shard1.records.values(): new_shard.put_record( record.partition_key, record.data, record.explicit_hash_key ) for record in shard2.records.values(): new_shard.put_record( record.partition_key, record.data, record.explicit_hash_key ) def update_shard_count(self, target_shard_count: int) -> None: if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": raise StreamCannotBeUpdatedError( stream_name=self.stream_name, account_id=self.account_id ) current_shard_count = len([s for s in self.shards.values() if s.is_open]) if current_shard_count == target_shard_count: return # Split shards until we have enough shards # AWS seems to split until we have (current * 2) shards, and then merge until we reach the target # That's what observable at least - the actual algorithm is probably more advanced # if current_shard_count < target_shard_count: open_shards = [ (shard_id, shard) for shard_id, shard in self.shards.items() if shard.is_open ] for shard_id, shard in open_shards: # Split the current shard new_starting_hash_key = str( int((shard.ending_hash + shard.starting_hash) / 2) ) self.split_shard(shard_id, new_starting_hash_key) current_shard_count = len([s for s in self.shards.values() if s.is_open]) # If we need to reduce the shard count, merge shards until we get there while current_shard_count > target_shard_count: # Keep track of how often we need to merge to get to the target shard count required_shard_merges = current_shard_count - target_shard_count # Get a list of pairs of adjacent shards shard_list = sorted( [s for s in self.shards.values() if s.is_open], key=lambda x: x.starting_hash, ) adjacent_shards = zip( [s for s in shard_list[0:-1:2]], [s for s in shard_list[1::2]] ) for (shard, adjacent) in adjacent_shards: self.merge_shards(shard.shard_id, adjacent.shard_id) required_shard_merges -= 1 if required_shard_merges == 0: break current_shard_count = len([s for s in self.shards.values() if s.is_open]) self.shard_count = target_shard_count def get_shard(self, shard_id: str) -> Shard: if shard_id in self.shards: return self.shards[shard_id] else: raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id) def get_shard_for_key( self, partition_key: str, explicit_hash_key: str ) -> Optional[Shard]: if not isinstance(partition_key, str): raise InvalidArgumentError("partition_key") if len(partition_key) > 256: raise InvalidArgumentError("partition_key") if explicit_hash_key: if not isinstance(explicit_hash_key, str): raise InvalidArgumentError("explicit_hash_key") int_key = int(explicit_hash_key) if int_key >= 2**128: raise InvalidArgumentError("explicit_hash_key") else: int_key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16) for shard in self.shards.values(): if shard.starting_hash <= int_key < shard.ending_hash: return shard return None def put_record( self, partition_key: str, explicit_hash_key: str, data: str ) -> Tuple[str, str]: shard = self.get_shard_for_key(partition_key, explicit_hash_key) sequence_number = shard.put_record(partition_key, data, explicit_hash_key) # type: ignore return sequence_number, shard.shard_id # type: ignore def to_json(self, shard_limit: Optional[int] = None) -> Dict[str, Any]: all_shards = list(self.shards.values()) requested_shards = all_shards[0 : shard_limit or len(all_shards)] return { "StreamDescription": { "StreamARN": self.arn, "StreamName": self.stream_name, "StreamCreationTimestamp": self.creation_datetime, "StreamStatus": self.status, "HasMoreShards": len(requested_shards) != len(all_shards), "RetentionPeriodHours": self.retention_period_hours, "EnhancedMonitoring": [{"ShardLevelMetrics": self.shard_level_metrics}], "EncryptionType": self.encryption_type, "KeyId": self.key_id, "Shards": [shard.to_json() for shard in requested_shards], } } def to_json_summary(self) -> Dict[str, Any]: return { "StreamDescriptionSummary": { "StreamARN": self.arn, "StreamName": self.stream_name, "StreamStatus": self.status, "StreamModeDetails": self.stream_mode, "RetentionPeriodHours": self.retention_period_hours, "StreamCreationTimestamp": self.creation_datetime, "EnhancedMonitoring": [{"ShardLevelMetrics": self.shard_level_metrics}], "OpenShardCount": self.shard_count, "EncryptionType": self.encryption_type, "KeyId": self.key_id, } } @staticmethod def cloudformation_name_type() -> str: return "Name" @staticmethod def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html return "AWS::Kinesis::Stream" @classmethod def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, **kwargs: Any, ) -> "Stream": properties = cloudformation_json.get("Properties", {}) shard_count = properties.get("ShardCount", 1) retention_period_hours = properties.get("RetentionPeriodHours", resource_name) tags = { tag_item["Key"]: tag_item["Value"] for tag_item in properties.get("Tags", []) } backend: KinesisBackend = kinesis_backends[account_id][region_name] stream = backend.create_stream( resource_name, shard_count, retention_period_hours=retention_period_hours ) if any(tags): backend.add_tags_to_stream( stream_arn=None, stream_name=stream.stream_name, tags=tags ) return stream @classmethod def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> "Stream": properties = cloudformation_json["Properties"] if Stream.is_replacement_update(properties): resource_name_property = cls.cloudformation_name_type() if resource_name_property not in properties: properties[resource_name_property] = new_resource_name new_resource = cls.create_from_cloudformation_json( resource_name=properties[resource_name_property], cloudformation_json=cloudformation_json, account_id=account_id, region_name=region_name, ) properties[resource_name_property] = original_resource.name cls.delete_from_cloudformation_json( resource_name=original_resource.name, cloudformation_json=cloudformation_json, account_id=account_id, region_name=region_name, ) return new_resource else: # No Interruption if "ShardCount" in properties: original_resource.update_shard_count(properties["ShardCount"]) if "RetentionPeriodHours" in properties: original_resource.retention_period_hours = properties[ "RetentionPeriodHours" ] if "Tags" in properties: original_resource.tags = { tag_item["Key"]: tag_item["Value"] for tag_item in properties.get("Tags", []) } return original_resource @classmethod def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Any, account_id: str, region_name: str, ) -> None: backend: KinesisBackend = kinesis_backends[account_id][region_name] backend.delete_stream(stream_arn=None, stream_name=resource_name) @staticmethod def is_replacement_update(properties: List[str]) -> bool: properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"] return any( [ property_requiring_replacement in properties for property_requiring_replacement in properties_requiring_replacement_update ] ) @classmethod def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() @property def physical_resource_id(self) -> str: return self.stream_name class KinesisBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.streams: Dict[str, Stream] = OrderedDict() @staticmethod def default_vpc_endpoint_service( service_region: str, zones: List[str] ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "kinesis", special_service_name="kinesis-streams" ) def create_stream( self, stream_name: str, shard_count: int, stream_mode: Optional[Dict[str, str]] = None, retention_period_hours: Optional[int] = None, ) -> Stream: if stream_name in self.streams: raise ResourceInUseError(stream_name) stream = Stream( stream_name, shard_count, stream_mode=stream_mode, retention_period_hours=retention_period_hours, account_id=self.account_id, region_name=self.region_name, ) self.streams[stream_name] = stream return stream def describe_stream( self, stream_arn: Optional[str], stream_name: Optional[str] ) -> Stream: if stream_name and stream_name in self.streams: return self.streams[stream_name] if stream_arn: for stream in self.streams.values(): if stream.arn == stream_arn: return stream if stream_arn: stream_name = stream_arn.split("/")[1] raise StreamNotFoundError(stream_name, self.account_id) # type: ignore def describe_stream_summary( self, stream_arn: Optional[str], stream_name: Optional[str] ) -> Stream: return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) def list_streams(self) -> Iterable[Stream]: return self.streams.values() def delete_stream( self, stream_arn: Optional[str], stream_name: Optional[str] ) -> Stream: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) return self.streams.pop(stream.stream_name) def get_shard_iterator( self, stream_arn: Optional[str], stream_name: Optional[str], shard_id: str, shard_iterator_type: str, starting_sequence_number: int, at_timestamp: datetime.datetime, ) -> str: # Validate params stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) try: shard = stream.get_shard(shard_id) except ShardNotFoundError: raise ResourceNotFoundError( message=f"Shard {shard_id} in stream {stream.stream_name} under account {self.account_id} does not exist" ) shard_iterator = compose_new_shard_iterator( stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp, ) return shard_iterator def get_records( self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int] ) -> Tuple[str, List[Record], int]: decomposed = decompose_shard_iterator(shard_iterator) stream_name, shard_id, last_sequence_id = decomposed stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shard = stream.get_shard(shard_id) records, last_sequence_id, millis_behind_latest = shard.get_records( # type: ignore last_sequence_id, limit ) next_shard_iterator = compose_shard_iterator( stream_name, shard, last_sequence_id # type: ignore ) return next_shard_iterator, records, millis_behind_latest def put_record( self, stream_arn: str, stream_name: str, partition_key: str, explicit_hash_key: str, data: str, ) -> Tuple[str, str]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, data ) return sequence_number, shard_id def put_records( self, stream_arn: str, stream_name: str, records: List[Dict[str, Any]] ) -> Dict[str, Any]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) response: Dict[str, Any] = {"FailedRecordCount": 0, "Records": []} if len(records) > 500: raise TooManyRecords data_sizes = [ len(b64decode(r.get("Data", ""))) + len(r.get("PartitionKey", "")) for r in records ] if sum(data_sizes) > 5242880: raise TotalRecordsSizeExceedsLimit idx_over_limit = next( (idx for idx, x in enumerate(data_sizes) if x > 1048576), None ) if idx_over_limit is not None: raise RecordSizeExceedsLimit(position=idx_over_limit + 1) for record in records: partition_key = record.get("PartitionKey") explicit_hash_key = record.get("ExplicitHashKey") data = record.get("Data") sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, data # type: ignore[arg-type] ) response["Records"].append( {"SequenceNumber": sequence_number, "ShardId": shard_id} ) return response def split_shard( self, stream_arn: str, stream_name: str, shard_to_split: str, new_starting_hash_key: str, ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if not re.match("[a-zA-Z0-9_.-]+", shard_to_split): raise ValidationException( value=shard_to_split, position="shardToSplit", regex_to_match="[a-zA-Z0-9_.-]+", ) if shard_to_split not in stream.shards: raise ShardNotFoundError( shard_id=shard_to_split, stream=stream_name, account_id=self.account_id ) if not re.match(r"0|([1-9]\d{0,38})", new_starting_hash_key): raise ValidationException( value=new_starting_hash_key, position="newStartingHashKey", regex_to_match=r"0|([1-9]\d{0,38})", ) stream.split_shard(shard_to_split, new_starting_hash_key) def merge_shards( self, stream_arn: str, stream_name: str, shard_to_merge: str, adjacent_shard_to_merge: str, ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if shard_to_merge not in stream.shards: raise ShardNotFoundError( shard_to_merge, stream=stream.stream_name, account_id=self.account_id ) if adjacent_shard_to_merge not in stream.shards: raise ShardNotFoundError( adjacent_shard_to_merge, stream=stream.stream_name, account_id=self.account_id, ) stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) def update_shard_count( self, stream_arn: str, stream_name: str, target_shard_count: int ) -> int: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_count = len([s for s in stream.shards.values() if s.is_open]) stream.update_shard_count(target_shard_count) return current_shard_count @paginate(pagination_model=PAGINATION_MODEL) # type: ignore[misc] def list_shards( self, stream_arn: Optional[str], stream_name: Optional[str] ) -> List[Dict[str, Any]]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shards = sorted(stream.shards.values(), key=lambda x: x.shard_id) return [shard.to_json() for shard in shards] def increase_stream_retention_period( self, stream_arn: Optional[str], stream_name: Optional[str], retention_period_hours: int, ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if retention_period_hours < 24: raise InvalidRetentionPeriod(retention_period_hours, too_short=True) if retention_period_hours > 8760: raise InvalidRetentionPeriod(retention_period_hours, too_short=False) if retention_period_hours < stream.retention_period_hours: raise InvalidIncreaseRetention( name=stream_name, requested=retention_period_hours, existing=stream.retention_period_hours, ) stream.retention_period_hours = retention_period_hours def decrease_stream_retention_period( self, stream_arn: Optional[str], stream_name: Optional[str], retention_period_hours: int, ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if retention_period_hours < 24: raise InvalidRetentionPeriod(retention_period_hours, too_short=True) if retention_period_hours > 8760: raise InvalidRetentionPeriod(retention_period_hours, too_short=False) if retention_period_hours > stream.retention_period_hours: raise InvalidDecreaseRetention( name=stream_name, requested=retention_period_hours, existing=stream.retention_period_hours, ) stream.retention_period_hours = retention_period_hours def list_tags_for_stream( self, stream_arn: str, stream_name: str, exclusive_start_tag_key: Optional[str] = None, limit: Optional[int] = None, ) -> Dict[str, Any]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) tags: List[Dict[str, str]] = [] result: Dict[str, Any] = {"HasMoreTags": False, "Tags": tags} for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): if limit and len(tags) >= limit: result["HasMoreTags"] = True break if exclusive_start_tag_key and key < exclusive_start_tag_key: continue tags.append({"Key": key, "Value": val}) return result def add_tags_to_stream( self, stream_arn: Optional[str], stream_name: Optional[str], tags: Dict[str, str], ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.tags.update(tags) def remove_tags_from_stream( self, stream_arn: Optional[str], stream_name: Optional[str], tag_keys: List[str] ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) for key in tag_keys: if key in stream.tags: del stream.tags[key] def enable_enhanced_monitoring( self, stream_arn: Optional[str], stream_name: Optional[str], shard_level_metrics: List[str], ) -> Tuple[str, str, List[str], List[str]]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_level_metrics = stream.shard_level_metrics desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) stream.shard_level_metrics = desired_metrics return ( stream.arn, stream.stream_name, current_shard_level_metrics, desired_metrics, ) def disable_enhanced_monitoring( self, stream_arn: Optional[str], stream_name: Optional[str], to_be_disabled: List[str], ) -> Tuple[str, str, List[str], List[str]]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_metrics = stream.shard_level_metrics if "ALL" in to_be_disabled: desired_metrics = [] else: desired_metrics = [ metric for metric in current_metrics if metric not in to_be_disabled ] stream.shard_level_metrics = desired_metrics return stream.arn, stream.stream_name, current_metrics, desired_metrics def _find_stream_by_arn(self, stream_arn: str) -> Stream: # type: ignore[return] for stream in self.streams.values(): if stream.arn == stream_arn: return stream def list_stream_consumers(self, stream_arn: str) -> List[Consumer]: """ Pagination is not yet implemented """ stream = self._find_stream_by_arn(stream_arn) return stream.consumers def register_stream_consumer(self, stream_arn: str, consumer_name: str) -> Consumer: consumer = Consumer( consumer_name, self.account_id, self.region_name, stream_arn ) stream = self._find_stream_by_arn(stream_arn) stream.consumers.append(consumer) return consumer def describe_stream_consumer( self, stream_arn: str, consumer_name: str, consumer_arn: str ) -> Consumer: if stream_arn: stream = self._find_stream_by_arn(stream_arn) for consumer in stream.consumers: if consumer_name and consumer.consumer_name == consumer_name: return consumer if consumer_arn: for stream in self.streams.values(): _consumer = stream.get_consumer_by_arn(consumer_arn) if _consumer: return _consumer raise ConsumerNotFound( consumer=consumer_name or consumer_arn, account_id=self.account_id ) def deregister_stream_consumer( self, stream_arn: str, consumer_name: str, consumer_arn: str ) -> None: if stream_arn: stream = self._find_stream_by_arn(stream_arn) stream.consumers = [ c for c in stream.consumers if c.consumer_name == consumer_name ] if consumer_arn: for stream in self.streams.values(): # Only one stream will actually have this consumer # It will be a noop for other streams stream.delete_consumer(consumer_arn) def start_stream_encryption( self, stream_arn: str, stream_name: str, encryption_type: str, key_id: str ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = encryption_type stream.key_id = key_id def stop_stream_encryption(self, stream_arn: str, stream_name: str) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = "NONE" stream.key_id = None def update_stream_mode(self, stream_arn: str, stream_mode: Dict[str, str]) -> None: stream = self._find_stream_by_arn(stream_arn) stream.stream_mode = stream_mode """Send log events to a Stream after encoding and gzipping it.""" def send_log_event( self, delivery_stream_arn: str, filter_name: str, log_group_name: str, log_stream_name: str, log_events: List[Dict[str, Any]], ) -> None: data = { "logEvents": log_events, "logGroup": log_group_name, "logStream": log_stream_name, "messageType": "DATA_MESSAGE", "owner": self.account_id, "subscriptionFilters": [filter_name], } output = io.BytesIO() with GzipFile(fileobj=output, mode="w") as fhandle: fhandle.write(json.dumps(data, separators=(",", ":")).encode("utf-8")) gzipped_payload = b64encode(output.getvalue()).decode("UTF-8") stream = self.describe_stream(stream_arn=delivery_stream_arn, stream_name=None) random_partition_key = random.get_random_string(length=32, lower_case=True) stream.put_record( partition_key=random_partition_key, data=gzipped_payload, explicit_hash_key="", ) kinesis_backends = BackendDict(KinesisBackend, "kinesis")