moto/moto/kinesis/models.py
2015-12-05 12:31:58 +00:00

433 lines
15 KiB
Python

from __future__ import unicode_literals
import datetime
import time
import boto.kinesis
import re
import six
import itertools
from operator import attrgetter
from hashlib import md5
from moto.compat import OrderedDict
from moto.core import BaseBackend
from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \
ResourceNotFoundError, InvalidArgumentError
from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator
class Record(object):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
self.partition_key = partition_key
self.data = data
self.sequence_number = sequence_number
self.explicit_hash_key = explicit_hash_key
def to_json(self):
return {
"Data": self.data,
"PartitionKey": self.partition_key,
"SequenceNumber": str(self.sequence_number),
}
class Shard(object):
def __init__(self, shard_id, starting_hash, ending_hash):
self._shard_id = shard_id
self.starting_hash = starting_hash
self.ending_hash = ending_hash
self.records = OrderedDict()
@property
def shard_id(self):
return "shardId-{0}".format(str(self._shard_id).zfill(12))
def get_records(self, last_sequence_id, limit):
last_sequence_id = int(last_sequence_id)
results = []
for sequence_number, record in self.records.items():
if sequence_number > last_sequence_id:
results.append(record)
last_sequence_id = sequence_number
if len(results) == limit:
break
return results, last_sequence_id
def put_record(self, partition_key, data, explicit_hash_key):
# Note: this function is not safe for concurrency
if self.records:
last_sequence_number = self.get_max_sequence_number()
else:
last_sequence_number = 0
sequence_number = last_sequence_number + 1
self.records[sequence_number] = Record(partition_key, data, sequence_number, explicit_hash_key)
return sequence_number
def get_min_sequence_number(self):
if self.records:
return list(self.records.keys())[0]
return 0
def get_max_sequence_number(self):
if self.records:
return list(self.records.keys())[-1]
return 0
def to_json(self):
return {
"HashKeyRange": {
"EndingHashKey": str(self.ending_hash),
"StartingHashKey": str(self.starting_hash)
},
"SequenceNumberRange": {
"EndingSequenceNumber": self.get_max_sequence_number(),
"StartingSequenceNumber": self.get_min_sequence_number(),
},
"ShardId": self.shard_id
}
class Stream(object):
def __init__(self, stream_name, shard_count, region):
self.stream_name = stream_name
self.shard_count = shard_count
self.region = region
self.account_number = "123456789012"
self.shards = {}
self.tags = {}
if six.PY3:
izip_longest = itertools.zip_longest
else:
izip_longest = itertools.izip_longest
for index, start, end in izip_longest(range(shard_count),
range(0,2**128,2**128//shard_count),
range(2**128//shard_count,2**128,2**128//shard_count),
fillvalue=2**128):
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard
@property
def arn(self):
return "arn:aws:kinesis:{region}:{account_number}:{stream_name}".format(
region=self.region,
account_number=self.account_number,
stream_name=self.stream_name
)
def get_shard(self, shard_id):
if shard_id in self.shards:
return self.shards[shard_id]
else:
raise ShardNotFoundError(shard_id)
def get_shard_for_key(self, partition_key, explicit_hash_key):
if not isinstance(partition_key, basestring):
raise InvalidArgumentError("partition_key")
if len(partition_key) > 256:
raise InvalidArgumentError("partition_key")
if explicit_hash_key:
if not isinstance(explicit_hash_key, basestring):
raise InvalidArgumentError("explicit_hash_key")
key = int(explicit_hash_key)
if key >= 2**128:
raise InvalidArgumentError("explicit_hash_key")
else:
key = int(md5(partition_key).hexdigest(), 16)
for shard in self.shards.values():
if shard.starting_hash <= key < shard.ending_hash:
return shard
def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data):
shard = self.get_shard_for_key(partition_key, explicit_hash_key)
sequence_number = shard.put_record(partition_key, data, explicit_hash_key)
return sequence_number, shard.shard_id
def to_json(self):
return {
"StreamDescription": {
"StreamARN": self.arn,
"StreamName": self.stream_name,
"StreamStatus": "ACTIVE",
"HasMoreShards": False,
"Shards": [shard.to_json() for shard in self.shards.values()],
}
}
class FirehoseRecord(object):
def __init__(self, record_data):
self.record_id = 12345678
self.record_data = record_data
class DeliveryStream(object):
def __init__(self, stream_name, **stream_kwargs):
self.name = stream_name
self.redshift_username = stream_kwargs['redshift_username']
self.redshift_password = stream_kwargs['redshift_password']
self.redshift_jdbc_url = stream_kwargs['redshift_jdbc_url']
self.redshift_role_arn = stream_kwargs['redshift_role_arn']
self.redshift_copy_command = stream_kwargs['redshift_copy_command']
self.redshift_s3_role_arn = stream_kwargs['redshift_s3_role_arn']
self.redshift_s3_bucket_arn = stream_kwargs['redshift_s3_bucket_arn']
self.redshift_s3_prefix = stream_kwargs['redshift_s3_prefix']
self.redshift_s3_compression_format = stream_kwargs.get('redshift_s3_compression_format', 'UNCOMPRESSED')
self.redshift_s3_buffering_hings = stream_kwargs['redshift_s3_buffering_hings']
self.records = []
self.status = 'ACTIVE'
self.create_at = datetime.datetime.utcnow()
self.last_updated = datetime.datetime.utcnow()
@property
def arn(self):
return 'arn:aws:firehose:us-east-1:123456789012:deliverystream/{0}'.format(self.name)
def to_dict(self):
return {
"DeliveryStreamDescription": {
"CreateTimestamp": time.mktime(self.create_at.timetuple()),
"DeliveryStreamARN": self.arn,
"DeliveryStreamName": self.name,
"DeliveryStreamStatus": self.status,
"Destinations": [
{
"DestinationId": "string",
"RedshiftDestinationDescription": {
"ClusterJDBCURL": self.redshift_jdbc_url,
"CopyCommand": self.redshift_copy_command,
"RoleARN": self.redshift_role_arn,
"S3DestinationDescription": {
"BucketARN": self.redshift_s3_bucket_arn,
"BufferingHints": self.redshift_s3_buffering_hings,
"CompressionFormat": self.redshift_s3_compression_format,
"Prefix": self.redshift_s3_prefix,
"RoleARN": self.redshift_s3_role_arn
},
"Username": self.redshift_username,
},
}
],
"HasMoreDestinations": False,
"LastUpdateTimestamp": time.mktime(self.last_updated.timetuple()),
"VersionId": "string",
}
}
def put_record(self, record_data):
record = FirehoseRecord(record_data)
self.records.append(record)
return record
class KinesisBackend(BaseBackend):
def __init__(self):
self.streams = {}
self.delivery_streams = {}
def create_stream(self, stream_name, shard_count, region):
if stream_name in self.streams:
raise ResourceInUseError(stream_name)
stream = Stream(stream_name, shard_count, region)
self.streams[stream_name] = stream
return stream
def describe_stream(self, stream_name):
if stream_name in self.streams:
return self.streams[stream_name]
else:
raise StreamNotFoundError(stream_name)
def list_streams(self):
return self.streams.values()
def delete_stream(self, stream_name):
if stream_name in self.streams:
return self.streams.pop(stream_name)
raise StreamNotFoundError(stream_name)
def get_shard_iterator(self, stream_name, shard_id, shard_iterator_type, starting_sequence_number):
# Validate params
stream = self.describe_stream(stream_name)
shard = stream.get_shard(shard_id)
shard_iterator = compose_new_shard_iterator(
stream_name, shard, shard_iterator_type, starting_sequence_number
)
return shard_iterator
def get_records(self, shard_iterator, limit):
decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed
stream = self.describe_stream(stream_name)
shard = stream.get_shard(shard_id)
records, last_sequence_id = shard.get_records(last_sequence_id, limit)
next_shard_iterator = compose_shard_iterator(stream_name, shard, last_sequence_id)
return next_shard_iterator, records
def put_record(self, stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data):
stream = self.describe_stream(stream_name)
sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, sequence_number_for_ordering, data
)
return sequence_number, shard_id
def put_records(self, stream_name, records):
stream = self.describe_stream(stream_name)
response = {
"FailedRecordCount": 0,
"Records" : []
}
for record in records:
partition_key = record.get("PartitionKey")
explicit_hash_key = record.get("ExplicitHashKey")
data = record.get("Data")
sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, None, data
)
response['Records'].append({
"SequenceNumber": sequence_number,
"ShardId": shard_id
})
return response
def split_shard(self, stream_name, shard_to_split, new_starting_hash_key):
stream = self.describe_stream(stream_name)
if shard_to_split not in stream.shards:
raise ResourceNotFoundError(shard_to_split)
if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key):
raise InvalidArgumentError(new_starting_hash_key)
new_starting_hash_key = int(new_starting_hash_key)
shard = stream.shards[shard_to_split]
last_id = sorted(stream.shards.values(), key=attrgetter('_shard_id'))[-1]._shard_id
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
new_shard = Shard(last_id+1, new_starting_hash_key, shard.ending_hash)
shard.ending_hash = new_starting_hash_key
stream.shards[new_shard.shard_id] = new_shard
else:
raise InvalidArgumentError(new_starting_hash_key)
records = shard.records
shard.records = OrderedDict()
for index in records:
record = records[index]
stream.put_record(
record.partition_key, record.explicit_hash_key, None, record.data
)
def merge_shards(self, stream_name, shard_to_merge, adjacent_shard_to_merge):
stream = self.describe_stream(stream_name)
if shard_to_merge not in stream.shards:
raise ResourceNotFoundError(shard_to_merge)
if adjacent_shard_to_merge not in stream.shards:
raise ResourceNotFoundError(adjacent_shard_to_merge)
shard1 = stream.shards[shard_to_merge]
shard2 = stream.shards[adjacent_shard_to_merge]
if shard1.ending_hash == shard2.starting_hash:
shard1.ending_hash = shard2.ending_hash
elif shard2.ending_hash == shard1.starting_hash:
shard1.starting_hash = shard2.starting_hash
else:
raise InvalidArgumentError(adjacent_shard_to_merge)
del stream.shards[shard2.shard_id]
for index in shard2.records:
record = shard2.records[index]
shard1.put_record(record.partition_key, record.data, record.explicit_hash_key)
''' Firehose '''
def create_delivery_stream(self, stream_name, **stream_kwargs):
stream = DeliveryStream(stream_name, **stream_kwargs)
self.delivery_streams[stream_name] = stream
return stream
def get_delivery_stream(self, stream_name):
if stream_name in self.delivery_streams:
return self.delivery_streams[stream_name]
else:
raise StreamNotFoundError(stream_name)
def list_delivery_streams(self):
return self.delivery_streams.values()
def delete_delivery_stream(self, stream_name):
self.delivery_streams.pop(stream_name)
def put_firehose_record(self, stream_name, record_data):
stream = self.get_delivery_stream(stream_name)
record = stream.put_record(record_data)
return record
def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None):
stream = self.describe_stream(stream_name)
tags = []
result = {
'HasMoreTags': False,
'Tags': tags
}
for key, val in sorted(stream.tags.items(), key=lambda x:x[0]):
if limit and len(res) >= limit:
result['HasMoreTags'] = True
break
if exclusive_start_tag_key and key < exexclusive_start_tag_key:
continue
tags.append({
'Key': key,
'Value': val
})
return result
def add_tags_to_stream(self, stream_name, tags):
stream = self.describe_stream(stream_name)
stream.tags.update(tags)
def remove_tags_from_stream(self, stream_name, tag_keys):
stream = self.describe_stream(stream_name)
for key in tag_keys:
if key in stream.tags:
del stream.tags[key]
kinesis_backends = {}
for region in boto.kinesis.regions():
kinesis_backends[region.name] = KinesisBackend()