diff --git a/moto/server.py b/moto/server.py index bf76095a6..a10dc4e3e 100644 --- a/moto/server.py +++ b/moto/server.py @@ -9,6 +9,7 @@ from threading import Lock import six from flask import Flask +from flask_cors import CORS from flask.testing import FlaskClient from six.moves.urllib.parse import urlencode @@ -205,6 +206,7 @@ def create_backend_app(service): backend_app = Flask(__name__) backend_app.debug = True backend_app.service = service + CORS(backend_app) # Reset view functions to reset the app backend_app.view_functions = {} diff --git a/requirements-dev.txt b/requirements-dev.txt index 8a91eb14f..ad1e30508 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,6 +7,7 @@ coverage==4.5.4 flake8==3.7.8 freezegun flask +flask-cors boto>=2.45.0 boto3>=1.4.4 botocore>=1.15.13 diff --git a/setup.py b/setup.py index 5f6840251..a6adbea6d 100755 --- a/setup.py +++ b/setup.py @@ -99,7 +99,7 @@ all_extra_deps = [ _dep_sshpubkeys_py2, _dep_sshpubkeys_py3, ] -all_server_deps = all_extra_deps + ['flask'] +all_server_deps = all_extra_deps + ['flask', 'flask-cors'] # TODO: do we want to add ALL services here? # i.e. even those without extra dependencies. diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index 56d46de09..9ef1acb11 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -108,3 +108,31 @@ def test_s3_server_post_unicode_bucket_key(): } ) assert backend_app + + +def test_s3_server_post_cors(): + test_client = authenticated_client() + + preflight_headers = { + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "origin, x-requested-with", + "Origin": "https://localhost:9000", + } + + res = test_client.options( + "/", "http://tester.localhost:5000/", headers=preflight_headers + ) + assert res.status_code in [200, 204] + + expected_methods = set(["DELETE", "PATCH", "PUT", "GET", "HEAD", "POST", "OPTIONS"]) + assert set(res.headers["Allow"].split(", ")) == expected_methods + assert ( + set(res.headers["Access-Control-Allow-Methods"].split(", ")) == expected_methods + ) + + res.headers.should.have.key("Access-Control-Allow-Origin").which.should.equal( + "https://localhost:9000" + ) + res.headers.should.have.key("Access-Control-Allow-Headers").which.should.equal( + "origin, x-requested-with" + )