Make keys pickleable

This commit is contained in:
Diego Argueta 2018-12-20 11:15:15 -08:00
parent f15f006f78
commit 191ad6d778
2 changed files with 45 additions and 14 deletions

View File

@ -23,7 +23,7 @@ from .utils import clean_key_name, _VersionedKeyStore
UPLOAD_ID_BYTES = 43 UPLOAD_ID_BYTES = 43
UPLOAD_PART_MIN_SIZE = 5242880 UPLOAD_PART_MIN_SIZE = 5242880
STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA"] STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA"]
DEFAULT_KEY_BUFFER_SIZE = 2 ** 24 DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024
DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
@ -60,7 +60,8 @@ class FakeKey(BaseModel):
self._is_versioned = is_versioned self._is_versioned = is_versioned
self._tagging = FakeTagging() self._tagging = FakeTagging()
self.value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size) self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size
self.value = value self.value = value
@property @property
@ -69,19 +70,19 @@ class FakeKey(BaseModel):
@property @property
def value(self): def value(self):
self.value_buffer.seek(0) self._value_buffer.seek(0)
return self.value_buffer.read() return self._value_buffer.read()
@value.setter @value.setter
def value(self, new_value): def value(self, new_value):
self.value_buffer.seek(0) self._value_buffer.seek(0)
self.value_buffer.truncate() self._value_buffer.truncate()
# Hack for working around moto's own unit tests; this probably won't # Hack for working around moto's own unit tests; this probably won't
# actually get hit in normal use. # actually get hit in normal use.
if isinstance(new_value, six.text_type): if isinstance(new_value, six.text_type):
new_value = new_value.encode(DEFAULT_TEXT_ENCODING) new_value = new_value.encode(DEFAULT_TEXT_ENCODING)
self.value_buffer.write(new_value) self._value_buffer.write(new_value)
def copy(self, new_name=None): def copy(self, new_name=None):
r = copy.deepcopy(self) r = copy.deepcopy(self)
@ -106,8 +107,8 @@ class FakeKey(BaseModel):
self.acl = acl self.acl = acl
def append_to_value(self, value): def append_to_value(self, value):
self.value_buffer.seek(0, os.SEEK_END) self._value_buffer.seek(0, os.SEEK_END)
self.value_buffer.write(value) self._value_buffer.write(value)
self.last_modified = datetime.datetime.utcnow() self.last_modified = datetime.datetime.utcnow()
self._etag = None # must recalculate etag self._etag = None # must recalculate etag
@ -126,10 +127,9 @@ class FakeKey(BaseModel):
def etag(self): def etag(self):
if self._etag is None: if self._etag is None:
value_md5 = hashlib.md5() value_md5 = hashlib.md5()
self._value_buffer.seek(0)
self.value_buffer.seek(0)
while True: while True:
block = self.value_buffer.read(DEFAULT_KEY_BUFFER_SIZE) block = self._value_buffer.read(DEFAULT_KEY_BUFFER_SIZE)
if not block: if not block:
break break
value_md5.update(block) value_md5.update(block)
@ -178,8 +178,8 @@ class FakeKey(BaseModel):
@property @property
def size(self): def size(self):
self.value_buffer.seek(0, os.SEEK_END) self._value_buffer.seek(0, os.SEEK_END)
return self.value_buffer.tell() return self._value_buffer.tell()
@property @property
def storage_class(self): def storage_class(self):
@ -190,6 +190,26 @@ class FakeKey(BaseModel):
if self._expiry is not None: if self._expiry is not None:
return self._expiry.strftime("%a, %d %b %Y %H:%M:%S GMT") return self._expiry.strftime("%a, %d %b %Y %H:%M:%S GMT")
# Keys need to be pickleable due to some implementation details of boto3.
# Since file objects aren't pickleable, we need to override the default
# behavior. The following is adapted from the Python docs:
# https://docs.python.org/3/library/pickle.html#handling-stateful-objects
def __getstate__(self):
state = self.__dict__.copy()
state['value'] = self.value
del state['_value_buffer']
return state
def __setstate__(self, state):
self.__dict__.update({
k: v for k, v in six.iteritems(state)
if k != 'value'
})
self._value_buffer = \
tempfile.SpooledTemporaryFile(max_size=self._max_buffer_size)
self.value = state['value']
class FakeMultipart(BaseModel): class FakeMultipart(BaseModel):

View File

@ -8,6 +8,7 @@ from functools import wraps
from gzip import GzipFile from gzip import GzipFile
from io import BytesIO from io import BytesIO
import zlib import zlib
import pickle
import json import json
import boto import boto
@ -65,6 +66,16 @@ class MyModel(object):
s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value) s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)
@mock_s3
def test_keys_are_pickleable():
"""Keys must be pickleable due to boto3 implementation details."""
key = s3model.FakeKey('name', b'data!')
pickled = pickle.dumps(key)
loaded = pickle.loads(pickled)
assert loaded.value == key.value
@mock_s3 @mock_s3
def test_my_model_save(): def test_my_model_save():
# Create Bucket so that test can run # Create Bucket so that test can run