Finalize implementation of DynamoDB Streams

This commit is contained in:
Karl Gutwin 2018-11-08 10:54:54 -05:00
parent 519899f74f
commit 0f6086f708
5 changed files with 166 additions and 33 deletions

View File

@ -1,10 +1,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import defaultdict from collections import defaultdict, namedtuple
import copy import copy
import datetime import datetime
import decimal import decimal
import json import json
import re import re
import uuid
import boto3 import boto3
from moto.compat import OrderedDict from moto.compat import OrderedDict
@ -292,6 +293,44 @@ class Item(BaseModel):
'ADD not supported for %s' % ', '.join(update_action['Value'].keys())) 'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
class StreamRecord(BaseModel):
def __init__(self, table, stream_type, event_name, old, new, seq):
old_a = old.to_json()['Attributes'] if old is not None else {}
new_a = new.to_json()['Attributes'] if new is not None else {}
rec = old if old is not None else new
keys = {table.hash_key_attr: rec.hash_key.to_json()}
if table.range_key_attr is not None:
keys[table.range_key_attr] = rec.range_key.to_json()
self.record = {
'eventID': uuid.uuid4().hex,
'eventName': event_name,
'eventSource': 'aws:dynamodb',
'eventVersion': '1.0',
'awsRegion': 'us-east-1',
'dynamodb': {
'StreamViewType': stream_type,
'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(),
'SequenceNumber': seq,
'SizeBytes': 1,
'Keys': keys
}
}
if stream_type in ('NEW_IMAGE', 'NEW_AND_OLD_IMAGES'):
self.record['dynamodb']['NewImage'] = new_a
if stream_type in ('OLD_IMAGE', 'NEW_AND_OLD_IMAGES'):
self.record['dynamodb']['OldImage'] = old_a
# This is a substantial overestimate but it's the easiest to do now
self.record['dynamodb']['SizeBytes'] = len(
json.dumps(self.record['dynamodb']))
def to_json(self):
return self.record
class StreamShard(BaseModel): class StreamShard(BaseModel):
def __init__(self, table): def __init__(self, table):
self.table = table self.table = table
@ -310,15 +349,22 @@ class StreamShard(BaseModel):
def add(self, old, new): def add(self, old, new):
t = self.table.stream_specification['StreamViewType'] t = self.table.stream_specification['StreamViewType']
if t == 'KEYS_ONLY': if old is None:
self.items.append(new.key) event_name = 'INSERT'
elif t == 'NEW_IMAGE': elif new is None:
self.items.append(new) event_name = 'DELETE'
elif t == 'OLD_IMAGE': else:
self.items.append(old) event_name = 'MODIFY'
elif t == 'NEW_AND_OLD_IMAGES': seq = len(self.items) + self.starting_sequence_number
self.items.append((old, new)) self.items.append(
StreamRecord(self.table, t, event_name, old, new, seq))
def get(self, start, quantity):
start -= self.starting_sequence_number
assert start >= 0
end = start + quantity
return [i.to_json() for i in self.items[start:end]]
class Table(BaseModel): class Table(BaseModel):
@ -428,22 +474,22 @@ class Table(BaseModel):
else: else:
range_value = None range_value = None
if expected is None:
expected = {}
lookup_range_value = range_value
else:
expected_range_value = expected.get(
self.range_key_attr, {}).get("Value")
if(expected_range_value is None):
lookup_range_value = range_value
else:
lookup_range_value = DynamoType(expected_range_value)
current = self.get_item(hash_value, lookup_range_value) current = self.get_item(hash_value, lookup_range_value)
item = Item(hash_value, self.hash_key_type, range_value, item = Item(hash_value, self.hash_key_type, range_value,
self.range_key_type, item_attrs) self.range_key_type, item_attrs)
if not overwrite: if not overwrite:
if expected is None:
expected = {}
lookup_range_value = range_value
else:
expected_range_value = expected.get(
self.range_key_attr, {}).get("Value")
if(expected_range_value is None):
lookup_range_value = range_value
else:
lookup_range_value = DynamoType(expected_range_value)
if current is None: if current is None:
current_attr = {} current_attr = {}
elif hasattr(current, 'attrs'): elif hasattr(current, 'attrs'):
@ -508,9 +554,14 @@ class Table(BaseModel):
def delete_item(self, hash_key, range_key): def delete_item(self, hash_key, range_key):
try: try:
if range_key: if range_key:
return self.items[hash_key].pop(range_key) item = self.items[hash_key].pop(range_key)
else: else:
return self.items.pop(hash_key) item = self.items.pop(hash_key)
if self.stream_shard is not None:
self.stream_shard.add(item, None)
return item
except KeyError: except KeyError:
return None return None

View File

