Kinesis - support Stream ARNs across all methods (#5893)

This commit is contained in:
Bert Blommers 2023-02-01 15:16:25 -01:00 committed by GitHub
parent 67ecc3b1db
commit 19bfa92dd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 314 additions and 109 deletions

View File

@ -4,6 +4,7 @@ import re
import itertools import itertools
from operator import attrgetter from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -439,12 +440,14 @@ class Stream(CloudFormationModel):
for tag_item in properties.get("Tags", []) for tag_item in properties.get("Tags", [])
} }
backend = kinesis_backends[account_id][region_name] backend: KinesisBackend = kinesis_backends[account_id][region_name]
stream = backend.create_stream( stream = backend.create_stream(
resource_name, shard_count, retention_period_hours=retention_period_hours resource_name, shard_count, retention_period_hours=retention_period_hours
) )
if any(tags): if any(tags):
backend.add_tags_to_stream(stream.stream_name, tags) backend.add_tags_to_stream(
stream_arn=None, stream_name=stream.stream_name, tags=tags
)
return stream return stream
@classmethod @classmethod
@ -489,8 +492,8 @@ class Stream(CloudFormationModel):
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name cls, resource_name, cloudformation_json, account_id, region_name
): ):
backend = kinesis_backends[account_id][region_name] backend: KinesisBackend = kinesis_backends[account_id][region_name]
backend.delete_stream(resource_name) backend.delete_stream(stream_arn=None, stream_name=resource_name)
@staticmethod @staticmethod
def is_replacement_update(properties): def is_replacement_update(properties):
@ -521,7 +524,7 @@ class Stream(CloudFormationModel):
class KinesisBackend(BaseBackend): class KinesisBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name, account_id):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.streams = OrderedDict() self.streams: Dict[str, Stream] = OrderedDict()
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region, zones):
@ -546,38 +549,49 @@ class KinesisBackend(BaseBackend):
self.streams[stream_name] = stream self.streams[stream_name] = stream
return stream return stream
def describe_stream(self, stream_name) -> Stream: def describe_stream(
if stream_name in self.streams: self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
if stream_name and stream_name in self.streams:
return self.streams[stream_name] return self.streams[stream_name]
else: if stream_arn:
raise StreamNotFoundError(stream_name, self.account_id) for stream in self.streams.values():
if stream.arn == stream_arn:
return stream
if stream_arn:
stream_name = stream_arn.split("/")[1]
raise StreamNotFoundError(stream_name, self.account_id)
def describe_stream_summary(self, stream_name): def describe_stream_summary(
return self.describe_stream(stream_name) self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
def list_streams(self): def list_streams(self):
return self.streams.values() return self.streams.values()
def delete_stream(self, stream_name): def delete_stream(
if stream_name in self.streams: self, stream_arn: Optional[str], stream_name: Optional[str]
return self.streams.pop(stream_name) ) -> Stream:
raise StreamNotFoundError(stream_name, self.account_id) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
return self.streams.pop(stream.stream_name)
def get_shard_iterator( def get_shard_iterator(
self, self,
stream_name, stream_arn: Optional[str],
shard_id, stream_name: Optional[str],
shard_iterator_type, shard_id: str,
starting_sequence_number, shard_iterator_type: str,
at_timestamp, starting_sequence_number: str,
at_timestamp: str,
): ):
# Validate params # Validate params
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
try: try:
shard = stream.get_shard(shard_id) shard = stream.get_shard(shard_id)
except ShardNotFoundError: except ShardNotFoundError:
raise ResourceNotFoundError( raise ResourceNotFoundError(
message=f"Shard {shard_id} in stream {stream_name} under account {self.account_id} does not exist" message=f"Shard {shard_id} in stream {stream.stream_name} under account {self.account_id} does not exist"
) )
shard_iterator = compose_new_shard_iterator( shard_iterator = compose_new_shard_iterator(
@ -589,11 +603,13 @@ class KinesisBackend(BaseBackend):
) )
return shard_iterator return shard_iterator
def get_records(self, shard_iterator, limit): def get_records(
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
):
decomposed = decompose_shard_iterator(shard_iterator) decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed stream_name, shard_id, last_sequence_id = decomposed
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shard = stream.get_shard(shard_id) shard = stream.get_shard(shard_id)
records, last_sequence_id, millis_behind_latest = shard.get_records( records, last_sequence_id, millis_behind_latest = shard.get_records(
@ -608,12 +624,13 @@ class KinesisBackend(BaseBackend):
def put_record( def put_record(
self, self,
stream_arn,
stream_name, stream_name,
partition_key, partition_key,
explicit_hash_key, explicit_hash_key,
data, data,
): ):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
sequence_number, shard_id = stream.put_record( sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, data partition_key, explicit_hash_key, data
@ -621,8 +638,8 @@ class KinesisBackend(BaseBackend):
return sequence_number, shard_id return sequence_number, shard_id
def put_records(self, stream_name, records): def put_records(self, stream_arn, stream_name, records):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
response = {"FailedRecordCount": 0, "Records": []} response = {"FailedRecordCount": 0, "Records": []}
@ -651,8 +668,10 @@ class KinesisBackend(BaseBackend):
return response return response
def split_shard(self, stream_name, shard_to_split, new_starting_hash_key): def split_shard(
stream = self.describe_stream(stream_name) self, stream_arn, stream_name, shard_to_split, new_starting_hash_key
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if not re.match("[a-zA-Z0-9_.-]+", shard_to_split): if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
raise ValidationException( raise ValidationException(
@ -675,23 +694,27 @@ class KinesisBackend(BaseBackend):
stream.split_shard(shard_to_split, new_starting_hash_key) stream.split_shard(shard_to_split, new_starting_hash_key)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge): def merge_shards(
stream = self.describe_stream(stream_name) self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if shard_to_merge not in stream.shards: if shard_to_merge not in stream.shards:
raise ShardNotFoundError( raise ShardNotFoundError(
shard_to_merge, stream=stream_name, account_id=self.account_id shard_to_merge, stream=stream.stream_name, account_id=self.account_id
) )
if adjacent_shard_to_merge not in stream.shards: if adjacent_shard_to_merge not in stream.shards:
raise ShardNotFoundError( raise ShardNotFoundError(
adjacent_shard_to_merge, stream=stream_name, account_id=self.account_id adjacent_shard_to_merge,
stream=stream.stream_name,
account_id=self.account_id,
) )
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
def update_shard_count(self, stream_name, target_shard_count): def update_shard_count(self, stream_arn, stream_name, target_shard_count):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open]) current_shard_count = len([s for s in stream.shards.values() if s.is_open])
stream.update_shard_count(target_shard_count) stream.update_shard_count(target_shard_count)
@ -699,13 +722,18 @@ class KinesisBackend(BaseBackend):
return current_shard_count return current_shard_count
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_name): def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id) shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
return [shard.to_json() for shard in shards] return [shard.to_json() for shard in shards]
def increase_stream_retention_period(self, stream_name, retention_period_hours): def increase_stream_retention_period(
stream = self.describe_stream(stream_name) self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24: if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True) raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760: if retention_period_hours > 8760:
@ -718,8 +746,13 @@ class KinesisBackend(BaseBackend):
) )
stream.retention_period_hours = retention_period_hours stream.retention_period_hours = retention_period_hours
def decrease_stream_retention_period(self, stream_name, retention_period_hours): def decrease_stream_retention_period(
stream = self.describe_stream(stream_name) self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24: if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True) raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760: if retention_period_hours > 8760:
@ -733,9 +766,9 @@ class KinesisBackend(BaseBackend):
stream.retention_period_hours = retention_period_hours stream.retention_period_hours = retention_period_hours
def list_tags_for_stream( def list_tags_for_stream(
self, stream_name, exclusive_start_tag_key=None, limit=None self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None
): ):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
tags = [] tags = []
result = {"HasMoreTags": False, "Tags": tags} result = {"HasMoreTags": False, "Tags": tags}
@ -750,25 +783,47 @@ class KinesisBackend(BaseBackend):
return result return result
def add_tags_to_stream(self, stream_name, tags): def add_tags_to_stream(
stream = self.describe_stream(stream_name) self,
stream_arn: Optional[str],
stream_name: Optional[str],
tags: Dict[str, str],
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.tags.update(tags) stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys): def remove_tags_from_stream(
stream = self.describe_stream(stream_name) self, stream_arn: Optional[str], stream_name: Optional[str], tag_keys: List[str]
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
for key in tag_keys: for key in tag_keys:
if key in stream.tags: if key in stream.tags:
del stream.tags[key] del stream.tags[key]
def enable_enhanced_monitoring(self, stream_name, shard_level_metrics): def enable_enhanced_monitoring(
stream = self.describe_stream(stream_name) self,
stream_arn: Optional[str],
stream_name: Optional[str],
shard_level_metrics: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_level_metrics = stream.shard_level_metrics current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
stream.shard_level_metrics = desired_metrics stream.shard_level_metrics = desired_metrics
return current_shard_level_metrics, desired_metrics return (
stream.arn,
stream.stream_name,
current_shard_level_metrics,
desired_metrics,
)
def disable_enhanced_monitoring(self, stream_name, to_be_disabled): def disable_enhanced_monitoring(
stream = self.describe_stream(stream_name) self,
stream_arn: Optional[str],
stream_name: Optional[str],
to_be_disabled: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_metrics = stream.shard_level_metrics current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled: if "ALL" in to_be_disabled:
desired_metrics = [] desired_metrics = []
@ -777,7 +832,7 @@ class KinesisBackend(BaseBackend):
metric for metric in current_metrics if metric not in to_be_disabled metric for metric in current_metrics if metric not in to_be_disabled
] ]
stream.shard_level_metrics = desired_metrics stream.shard_level_metrics = desired_metrics
return current_metrics, desired_metrics return stream.arn, stream.stream_name, current_metrics, desired_metrics
def _find_stream_by_arn(self, stream_arn): def _find_stream_by_arn(self, stream_arn):
for stream in self.streams.values(): for stream in self.streams.values():
@ -826,13 +881,13 @@ class KinesisBackend(BaseBackend):
# It will be a noop for other streams # It will be a noop for other streams
stream.delete_consumer(consumer_arn) stream.delete_consumer(consumer_arn)
def start_stream_encryption(self, stream_name, encryption_type, key_id): def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = encryption_type stream.encryption_type = encryption_type
stream.key_id = key_id stream.key_id = key_id
def stop_stream_encryption(self, stream_name): def stop_stream_encryption(self, stream_arn, stream_name):
stream = self.describe_stream(stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = "NONE" stream.encryption_type = "NONE"
stream.key_id = None stream.key_id = None

View File

@ -1,7 +1,7 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kinesis_backends from .models import kinesis_backends, KinesisBackend
class KinesisResponse(BaseResponse): class KinesisResponse(BaseResponse):
@ -13,7 +13,7 @@ class KinesisResponse(BaseResponse):
return json.loads(self.body) return json.loads(self.body)
@property @property
def kinesis_backend(self): def kinesis_backend(self) -> KinesisBackend:
return kinesis_backends[self.current_account][self.region] return kinesis_backends[self.current_account][self.region]
def create_stream(self): def create_stream(self):
@ -27,13 +27,15 @@ class KinesisResponse(BaseResponse):
def describe_stream(self): def describe_stream(self):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
stream_arn = self.parameters.get("StreamARN")
limit = self.parameters.get("Limit") limit = self.parameters.get("Limit")
stream = self.kinesis_backend.describe_stream(stream_name) stream = self.kinesis_backend.describe_stream(stream_arn, stream_name)
return json.dumps(stream.to_json(shard_limit=limit)) return json.dumps(stream.to_json(shard_limit=limit))
def describe_stream_summary(self): def describe_stream_summary(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
stream = self.kinesis_backend.describe_stream_summary(stream_name) stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name)
return json.dumps(stream.to_json_summary()) return json.dumps(stream.to_json_summary())
def list_streams(self): def list_streams(self):
@ -58,11 +60,13 @@ class KinesisResponse(BaseResponse):
) )
def delete_stream(self): def delete_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
self.kinesis_backend.delete_stream(stream_name) self.kinesis_backend.delete_stream(stream_arn, stream_name)
return "" return ""
def get_shard_iterator(self): def get_shard_iterator(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_id = self.parameters.get("ShardId") shard_id = self.parameters.get("ShardId")
shard_iterator_type = self.parameters.get("ShardIteratorType") shard_iterator_type = self.parameters.get("ShardIteratorType")
@ -70,6 +74,7 @@ class KinesisResponse(BaseResponse):
at_timestamp = self.parameters.get("Timestamp") at_timestamp = self.parameters.get("Timestamp")
shard_iterator = self.kinesis_backend.get_shard_iterator( shard_iterator = self.kinesis_backend.get_shard_iterator(
stream_arn,
stream_name, stream_name,
shard_id, shard_id,
shard_iterator_type, shard_iterator_type,
@ -80,6 +85,7 @@ class KinesisResponse(BaseResponse):
return json.dumps({"ShardIterator": shard_iterator}) return json.dumps({"ShardIterator": shard_iterator})
def get_records(self): def get_records(self):
stream_arn = self.parameters.get("StreamARN")
shard_iterator = self.parameters.get("ShardIterator") shard_iterator = self.parameters.get("ShardIterator")
limit = self.parameters.get("Limit") limit = self.parameters.get("Limit")
@ -87,7 +93,7 @@ class KinesisResponse(BaseResponse):
next_shard_iterator, next_shard_iterator,
records, records,
millis_behind_latest, millis_behind_latest,
) = self.kinesis_backend.get_records(shard_iterator, limit) ) = self.kinesis_backend.get_records(stream_arn, shard_iterator, limit)
return json.dumps( return json.dumps(
{ {
@ -98,12 +104,14 @@ class KinesisResponse(BaseResponse):
) )
def put_record(self): def put_record(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
partition_key = self.parameters.get("PartitionKey") partition_key = self.parameters.get("PartitionKey")
explicit_hash_key = self.parameters.get("ExplicitHashKey") explicit_hash_key = self.parameters.get("ExplicitHashKey")
data = self.parameters.get("Data") data = self.parameters.get("Data")
sequence_number, shard_id = self.kinesis_backend.put_record( sequence_number, shard_id = self.kinesis_backend.put_record(
stream_arn,
stream_name, stream_name,
partition_key, partition_key,
explicit_hash_key, explicit_hash_key,
@ -113,37 +121,44 @@ class KinesisResponse(BaseResponse):
return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id})
def put_records(self): def put_records(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
records = self.parameters.get("Records") records = self.parameters.get("Records")
response = self.kinesis_backend.put_records(stream_name, records) response = self.kinesis_backend.put_records(stream_arn, stream_name, records)
return json.dumps(response) return json.dumps(response)
def split_shard(self): def split_shard(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit") shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey") new_starting_hash_key = self.parameters.get("NewStartingHashKey")
self.kinesis_backend.split_shard( self.kinesis_backend.split_shard(
stream_name, shard_to_split, new_starting_hash_key stream_arn, stream_name, shard_to_split, new_starting_hash_key
) )
return "" return ""
def merge_shards(self): def merge_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge") shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
self.kinesis_backend.merge_shards( self.kinesis_backend.merge_shards(
stream_name, shard_to_merge, adjacent_shard_to_merge stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
) )
return "" return ""
def list_shards(self): def list_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
next_token = self.parameters.get("NextToken") next_token = self.parameters.get("NextToken")
max_results = self.parameters.get("MaxResults", 10000) max_results = self.parameters.get("MaxResults", 10000)
shards, token = self.kinesis_backend.list_shards( shards, token = self.kinesis_backend.list_shards(
stream_name=stream_name, limit=max_results, next_token=next_token stream_arn=stream_arn,
stream_name=stream_name,
limit=max_results,
next_token=next_token,
) )
res = {"Shards": shards} res = {"Shards": shards}
if token: if token:
@ -151,10 +166,13 @@ class KinesisResponse(BaseResponse):
return json.dumps(res) return json.dumps(res)
def update_shard_count(self): def update_shard_count(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
target_shard_count = self.parameters.get("TargetShardCount") target_shard_count = self.parameters.get("TargetShardCount")
current_shard_count = self.kinesis_backend.update_shard_count( current_shard_count = self.kinesis_backend.update_shard_count(
stream_name=stream_name, target_shard_count=target_shard_count stream_arn=stream_arn,
stream_name=stream_name,
target_shard_count=target_shard_count,
) )
return json.dumps( return json.dumps(
dict( dict(
@ -165,67 +183,80 @@ class KinesisResponse(BaseResponse):
) )
def increase_stream_retention_period(self): def increase_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours") retention_period_hours = self.parameters.get("RetentionPeriodHours")
self.kinesis_backend.increase_stream_retention_period( self.kinesis_backend.increase_stream_retention_period(
stream_name, retention_period_hours stream_arn, stream_name, retention_period_hours
) )
return "" return ""
def decrease_stream_retention_period(self): def decrease_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours") retention_period_hours = self.parameters.get("RetentionPeriodHours")
self.kinesis_backend.decrease_stream_retention_period( self.kinesis_backend.decrease_stream_retention_period(
stream_name, retention_period_hours stream_arn, stream_name, retention_period_hours
) )
return "" return ""
def add_tags_to_stream(self): def add_tags_to_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
tags = self.parameters.get("Tags") tags = self.parameters.get("Tags")
self.kinesis_backend.add_tags_to_stream(stream_name, tags) self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags)
return json.dumps({}) return json.dumps({})
def list_tags_for_stream(self): def list_tags_for_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey")
limit = self.parameters.get("Limit") limit = self.parameters.get("Limit")
response = self.kinesis_backend.list_tags_for_stream( response = self.kinesis_backend.list_tags_for_stream(
stream_name, exclusive_start_tag_key, limit stream_arn, stream_name, exclusive_start_tag_key, limit
) )
return json.dumps(response) return json.dumps(response)
def remove_tags_from_stream(self): def remove_tags_from_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
tag_keys = self.parameters.get("TagKeys") tag_keys = self.parameters.get("TagKeys")
self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys) self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys)
return json.dumps({}) return json.dumps({})
def enable_enhanced_monitoring(self): def enable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics") shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.enable_enhanced_monitoring( arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_name=stream_name, shard_level_metrics=shard_level_metrics stream_arn=stream_arn,
stream_name=stream_name,
shard_level_metrics=shard_level_metrics,
) )
return json.dumps( return json.dumps(
dict( dict(
StreamName=stream_name, StreamName=name,
CurrentShardLevelMetrics=current, CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired, DesiredShardLevelMetrics=desired,
StreamARN=arn,
) )
) )
def disable_enhanced_monitoring(self): def disable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics") shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.disable_enhanced_monitoring( arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_name=stream_name, to_be_disabled=shard_level_metrics stream_arn=stream_arn,
stream_name=stream_name,
to_be_disabled=shard_level_metrics,
) )
return json.dumps( return json.dumps(
dict( dict(
StreamName=stream_name, StreamName=name,
CurrentShardLevelMetrics=current, CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired, DesiredShardLevelMetrics=desired,
StreamARN=arn,
) )
) )
@ -267,17 +298,24 @@ class KinesisResponse(BaseResponse):
return json.dumps(dict()) return json.dumps(dict())
def start_stream_encryption(self): def start_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
encryption_type = self.parameters.get("EncryptionType") encryption_type = self.parameters.get("EncryptionType")
key_id = self.parameters.get("KeyId") key_id = self.parameters.get("KeyId")
self.kinesis_backend.start_stream_encryption( self.kinesis_backend.start_stream_encryption(
stream_name=stream_name, encryption_type=encryption_type, key_id=key_id stream_arn=stream_arn,
stream_name=stream_name,
encryption_type=encryption_type,
key_id=key_id,
) )
return json.dumps(dict()) return json.dumps(dict())
def stop_stream_encryption(self): def stop_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
self.kinesis_backend.stop_stream_encryption(stream_name=stream_name) self.kinesis_backend.stop_stream_encryption(
stream_arn=stream_arn, stream_name=stream_name
)
return json.dumps(dict()) return json.dumps(dict())
def update_stream_mode(self): def update_stream_mode(self):

View File

@ -6,6 +6,8 @@ url_bases = [
# Somewhere around boto3-1.26.31 botocore-1.29.31, AWS started using a new endpoint: # Somewhere around boto3-1.26.31 botocore-1.29.31, AWS started using a new endpoint:
# 111122223333.control-kinesis.us-east-1.amazonaws.com # 111122223333.control-kinesis.us-east-1.amazonaws.com
r"https?://(.+)\.control-kinesis\.(.+)\.amazonaws\.com", r"https?://(.+)\.control-kinesis\.(.+)\.amazonaws\.com",
# When passing in the StreamARN to get_shard_iterator/get_records, this endpoint is called:
r"https?://(.+)\.data-kinesis\.(.+)\.amazonaws\.com",
] ]
url_paths = {"{0}/$": KinesisResponse.dispatch} url_paths = {"{0}/$": KinesisResponse.dispatch}

View File

@ -19,15 +19,17 @@ def test_stream_creation_on_demand():
client.create_stream( client.create_stream(
StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"} StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"}
) )
# At the same time, test whether we can pass the StreamARN instead of the name
stream_arn = get_stream_arn(client, "my_stream")
# AWS starts with 4 shards by default # AWS starts with 4 shards by default
shard_list = client.list_shards(StreamName="my_stream")["Shards"] shard_list = client.list_shards(StreamARN=stream_arn)["Shards"]
shard_list.should.have.length_of(4) shard_list.should.have.length_of(4)
# Cannot update-shard-count when we're in on-demand mode # Cannot update-shard-count when we're in on-demand mode
with pytest.raises(ClientError) as exc: with pytest.raises(ClientError) as exc:
client.update_shard_count( client.update_shard_count(
StreamName="my_stream", TargetShardCount=3, ScalingType="UNIFORM_SCALING" StreamARN=stream_arn, TargetShardCount=3, ScalingType="UNIFORM_SCALING"
) )
err = exc.value.response["Error"] err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException") err["Code"].should.equal("ValidationException")
@ -39,7 +41,7 @@ def test_stream_creation_on_demand():
@mock_kinesis @mock_kinesis
def test_update_stream_mode(): def test_update_stream_mode():
client = boto3.client("kinesis", region_name="eu-west-1") client = boto3.client("kinesis", region_name="eu-west-1")
resp = client.create_stream( client.create_stream(
StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"} StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"}
) )
arn = client.describe_stream(StreamName="my_stream")["StreamDescription"][ arn = client.describe_stream(StreamName="my_stream")["StreamDescription"][
@ -56,7 +58,7 @@ def test_update_stream_mode():
@mock_kinesis @mock_kinesis
def test_describe_non_existent_stream_boto3(): def test_describe_non_existent_stream():
client = boto3.client("kinesis", region_name="us-west-2") client = boto3.client("kinesis", region_name="us-west-2")
with pytest.raises(ClientError) as exc: with pytest.raises(ClientError) as exc:
client.describe_stream_summary(StreamName="not-a-stream") client.describe_stream_summary(StreamName="not-a-stream")
@ -68,7 +70,7 @@ def test_describe_non_existent_stream_boto3():
@mock_kinesis @mock_kinesis
def test_list_and_delete_stream_boto3(): def test_list_and_delete_stream():
client = boto3.client("kinesis", region_name="us-west-2") client = boto3.client("kinesis", region_name="us-west-2")
client.list_streams()["StreamNames"].should.have.length_of(0) client.list_streams()["StreamNames"].should.have.length_of(0)
@ -79,6 +81,10 @@ def test_list_and_delete_stream_boto3():
client.delete_stream(StreamName="stream1") client.delete_stream(StreamName="stream1")
client.list_streams()["StreamNames"].should.have.length_of(1) client.list_streams()["StreamNames"].should.have.length_of(1)
stream_arn = get_stream_arn(client, "stream2")
client.delete_stream(StreamARN=stream_arn)
client.list_streams()["StreamNames"].should.have.length_of(0)
@mock_kinesis @mock_kinesis
def test_delete_unknown_stream(): def test_delete_unknown_stream():
@ -128,9 +134,15 @@ def test_describe_stream_summary():
) )
stream["StreamStatus"].should.equal("ACTIVE") stream["StreamStatus"].should.equal("ACTIVE")
stream_arn = get_stream_arn(conn, stream_name)
resp = conn.describe_stream_summary(StreamARN=stream_arn)
stream = resp["StreamDescriptionSummary"]
stream["StreamName"].should.equal(stream_name)
@mock_kinesis @mock_kinesis
def test_basic_shard_iterator_boto3(): def test_basic_shard_iterator():
client = boto3.client("kinesis", region_name="us-west-1") client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream" stream_name = "mystream"
@ -149,7 +161,30 @@ def test_basic_shard_iterator_boto3():
@mock_kinesis @mock_kinesis
def test_get_invalid_shard_iterator_boto3(): def test_basic_shard_iterator_by_stream_arn():
client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream"
client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shard_id = stream["Shards"][0]["ShardId"]
resp = client.get_shard_iterator(
StreamARN=stream["StreamARN"],
ShardId=shard_id,
ShardIteratorType="TRIM_HORIZON",
)
shard_iterator = resp["ShardIterator"]
resp = client.get_records(
StreamARN=stream["StreamARN"], ShardIterator=shard_iterator
)
resp.should.have.key("Records").length_of(0)
resp.should.have.key("MillisBehindLatest").equal(0)
@mock_kinesis
def test_get_invalid_shard_iterator():
client = boto3.client("kinesis", region_name="us-west-1") client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream" stream_name = "mystream"
@ -169,21 +204,22 @@ def test_get_invalid_shard_iterator_boto3():
@mock_kinesis @mock_kinesis
def test_put_records_boto3(): def test_put_records():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
stream_arn = stream["StreamARN"]
shard_id = stream["Shards"][0]["ShardId"] shard_id = stream["Shards"][0]["ShardId"]
data = b"hello world" data = b"hello world"
partition_key = "1234" partition_key = "1234"
response = client.put_record( client.put_records(
StreamName=stream_name, Data=data, PartitionKey=partition_key Records=[{"Data": data, "PartitionKey": partition_key}] * 5,
StreamARN=stream_arn,
) )
response["SequenceNumber"].should.equal("1")
resp = client.get_shard_iterator( resp = client.get_shard_iterator(
StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON"
@ -191,27 +227,28 @@ def test_put_records_boto3():
shard_iterator = resp["ShardIterator"] shard_iterator = resp["ShardIterator"]
resp = client.get_records(ShardIterator=shard_iterator) resp = client.get_records(ShardIterator=shard_iterator)
resp["Records"].should.have.length_of(1) resp["Records"].should.have.length_of(5)
record = resp["Records"][0] record = resp["Records"][0]
record["Data"].should.equal(b"hello world") record["Data"].should.equal(data)
record["PartitionKey"].should.equal("1234") record["PartitionKey"].should.equal(partition_key)
record["SequenceNumber"].should.equal("1") record["SequenceNumber"].should.equal("1")
@mock_kinesis @mock_kinesis
def test_get_records_limit_boto3(): def test_get_records_limit():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
stream_arn = stream["StreamARN"]
shard_id = stream["Shards"][0]["ShardId"] shard_id = stream["Shards"][0]["ShardId"]
data = b"hello world" data = b"hello world"
for index in range(5): for index in range(5):
client.put_record(StreamName=stream_name, Data=data, PartitionKey=str(index)) client.put_record(StreamARN=stream_arn, Data=data, PartitionKey=str(index))
resp = client.get_shard_iterator( resp = client.get_shard_iterator(
StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON"
@ -229,7 +266,7 @@ def test_get_records_limit_boto3():
@mock_kinesis @mock_kinesis
def test_get_records_at_sequence_number_boto3(): def test_get_records_at_sequence_number():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
@ -268,7 +305,7 @@ def test_get_records_at_sequence_number_boto3():
@mock_kinesis @mock_kinesis
def test_get_records_after_sequence_number_boto3(): def test_get_records_after_sequence_number():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
@ -308,7 +345,7 @@ def test_get_records_after_sequence_number_boto3():
@mock_kinesis @mock_kinesis
def test_get_records_latest_boto3(): def test_get_records_latest():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
@ -607,6 +644,7 @@ def test_valid_decrease_stream_retention_period():
conn = boto3.client("kinesis", region_name="us-west-2") conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "decrease_stream" stream_name = "decrease_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1) conn.create_stream(StreamName=stream_name, ShardCount=1)
stream_arn = get_stream_arn(conn, stream_name)
conn.increase_stream_retention_period( conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=30 StreamName=stream_name, RetentionPeriodHours=30
@ -618,6 +656,12 @@ def test_valid_decrease_stream_retention_period():
response = conn.describe_stream(StreamName=stream_name) response = conn.describe_stream(StreamName=stream_name)
response["StreamDescription"]["RetentionPeriodHours"].should.equal(25) response["StreamDescription"]["RetentionPeriodHours"].should.equal(25)
conn.increase_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=29)
conn.decrease_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=26)
response = conn.describe_stream(StreamARN=stream_arn)
response["StreamDescription"]["RetentionPeriodHours"].should.equal(26)
@mock_kinesis @mock_kinesis
def test_decrease_stream_retention_period_upwards(): def test_decrease_stream_retention_period_upwards():
@ -671,7 +715,7 @@ def test_decrease_stream_retention_period_too_high():
@mock_kinesis @mock_kinesis
def test_invalid_shard_iterator_type_boto3(): def test_invalid_shard_iterator_type():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
@ -688,10 +732,11 @@ def test_invalid_shard_iterator_type_boto3():
@mock_kinesis @mock_kinesis
def test_add_list_remove_tags_boto3(): def test_add_list_remove_tags():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1) client.create_stream(StreamName=stream_name, ShardCount=1)
stream_arn = get_stream_arn(client, stream_name)
client.add_tags_to_stream( client.add_tags_to_stream(
StreamName=stream_name, StreamName=stream_name,
Tags={"tag1": "val1", "tag2": "val2", "tag3": "val3", "tag4": "val4"}, Tags={"tag1": "val1", "tag2": "val2", "tag3": "val3", "tag4": "val4"},
@ -704,9 +749,9 @@ def test_add_list_remove_tags_boto3():
tags.should.contain({"Key": "tag3", "Value": "val3"}) tags.should.contain({"Key": "tag3", "Value": "val3"})
tags.should.contain({"Key": "tag4", "Value": "val4"}) tags.should.contain({"Key": "tag4", "Value": "val4"})
client.add_tags_to_stream(StreamName=stream_name, Tags={"tag5": "val5"}) client.add_tags_to_stream(StreamARN=stream_arn, Tags={"tag5": "val5"})
tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"] tags = client.list_tags_for_stream(StreamARN=stream_arn)["Tags"]
tags.should.have.length_of(5) tags.should.have.length_of(5)
tags.should.contain({"Key": "tag5", "Value": "val5"}) tags.should.contain({"Key": "tag5", "Value": "val5"})
@ -718,19 +763,33 @@ def test_add_list_remove_tags_boto3():
tags.should.contain({"Key": "tag4", "Value": "val4"}) tags.should.contain({"Key": "tag4", "Value": "val4"})
tags.should.contain({"Key": "tag5", "Value": "val5"}) tags.should.contain({"Key": "tag5", "Value": "val5"})
client.remove_tags_from_stream(StreamARN=stream_arn, TagKeys=["tag4"])
tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"]
tags.should.have.length_of(2)
tags.should.contain({"Key": "tag1", "Value": "val1"})
tags.should.contain({"Key": "tag5", "Value": "val5"})
@mock_kinesis @mock_kinesis
def test_merge_shards_boto3(): def test_merge_shards():
client = boto3.client("kinesis", region_name="eu-west-2") client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4) client.create_stream(StreamName=stream_name, ShardCount=4)
stream_arn = get_stream_arn(client, stream_name)
for index in range(1, 100): for index in range(1, 50):
client.put_record( client.put_record(
StreamName=stream_name, StreamName=stream_name,
Data=f"data_{index}".encode("utf-8"), Data=f"data_{index}".encode("utf-8"),
PartitionKey=str(index), PartitionKey=str(index),
) )
for index in range(51, 100):
client.put_record(
StreamARN=stream_arn,
Data=f"data_{index}".encode("utf-8"),
PartitionKey=str(index),
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"] stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shards = stream["Shards"] shards = stream["Shards"]
@ -757,7 +816,7 @@ def test_merge_shards_boto3():
active_shards.should.have.length_of(3) active_shards.should.have.length_of(3)
client.merge_shards( client.merge_shards(
StreamName=stream_name, StreamARN=stream_arn,
ShardToMerge="shardId-000000000004", ShardToMerge="shardId-000000000004",
AdjacentShardToMerge="shardId-000000000002", AdjacentShardToMerge="shardId-000000000002",
) )
@ -804,3 +863,9 @@ def test_merge_shards_invalid_arg():
err = exc.value.response["Error"] err = exc.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException") err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal("shardId-000000000002") err["Message"].should.equal("shardId-000000000002")
def get_stream_arn(client, stream_name):
return client.describe_stream(StreamName=stream_name)["StreamDescription"][
"StreamARN"
]

View File

@ -4,6 +4,7 @@ import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_kinesis from moto import mock_kinesis
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from .test_kinesis import get_stream_arn
import sure # noqa # pylint: disable=unused-import import sure # noqa # pylint: disable=unused-import
@ -279,9 +280,10 @@ def test_split_shard():
def test_split_shard_that_was_split_before(): def test_split_shard_that_was_split_before():
client = boto3.client("kinesis", region_name="us-west-2") client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2) client.create_stream(StreamName="my-stream", ShardCount=2)
stream_arn = get_stream_arn(client, "my-stream")
client.split_shard( client.split_shard(
StreamName="my-stream", StreamARN=stream_arn,
ShardToSplit="shardId-000000000001", ShardToSplit="shardId-000000000001",
NewStartingHashKey="170141183460469231731687303715884105829", NewStartingHashKey="170141183460469231731687303715884105829",
) )

View File

@ -1,6 +1,7 @@
import boto3 import boto3
from moto import mock_kinesis from moto import mock_kinesis
from .test_kinesis import get_stream_arn
@mock_kinesis @mock_kinesis
@ -44,3 +45,27 @@ def test_disable_encryption():
desc = resp["StreamDescription"] desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE") desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId") desc.shouldnt.have.key("KeyId")
@mock_kinesis
def test_disable_encryption__using_arns():
client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2)
stream_arn = get_stream_arn(client, "my-stream")
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
client.start_stream_encryption(
StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a"
)
client.stop_stream_encryption(
StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a"
)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId")

View File

@ -1,6 +1,8 @@
import boto3 import boto3
from moto import mock_kinesis from moto import mock_kinesis
from tests import DEFAULT_ACCOUNT_ID
from .test_kinesis import get_stream_arn
@mock_kinesis @mock_kinesis
@ -16,6 +18,9 @@ def test_enable_enhanced_monitoring_all():
resp.should.have.key("StreamName").equals(stream_name) resp.should.have.key("StreamName").equals(stream_name)
resp.should.have.key("CurrentShardLevelMetrics").equals([]) resp.should.have.key("CurrentShardLevelMetrics").equals([])
resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"]) resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"])
resp.should.have.key("StreamARN").equals(
f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}"
)
@mock_kinesis @mock_kinesis
@ -70,9 +75,10 @@ def test_disable_enhanced_monitoring():
client = boto3.client("kinesis", region_name="us-east-1") client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary" stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4) client.create_stream(StreamName=stream_name, ShardCount=4)
stream_arn = get_stream_arn(client, stream_name)
client.enable_enhanced_monitoring( client.enable_enhanced_monitoring(
StreamName=stream_name, StreamARN=stream_arn,
ShardLevelMetrics=[ ShardLevelMetrics=[
"IncomingBytes", "IncomingBytes",
"OutgoingBytes", "OutgoingBytes",
@ -84,6 +90,11 @@ def test_disable_enhanced_monitoring():
StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"] StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"]
) )
resp.should.have.key("StreamName").equals(stream_name)
resp.should.have.key("StreamARN").equals(
f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}"
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3) resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3)
resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes") resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes")
resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes") resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes")
@ -102,6 +113,13 @@ def test_disable_enhanced_monitoring():
metrics.should.contain("IncomingBytes") metrics.should.contain("IncomingBytes")
metrics.should.contain("WriteProvisionedThroughputExceeded") metrics.should.contain("WriteProvisionedThroughputExceeded")
resp = client.disable_enhanced_monitoring(
StreamARN=stream_arn, ShardLevelMetrics=["IncomingBytes"]
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(2)
resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(1)
@mock_kinesis @mock_kinesis
def test_disable_enhanced_monitoring_all(): def test_disable_enhanced_monitoring_all():