diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index a68103f26..40c8fec5d 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -3,18 +3,25 @@ from __future__ import unicode_literals import datetime import time import boto.kinesis +import re + +from operator import attrgetter +from hashlib import md5 +from itertools import izip_longest from moto.compat import OrderedDict from moto.core import BaseBackend -from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError +from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \ + ResourceNotFoundError, InvalidArgumentError from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator class Record(object): - def __init__(self, partition_key, data, sequence_number): + def __init__(self, partition_key, data, sequence_number, explicit_hash_key): self.partition_key = partition_key self.data = data self.sequence_number = sequence_number + self.explicit_hash_key = explicit_hash_key def to_json(self): return { @@ -25,10 +32,16 @@ class Record(object): class Shard(object): - def __init__(self, shard_id): - self.shard_id = shard_id + def __init__(self, shard_id, starting_hash, ending_hash): + self._shard_id = shard_id + self.starting_hash = starting_hash + self.ending_hash = ending_hash self.records = OrderedDict() + @property + def shard_id(self): + return "shardId-{0}".format(str(self._shard_id).zfill(12)) + def get_records(self, last_sequence_id, limit): last_sequence_id = int(last_sequence_id) results = [] @@ -43,14 +56,14 @@ class Shard(object): return results, last_sequence_id - def put_record(self, partition_key, data): + def put_record(self, partition_key, data, explicit_hash_key): # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() else: last_sequence_number = 0 sequence_number = last_sequence_number + 1 - self.records[sequence_number] = Record(partition_key, data, sequence_number) + self.records[sequence_number] = Record(partition_key, data, sequence_number, explicit_hash_key) return sequence_number def get_min_sequence_number(self): @@ -66,8 +79,8 @@ class Shard(object): def to_json(self): return { "HashKeyRange": { - "EndingHashKey": "113427455640312821154458202477256070484", - "StartingHashKey": "0" + "EndingHashKey": str(self.ending_hash), + "StartingHashKey": str(self.starting_hash) }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), @@ -78,6 +91,7 @@ class Shard(object): class Stream(object): + def __init__(self, stream_name, shard_count, region): self.stream_name = stream_name self.shard_count = shard_count @@ -86,9 +100,12 @@ class Stream(object): self.shards = {} self.tags = {} - for index in range(shard_count): - shard_id = "shardId-{0}".format(str(index).zfill(12)) - self.shards[shard_id] = Shard(shard_id) + for index, start, end in izip_longest(range(shard_count), + range(0,2**128,2**128/shard_count), + range(2**128/shard_count,2**128,2**128/shard_count), + fillvalue=2**128): + shard = Shard(index, start, end) + self.shards[shard.shard_id] = shard @property def arn(self): @@ -104,16 +121,32 @@ class Stream(object): else: raise ShardNotFoundError(shard_id) - def get_shard_for_key(self, partition_key): - # TODO implement sharding - shard = list(self.shards.values())[0] - return shard + def get_shard_for_key(self, partition_key, explicit_hash_key): + if not isinstance(partition_key, basestring): + raise InvalidArgumentError("partition_key") + if len(partition_key) > 256: + raise InvalidArgumentError("partition_key") + + if explicit_hash_key: + if not isinstance(explicit_hash_key, basestring): + raise InvalidArgumentError("explicit_hash_key") + + key = int(explicit_hash_key) + + if key >= 2**128: + raise InvalidArgumentError("explicit_hash_key") + + else: + key = int(md5(partition_key).hexdigest(), 16) + + for shard in self.shards.values(): + if shard.starting_hash <= key < shard.ending_hash: + return shard def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data): - partition_key = explicit_hash_key if explicit_hash_key else partition_key - shard = self.get_shard_for_key(partition_key) + shard = self.get_shard_for_key(partition_key, explicit_hash_key) - sequence_number = shard.put_record(partition_key, data) + sequence_number = shard.put_record(partition_key, data, explicit_hash_key) return sequence_number, shard.shard_id def to_json(self): @@ -201,6 +234,7 @@ class KinesisBackend(BaseBackend): self.streams = {} self.delivery_streams = {} + def create_stream(self, stream_name, shard_count, region): if stream_name in self.streams: raise ResourceInUseError(stream_name) @@ -277,6 +311,60 @@ class KinesisBackend(BaseBackend): return response + def split_shard(self, stream_name, shard_to_split, new_starting_hash_key): + stream = self.describe_stream(stream_name) + + if shard_to_split not in stream.shards: + raise ResourceNotFoundError(shard_to_split) + + if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key): + raise InvalidArgumentError(new_starting_hash_key) + new_starting_hash_key = int(new_starting_hash_key) + + shard = stream.shards[shard_to_split] + + last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id + + if shard.starting_hash < new_starting_hash_key < shard.ending_hash: + new_shard = Shard(last_id+1, new_starting_hash_key, shard.ending_hash) + shard.ending_hash = new_starting_hash_key + stream.shards[new_shard.shard_id] = new_shard + else: + raise InvalidArgumentError(new_starting_hash_key) + + records = shard.records + shard.records = OrderedDict() + + for index in records: + record = records[index] + stream.put_record( + record.partition_key, record.explicit_hash_key, None, record.data + ) + + def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge): + stream = self.describe_stream(stream_name) + + if shard_to_merge not in stream.shards: + raise ResourceNotFoundError(shard_to_merge) + + if adjacent_shard_to_merge not in stream.shards: + raise ResourceNotFoundError(adjacent_shard_to_merge) + + shard1 = stream.shards[shard_to_merge] + shard2 = stream.shards[adjacent_shard_to_merge] + + if shard1.ending_hash == shard2.starting_hash: + shard1.ending_hash = shard2.ending_hash + elif shard2.ending_hash == shard1.starting_hash: + shard1.starting_hash = shard2.starting_hash + else: + raise InvalidArgumentError(adjacent_shard_to_merge) + + del stream.shards[shard2.shard_id] + for index in shard2.records: + record = shard2.records[index] + shard1.put_record(record.partition_key, record.data, record.explicit_hash_key) + ''' Firehose ''' def create_delivery_stream(self, stream_name, **stream_kwargs): stream = DeliveryStream(stream_name, **stream_kwargs) diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index dd364c49a..b52bdedf0 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -101,6 +101,24 @@ class KinesisResponse(BaseResponse): return json.dumps(response) + def split_shard(self): + stream_name = self.parameters.get("StreamName") + shard_to_split = self.parameters.get("ShardToSplit") + new_starting_hash_key = self.parameters.get("NewStartingHashKey") + response = self.kinesis_backend.split_shard( + stream_name, shard_to_split, new_starting_hash_key + ) + return "" + + def merge_shards(self): + stream_name = self.parameters.get("StreamName") + shard_to_merge = self.parameters.get("ShardToMerge") + adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") + response = self.kinesis_backend.merge_shards( + stream_name, shard_to_merge, adjacent_shard_to_merge + ) + return "" + ''' Firehose ''' def create_delivery_stream(self): stream_name = self.parameters['DeliveryStreamName'] diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index 75f5acd7a..590bd025a 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -85,6 +85,10 @@ def test_put_records(): data = "hello world" partition_key = "1234" + + conn.put_record.when.called_with( + stream_name, data, 1234).should.throw(InvalidArgumentException) + conn.put_record(stream_name, data, partition_key) response = conn.describe_stream(stream_name) @@ -112,8 +116,9 @@ def test_get_records_limit(): # Create some data data = "hello world" + for index in range(5): - conn.put_record(stream_name, data, index) + conn.put_record(stream_name, data, str(index)) # Get a shard iterator response = conn.describe_stream(stream_name) @@ -140,7 +145,7 @@ def test_get_records_at_sequence_number(): # Create some data for index in range(1, 5): - conn.put_record(stream_name, str(index), index) + conn.put_record(stream_name, str(index), str(index)) # Get a shard iterator response = conn.describe_stream(stream_name) @@ -171,7 +176,7 @@ def test_get_records_after_sequence_number(): # Create some data for index in range(1, 5): - conn.put_record(stream_name, str(index), index) + conn.put_record(stream_name, str(index), str(index)) # Get a shard iterator response = conn.describe_stream(stream_name) @@ -201,7 +206,7 @@ def test_get_records_latest(): # Create some data for index in range(1, 5): - conn.put_record(stream_name, str(index), index) + conn.put_record(stream_name, str(index), str(index)) # Get a shard iterator response = conn.describe_stream(stream_name) @@ -261,16 +266,16 @@ def test_list_tags(): conn.describe_stream(stream_name) conn.add_tags_to_stream(stream_name, {'tag1':'val1'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag1') == 'val1' + tags.get('tag1').should.equal('val1') conn.add_tags_to_stream(stream_name, {'tag2':'val2'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag2') == 'val2' + tags.get('tag2').should.equal('val2') conn.add_tags_to_stream(stream_name, {'tag1':'val3'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag1') == 'val3' + tags.get('tag1').should.equal('val3') conn.add_tags_to_stream(stream_name, {'tag2':'val4'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag2') == 'val4' + tags.get('tag2').should.equal('val4') @mock_kinesis @@ -282,14 +287,99 @@ def test_remove_tags(): conn.describe_stream(stream_name) conn.add_tags_to_stream(stream_name, {'tag1':'val1'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag1') == 'val1' + tags.get('tag1').should.equal('val1') conn.remove_tags_from_stream(stream_name, ['tag1']) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag1') is None + tags.get('tag1').should.equal(None) conn.add_tags_to_stream(stream_name, {'tag2':'val2'}) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag2') == 'val2' + tags.get('tag2').should.equal('val2') conn.remove_tags_from_stream(stream_name, ['tag2']) tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - assert tags.get('tag2') is None + tags.get('tag2').should.equal(None) + + +@mock_kinesis +def test_split_shard(): + conn = boto.kinesis.connect_to_region("us-west-2") + stream_name = 'my_stream' + + conn.create_stream(stream_name, 2) + + # Create some data + for index in range(1, 100): + conn.put_record(stream_name, str(index), str(index)) + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(2) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99) + + shard_range = shards[0]['HashKeyRange'] + new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) / 2 + conn.split_shard("my_stream", shards[0]['ShardId'], str(new_starting_hash)) + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(3) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99) + + shard_range = shards[2]['HashKeyRange'] + new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) / 2 + conn.split_shard("my_stream", shards[2]['ShardId'], str(new_starting_hash)) + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(4) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99) + + +@mock_kinesis +def test_merge_shards(): + conn = boto.kinesis.connect_to_region("us-west-2") + stream_name = 'my_stream' + + conn.create_stream(stream_name, 4) + + # Create some data + for index in range(1, 100): + conn.put_record(stream_name, str(index), str(index)) + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(4) + + conn.merge_shards.when.called_with(stream_name, 'shardId-000000000000', 'shardId-000000000002').should.throw(InvalidArgumentException) + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(4) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99) + + conn.merge_shards(stream_name, 'shardId-000000000000', 'shardId-000000000001') + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(3) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99) + conn.merge_shards(stream_name, 'shardId-000000000002', 'shardId-000000000000') + + stream_response = conn.describe_stream(stream_name) + + stream = stream_response["StreamDescription"] + shards = stream['Shards'] + shards.should.have.length_of(2) + sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)