diff --git a/moto/s3/models.py b/moto/s3/models.py index d80eec417..2462d59be 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -1,5 +1,7 @@ # from boto.s3.bucket import Bucket # from boto.s3.key import Key +import os +import base64 import md5 from moto.core import BaseBackend @@ -21,10 +23,40 @@ class FakeKey(object): return len(self.value) +class FakeMultipart(object): + def __init__(self, key_name): + self.key_name = key_name + self.parts = {} + self.id = base64.b64encode(os.urandom(43)).replace('=', '') + + def complete(self): + total = bytearray() + + for part_id, index in enumerate(sorted(self.parts.keys()), start=1): + # Make sure part ids are continuous + if part_id != index: + return + + total.extend(self.parts[part_id]) + + if len(total) < 5242880: + return + + return total + + def set_part(self, part_id, value): + if part_id < 1: + return False + + self.parts[part_id] = value + return True + + class FakeBucket(object): def __init__(self, name): self.name = name self.keys = {} + self.multiparts = {} class S3Backend(BaseBackend): @@ -65,6 +97,27 @@ class S3Backend(BaseBackend): if bucket: return bucket.keys.get(key_name) + def initiate_multipart(self, bucket_name, key_name): + bucket = self.buckets[bucket_name] + new_multipart = FakeMultipart(key_name) + bucket.multiparts[new_multipart.id] = new_multipart + + return new_multipart + + def complete_multipart(self, bucket_name, multipart_id): + bucket = self.buckets[bucket_name] + multipart = bucket.multiparts[multipart_id] + value = multipart.complete() + if value is None: + return False + + self.set_key(bucket_name, multipart.key_name, value) + + def set_part(self, bucket_name, multipart_id, part_id, value): + bucket = self.buckets[bucket_name] + multipart = bucket.multiparts[multipart_id] + return multipart.set_part(part_id, value) + def prefix_query(self, bucket, prefix): key_results = set() folder_results = set() diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 80a0a9421..370c7cf5b 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -106,6 +106,20 @@ def key_response(uri_info, method, body, headers): removed_key = s3_backend.delete_key(bucket_name, key_name) template = Template(S3_DELETE_OBJECT_SUCCESS) return template.render(bucket=removed_key), dict(status=204) + elif method == 'POST': + if body == '' and uri_info.query == 'uploads': + multipart = s3_backend.initiate_multipart(bucket_name, key_name) + template = Template(S3_MULTIPART_RESPONSE) + response = template.render( + bucket_name=bucket_name, + key_name=key_name, + multipart_id=multipart.id, + ) + print response + return response, dict() + else: + import pdb; pdb.set_trace() + raise NotImplementedError("POST is only allowed for multipart uploads") else: raise NotImplementedError("Method {} has not been impelemented in the S3 backend yet".format(method)) @@ -202,3 +216,16 @@ S3_OBJECT_COPY_RESPONSE = """ + + {{ bucket_name }} + {{ key_name }} + {{ upload_id }} +""" + +S3_MULTIPART_COMPLETE_RESPONSE = """ +""" + +S3_MULTIPART_ERROR_RESPONSE = """ +""" diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 31e011bfc..1f713fabb 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1,4 +1,5 @@ import urllib2 +from io import BytesIO import boto from boto.exception import S3ResponseError @@ -36,6 +37,26 @@ def test_my_model_save(): conn.get_bucket('mybucket').get_key('steve').get_contents_as_string().should.equal('is awesome') +@mock_s3 +def test_multipart_upload(): + conn = boto.connect_s3('the_key', 'the_secret') + bucket = conn.create_bucket("foobar") + + multipart = bucket.initiate_multipart_upload("the-key") + multipart.upload_part_from_file(BytesIO('hello'), 1) + multipart.upload_part_from_file(BytesIO('world'), 1) + # Multipart with total size under 5MB is refused + multipart.complete_upload().should.throw(S3ResponseError) + + multipart = bucket.initiate_multipart_upload("the-key") + part1 = '0' * 5242880 + multipart.upload_part_from_file(BytesIO('0' * 5242880), 1) + part2 = '1' + multipart.upload_part_from_file(BytesIO('1'), 1) + multipart.complete_upload() + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) + + @mock_s3 def test_missing_key(): conn = boto.connect_s3('the_key', 'the_secret')