diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 1a43884e8..a5fa6b849 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os import re import sys @@ -11,7 +12,14 @@ from moto.core.utils import ( py2_strip_unicode_keys, unix_time_millis, ) -from six.moves.urllib.parse import parse_qs, urlparse, unquote, parse_qsl +from six.moves.urllib.parse import ( + parse_qs, + parse_qsl, + urlparse, + unquote, + urlencode, + urlunparse, +) import xmltodict @@ -859,10 +867,28 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if "file" in form: f = form["file"] else: - f = request.files["file"].stream.read() + fobj = request.files["file"] + f = fobj.stream.read() + key = key.replace("${filename}", os.path.basename(fobj.filename)) if "success_action_redirect" in form: - response_headers["Location"] = form["success_action_redirect"] + redirect = form["success_action_redirect"] + parts = urlparse(redirect) + queryargs = parse_qs(parts.query) + queryargs["key"] = key + queryargs["bucket"] = bucket_name + redirect_queryargs = urlencode(queryargs, doseq=True) + newparts = ( + parts.scheme, + parts.netloc, + parts.path, + parts.params, + redirect_queryargs, + parts.fragment, + ) + fixed_redirect = urlunparse(newparts) + + response_headers["Location"] = fixed_redirect if "success_action_status" in form: status_code = form["success_action_status"] diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 630baf1c9..b83091315 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -7,6 +7,7 @@ import os from boto3 import Session from six.moves.urllib.request import urlopen from six.moves.urllib.error import HTTPError +from six.moves.urllib.parse import urlparse, parse_qs from functools import wraps from gzip import GzipFile from io import BytesIO @@ -4814,9 +4815,11 @@ def test_creating_presigned_post(): {"success_action_redirect": success_url}, ] conditions.append(["content-length-range", 1, 30]) + + real_key = "{file_uid}.txt".format(file_uid=file_uid) data = s3.generate_presigned_post( Bucket=bucket, - Key="{file_uid}.txt".format(file_uid=file_uid), + Key=real_key, Fields={ "content-type": "text/plain", "success_action_redirect": success_url, @@ -4828,14 +4831,15 @@ def test_creating_presigned_post(): resp = requests.post( data["url"], data=data["fields"], files={"file": fdata}, allow_redirects=False ) - assert resp.headers["Location"] == success_url assert resp.status_code == 303 - assert ( - s3.get_object(Bucket=bucket, Key="{file_uid}.txt".format(file_uid=file_uid))[ - "Body" - ].read() - == fdata - ) + redirect = resp.headers["Location"] + assert redirect.startswith(success_url) + parts = urlparse(redirect) + args = parse_qs(parts.query) + assert args["key"][0] == real_key + assert args["bucket"][0] == bucket + + assert s3.get_object(Bucket=bucket, Key=real_key)["Body"].read() == fdata @mock_s3 diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index 9ef1acb11..a81600db5 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -1,6 +1,8 @@ # coding=utf-8 from __future__ import unicode_literals +import io +from six.moves.urllib.parse import urlparse, parse_qs import sure # noqa from flask.testing import FlaskClient @@ -80,6 +82,39 @@ def test_s3_server_post_to_bucket(): res.data.should.equal(b"nothing") +def test_s3_server_post_to_bucket_redirect(): + test_client = authenticated_client() + + res = test_client.put("/", "http://tester.localhost:5000/") + res.status_code.should.equal(200) + + redirect_base = "https://redirect.com/success/" + filecontent = "nothing" + filename = "test_filename.txt" + res = test_client.post( + "/", + "https://tester.localhost:5000/", + data={ + "key": "asdf/the-key/${filename}", + "file": (io.BytesIO(filecontent.encode("utf8")), filename), + "success_action_redirect": redirect_base, + }, + ) + real_key = "asdf/the-key/{}".format(filename) + res.status_code.should.equal(303) + redirect = res.headers["location"] + assert redirect.startswith(redirect_base) + + parts = urlparse(redirect) + args = parse_qs(parts.query) + assert args["key"][0] == real_key + assert args["bucket"][0] == "tester" + + res = test_client.get("/{}".format(real_key), "http://tester.localhost:5000/") + res.status_code.should.equal(200) + res.data.should.equal(filecontent.encode("utf8")) + + def test_s3_server_post_without_content_length(): test_client = authenticated_client()