Merge pull request #475 from silveregg/master

Add Kinesis API and fix some typo
This commit is contained in:
Steve Pulec 2015-12-05 21:04:41 -05:00
commit be3291b758
3 changed files with 338 additions and 27 deletions

View File

@ -2,19 +2,27 @@ from __future__ import unicode_literals
import datetime import datetime
import time import time
import boto.kinesis import boto.kinesis
import re
import six
import itertools
from operator import attrgetter
from hashlib import md5
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend from moto.core import BaseBackend
from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \
ResourceNotFoundError, InvalidArgumentError
from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator
class Record(object): class Record(object):
def __init__(self, partition_key, data, sequence_number): def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
self.partition_key = partition_key self.partition_key = partition_key
self.data = data self.data = data
self.sequence_number = sequence_number self.sequence_number = sequence_number
self.explicit_hash_key = explicit_hash_key
def to_json(self): def to_json(self):
return { return {
@ -25,10 +33,16 @@ class Record(object):
class Shard(object): class Shard(object):
def __init__(self, shard_id): def __init__(self, shard_id, starting_hash, ending_hash):
self.shard_id = shard_id self._shard_id = shard_id
self.starting_hash = starting_hash
self.ending_hash = ending_hash
self.records = OrderedDict() self.records = OrderedDict()
@property
def shard_id(self):
return "shardId-{0}".format(str(self._shard_id).zfill(12))
def get_records(self, last_sequence_id, limit): def get_records(self, last_sequence_id, limit):
last_sequence_id = int(last_sequence_id) last_sequence_id = int(last_sequence_id)
results = [] results = []
@ -43,14 +57,14 @@ class Shard(object):
return results, last_sequence_id return results, last_sequence_id
def put_record(self, partition_key, data): def put_record(self, partition_key, data, explicit_hash_key):
# Note: this function is not safe for concurrency # Note: this function is not safe for concurrency
if self.records: if self.records:
last_sequence_number = self.get_max_sequence_number() last_sequence_number = self.get_max_sequence_number()
else: else:
last_sequence_number = 0 last_sequence_number = 0
sequence_number = last_sequence_number + 1 sequence_number = last_sequence_number + 1
self.records[sequence_number] = Record(partition_key, data, sequence_number) self.records[sequence_number] = Record(partition_key, data, sequence_number, explicit_hash_key)
return sequence_number return sequence_number
def get_min_sequence_number(self): def get_min_sequence_number(self):
@ -66,8 +80,8 @@ class Shard(object):
def to_json(self): def to_json(self):
return { return {
"HashKeyRange": { "HashKeyRange": {
"EndingHashKey": "113427455640312821154458202477256070484", "EndingHashKey": str(self.ending_hash),
"StartingHashKey": "0" "StartingHashKey": str(self.starting_hash)
}, },
"SequenceNumberRange": { "SequenceNumberRange": {
"EndingSequenceNumber": self.get_max_sequence_number(), "EndingSequenceNumber": self.get_max_sequence_number(),
@ -78,16 +92,26 @@ class Shard(object):
class Stream(object): class Stream(object):
def __init__(self, stream_name, shard_count, region): def __init__(self, stream_name, shard_count, region):
self.stream_name = stream_name self.stream_name = stream_name
self.shard_count = shard_count self.shard_count = shard_count
self.region = region self.region = region
self.account_number = "123456789012" self.account_number = "123456789012"
self.shards = {} self.shards = {}
self.tags = {}
for index in range(shard_count): if six.PY3:
shard_id = "shardId-{0}".format(str(index).zfill(12)) izip_longest = itertools.zip_longest
self.shards[shard_id] = Shard(shard_id) else:
izip_longest = itertools.izip_longest
for index, start, end in izip_longest(range(shard_count),
range(0,2**128,2**128//shard_count),
range(2**128//shard_count,2**128,2**128//shard_count),
fillvalue=2**128):
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard
@property @property
def arn(self): def arn(self):
@ -103,16 +127,32 @@ class Stream(object):
else: else:
raise ShardNotFoundError(shard_id) raise ShardNotFoundError(shard_id)
def get_shard_for_key(self, partition_key): def get_shard_for_key(self, partition_key, explicit_hash_key):
# TODO implement sharding if not isinstance(partition_key, six.string_types):
shard = list(self.shards.values())[0] raise InvalidArgumentError("partition_key")
if len(partition_key) > 256:
raise InvalidArgumentError("partition_key")
if explicit_hash_key:
if not isinstance(explicit_hash_key, six.string_types):
raise InvalidArgumentError("explicit_hash_key")
key = int(explicit_hash_key)
if key >= 2**128:
raise InvalidArgumentError("explicit_hash_key")
else:
key = int(md5(partition_key.encode('utf-8')).hexdigest(), 16)
for shard in self.shards.values():
if shard.starting_hash <= key < shard.ending_hash:
return shard return shard
def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data): def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data):
partition_key = explicit_hash_key if explicit_hash_key else partition_key shard = self.get_shard_for_key(partition_key, explicit_hash_key)
shard = self.get_shard_for_key(partition_key)
sequence_number = shard.put_record(partition_key, data) sequence_number = shard.put_record(partition_key, data, explicit_hash_key)
return sequence_number, shard.shard_id return sequence_number, shard.shard_id
def to_json(self): def to_json(self):
@ -200,9 +240,10 @@ class KinesisBackend(BaseBackend):
self.streams = {} self.streams = {}
self.delivery_streams = {} self.delivery_streams = {}
def create_stream(self, stream_name, shard_count, region): def create_stream(self, stream_name, shard_count, region):
if stream_name in self.streams: if stream_name in self.streams:
return ResourceInUseError(stream_name) raise ResourceInUseError(stream_name)
stream = Stream(stream_name, shard_count, region) stream = Stream(stream_name, shard_count, region)
self.streams[stream_name] = stream self.streams[stream_name] = stream
return stream return stream
@ -264,7 +305,7 @@ class KinesisBackend(BaseBackend):
for record in records: for record in records:
partition_key = record.get("PartitionKey") partition_key = record.get("PartitionKey")
explicit_hash_key = record.get("ExplicitHashKey") explicit_hash_key = record.get("ExplicitHashKey")
data = record.get("data") data = record.get("Data")
sequence_number, shard_id = stream.put_record( sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, None, data partition_key, explicit_hash_key, None, data
@ -276,6 +317,60 @@ class KinesisBackend(BaseBackend):
return response return response
def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):
stream = self.describe_stream(stream_name)
if shard_to_split not in stream.shards:
raise ResourceNotFoundError(shard_to_split)
if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key):
raise InvalidArgumentError(new_starting_hash_key)
new_starting_hash_key = int(new_starting_hash_key)
shard = stream.shards[shard_to_split]
last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
new_shard = Shard(last_id+1, new_starting_hash_key, shard.ending_hash)
shard.ending_hash = new_starting_hash_key
stream.shards[new_shard.shard_id] = new_shard
else:
raise InvalidArgumentError(new_starting_hash_key)
records = shard.records
shard.records = OrderedDict()
for index in records:
record = records[index]
stream.put_record(
record.partition_key, record.explicit_hash_key, None, record.data
)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
if shard_to_merge not in stream.shards:
raise ResourceNotFoundError(shard_to_merge)
if adjacent_shard_to_merge not in stream.shards:
raise ResourceNotFoundError(adjacent_shard_to_merge)
shard1 = stream.shards[shard_to_merge]
shard2 = stream.shards[adjacent_shard_to_merge]
if shard1.ending_hash == shard2.starting_hash:
shard1.ending_hash = shard2.ending_hash
elif shard2.ending_hash == shard1.starting_hash:
shard1.starting_hash = shard2.starting_hash
else:
raise InvalidArgumentError(adjacent_shard_to_merge)
del stream.shards[shard2.shard_id]
for index in shard2.records:
record = shard2.records[index]
shard1.put_record(record.partition_key, record.data, record.explicit_hash_key)
''' Firehose ''' ''' Firehose '''
def create_delivery_stream(self, stream_name, **stream_kwargs): def create_delivery_stream(self, stream_name, **stream_kwargs):
stream = DeliveryStream(stream_name, **stream_kwargs) stream = DeliveryStream(stream_name, **stream_kwargs)
@ -299,6 +394,39 @@ class KinesisBackend(BaseBackend):
record = stream.put_record(record_data) record = stream.put_record(record_data)
return record return record
def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None):
stream = self.describe_stream(stream_name)
tags = []
result = {
'HasMoreTags': False,
'Tags': tags
}
for key, val in sorted(stream.tags.items(), key=lambda x:x[0]):
if limit and len(res) >= limit:
result['HasMoreTags'] = True
break
if exclusive_start_tag_key and key < exexclusive_start_tag_key:
continue
tags.append({
'Key': key,
'Value': val
})
return result
def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name)
stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name)
for key in tag_keys:
if key in stream.tags:
del stream.tags[key]
kinesis_backends = {} kinesis_backends = {}
for region in boto.kinesis.regions(): for region in boto.kinesis.regions():
kinesis_backends[region.name] = KinesisBackend() kinesis_backends[region.name] = KinesisBackend()

View File

@ -4,6 +4,7 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kinesis_backends from .models import kinesis_backends
from werkzeug.exceptions import BadRequest
class KinesisResponse(BaseResponse): class KinesisResponse(BaseResponse):
@ -43,7 +44,6 @@ class KinesisResponse(BaseResponse):
def delete_stream(self): def delete_stream(self):
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
self.kinesis_backend.delete_stream(stream_name) self.kinesis_backend.delete_stream(stream_name)
return "" return ""
def get_shard_iterator(self): def get_shard_iterator(self):
@ -91,7 +91,7 @@ class KinesisResponse(BaseResponse):
def put_records(self): def put_records(self):
if self.is_firehose: if self.is_firehose:
return self.firehose_put_record() return self.put_record_batch()
stream_name = self.parameters.get("StreamName") stream_name = self.parameters.get("StreamName")
records = self.parameters.get("Records") records = self.parameters.get("Records")
@ -101,6 +101,24 @@ class KinesisResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
def split_shard(self):
stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey")
response = self.kinesis_backend.split_shard(
stream_name, shard_to_split, new_starting_hash_key
)
return ""
def merge_shards(self):
stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
response = self.kinesis_backend.merge_shards(
stream_name, shard_to_merge, adjacent_shard_to_merge
)
return ""
''' Firehose ''' ''' Firehose '''
def create_delivery_stream(self): def create_delivery_stream(self):
stream_name = self.parameters['DeliveryStreamName'] stream_name = self.parameters['DeliveryStreamName']
@ -168,3 +186,22 @@ class KinesisResponse(BaseResponse):
"FailedPutCount": 0, "FailedPutCount": 0,
"RequestResponses": request_responses, "RequestResponses": request_responses,
}) })
def add_tags_to_stream(self):
stream_name = self.parameters.get('StreamName')
tags = self.parameters.get('Tags')
self.kinesis_backend.add_tags_to_stream(stream_name, tags)
return json.dumps({})
def list_tags_for_stream(self):
stream_name = self.parameters.get('StreamName')
exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey')
limit = self.parameters.get('Limit')
response = self.kinesis_backend.list_tags_for_stream(stream_name, exclusive_start_tag_key, limit)
return json.dumps(response)
def remove_tags_from_stream(self):
stream_name = self.parameters.get('StreamName')
tag_keys = self.parameters.get('TagKeys')
self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys)
return json.dumps({})

