Kinesis - support Stream ARNs across all methods (#5893)

This commit is contained in:
Bert Blommers 2023-02-01 15:16:25 -01:00 committed by GitHub
parent 67ecc3b1db
commit 19bfa92dd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 314 additions and 109 deletions

View File

@ -4,6 +4,7 @@ import re
import itertools
from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time
@ -439,12 +440,14 @@ class Stream(CloudFormationModel):
for tag_item in properties.get("Tags", [])
}
backend = kinesis_backends[account_id][region_name]
backend: KinesisBackend = kinesis_backends[account_id][region_name]
stream = backend.create_stream(
resource_name, shard_count, retention_period_hours=retention_period_hours
)
if any(tags):
backend.add_tags_to_stream(stream.stream_name, tags)
backend.add_tags_to_stream(
stream_arn=None, stream_name=stream.stream_name, tags=tags
)
return stream
@classmethod
@ -489,8 +492,8 @@ class Stream(CloudFormationModel):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
backend = kinesis_backends[account_id][region_name]
backend.delete_stream(resource_name)
backend: KinesisBackend = kinesis_backends[account_id][region_name]
backend.delete_stream(stream_arn=None, stream_name=resource_name)
@staticmethod
def is_replacement_update(properties):
@ -521,7 +524,7 @@ class Stream(CloudFormationModel):
class KinesisBackend(BaseBackend):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.streams = OrderedDict()
self.streams: Dict[str, Stream] = OrderedDict()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
@ -546,38 +549,49 @@ class KinesisBackend(BaseBackend):
self.streams[stream_name] = stream
return stream
def describe_stream(self, stream_name) -> Stream:
if stream_name in self.streams:
def describe_stream(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
if stream_name and stream_name in self.streams:
return self.streams[stream_name]
else:
raise StreamNotFoundError(stream_name, self.account_id)
if stream_arn:
for stream in self.streams.values():
if stream.arn == stream_arn:
return stream
if stream_arn:
stream_name = stream_arn.split("/")[1]
raise StreamNotFoundError(stream_name, self.account_id)
def describe_stream_summary(self, stream_name):
return self.describe_stream(stream_name)
def describe_stream_summary(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
def list_streams(self):
return self.streams.values()
def delete_stream(self, stream_name):
if stream_name in self.streams:
return self.streams.pop(stream_name)
raise StreamNotFoundError(stream_name, self.account_id)
def delete_stream(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
return self.streams.pop(stream.stream_name)
def get_shard_iterator(
self,
stream_name,
shard_id,
shard_iterator_type,
starting_sequence_number,
at_timestamp,
stream_arn: Optional[str],
stream_name: Optional[str],
shard_id: str,
shard_iterator_type: str,
starting_sequence_number: str,
at_timestamp: str,
):
# Validate params
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
try:
shard = stream.get_shard(shard_id)
except ShardNotFoundError:
raise ResourceNotFoundError(
message=f"Shard {shard_id} in stream {stream_name} under account {self.account_id} does not exist"
message=f"Shard {shard_id} in stream {stream.stream_name} under account {self.account_id} does not exist"
)
shard_iterator = compose_new_shard_iterator(
@ -589,11 +603,13 @@ class KinesisBackend(BaseBackend):
)
return shard_iterator
def get_records(self, shard_iterator, limit):
def get_records(
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
):
decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shard = stream.get_shard(shard_id)
records, last_sequence_id, millis_behind_latest = shard.get_records(
@ -608,12 +624,13 @@ class KinesisBackend(BaseBackend):
def put_record(
self,
stream_arn,
stream_name,
partition_key,
explicit_hash_key,
data,
):
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, data
@ -621,8 +638,8 @@ class KinesisBackend(BaseBackend):
return sequence_number, shard_id
def put_records(self, stream_name, records):
stream = self.describe_stream(stream_name)
def put_records(self, stream_arn, stream_name, records):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
response = {"FailedRecordCount": 0, "Records": []}
@ -651,8 +668,10 @@ class KinesisBackend(BaseBackend):
return response
def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):
stream = self.describe_stream(stream_name)
def split_shard(
self, stream_arn, stream_name, shard_to_split, new_starting_hash_key
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
raise ValidationException(
@ -675,23 +694,27 @@ class KinesisBackend(BaseBackend):
stream.split_shard(shard_to_split, new_starting_hash_key)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
def merge_shards(
self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if shard_to_merge not in stream.shards:
raise ShardNotFoundError(
shard_to_merge, stream=stream_name, account_id=self.account_id
shard_to_merge, stream=stream.stream_name, account_id=self.account_id
)
if adjacent_shard_to_merge not in stream.shards:
raise ShardNotFoundError(
adjacent_shard_to_merge, stream=stream_name, account_id=self.account_id
adjacent_shard_to_merge,
stream=stream.stream_name,
account_id=self.account_id,
)
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
def update_shard_count(self, stream_name, target_shard_count):
stream = self.describe_stream(stream_name)
def update_shard_count(self, stream_arn, stream_name, target_shard_count):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open])
stream.update_shard_count(target_shard_count)
@ -699,13 +722,18 @@ class KinesisBackend(BaseBackend):
return current_shard_count
@paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_name):
stream = self.describe_stream(stream_name)
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
return [shard.to_json() for shard in shards]
def increase_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
def increase_stream_retention_period(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
@ -718,8 +746,13 @@ class KinesisBackend(BaseBackend):
)
stream.retention_period_hours = retention_period_hours
def decrease_stream_retention_period(self, stream_name, retention_period_hours):
stream = self.describe_stream(stream_name)
def decrease_stream_retention_period(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
retention_period_hours: int,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if retention_period_hours < 24:
raise InvalidRetentionPeriod(retention_period_hours, too_short=True)
if retention_period_hours > 8760:
@ -733,9 +766,9 @@ class KinesisBackend(BaseBackend):
stream.retention_period_hours = retention_period_hours
def list_tags_for_stream(
self, stream_name, exclusive_start_tag_key=None, limit=None
self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None
):
stream = self.describe_stream(stream_name)
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
tags = []
result = {"HasMoreTags": False, "Tags": tags}
@ -750,25 +783,47 @@ class KinesisBackend(BaseBackend):
return result
def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name)
def add_tags_to_stream(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
tags: Dict[str, str],
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name)
def remove_tags_from_stream(
self, stream_arn: Optional[str], stream_name: Optional[str], tag_keys: List[str]
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
for key in tag_keys:
if key in stream.tags:
del stream.tags[key]
def enable_enhanced_monitoring(self, stream_name, shard_level_metrics):
stream = self.describe_stream(stream_name)
def enable_enhanced_monitoring(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
shard_level_metrics: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
stream.shard_level_metrics = desired_metrics
return current_shard_level_metrics, desired_metrics
return (
stream.arn,
stream.stream_name,
current_shard_level_metrics,
desired_metrics,
)
def disable_enhanced_monitoring(self, stream_name, to_be_disabled):
stream = self.describe_stream(stream_name)
def disable_enhanced_monitoring(
self,
stream_arn: Optional[str],
stream_name: Optional[str],
to_be_disabled: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled:
desired_metrics = []
@ -777,7 +832,7 @@ class KinesisBackend(BaseBackend):
metric for metric in current_metrics if metric not in to_be_disabled
]
stream.shard_level_metrics = desired_metrics
return current_metrics, desired_metrics
return stream.arn, stream.stream_name, current_metrics, desired_metrics
def _find_stream_by_arn(self, stream_arn):
for stream in self.streams.values():
@ -826,13 +881,13 @@ class KinesisBackend(BaseBackend):
# It will be a noop for other streams
stream.delete_consumer(consumer_arn)
def start_stream_encryption(self, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_name)
def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = encryption_type
stream.key_id = key_id
def stop_stream_encryption(self, stream_name):
stream = self.describe_stream(stream_name)
def stop_stream_encryption(self, stream_arn, stream_name):
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = "NONE"
stream.key_id = None

View File

@ -1,7 +1,7 @@
import json
from moto.core.responses import BaseResponse
from .models import kinesis_backends
from .models import kinesis_backends, KinesisBackend
class KinesisResponse(BaseResponse):
@ -13,7 +13,7 @@ class KinesisResponse(BaseResponse):
return json.loads(self.body)
@property
def kinesis_backend(self):
def kinesis_backend(self) -> KinesisBackend:
return kinesis_backends[self.current_account][self.region]
def create_stream(self):
@ -27,13 +27,15 @@ class KinesisResponse(BaseResponse):
def describe_stream(self):
stream_name = self.parameters.get("StreamName")
stream_arn = self.parameters.get("StreamARN")
limit = self.parameters.get("Limit")
stream = self.kinesis_backend.describe_stream(stream_name)
stream = self.kinesis_backend.describe_stream(stream_arn, stream_name)
return json.dumps(stream.to_json(shard_limit=limit))
def describe_stream_summary(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
stream = self.kinesis_backend.describe_stream_summary(stream_name)
stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name)
return json.dumps(stream.to_json_summary())
def list_streams(self):
@ -58,11 +60,13 @@ class KinesisResponse(BaseResponse):
)
def delete_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
self.kinesis_backend.delete_stream(stream_name)
self.kinesis_backend.delete_stream(stream_arn, stream_name)
return ""
def get_shard_iterator(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_id = self.parameters.get("ShardId")
shard_iterator_type = self.parameters.get("ShardIteratorType")
@ -70,6 +74,7 @@ class KinesisResponse(BaseResponse):
at_timestamp = self.parameters.get("Timestamp")
shard_iterator = self.kinesis_backend.get_shard_iterator(
stream_arn,
stream_name,
shard_id,
shard_iterator_type,
@ -80,6 +85,7 @@ class KinesisResponse(BaseResponse):
return json.dumps({"ShardIterator": shard_iterator})
def get_records(self):
stream_arn = self.parameters.get("StreamARN")
shard_iterator = self.parameters.get("ShardIterator")
limit = self.parameters.get("Limit")
@ -87,7 +93,7 @@ class KinesisResponse(BaseResponse):
next_shard_iterator,
records,
millis_behind_latest,
) = self.kinesis_backend.get_records(shard_iterator, limit)
) = self.kinesis_backend.get_records(stream_arn, shard_iterator, limit)
return json.dumps(
{
@ -98,12 +104,14 @@ class KinesisResponse(BaseResponse):
)
def put_record(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
partition_key = self.parameters.get("PartitionKey")
explicit_hash_key = self.parameters.get("ExplicitHashKey")
data = self.parameters.get("Data")
sequence_number, shard_id = self.kinesis_backend.put_record(
stream_arn,
stream_name,
partition_key,
explicit_hash_key,
@ -113,37 +121,44 @@ class KinesisResponse(BaseResponse):
return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id})
def put_records(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
records = self.parameters.get("Records")
response = self.kinesis_backend.put_records(stream_name, records)
response = self.kinesis_backend.put_records(stream_arn, stream_name, records)
return json.dumps(response)
def split_shard(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey")
self.kinesis_backend.split_shard(
stream_name, shard_to_split, new_starting_hash_key
stream_arn, stream_name, shard_to_split, new_starting_hash_key
)
return ""
def merge_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
self.kinesis_backend.merge_shards(
stream_name, shard_to_merge, adjacent_shard_to_merge
stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
)
return ""
def list_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
next_token = self.parameters.get("NextToken")
max_results = self.parameters.get("MaxResults", 10000)
shards, token = self.kinesis_backend.list_shards(
stream_name=stream_name, limit=max_results, next_token=next_token
stream_arn=stream_arn,
stream_name=stream_name,
limit=max_results,
next_token=next_token,
)
res = {"Shards": shards}
if token:
@ -151,10 +166,13 @@ class KinesisResponse(BaseResponse):
return json.dumps(res)
def update_shard_count(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
target_shard_count = self.parameters.get("TargetShardCount")
current_shard_count = self.kinesis_backend.update_shard_count(
stream_name=stream_name, target_shard_count=target_shard_count
stream_arn=stream_arn,
stream_name=stream_name,
target_shard_count=target_shard_count,
)
return json.dumps(
dict(
@ -165,67 +183,80 @@ class KinesisResponse(BaseResponse):
)
def increase_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
self.kinesis_backend.increase_stream_retention_period(
stream_name, retention_period_hours
stream_arn, stream_name, retention_period_hours
)
return ""
def decrease_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
self.kinesis_backend.decrease_stream_retention_period(
stream_name, retention_period_hours
stream_arn, stream_name, retention_period_hours
)
return ""
def add_tags_to_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
tags = self.parameters.get("Tags")
self.kinesis_backend.add_tags_to_stream(stream_name, tags)
self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags)
return json.dumps({})
def list_tags_for_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey")
limit = self.parameters.get("Limit")
response = self.kinesis_backend.list_tags_for_stream(
stream_name, exclusive_start_tag_key, limit
stream_arn, stream_name, exclusive_start_tag_key, limit
)
return json.dumps(response)
def remove_tags_from_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
tag_keys = self.parameters.get("TagKeys")
self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys)
self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys)
return json.dumps({})
def enable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_name=stream_name, shard_level_metrics=shard_level_metrics
arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_arn=stream_arn,
stream_name=stream_name,
shard_level_metrics=shard_level_metrics,
)
return json.dumps(
dict(
StreamName=stream_name,
StreamName=name,
CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired,
StreamARN=arn,
)
)
def disable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_name=stream_name, to_be_disabled=shard_level_metrics
arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_arn=stream_arn,
stream_name=stream_name,
to_be_disabled=shard_level_metrics,
)
return json.dumps(
dict(
StreamName=stream_name,
StreamName=name,
CurrentShardLevelMetrics=current,
DesiredShardLevelMetrics=desired,
StreamARN=arn,
)
)
@ -267,17 +298,24 @@ class KinesisResponse(BaseResponse):
return json.dumps(dict())
def start_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
encryption_type = self.parameters.get("EncryptionType")
key_id = self.parameters.get("KeyId")
self.kinesis_backend.start_stream_encryption(
stream_name=stream_name, encryption_type=encryption_type, key_id=key_id
stream_arn=stream_arn,
stream_name=stream_name,
encryption_type=encryption_type,
key_id=key_id,
)
return json.dumps(dict())
def stop_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
self.kinesis_backend.stop_stream_encryption(stream_name=stream_name)
self.kinesis_backend.stop_stream_encryption(
stream_arn=stream_arn, stream_name=stream_name
)
return json.dumps(dict())
def update_stream_mode(self):

View File

@ -6,6 +6,8 @@ url_bases = [
# Somewhere around boto3-1.26.31 botocore-1.29.31, AWS started using a new endpoint:
# 111122223333.control-kinesis.us-east-1.amazonaws.com
r"https?://(.+)\.control-kinesis\.(.+)\.amazonaws\.com",
# When passing in the StreamARN to get_shard_iterator/get_records, this endpoint is called:
r"https?://(.+)\.data-kinesis\.(.+)\.amazonaws\.com",
]
url_paths = {"{0}/$": KinesisResponse.dispatch}

View File

@ -19,15 +19,17 @@ def test_stream_creation_on_demand():
client.create_stream(
StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"}
)
# At the same time, test whether we can pass the StreamARN instead of the name
stream_arn = get_stream_arn(client, "my_stream")
# AWS starts with 4 shards by default
shard_list = client.list_shards(StreamName="my_stream")["Shards"]
shard_list = client.list_shards(StreamARN=stream_arn)["Shards"]
shard_list.should.have.length_of(4)
# Cannot update-shard-count when we're in on-demand mode
with pytest.raises(ClientError) as exc:
client.update_shard_count(
StreamName="my_stream", TargetShardCount=3, ScalingType="UNIFORM_SCALING"
StreamARN=stream_arn, TargetShardCount=3, ScalingType="UNIFORM_SCALING"
)
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
@ -39,7 +41,7 @@ def test_stream_creation_on_demand():
@mock_kinesis
def test_update_stream_mode():
client = boto3.client("kinesis", region_name="eu-west-1")
resp = client.create_stream(
client.create_stream(
StreamName="my_stream", StreamModeDetails={"StreamMode": "ON_DEMAND"}
)
arn = client.describe_stream(StreamName="my_stream")["StreamDescription"][
@ -56,7 +58,7 @@ def test_update_stream_mode():
@mock_kinesis
def test_describe_non_existent_stream_boto3():
def test_describe_non_existent_stream():
client = boto3.client("kinesis", region_name="us-west-2")
with pytest.raises(ClientError) as exc:
client.describe_stream_summary(StreamName="not-a-stream")
@ -68,7 +70,7 @@ def test_describe_non_existent_stream_boto3():
@mock_kinesis
def test_list_and_delete_stream_boto3():
def test_list_and_delete_stream():
client = boto3.client("kinesis", region_name="us-west-2")
client.list_streams()["StreamNames"].should.have.length_of(0)
@ -79,6 +81,10 @@ def test_list_and_delete_stream_boto3():
client.delete_stream(StreamName="stream1")
client.list_streams()["StreamNames"].should.have.length_of(1)
stream_arn = get_stream_arn(client, "stream2")
client.delete_stream(StreamARN=stream_arn)
client.list_streams()["StreamNames"].should.have.length_of(0)
@mock_kinesis
def test_delete_unknown_stream():
@ -128,9 +134,15 @@ def test_describe_stream_summary():
)
stream["StreamStatus"].should.equal("ACTIVE")
stream_arn = get_stream_arn(conn, stream_name)
resp = conn.describe_stream_summary(StreamARN=stream_arn)
stream = resp["StreamDescriptionSummary"]
stream["StreamName"].should.equal(stream_name)
@mock_kinesis
def test_basic_shard_iterator_boto3():
def test_basic_shard_iterator():
client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream"
@ -149,7 +161,30 @@ def test_basic_shard_iterator_boto3():
@mock_kinesis
def test_get_invalid_shard_iterator_boto3():
def test_basic_shard_iterator_by_stream_arn():
client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream"
client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shard_id = stream["Shards"][0]["ShardId"]
resp = client.get_shard_iterator(
StreamARN=stream["StreamARN"],
ShardId=shard_id,
ShardIteratorType="TRIM_HORIZON",
)
shard_iterator = resp["ShardIterator"]
resp = client.get_records(
StreamARN=stream["StreamARN"], ShardIterator=shard_iterator
)
resp.should.have.key("Records").length_of(0)
resp.should.have.key("MillisBehindLatest").equal(0)
@mock_kinesis
def test_get_invalid_shard_iterator():
client = boto3.client("kinesis", region_name="us-west-1")
stream_name = "mystream"
@ -169,21 +204,22 @@ def test_get_invalid_shard_iterator_boto3():
@mock_kinesis
def test_put_records_boto3():
def test_put_records():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
stream_arn = stream["StreamARN"]
shard_id = stream["Shards"][0]["ShardId"]
data = b"hello world"
partition_key = "1234"
response = client.put_record(
StreamName=stream_name, Data=data, PartitionKey=partition_key
client.put_records(
Records=[{"Data": data, "PartitionKey": partition_key}] * 5,
StreamARN=stream_arn,
)
response["SequenceNumber"].should.equal("1")
resp = client.get_shard_iterator(
StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON"
@ -191,27 +227,28 @@ def test_put_records_boto3():
shard_iterator = resp["ShardIterator"]
resp = client.get_records(ShardIterator=shard_iterator)
resp["Records"].should.have.length_of(1)
resp["Records"].should.have.length_of(5)
record = resp["Records"][0]
record["Data"].should.equal(b"hello world")
record["PartitionKey"].should.equal("1234")
record["Data"].should.equal(data)
record["PartitionKey"].should.equal(partition_key)
record["SequenceNumber"].should.equal("1")
@mock_kinesis
def test_get_records_limit_boto3():
def test_get_records_limit():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
stream_arn = stream["StreamARN"]
shard_id = stream["Shards"][0]["ShardId"]
data = b"hello world"
for index in range(5):
client.put_record(StreamName=stream_name, Data=data, PartitionKey=str(index))
client.put_record(StreamARN=stream_arn, Data=data, PartitionKey=str(index))
resp = client.get_shard_iterator(
StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON"
@ -229,7 +266,7 @@ def test_get_records_limit_boto3():
@mock_kinesis
def test_get_records_at_sequence_number_boto3():
def test_get_records_at_sequence_number():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
@ -268,7 +305,7 @@ def test_get_records_at_sequence_number_boto3():
@mock_kinesis
def test_get_records_after_sequence_number_boto3():
def test_get_records_after_sequence_number():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
@ -308,7 +345,7 @@ def test_get_records_after_sequence_number_boto3():
@mock_kinesis
def test_get_records_latest_boto3():
def test_get_records_latest():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
@ -607,6 +644,7 @@ def test_valid_decrease_stream_retention_period():
conn = boto3.client("kinesis", region_name="us-west-2")
stream_name = "decrease_stream"
conn.create_stream(StreamName=stream_name, ShardCount=1)
stream_arn = get_stream_arn(conn, stream_name)
conn.increase_stream_retention_period(
StreamName=stream_name, RetentionPeriodHours=30
@ -618,6 +656,12 @@ def test_valid_decrease_stream_retention_period():
response = conn.describe_stream(StreamName=stream_name)
response["StreamDescription"]["RetentionPeriodHours"].should.equal(25)
conn.increase_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=29)
conn.decrease_stream_retention_period(StreamARN=stream_arn, RetentionPeriodHours=26)
response = conn.describe_stream(StreamARN=stream_arn)
response["StreamDescription"]["RetentionPeriodHours"].should.equal(26)
@mock_kinesis
def test_decrease_stream_retention_period_upwards():
@ -671,7 +715,7 @@ def test_decrease_stream_retention_period_too_high():
@mock_kinesis
def test_invalid_shard_iterator_type_boto3():
def test_invalid_shard_iterator_type():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
@ -688,10 +732,11 @@ def test_invalid_shard_iterator_type_boto3():
@mock_kinesis
def test_add_list_remove_tags_boto3():
def test_add_list_remove_tags():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=1)
stream_arn = get_stream_arn(client, stream_name)
client.add_tags_to_stream(
StreamName=stream_name,
Tags={"tag1": "val1", "tag2": "val2", "tag3": "val3", "tag4": "val4"},
@ -704,9 +749,9 @@ def test_add_list_remove_tags_boto3():
tags.should.contain({"Key": "tag3", "Value": "val3"})
tags.should.contain({"Key": "tag4", "Value": "val4"})
client.add_tags_to_stream(StreamName=stream_name, Tags={"tag5": "val5"})
client.add_tags_to_stream(StreamARN=stream_arn, Tags={"tag5": "val5"})
tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"]
tags = client.list_tags_for_stream(StreamARN=stream_arn)["Tags"]
tags.should.have.length_of(5)
tags.should.contain({"Key": "tag5", "Value": "val5"})
@ -718,19 +763,33 @@ def test_add_list_remove_tags_boto3():
tags.should.contain({"Key": "tag4", "Value": "val4"})
tags.should.contain({"Key": "tag5", "Value": "val5"})
client.remove_tags_from_stream(StreamARN=stream_arn, TagKeys=["tag4"])
tags = client.list_tags_for_stream(StreamName=stream_name)["Tags"]
tags.should.have.length_of(2)
tags.should.contain({"Key": "tag1", "Value": "val1"})
tags.should.contain({"Key": "tag5", "Value": "val5"})
@mock_kinesis
def test_merge_shards_boto3():
def test_merge_shards():
client = boto3.client("kinesis", region_name="eu-west-2")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
stream_arn = get_stream_arn(client, stream_name)
for index in range(1, 100):
for index in range(1, 50):
client.put_record(
StreamName=stream_name,
Data=f"data_{index}".encode("utf-8"),
PartitionKey=str(index),
)
for index in range(51, 100):
client.put_record(
StreamARN=stream_arn,
Data=f"data_{index}".encode("utf-8"),
PartitionKey=str(index),
)
stream = client.describe_stream(StreamName=stream_name)["StreamDescription"]
shards = stream["Shards"]
@ -757,7 +816,7 @@ def test_merge_shards_boto3():
active_shards.should.have.length_of(3)
client.merge_shards(
StreamName=stream_name,
StreamARN=stream_arn,
ShardToMerge="shardId-000000000004",
AdjacentShardToMerge="shardId-000000000002",
)
@ -804,3 +863,9 @@ def test_merge_shards_invalid_arg():
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidArgumentException")
err["Message"].should.equal("shardId-000000000002")
def get_stream_arn(client, stream_name):
return client.describe_stream(StreamName=stream_name)["StreamDescription"][
"StreamARN"
]

View File

@ -4,6 +4,7 @@ import pytest
from botocore.exceptions import ClientError
from moto import mock_kinesis
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
from .test_kinesis import get_stream_arn
import sure # noqa # pylint: disable=unused-import
@ -279,9 +280,10 @@ def test_split_shard():
def test_split_shard_that_was_split_before():
client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2)
stream_arn = get_stream_arn(client, "my-stream")
client.split_shard(
StreamName="my-stream",
StreamARN=stream_arn,
ShardToSplit="shardId-000000000001",
NewStartingHashKey="170141183460469231731687303715884105829",
)

View File

@ -1,6 +1,7 @@
import boto3
from moto import mock_kinesis
from .test_kinesis import get_stream_arn
@mock_kinesis
@ -44,3 +45,27 @@ def test_disable_encryption():
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId")
@mock_kinesis
def test_disable_encryption__using_arns():
client = boto3.client("kinesis", region_name="us-west-2")
client.create_stream(StreamName="my-stream", ShardCount=2)
stream_arn = get_stream_arn(client, "my-stream")
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
client.start_stream_encryption(
StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a"
)
client.stop_stream_encryption(
StreamARN=stream_arn, EncryptionType="KMS", KeyId="n/a"
)
resp = client.describe_stream(StreamName="my-stream")
desc = resp["StreamDescription"]
desc.should.have.key("EncryptionType").should.equal("NONE")
desc.shouldnt.have.key("KeyId")

View File

@ -1,6 +1,8 @@
import boto3
from moto import mock_kinesis
from tests import DEFAULT_ACCOUNT_ID
from .test_kinesis import get_stream_arn
@mock_kinesis
@ -16,6 +18,9 @@ def test_enable_enhanced_monitoring_all():
resp.should.have.key("StreamName").equals(stream_name)
resp.should.have.key("CurrentShardLevelMetrics").equals([])
resp.should.have.key("DesiredShardLevelMetrics").equals(["ALL"])
resp.should.have.key("StreamARN").equals(
f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}"
)
@mock_kinesis
@ -70,9 +75,10 @@ def test_disable_enhanced_monitoring():
client = boto3.client("kinesis", region_name="us-east-1")
stream_name = "my_stream_summary"
client.create_stream(StreamName=stream_name, ShardCount=4)
stream_arn = get_stream_arn(client, stream_name)
client.enable_enhanced_monitoring(
StreamName=stream_name,
StreamARN=stream_arn,
ShardLevelMetrics=[
"IncomingBytes",
"OutgoingBytes",
@ -84,6 +90,11 @@ def test_disable_enhanced_monitoring():
StreamName=stream_name, ShardLevelMetrics=["OutgoingBytes"]
)
resp.should.have.key("StreamName").equals(stream_name)
resp.should.have.key("StreamARN").equals(
f"arn:aws:kinesis:us-east-1:{DEFAULT_ACCOUNT_ID}:stream/{stream_name}"
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(3)
resp["CurrentShardLevelMetrics"].should.contain("IncomingBytes")
resp["CurrentShardLevelMetrics"].should.contain("OutgoingBytes")
@ -102,6 +113,13 @@ def test_disable_enhanced_monitoring():
metrics.should.contain("IncomingBytes")
metrics.should.contain("WriteProvisionedThroughputExceeded")
resp = client.disable_enhanced_monitoring(
StreamARN=stream_arn, ShardLevelMetrics=["IncomingBytes"]
)
resp.should.have.key("CurrentShardLevelMetrics").should.have.length_of(2)
resp.should.have.key("DesiredShardLevelMetrics").should.have.length_of(1)
@mock_kinesis
def test_disable_enhanced_monitoring_all():