S3 - Add test case to showcase bug when downloading large files
This commit is contained in:
parent
1aa99bb405
commit
3802767817
@ -12,6 +12,7 @@ import codecs
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -110,6 +111,7 @@ class FakeKey(BaseModel):
|
|||||||
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
|
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
|
||||||
self._max_buffer_size = max_buffer_size
|
self._max_buffer_size = max_buffer_size
|
||||||
self.value = value
|
self.value = value
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version_id(self):
|
def version_id(self):
|
||||||
@ -117,8 +119,14 @@ class FakeKey(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
|
self.lock.acquire()
|
||||||
|
print("===>value")
|
||||||
self._value_buffer.seek(0)
|
self._value_buffer.seek(0)
|
||||||
return self._value_buffer.read()
|
print("===>seek")
|
||||||
|
r = self._value_buffer.read()
|
||||||
|
print("===>read")
|
||||||
|
self.lock.release()
|
||||||
|
return r
|
||||||
|
|
||||||
@value.setter
|
@value.setter
|
||||||
def value(self, new_value):
|
def value(self, new_value):
|
||||||
@ -1319,6 +1327,7 @@ class S3Backend(BaseBackend):
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
|
def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
|
||||||
|
print("get_key("+str(bucket_name)+","+str(key_name)+","+str(version_id)+","+str(part_number)+")")
|
||||||
key_name = clean_key_name(key_name)
|
key_name = clean_key_name(key_name)
|
||||||
bucket = self.get_bucket(bucket_name)
|
bucket = self.get_bucket(bucket_name)
|
||||||
key = None
|
key = None
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import unicode_literals
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from botocore.awsrequest import AWSPreparedRequest
|
from botocore.awsrequest import AWSPreparedRequest
|
||||||
@ -150,6 +151,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
self.path = ""
|
self.path = ""
|
||||||
self.data = {}
|
self.data = {}
|
||||||
self.headers = {}
|
self.headers = {}
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def should_autoescape(self):
|
def should_autoescape(self):
|
||||||
@ -857,6 +859,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
def _handle_range_header(self, request, headers, response_content):
|
def _handle_range_header(self, request, headers, response_content):
|
||||||
response_headers = {}
|
response_headers = {}
|
||||||
length = len(response_content)
|
length = len(response_content)
|
||||||
|
print("Length: " + str(length) + " Range: " + str(request.headers.get("range")))
|
||||||
last = length - 1
|
last = length - 1
|
||||||
_, rspec = request.headers.get("range").split("=")
|
_, rspec = request.headers.get("range").split("=")
|
||||||
if "," in rspec:
|
if "," in rspec:
|
||||||
@ -874,6 +877,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
else:
|
else:
|
||||||
return 400, response_headers, ""
|
return 400, response_headers, ""
|
||||||
if begin < 0 or end > last or begin > min(end, last):
|
if begin < 0 or end > last or begin > min(end, last):
|
||||||
|
print(str(begin)+ " < 0 or " + str(end) + " > " + str(last) + " or " + str(begin) + " > min("+str(end)+","+str(last)+")")
|
||||||
return 416, response_headers, ""
|
return 416, response_headers, ""
|
||||||
response_headers["content-range"] = "bytes {0}-{1}/{2}".format(
|
response_headers["content-range"] = "bytes {0}-{1}/{2}".format(
|
||||||
begin, end, length
|
begin, end, length
|
||||||
@ -903,14 +907,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
response_content = response
|
response_content = response
|
||||||
else:
|
else:
|
||||||
status_code, response_headers, response_content = response
|
status_code, response_headers, response_content = response
|
||||||
|
print("response received: " + str(len(response_content)))
|
||||||
|
print(request.headers)
|
||||||
|
|
||||||
if status_code == 200 and "range" in request.headers:
|
if status_code == 200 and "range" in request.headers:
|
||||||
return self._handle_range_header(
|
self.lock.acquire()
|
||||||
|
r = self._handle_range_header(
|
||||||
request, response_headers, response_content
|
request, response_headers, response_content
|
||||||
)
|
)
|
||||||
|
self.lock.release()
|
||||||
|
return r
|
||||||
return status_code, response_headers, response_content
|
return status_code, response_headers, response_content
|
||||||
|
|
||||||
def _control_response(self, request, full_url, headers):
|
def _control_response(self, request, full_url, headers):
|
||||||
|
print("_control_response")
|
||||||
parsed_url = urlparse(full_url)
|
parsed_url = urlparse(full_url)
|
||||||
query = parse_qs(parsed_url.query, keep_blank_values=True)
|
query = parse_qs(parsed_url.query, keep_blank_values=True)
|
||||||
method = request.method
|
method = request.method
|
||||||
@ -1058,12 +1068,14 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _key_response_get(self, bucket_name, query, key_name, headers):
|
def _key_response_get(self, bucket_name, query, key_name, headers):
|
||||||
|
print("_key_response_get("+str(key_name)+","+str(headers)+")")
|
||||||
self._set_action("KEY", "GET", query)
|
self._set_action("KEY", "GET", query)
|
||||||
self._authenticate_and_authorize_s3_action()
|
self._authenticate_and_authorize_s3_action()
|
||||||
|
|
||||||
response_headers = {}
|
response_headers = {}
|
||||||
if query.get("uploadId"):
|
if query.get("uploadId"):
|
||||||
upload_id = query["uploadId"][0]
|
upload_id = query["uploadId"][0]
|
||||||
|
print("UploadID: " + str(upload_id))
|
||||||
parts = self.backend.list_multipart(bucket_name, upload_id)
|
parts = self.backend.list_multipart(bucket_name, upload_id)
|
||||||
template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
|
template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
|
||||||
return (
|
return (
|
||||||
@ -1095,6 +1107,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||||||
|
|
||||||
response_headers.update(key.metadata)
|
response_headers.update(key.metadata)
|
||||||
response_headers.update(key.response_dict)
|
response_headers.update(key.response_dict)
|
||||||
|
print("returning 200, " + str(headers) + ", " + str(len(key.value)) + " ( " + str(key_name) + ")")
|
||||||
return 200, response_headers, key.value
|
return 200, response_headers, key.value
|
||||||
|
|
||||||
def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
|
def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
|
||||||
|
@ -104,7 +104,9 @@ class _VersionedKeyStore(dict):
|
|||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError) as e:
|
||||||
|
print("Error retrieving " + str(key))
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
@ -4393,3 +4393,87 @@ def test_s3_config_dict():
|
|||||||
assert not logging_bucket["supplementaryConfiguration"].get(
|
assert not logging_bucket["supplementaryConfiguration"].get(
|
||||||
"BucketTaggingConfiguration"
|
"BucketTaggingConfiguration"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_s3
|
||||||
|
def test_delete_downloaded_file():
|
||||||
|
# SET UP
|
||||||
|
filename = '...'
|
||||||
|
file = open(filename, 'rb')
|
||||||
|
uploader = PdfFileUploader(file)
|
||||||
|
boto3.client('s3').create_bucket(Bucket=uploader.bucket_name())
|
||||||
|
uploader.upload()
|
||||||
|
print("================\nUPLOADED\n=================")
|
||||||
|
# DOWNLOAD
|
||||||
|
# the following two lines are basically
|
||||||
|
# boto3.client('s3').download_file(bucket_name, file_name, local_path)
|
||||||
|
# where bucket_name, file_name and local_path are retrieved from PdfFileUploader
|
||||||
|
# e.g. boto3.client('s3').download_file("bucket_name", "asdf.pdf", "/tmp/asdf.pdf")
|
||||||
|
downloader = PdfFileDownloader(uploader.full_bucket_file_name())
|
||||||
|
downloader.download()
|
||||||
|
|
||||||
|
downloader.delete_downloaded_file()
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
class PdfFileDownloader:
|
||||||
|
def __init__(self, full_bucket_file_name):
|
||||||
|
self.bucket_name, self.file_name = self.extract(full_bucket_file_name)
|
||||||
|
self.s3 = boto3.client('s3')
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
try:
|
||||||
|
self.s3.download_file(self.bucket_name, self.file_name, self.local_path())
|
||||||
|
|
||||||
|
return self.local_path()
|
||||||
|
except ClientError as exc:
|
||||||
|
print("=======")
|
||||||
|
print(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
def local_path(self):
|
||||||
|
return '/tmp/' + self.file_name.replace('/', '')
|
||||||
|
|
||||||
|
def delete_downloaded_file(self):
|
||||||
|
if Path(self.local_path()).is_file():
|
||||||
|
print("Removing " + str(self.local_path()))
|
||||||
|
os.remove(self.local_path())
|
||||||
|
|
||||||
|
def file(self):
|
||||||
|
return open(self.local_path(), 'rb')
|
||||||
|
|
||||||
|
def extract(self, full_bucket_file_name):
|
||||||
|
match = re.search(r'([\.a-zA-Z_-]+)\/(.*)', full_bucket_file_name)
|
||||||
|
|
||||||
|
if match and len(match.groups()) == 2:
|
||||||
|
return (match.groups()[0], match.groups()[1])
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Cannot determine bucket and file name for {full_bucket_file_name}")
|
||||||
|
|
||||||
|
|
||||||
|
import binascii
|
||||||
|
class PdfFileUploader:
|
||||||
|
def __init__(self, file):
|
||||||
|
self.file = file
|
||||||
|
date = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||||
|
random_hex = binascii.b2a_hex(os.urandom(16)).decode('ascii')
|
||||||
|
self.bucket_file_name = f"{date}_{random_hex}.pdf"
|
||||||
|
|
||||||
|
def upload(self):
|
||||||
|
self.file.seek(0)
|
||||||
|
boto3.client('s3').upload_fileobj(self.file, self.bucket_name(), self.bucket_file_name)
|
||||||
|
|
||||||
|
return (self.original_file_name(), self.full_bucket_file_name())
|
||||||
|
|
||||||
|
def original_file_name(self):
|
||||||
|
return os.path.basename(self.file.name)
|
||||||
|
|
||||||
|
def bucket_name(self):
|
||||||
|
return 'test_bucket' #os.environ['AWS_BUCKET_NAME']
|
||||||
|
|
||||||
|
def full_bucket_file_name(self):
|
||||||
|
return f"{self.bucket_name()}/{self.bucket_file_name}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user