View File

@ -85,6 +85,10 @@ def test_put_records():
data = "hello world" data = "hello world"
partition_key = "1234" partition_key = "1234"
conn.put_record.when.called_with(
stream_name, data, 1234).should.throw(InvalidArgumentException)
conn.put_record(stream_name, data, partition_key) conn.put_record(stream_name, data, partition_key)
response = conn.describe_stream(stream_name) response = conn.describe_stream(stream_name)
@ -112,8 +116,9 @@ def test_get_records_limit():
# Create some data # Create some data
data = "hello world" data = "hello world"
for index in range(5): for index in range(5):
conn.put_record(stream_name, data, index) conn.put_record(stream_name, data, str(index))
# Get a shard iterator # Get a shard iterator
response = conn.describe_stream(stream_name) response = conn.describe_stream(stream_name)
@ -140,7 +145,7 @@ def test_get_records_at_sequence_number():
# Create some data # Create some data
for index in range(1, 5): for index in range(1, 5):
conn.put_record(stream_name, str(index), index) conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator # Get a shard iterator
response = conn.describe_stream(stream_name) response = conn.describe_stream(stream_name)
@ -171,7 +176,7 @@ def test_get_records_after_sequence_number():
# Create some data # Create some data
for index in range(1, 5): for index in range(1, 5):
conn.put_record(stream_name, str(index), index) conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator # Get a shard iterator
response = conn.describe_stream(stream_name) response = conn.describe_stream(stream_name)
@ -201,7 +206,7 @@ def test_get_records_latest():
# Create some data # Create some data
for index in range(1, 5): for index in range(1, 5):
conn.put_record(stream_name, str(index), index) conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator # Get a shard iterator
response = conn.describe_stream(stream_name) response = conn.describe_stream(stream_name)
@ -237,3 +242,144 @@ def test_invalid_shard_iterator_type():
shard_id = response['StreamDescription']['Shards'][0]['ShardId'] shard_id = response['StreamDescription']['Shards'][0]['ShardId']
response = conn.get_shard_iterator.when.called_with( response = conn.get_shard_iterator.when.called_with(
stream_name, shard_id, 'invalid-type').should.throw(InvalidArgumentException) stream_name, shard_id, 'invalid-type').should.throw(InvalidArgumentException)
@mock_kinesis
def test_add_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
conn.add_tags_to_stream(stream_name, {'tag1':'val3'})
conn.add_tags_to_stream(stream_name, {'tag2':'val4'})
@mock_kinesis
def test_list_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val1')
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val2')
conn.add_tags_to_stream(stream_name, {'tag1':'val3'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val3')
conn.add_tags_to_stream(stream_name, {'tag2':'val4'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val4')
@mock_kinesis
def test_remove_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val1')
conn.remove_tags_from_stream(stream_name, ['tag1'])
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal(None)
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val2')
conn.remove_tags_from_stream(stream_name, ['tag2'])
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal(None)
@mock_kinesis
def test_split_shard():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = 'my_stream'
conn.create_stream(stream_name, 2)
# Create some data
for index in range(1, 100):
conn.put_record(stream_name, str(index), str(index))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(2)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
shard_range = shards[0]['HashKeyRange']
new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) // 2
conn.split_shard("my_stream", shards[0]['ShardId'], str(new_starting_hash))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(3)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
shard_range = shards[2]['HashKeyRange']
new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) // 2
conn.split_shard("my_stream", shards[2]['ShardId'], str(new_starting_hash))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
@mock_kinesis
def test_merge_shards():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = 'my_stream'
conn.create_stream(stream_name, 4)
# Create some data
for index in range(1, 100):
conn.put_record(stream_name, str(index), str(index))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
conn.merge_shards.when.called_with(stream_name, 'shardId-000000000000', 'shardId-000000000002').should.throw(InvalidArgumentException)
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
conn.merge_shards(stream_name, 'shardId-000000000000', 'shardId-000000000001')
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(3)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
conn.merge_shards(stream_name, 'shardId-000000000002', 'shardId-000000000000')
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(2)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)