@ -11,8 +11,9 @@ from moto.dynamodb2.models import dynamodb_backends
class ShardIterator(BaseModel): class ShardIterator(BaseModel):
def __init__(self, stream_shard, shard_iterator_type, sequence_number=None): def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None):
self.id = base64.b64encode(os.urandom(472)).decode('utf-8') self.id = base64.b64encode(os.urandom(472)).decode('utf-8')
self.streams_backend = streams_backend
self.stream_shard = stream_shard self.stream_shard = stream_shard
self.shard_iterator_type = shard_iterator_type self.shard_iterator_type = shard_iterator_type
if shard_iterator_type == 'TRIM_HORIZON': if shard_iterator_type == 'TRIM_HORIZON':
@ -24,18 +25,43 @@ class ShardIterator(BaseModel):
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER': elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER':
self.sequence_number = sequence_number + 1 self.sequence_number = sequence_number + 1
@property
def arn(self):
return '{}/stream/{}|1|{}'.format(
self.stream_shard.table.table_arn,
self.stream_shard.table.latest_stream_label,
self.id)
def to_json(self): def to_json(self):
return { return {
'ShardIterator': '{}/stream/{}|1|{}'.format( 'ShardIterator': self.arn
self.stream_shard.table.table_arn, }
self.stream_shard.table.latest_stream_label,
self.id) def get(self, limit=1000):
items = self.stream_shard.get(self.sequence_number, limit)
try:
last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items)
new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard,
'AFTER_SEQUENCE_NUMBER',
last_sequence_number)
except ValueError:
new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard,
'AT_SEQUENCE_NUMBER',
self.sequence_number)
self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator
return {
'NextShardIterator': new_shard_iterator.arn,
'Records': items
} }
class DynamoDBStreamsBackend(BaseBackend): class DynamoDBStreamsBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
self.region = region self.region = region
self.shard_iterators = {}
def reset(self): def reset(self):
region = self.region region = self.region
@ -86,11 +112,17 @@ class DynamoDBStreamsBackend(BaseBackend):
table = self._get_table_from_arn(arn) table = self._get_table_from_arn(arn)
assert table.stream_shard.id == shard_id assert table.stream_shard.id == shard_id
shard_iterator = ShardIterator(table.stream_shard, shard_iterator_type, shard_iterator = ShardIterator(self, table.stream_shard,
shard_iterator_type,
sequence_number) sequence_number)
self.shard_iterators[shard_iterator.arn] = shard_iterator
return json.dumps(shard_iterator.to_json()) return json.dumps(shard_iterator.to_json())
def get_records(self, iterator_arn, limit):
shard_iterator = self.shard_iterators[iterator_arn]
return json.dumps(shard_iterator.get(limit))
available_regions = boto3.session.Session().get_available_regions( available_regions = boto3.session.Session().get_available_regions(

View File

@ -27,3 +27,10 @@ class DynamoDBStreamsHandler(BaseResponse):
shard_iterator_type = self._get_param('ShardIteratorType') shard_iterator_type = self._get_param('ShardIteratorType')
return self.backend.get_shard_iterator(arn, shard_id, return self.backend.get_shard_iterator(arn, shard_id,
shard_iterator_type) shard_iterator_type)
def get_records(self):
arn = self._get_param('ShardIterator')
limit = self._get_param('Limit')
if limit is None:
limit = 1000
return self.backend.get_records(arn, limit)

View File

@ -1,7 +1,7 @@
[nosetests] [nosetests]
verbosity=1 verbosity=1
detailed-errors=1 detailed-errors=1
#with-coverage=1 with-coverage=1
cover-package=moto cover-package=moto
[bdist_wheel] [bdist_wheel]

View File

@ -75,7 +75,7 @@ class TestClass():
) )
assert 'ShardIterator' in resp assert 'ShardIterator' in resp
def test_get_records(self): def test_get_records_empty(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1') conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn) resp = conn.describe_stream(StreamArn=self.stream_arn)
@ -90,7 +90,50 @@ class TestClass():
resp = conn.get_records(ShardIterator=iterator_id) resp = conn.get_records(ShardIterator=iterator_id)
assert 'Records' in resp assert 'Records' in resp
assert len(resp['Records']) == 0
# TODO: Add tests for inserting records into the stream, and def test_get_records_seq(self):
# the various stream types conn = boto3.client('dynamodb', region_name='us-east-1')
conn.put_item(
TableName='test-streams',
Item={
'id': {'S': 'entry1'},
'first_col': {'S': 'foo'}
}
)
conn.put_item(
TableName='test-streams',
Item={
'id': {'S': 'entry1'},
'first_col': {'S': 'bar'},
'second_col': {'S': 'baz'}
}
)
conn.delete_item(
TableName='test-streams',
Key={'id': {'S': 'entry1'}}
)
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn)
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='TRIM_HORIZON'
)
iterator_id = resp['ShardIterator']
resp = conn.get_records(ShardIterator=iterator_id)
assert len(resp['Records']) == 3
assert resp['Records'][0]['eventName'] == 'INSERT'
assert resp['Records'][1]['eventName'] == 'MODIFY'
assert resp['Records'][2]['eventName'] == 'DELETE'
# now try fetching from the next shard iterator, it should be
# empty
resp = conn.get_records(ShardIterator=resp['NextShardIterator'])
assert len(resp['Records']) == 0