S3: Return CORS headers for GET/PUT requests (#6376)
This commit is contained in:
parent
d20da03225
commit
7c702ee33f
@ -353,7 +353,9 @@ class S3Response(BaseResponse):
|
||||
return 404, {}, ""
|
||||
return 200, {"x-amz-bucket-region": bucket.region_name}, ""
|
||||
|
||||
def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None:
|
||||
def _set_cors_headers_options(
|
||||
self, headers: Dict[str, str], bucket: FakeBucket
|
||||
) -> None:
|
||||
"""
|
||||
TODO: smarter way of matching the right CORS rule:
|
||||
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
|
||||
@ -408,10 +410,57 @@ class S3Response(BaseResponse):
|
||||
# AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
|
||||
return 403, {}, ""
|
||||
|
||||
self._set_cors_headers(headers, bucket)
|
||||
self._set_cors_headers_options(headers, bucket)
|
||||
|
||||
return 200, self.response_headers, ""
|
||||
|
||||
def _get_cors_headers_other(
|
||||
self, headers: Dict[str, str], bucket_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary with the appropriate CORS headers
|
||||
Should be used for non-OPTIONS requests only
|
||||
Applicable if the 'Origin' header matches one of a CORS-rules - returns an empty dictionary otherwise
|
||||
"""
|
||||
response_headers: Dict[str, Any] = dict()
|
||||
try:
|
||||
origin = headers.get("Origin")
|
||||
if not origin:
|
||||
return response_headers
|
||||
bucket = self.backend.get_bucket(bucket_name)
|
||||
|
||||
def _to_string(header: Union[List[str], str]) -> str:
|
||||
# We allow list and strs in header values. Transform lists in comma-separated strings
|
||||
if isinstance(header, list):
|
||||
return ", ".join(header)
|
||||
return header
|
||||
|
||||
for cors_rule in bucket.cors:
|
||||
if cors_rule.allowed_origins is not None:
|
||||
if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore
|
||||
response_headers["Access-Control-Allow-Origin"] = origin # type: ignore
|
||||
if cors_rule.allowed_methods is not None:
|
||||
response_headers[
|
||||
"Access-Control-Allow-Methods"
|
||||
] = _to_string(cors_rule.allowed_methods)
|
||||
if cors_rule.allowed_headers is not None:
|
||||
response_headers[
|
||||
"Access-Control-Allow-Headers"
|
||||
] = _to_string(cors_rule.allowed_headers)
|
||||
if cors_rule.exposed_headers is not None:
|
||||
response_headers[
|
||||
"Access-Control-Expose-Headers"
|
||||
] = _to_string(cors_rule.exposed_headers)
|
||||
if cors_rule.max_age_seconds is not None:
|
||||
response_headers["Access-Control-Max-Age"] = _to_string(
|
||||
cors_rule.max_age_seconds
|
||||
)
|
||||
|
||||
return response_headers
|
||||
except S3ClientError:
|
||||
pass
|
||||
return response_headers
|
||||
|
||||
def _bucket_response_get(
|
||||
self, bucket_name: str, querystring: Dict[str, Any]
|
||||
) -> Union[str, TYPE_RESPONSE]:
|
||||
@ -1294,7 +1343,7 @@ class S3Response(BaseResponse):
|
||||
self._set_action("KEY", "GET", query)
|
||||
self._authenticate_and_authorize_s3_action()
|
||||
|
||||
response_headers: Dict[str, Any] = {}
|
||||
response_headers = self._get_cors_headers_other(headers, bucket_name)
|
||||
if query.get("uploadId"):
|
||||
upload_id = query["uploadId"][0]
|
||||
|
||||
@ -1411,7 +1460,7 @@ class S3Response(BaseResponse):
|
||||
self._set_action("KEY", "PUT", query)
|
||||
self._authenticate_and_authorize_s3_action()
|
||||
|
||||
response_headers: Dict[str, Any] = {}
|
||||
response_headers = self._get_cors_headers_other(request.headers, bucket_name)
|
||||
if query.get("uploadId") and query.get("partNumber"):
|
||||
upload_id = query["uploadId"][0]
|
||||
part_number = int(query["partNumber"][0])
|
||||
|
@ -254,10 +254,11 @@ def test_s3_server_post_cors_exposed_header():
|
||||
"""
|
||||
|
||||
test_client = authenticated_client()
|
||||
valid_origin = "https://example.org"
|
||||
preflight_headers = {
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "origin, x-requested-with",
|
||||
"Origin": "https://example.org",
|
||||
"Origin": valid_origin,
|
||||
}
|
||||
# Returns 403 on non existing bucket
|
||||
preflight_response = test_client.options(
|
||||
@ -265,8 +266,9 @@ def test_s3_server_post_cors_exposed_header():
|
||||
)
|
||||
assert preflight_response.status_code == 403
|
||||
|
||||
# Create the bucket
|
||||
# Create the bucket & file
|
||||
test_client.put("/", "http://testcors.localhost:5000/")
|
||||
test_client.put("/test", "http://testcors.localhost:5000/")
|
||||
res = test_client.put(
|
||||
"/?cors", "http://testcors.localhost:5000", data=cors_config_payload
|
||||
)
|
||||
@ -292,6 +294,50 @@ def test_s3_server_post_cors_exposed_header():
|
||||
assert header_name in preflight_response.headers
|
||||
assert preflight_response.headers[header_name] == header_value
|
||||
|
||||
# Test GET key response
|
||||
# A regular GET should not receive any CORS headers
|
||||
resp = test_client.get("/test", "http://testcors.localhost:5000/")
|
||||
assert "Access-Control-Allow-Methods" not in resp.headers
|
||||
assert "Access-Control-Expose-Headers" not in resp.headers
|
||||
|
||||
# A GET with mismatched Origin-header should not receive any CORS headers
|
||||
resp = test_client.get(
|
||||
"/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"}
|
||||
)
|
||||
assert "Access-Control-Allow-Methods" not in resp.headers
|
||||
assert "Access-Control-Expose-Headers" not in resp.headers
|
||||
|
||||
# Only a GET with matching Origin-header should receive CORS headers
|
||||
resp = test_client.get(
|
||||
"/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin}
|
||||
)
|
||||
assert (
|
||||
resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE"
|
||||
)
|
||||
assert resp.headers["Access-Control-Expose-Headers"] == "ETag"
|
||||
|
||||
# Test PUT key response
|
||||
# A regular PUT should not receive any CORS headers
|
||||
resp = test_client.put("/test", "http://testcors.localhost:5000/")
|
||||
assert "Access-Control-Allow-Methods" not in resp.headers
|
||||
assert "Access-Control-Expose-Headers" not in resp.headers
|
||||
|
||||
# A PUT with mismatched Origin-header should not receive any CORS headers
|
||||
resp = test_client.put(
|
||||
"/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"}
|
||||
)
|
||||
assert "Access-Control-Allow-Methods" not in resp.headers
|
||||
assert "Access-Control-Expose-Headers" not in resp.headers
|
||||
|
||||
# Only a PUT with matching Origin-header should receive CORS headers
|
||||
resp = test_client.put(
|
||||
"/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin}
|
||||
)
|
||||
assert (
|
||||
resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE"
|
||||
)
|
||||
assert resp.headers["Access-Control-Expose-Headers"] == "ETag"
|
||||
|
||||
|
||||
def test_s3_server_post_cors_multiple_origins():
|
||||
"""Test that Moto only responds with the Origin that we that hosts the server"""
|
||||
@ -315,7 +361,7 @@ def test_s3_server_post_cors_multiple_origins():
|
||||
|
||||
# Test only our requested origin is returned
|
||||
preflight_response = requests.options(
|
||||
"http://testcors.localhost:6789/test",
|
||||
"http://testcors.localhost:6789/test2",
|
||||
headers={
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Origin": "https://localhost:6789",
|
||||
@ -330,7 +376,7 @@ def test_s3_server_post_cors_multiple_origins():
|
||||
|
||||
# Verify a request with unknown origin fails
|
||||
preflight_response = requests.options(
|
||||
"http://testcors.localhost:6789/test",
|
||||
"http://testcors.localhost:6789/test2",
|
||||
headers={
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Origin": "https://unknown.host",
|
||||
@ -347,7 +393,7 @@ def test_s3_server_post_cors_multiple_origins():
|
||||
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
|
||||
for origin in ["https://sth.google.com", "https://a.google.com"]:
|
||||
preflight_response = requests.options(
|
||||
"http://testcors.localhost:6789/test",
|
||||
"http://testcors.localhost:6789/test2",
|
||||
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
|
||||
)
|
||||
assert preflight_response.status_code == 200
|
||||
@ -355,7 +401,7 @@ def test_s3_server_post_cors_multiple_origins():
|
||||
|
||||
# Non-matching requests throw an error though - it does not act as a full wildcard
|
||||
preflight_response = requests.options(
|
||||
"http://testcors.localhost:6789/test",
|
||||
"http://testcors.localhost:6789/test2",
|
||||
headers={
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Origin": "sth.microsoft.com",
|
||||
@ -372,7 +418,7 @@ def test_s3_server_post_cors_multiple_origins():
|
||||
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
|
||||
for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]:
|
||||
preflight_response = requests.options(
|
||||
"http://testcors.localhost:6789/test",
|
||||
"http://testcors.localhost:6789/test2",
|
||||
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
|
||||
)
|
||||
assert preflight_response.status_code == 200
|
||||
|
Loading…
Reference in New Issue
Block a user