Finalize implementation of DynamoDB Streams
This commit is contained in:
parent
519899f74f
commit
0f6086f708
@ -1,10 +1,11 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from collections import defaultdict
|
from collections import defaultdict, namedtuple
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import decimal
|
import decimal
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from moto.compat import OrderedDict
|
from moto.compat import OrderedDict
|
||||||
@ -292,6 +293,44 @@ class Item(BaseModel):
|
|||||||
'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
|
'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
|
||||||
|
|
||||||
|
|
||||||
|
class StreamRecord(BaseModel):
|
||||||
|
def __init__(self, table, stream_type, event_name, old, new, seq):
|
||||||
|
old_a = old.to_json()['Attributes'] if old is not None else {}
|
||||||
|
new_a = new.to_json()['Attributes'] if new is not None else {}
|
||||||
|
|
||||||
|
rec = old if old is not None else new
|
||||||
|
keys = {table.hash_key_attr: rec.hash_key.to_json()}
|
||||||
|
if table.range_key_attr is not None:
|
||||||
|
keys[table.range_key_attr] = rec.range_key.to_json()
|
||||||
|
|
||||||
|
self.record = {
|
||||||
|
'eventID': uuid.uuid4().hex,
|
||||||
|
'eventName': event_name,
|
||||||
|
'eventSource': 'aws:dynamodb',
|
||||||
|
'eventVersion': '1.0',
|
||||||
|
'awsRegion': 'us-east-1',
|
||||||
|
'dynamodb': {
|
||||||
|
'StreamViewType': stream_type,
|
||||||
|
'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(),
|
||||||
|
'SequenceNumber': seq,
|
||||||
|
'SizeBytes': 1,
|
||||||
|
'Keys': keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream_type in ('NEW_IMAGE', 'NEW_AND_OLD_IMAGES'):
|
||||||
|
self.record['dynamodb']['NewImage'] = new_a
|
||||||
|
if stream_type in ('OLD_IMAGE', 'NEW_AND_OLD_IMAGES'):
|
||||||
|
self.record['dynamodb']['OldImage'] = old_a
|
||||||
|
|
||||||
|
# This is a substantial overestimate but it's the easiest to do now
|
||||||
|
self.record['dynamodb']['SizeBytes'] = len(
|
||||||
|
json.dumps(self.record['dynamodb']))
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return self.record
|
||||||
|
|
||||||
|
|
||||||
class StreamShard(BaseModel):
|
class StreamShard(BaseModel):
|
||||||
def __init__(self, table):
|
def __init__(self, table):
|
||||||
self.table = table
|
self.table = table
|
||||||
@ -310,15 +349,22 @@ class StreamShard(BaseModel):
|
|||||||
|
|
||||||
def add(self, old, new):
|
def add(self, old, new):
|
||||||
t = self.table.stream_specification['StreamViewType']
|
t = self.table.stream_specification['StreamViewType']
|
||||||
if t == 'KEYS_ONLY':
|
if old is None:
|
||||||
self.items.append(new.key)
|
event_name = 'INSERT'
|
||||||
elif t == 'NEW_IMAGE':
|
elif new is None:
|
||||||
self.items.append(new)
|
event_name = 'DELETE'
|
||||||
elif t == 'OLD_IMAGE':
|
else:
|
||||||
self.items.append(old)
|
event_name = 'MODIFY'
|
||||||
elif t == 'NEW_AND_OLD_IMAGES':
|
seq = len(self.items) + self.starting_sequence_number
|
||||||
self.items.append((old, new))
|
self.items.append(
|
||||||
|
StreamRecord(self.table, t, event_name, old, new, seq))
|
||||||
|
|
||||||
|
def get(self, start, quantity):
|
||||||
|
start -= self.starting_sequence_number
|
||||||
|
assert start >= 0
|
||||||
|
end = start + quantity
|
||||||
|
return [i.to_json() for i in self.items[start:end]]
|
||||||
|
|
||||||
|
|
||||||
class Table(BaseModel):
|
class Table(BaseModel):
|
||||||
|
|
||||||
@ -428,22 +474,22 @@ class Table(BaseModel):
|
|||||||
else:
|
else:
|
||||||
range_value = None
|
range_value = None
|
||||||
|
|
||||||
|
if expected is None:
|
||||||
|
expected = {}
|
||||||
|
lookup_range_value = range_value
|
||||||
|
else:
|
||||||
|
expected_range_value = expected.get(
|
||||||
|
self.range_key_attr, {}).get("Value")
|
||||||
|
if(expected_range_value is None):
|
||||||
|
lookup_range_value = range_value
|
||||||
|
else:
|
||||||
|
lookup_range_value = DynamoType(expected_range_value)
|
||||||
current = self.get_item(hash_value, lookup_range_value)
|
current = self.get_item(hash_value, lookup_range_value)
|
||||||
|
|
||||||
item = Item(hash_value, self.hash_key_type, range_value,
|
item = Item(hash_value, self.hash_key_type, range_value,
|
||||||
self.range_key_type, item_attrs)
|
self.range_key_type, item_attrs)
|
||||||
|
|
||||||
if not overwrite:
|
if not overwrite:
|
||||||
if expected is None:
|
|
||||||
expected = {}
|
|
||||||
lookup_range_value = range_value
|
|
||||||
else:
|
|
||||||
expected_range_value = expected.get(
|
|
||||||
self.range_key_attr, {}).get("Value")
|
|
||||||
if(expected_range_value is None):
|
|
||||||
lookup_range_value = range_value
|
|
||||||
else:
|
|
||||||
lookup_range_value = DynamoType(expected_range_value)
|
|
||||||
|
|
||||||
if current is None:
|
if current is None:
|
||||||
current_attr = {}
|
current_attr = {}
|
||||||
elif hasattr(current, 'attrs'):
|
elif hasattr(current, 'attrs'):
|
||||||
@ -508,9 +554,14 @@ class Table(BaseModel):
|
|||||||
def delete_item(self, hash_key, range_key):
|
def delete_item(self, hash_key, range_key):
|
||||||
try:
|
try:
|
||||||
if range_key:
|
if range_key:
|
||||||
return self.items[hash_key].pop(range_key)
|
item = self.items[hash_key].pop(range_key)
|
||||||
else:
|
else:
|
||||||
return self.items.pop(hash_key)
|
item = self.items.pop(hash_key)
|
||||||
|
|
||||||
|
if self.stream_shard is not None:
|
||||||
|
self.stream_shard.add(item, None)
|
||||||
|
|
||||||
|
return item
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -11,8 +11,9 @@ from moto.dynamodb2.models import dynamodb_backends
|
|||||||
|
|
||||||
|
|
||||||
class ShardIterator(BaseModel):
|
class ShardIterator(BaseModel):
|
||||||
def __init__(self, stream_shard, shard_iterator_type, sequence_number=None):
|
def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None):
|
||||||
self.id = base64.b64encode(os.urandom(472)).decode('utf-8')
|
self.id = base64.b64encode(os.urandom(472)).decode('utf-8')
|
||||||
|
self.streams_backend = streams_backend
|
||||||
self.stream_shard = stream_shard
|
self.stream_shard = stream_shard
|
||||||
self.shard_iterator_type = shard_iterator_type
|
self.shard_iterator_type = shard_iterator_type
|
||||||
if shard_iterator_type == 'TRIM_HORIZON':
|
if shard_iterator_type == 'TRIM_HORIZON':
|
||||||
@ -24,18 +25,43 @@ class ShardIterator(BaseModel):
|
|||||||
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER':
|
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER':
|
||||||
self.sequence_number = sequence_number + 1
|
self.sequence_number = sequence_number + 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arn(self):
|
||||||
|
return '{}/stream/{}|1|{}'.format(
|
||||||
|
self.stream_shard.table.table_arn,
|
||||||
|
self.stream_shard.table.latest_stream_label,
|
||||||
|
self.id)
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return {
|
return {
|
||||||
'ShardIterator': '{}/stream/{}|1|{}'.format(
|
'ShardIterator': self.arn
|
||||||
self.stream_shard.table.table_arn,
|
}
|
||||||
self.stream_shard.table.latest_stream_label,
|
|
||||||
self.id)
|
def get(self, limit=1000):
|
||||||
|
items = self.stream_shard.get(self.sequence_number, limit)
|
||||||
|
try:
|
||||||
|
last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items)
|
||||||
|
new_shard_iterator = ShardIterator(self.streams_backend,
|
||||||
|
self.stream_shard,
|
||||||
|
'AFTER_SEQUENCE_NUMBER',
|
||||||
|
last_sequence_number)
|
||||||
|
except ValueError:
|
||||||
|
new_shard_iterator = ShardIterator(self.streams_backend,
|
||||||
|
self.stream_shard,
|
||||||
|
'AT_SEQUENCE_NUMBER',
|
||||||
|
self.sequence_number)
|
||||||
|
|
||||||
|
self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator
|
||||||
|
return {
|
||||||
|
'NextShardIterator': new_shard_iterator.arn,
|
||||||
|
'Records': items
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DynamoDBStreamsBackend(BaseBackend):
|
class DynamoDBStreamsBackend(BaseBackend):
|
||||||
def __init__(self, region):
|
def __init__(self, region):
|
||||||
self.region = region
|
self.region = region
|
||||||
|
self.shard_iterators = {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
region = self.region
|
region = self.region
|
||||||
@ -86,11 +112,17 @@ class DynamoDBStreamsBackend(BaseBackend):
|
|||||||
table = self._get_table_from_arn(arn)
|
table = self._get_table_from_arn(arn)
|
||||||
assert table.stream_shard.id == shard_id
|
assert table.stream_shard.id == shard_id
|
||||||
|
|
||||||
shard_iterator = ShardIterator(table.stream_shard, shard_iterator_type,
|
shard_iterator = ShardIterator(self, table.stream_shard,
|
||||||
|
shard_iterator_type,
|
||||||
sequence_number)
|
sequence_number)
|
||||||
|
self.shard_iterators[shard_iterator.arn] = shard_iterator
|
||||||
|
|
||||||
return json.dumps(shard_iterator.to_json())
|
return json.dumps(shard_iterator.to_json())
|
||||||
|
|
||||||
|
def get_records(self, iterator_arn, limit):
|
||||||
|
shard_iterator = self.shard_iterators[iterator_arn]
|
||||||
|
return json.dumps(shard_iterator.get(limit))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
available_regions = boto3.session.Session().get_available_regions(
|
available_regions = boto3.session.Session().get_available_regions(
|
||||||
|
@ -27,3 +27,10 @@ class DynamoDBStreamsHandler(BaseResponse):
|
|||||||
shard_iterator_type = self._get_param('ShardIteratorType')
|
shard_iterator_type = self._get_param('ShardIteratorType')
|
||||||
return self.backend.get_shard_iterator(arn, shard_id,
|
return self.backend.get_shard_iterator(arn, shard_id,
|
||||||
shard_iterator_type)
|
shard_iterator_type)
|
||||||
|
|
||||||
|
def get_records(self):
|
||||||
|
arn = self._get_param('ShardIterator')
|
||||||
|
limit = self._get_param('Limit')
|
||||||
|
if limit is None:
|
||||||
|
limit = 1000
|
||||||
|
return self.backend.get_records(arn, limit)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
[nosetests]
|
[nosetests]
|
||||||
verbosity=1
|
verbosity=1
|
||||||
detailed-errors=1
|
detailed-errors=1
|
||||||
#with-coverage=1
|
with-coverage=1
|
||||||
cover-package=moto
|
cover-package=moto
|
||||||
|
|
||||||
[bdist_wheel]
|
[bdist_wheel]
|
||||||
|
@ -75,7 +75,7 @@ class TestClass():
|
|||||||
)
|
)
|
||||||
assert 'ShardIterator' in resp
|
assert 'ShardIterator' in resp
|
||||||
|
|
||||||
def test_get_records(self):
|
def test_get_records_empty(self):
|
||||||
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
|
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
|
||||||
|
|
||||||
resp = conn.describe_stream(StreamArn=self.stream_arn)
|
resp = conn.describe_stream(StreamArn=self.stream_arn)
|
||||||
@ -90,7 +90,50 @@ class TestClass():
|
|||||||
|
|
||||||
resp = conn.get_records(ShardIterator=iterator_id)
|
resp = conn.get_records(ShardIterator=iterator_id)
|
||||||
assert 'Records' in resp
|
assert 'Records' in resp
|
||||||
|
assert len(resp['Records']) == 0
|
||||||
|
|
||||||
# TODO: Add tests for inserting records into the stream, and
|
def test_get_records_seq(self):
|
||||||
# the various stream types
|
conn = boto3.client('dynamodb', region_name='us-east-1')
|
||||||
|
|
||||||
|
conn.put_item(
|
||||||
|
TableName='test-streams',
|
||||||
|
Item={
|
||||||
|
'id': {'S': 'entry1'},
|
||||||
|
'first_col': {'S': 'foo'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
conn.put_item(
|
||||||
|
TableName='test-streams',
|
||||||
|
Item={
|
||||||
|
'id': {'S': 'entry1'},
|
||||||
|
'first_col': {'S': 'bar'},
|
||||||
|
'second_col': {'S': 'baz'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
conn.delete_item(
|
||||||
|
TableName='test-streams',
|
||||||
|
Key={'id': {'S': 'entry1'}}
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
|
||||||
|
|
||||||
|
resp = conn.describe_stream(StreamArn=self.stream_arn)
|
||||||
|
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
|
||||||
|
|
||||||
|
resp = conn.get_shard_iterator(
|
||||||
|
StreamArn=self.stream_arn,
|
||||||
|
ShardId=shard_id,
|
||||||
|
ShardIteratorType='TRIM_HORIZON'
|
||||||
|
)
|
||||||
|
iterator_id = resp['ShardIterator']
|
||||||
|
|
||||||
|
resp = conn.get_records(ShardIterator=iterator_id)
|
||||||
|
assert len(resp['Records']) == 3
|
||||||
|
assert resp['Records'][0]['eventName'] == 'INSERT'
|
||||||
|
assert resp['Records'][1]['eventName'] == 'MODIFY'
|
||||||
|
assert resp['Records'][2]['eventName'] == 'DELETE'
|
||||||
|
|
||||||
|
# now try fetching from the next shard iterator, it should be
|
||||||
|
# empty
|
||||||
|
resp = conn.get_records(ShardIterator=resp['NextShardIterator'])
|
||||||
|
assert len(resp['Records']) == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user