Merge pull request #475 from silveregg/master

Add Kinesis API and fix some typo
This commit is contained in:
Steve Pulec 2015-12-05 21:04:41 -05:00
commit be3291b758
3 changed files with 338 additions and 27 deletions

View File

@ -2,19 +2,27 @@ from __future__ import unicode_literals
import datetime
import time
import boto.kinesis
import re
import six
import itertools
from operator import attrgetter
from hashlib import md5
from moto.compat import OrderedDict
from moto.core import BaseBackend
from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError
from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \
ResourceNotFoundError, InvalidArgumentError
from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator
class Record(object):
def __init__(self, partition_key, data, sequence_number):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
self.partition_key = partition_key
self.data = data
self.sequence_number = sequence_number
self.explicit_hash_key = explicit_hash_key
def to_json(self):
return {
@ -25,10 +33,16 @@ class Record(object):
class Shard(object):
def __init__(self, shard_id):
self.shard_id = shard_id
def __init__(self, shard_id, starting_hash, ending_hash):
self._shard_id = shard_id
self.starting_hash = starting_hash
self.ending_hash = ending_hash
self.records = OrderedDict()
@property
def shard_id(self):
return "shardId-{0}".format(str(self._shard_id).zfill(12))
def get_records(self, last_sequence_id, limit):
last_sequence_id = int(last_sequence_id)
results = []
@ -43,14 +57,14 @@ class Shard(object):
return results, last_sequence_id
def put_record(self, partition_key, data):
def put_record(self, partition_key, data, explicit_hash_key):
# Note: this function is not safe for concurrency
if self.records:
last_sequence_number = self.get_max_sequence_number()
else:
last_sequence_number = 0
sequence_number = last_sequence_number + 1
self.records[sequence_number] = Record(partition_key, data, sequence_number)
self.records[sequence_number] = Record(partition_key, data, sequence_number, explicit_hash_key)
return sequence_number
def get_min_sequence_number(self):
@ -66,8 +80,8 @@ class Shard(object):
def to_json(self):
return {
"HashKeyRange": {
"EndingHashKey": "113427455640312821154458202477256070484",
"StartingHashKey": "0"
"EndingHashKey": str(self.ending_hash),
"StartingHashKey": str(self.starting_hash)
},
"SequenceNumberRange": {
"EndingSequenceNumber": self.get_max_sequence_number(),
@ -78,16 +92,26 @@ class Shard(object):
class Stream(object):
def __init__(self, stream_name, shard_count, region):
self.stream_name = stream_name
self.shard_count = shard_count
self.region = region
self.account_number = "123456789012"
self.shards = {}
self.tags = {}
for index in range(shard_count):
shard_id = "shardId-{0}".format(str(index).zfill(12))
self.shards[shard_id] = Shard(shard_id)
if six.PY3:
izip_longest = itertools.zip_longest
else:
izip_longest = itertools.izip_longest
for index, start, end in izip_longest(range(shard_count),
range(0,2**128,2**128//shard_count),
range(2**128//shard_count,2**128,2**128//shard_count),
fillvalue=2**128):
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard
@property
def arn(self):
@ -103,16 +127,32 @@ class Stream(object):
else:
raise ShardNotFoundError(shard_id)
def get_shard_for_key(self, partition_key):
# TODO implement sharding
shard = list(self.shards.values())[0]
return shard
def get_shard_for_key(self, partition_key, explicit_hash_key):
if not isinstance(partition_key, six.string_types):
raise InvalidArgumentError("partition_key")
if len(partition_key) > 256:
raise InvalidArgumentError("partition_key")
if explicit_hash_key:
if not isinstance(explicit_hash_key, six.string_types):
raise InvalidArgumentError("explicit_hash_key")
key = int(explicit_hash_key)
if key >= 2**128:
raise InvalidArgumentError("explicit_hash_key")
else:
key = int(md5(partition_key.encode('utf-8')).hexdigest(), 16)
for shard in self.shards.values():
if shard.starting_hash <= key < shard.ending_hash:
return shard
def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data):
partition_key = explicit_hash_key if explicit_hash_key else partition_key
shard = self.get_shard_for_key(partition_key)
shard = self.get_shard_for_key(partition_key, explicit_hash_key)
sequence_number = shard.put_record(partition_key, data)
sequence_number = shard.put_record(partition_key, data, explicit_hash_key)
return sequence_number, shard.shard_id
def to_json(self):
@ -200,9 +240,10 @@ class KinesisBackend(BaseBackend):
self.streams = {}
self.delivery_streams = {}
def create_stream(self, stream_name, shard_count, region):
if stream_name in self.streams:
return ResourceInUseError(stream_name)
raise ResourceInUseError(stream_name)
stream = Stream(stream_name, shard_count, region)
self.streams[stream_name] = stream
return stream
@ -264,7 +305,7 @@ class KinesisBackend(BaseBackend):
for record in records:
partition_key = record.get("PartitionKey")
explicit_hash_key = record.get("ExplicitHashKey")
data = record.get("data")
data = record.get("Data")
sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, None, data
@ -276,6 +317,60 @@ class KinesisBackend(BaseBackend):
return response
def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):
stream = self.describe_stream(stream_name)
if shard_to_split not in stream.shards:
raise ResourceNotFoundError(shard_to_split)
if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key):
raise InvalidArgumentError(new_starting_hash_key)
new_starting_hash_key = int(new_starting_hash_key)
shard = stream.shards[shard_to_split]
last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
new_shard = Shard(last_id+1, new_starting_hash_key, shard.ending_hash)
shard.ending_hash = new_starting_hash_key
stream.shards[new_shard.shard_id] = new_shard
else:
raise InvalidArgumentError(new_starting_hash_key)
records = shard.records
shard.records = OrderedDict()
for index in records:
record = records[index]
stream.put_record(
record.partition_key, record.explicit_hash_key, None, record.data
)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
if shard_to_merge not in stream.shards:
raise ResourceNotFoundError(shard_to_merge)
if adjacent_shard_to_merge not in stream.shards:
raise ResourceNotFoundError(adjacent_shard_to_merge)
shard1 = stream.shards[shard_to_merge]
shard2 = stream.shards[adjacent_shard_to_merge]
if shard1.ending_hash == shard2.starting_hash:
shard1.ending_hash = shard2.ending_hash
elif shard2.ending_hash == shard1.starting_hash:
shard1.starting_hash = shard2.starting_hash
else:
raise InvalidArgumentError(adjacent_shard_to_merge)
del stream.shards[shard2.shard_id]
for index in shard2.records:
record = shard2.records[index]
shard1.put_record(record.partition_key, record.data, record.explicit_hash_key)
''' Firehose '''
def create_delivery_stream(self, stream_name, **stream_kwargs):
stream = DeliveryStream(stream_name, **stream_kwargs)
@ -299,6 +394,39 @@ class KinesisBackend(BaseBackend):
record = stream.put_record(record_data)
return record
def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None):
stream = self.describe_stream(stream_name)
tags = []
result = {
'HasMoreTags': False,
'Tags': tags
}
for key, val in sorted(stream.tags.items(), key=lambda x:x[0]):
if limit and len(res) >= limit:
result['HasMoreTags'] = True
break
if exclusive_start_tag_key and key < exexclusive_start_tag_key:
continue
tags.append({
'Key': key,
'Value': val
})
return result
def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name)
stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name)
for key in tag_keys:
if key in stream.tags:
del stream.tags[key]
kinesis_backends = {}
for region in boto.kinesis.regions():
kinesis_backends[region.name] = KinesisBackend()

View File

@ -4,6 +4,7 @@ import json
from moto.core.responses import BaseResponse
from .models import kinesis_backends
from werkzeug.exceptions import BadRequest
class KinesisResponse(BaseResponse):
@ -43,7 +44,6 @@ class KinesisResponse(BaseResponse):
def delete_stream(self):
stream_name = self.parameters.get("StreamName")
self.kinesis_backend.delete_stream(stream_name)
return ""
def get_shard_iterator(self):
@ -91,7 +91,7 @@ class KinesisResponse(BaseResponse):
def put_records(self):
if self.is_firehose:
return self.firehose_put_record()
return self.put_record_batch()
stream_name = self.parameters.get("StreamName")
records = self.parameters.get("Records")
@ -101,6 +101,24 @@ class KinesisResponse(BaseResponse):
return json.dumps(response)
def split_shard(self):
stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey")
response = self.kinesis_backend.split_shard(
stream_name, shard_to_split, new_starting_hash_key
)
return ""
def merge_shards(self):
stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
response = self.kinesis_backend.merge_shards(
stream_name, shard_to_merge, adjacent_shard_to_merge
)
return ""
''' Firehose '''
def create_delivery_stream(self):
stream_name = self.parameters['DeliveryStreamName']
@ -168,3 +186,22 @@ class KinesisResponse(BaseResponse):
"FailedPutCount": 0,
"RequestResponses": request_responses,
})
def add_tags_to_stream(self):
stream_name = self.parameters.get('StreamName')
tags = self.parameters.get('Tags')
self.kinesis_backend.add_tags_to_stream(stream_name, tags)
return json.dumps({})
def list_tags_for_stream(self):
stream_name = self.parameters.get('StreamName')
exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey')
limit = self.parameters.get('Limit')
response = self.kinesis_backend.list_tags_for_stream(stream_name, exclusive_start_tag_key, limit)
return json.dumps(response)
def remove_tags_from_stream(self):
stream_name = self.parameters.get('StreamName')
tag_keys = self.parameters.get('TagKeys')
self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys)
return json.dumps({})

View File

@ -85,6 +85,10 @@ def test_put_records():
data = "hello world"
partition_key = "1234"
conn.put_record.when.called_with(
stream_name, data, 1234).should.throw(InvalidArgumentException)
conn.put_record(stream_name, data, partition_key)
response = conn.describe_stream(stream_name)
@ -112,8 +116,9 @@ def test_get_records_limit():
# Create some data
data = "hello world"
for index in range(5):
conn.put_record(stream_name, data, index)
conn.put_record(stream_name, data, str(index))
# Get a shard iterator
response = conn.describe_stream(stream_name)
@ -140,7 +145,7 @@ def test_get_records_at_sequence_number():
# Create some data
for index in range(1, 5):
conn.put_record(stream_name, str(index), index)
conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator
response = conn.describe_stream(stream_name)
@ -171,7 +176,7 @@ def test_get_records_after_sequence_number():
# Create some data
for index in range(1, 5):
conn.put_record(stream_name, str(index), index)
conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator
response = conn.describe_stream(stream_name)
@ -201,7 +206,7 @@ def test_get_records_latest():
# Create some data
for index in range(1, 5):
conn.put_record(stream_name, str(index), index)
conn.put_record(stream_name, str(index), str(index))
# Get a shard iterator
response = conn.describe_stream(stream_name)
@ -237,3 +242,144 @@ def test_invalid_shard_iterator_type():
shard_id = response['StreamDescription']['Shards'][0]['ShardId']
response = conn.get_shard_iterator.when.called_with(
stream_name, shard_id, 'invalid-type').should.throw(InvalidArgumentException)
@mock_kinesis
def test_add_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
conn.add_tags_to_stream(stream_name, {'tag1':'val3'})
conn.add_tags_to_stream(stream_name, {'tag2':'val4'})
@mock_kinesis
def test_list_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val1')
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val2')
conn.add_tags_to_stream(stream_name, {'tag1':'val3'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val3')
conn.add_tags_to_stream(stream_name, {'tag2':'val4'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val4')
@mock_kinesis
def test_remove_tags():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = "my_stream"
conn.create_stream(stream_name, 1)
conn.describe_stream(stream_name)
conn.add_tags_to_stream(stream_name, {'tag1':'val1'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal('val1')
conn.remove_tags_from_stream(stream_name, ['tag1'])
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag1').should.equal(None)
conn.add_tags_to_stream(stream_name, {'tag2':'val2'})
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal('val2')
conn.remove_tags_from_stream(stream_name, ['tag2'])
tags = dict([(tag['Key'], tag['Value']) for tag in conn.list_tags_for_stream(stream_name)['Tags']])
tags.get('tag2').should.equal(None)
@mock_kinesis
def test_split_shard():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = 'my_stream'
conn.create_stream(stream_name, 2)
# Create some data
for index in range(1, 100):
conn.put_record(stream_name, str(index), str(index))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(2)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
shard_range = shards[0]['HashKeyRange']
new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) // 2
conn.split_shard("my_stream", shards[0]['ShardId'], str(new_starting_hash))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(3)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
shard_range = shards[2]['HashKeyRange']
new_starting_hash = (int(shard_range['EndingHashKey'])+int(shard_range['StartingHashKey'])) // 2
conn.split_shard("my_stream", shards[2]['ShardId'], str(new_starting_hash))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
@mock_kinesis
def test_merge_shards():
conn = boto.kinesis.connect_to_region("us-west-2")
stream_name = 'my_stream'
conn.create_stream(stream_name, 4)
# Create some data
for index in range(1, 100):
conn.put_record(stream_name, str(index), str(index))
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
conn.merge_shards.when.called_with(stream_name, 'shardId-000000000000', 'shardId-000000000002').should.throw(InvalidArgumentException)
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(4)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
conn.merge_shards(stream_name, 'shardId-000000000000', 'shardId-000000000001')
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(3)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)
conn.merge_shards(stream_name, 'shardId-000000000002', 'shardId-000000000000')
stream_response = conn.describe_stream(stream_name)
stream = stream_response["StreamDescription"]
shards = stream['Shards']
shards.should.have.length_of(2)
sum([shard['SequenceNumberRange']['EndingSequenceNumber'] for shard in shards]).should.equal(99)