S3: Return CORS headers for GET/PUT requests (#6376)

This commit is contained in:
Bert Blommers 2023-06-07 22:28:40 +00:00 committed by GitHub
parent d20da03225
commit 7c702ee33f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 11 deletions

View File

@ -353,7 +353,9 @@ class S3Response(BaseResponse):
return 404, {}, "" return 404, {}, ""
return 200, {"x-amz-bucket-region": bucket.region_name}, "" return 200, {"x-amz-bucket-region": bucket.region_name}, ""
def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None: def _set_cors_headers_options(
self, headers: Dict[str, str], bucket: FakeBucket
) -> None:
""" """
TODO: smarter way of matching the right CORS rule: TODO: smarter way of matching the right CORS rule:
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
@ -408,10 +410,57 @@ class S3Response(BaseResponse):
# AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
return 403, {}, "" return 403, {}, ""
self._set_cors_headers(headers, bucket) self._set_cors_headers_options(headers, bucket)
return 200, self.response_headers, "" return 200, self.response_headers, ""
def _get_cors_headers_other(
self, headers: Dict[str, str], bucket_name: str
) -> Dict[str, Any]:
"""
Returns a dictionary with the appropriate CORS headers
Should be used for non-OPTIONS requests only
Applicable if the 'Origin' header matches one of a CORS-rules - returns an empty dictionary otherwise
"""
response_headers: Dict[str, Any] = dict()
try:
origin = headers.get("Origin")
if not origin:
return response_headers
bucket = self.backend.get_bucket(bucket_name)
def _to_string(header: Union[List[str], str]) -> str:
# We allow list and strs in header values. Transform lists in comma-separated strings
if isinstance(header, list):
return ", ".join(header)
return header
for cors_rule in bucket.cors:
if cors_rule.allowed_origins is not None:
if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore
response_headers["Access-Control-Allow-Origin"] = origin # type: ignore
if cors_rule.allowed_methods is not None:
response_headers[
"Access-Control-Allow-Methods"
] = _to_string(cors_rule.allowed_methods)
if cors_rule.allowed_headers is not None:
response_headers[
"Access-Control-Allow-Headers"
] = _to_string(cors_rule.allowed_headers)
if cors_rule.exposed_headers is not None:
response_headers[
"Access-Control-Expose-Headers"
] = _to_string(cors_rule.exposed_headers)
if cors_rule.max_age_seconds is not None:
response_headers["Access-Control-Max-Age"] = _to_string(
cors_rule.max_age_seconds
)
return response_headers
except S3ClientError:
pass
return response_headers
def _bucket_response_get( def _bucket_response_get(
self, bucket_name: str, querystring: Dict[str, Any] self, bucket_name: str, querystring: Dict[str, Any]
) -> Union[str, TYPE_RESPONSE]: ) -> Union[str, TYPE_RESPONSE]:
@ -1294,7 +1343,7 @@ class S3Response(BaseResponse):
self._set_action("KEY", "GET", query) self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
response_headers: Dict[str, Any] = {} response_headers = self._get_cors_headers_other(headers, bucket_name)
if query.get("uploadId"): if query.get("uploadId"):
upload_id = query["uploadId"][0] upload_id = query["uploadId"][0]
@ -1411,7 +1460,7 @@ class S3Response(BaseResponse):
self._set_action("KEY", "PUT", query) self._set_action("KEY", "PUT", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
response_headers: Dict[str, Any] = {} response_headers = self._get_cors_headers_other(request.headers, bucket_name)
if query.get("uploadId") and query.get("partNumber"): if query.get("uploadId") and query.get("partNumber"):
upload_id = query["uploadId"][0] upload_id = query["uploadId"][0]
part_number = int(query["partNumber"][0]) part_number = int(query["partNumber"][0])

View File

@ -254,10 +254,11 @@ def test_s3_server_post_cors_exposed_header():
""" """
test_client = authenticated_client() test_client = authenticated_client()
valid_origin = "https://example.org"
preflight_headers = { preflight_headers = {
"Access-Control-Request-Method": "POST", "Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "origin, x-requested-with", "Access-Control-Request-Headers": "origin, x-requested-with",
"Origin": "https://example.org", "Origin": valid_origin,
} }
# Returns 403 on non existing bucket # Returns 403 on non existing bucket
preflight_response = test_client.options( preflight_response = test_client.options(
@ -265,8 +266,9 @@ def test_s3_server_post_cors_exposed_header():
) )
assert preflight_response.status_code == 403 assert preflight_response.status_code == 403
# Create the bucket # Create the bucket & file
test_client.put("/", "http://testcors.localhost:5000/") test_client.put("/", "http://testcors.localhost:5000/")
test_client.put("/test", "http://testcors.localhost:5000/")
res = test_client.put( res = test_client.put(
"/?cors", "http://testcors.localhost:5000", data=cors_config_payload "/?cors", "http://testcors.localhost:5000", data=cors_config_payload
) )
@ -292,6 +294,50 @@ def test_s3_server_post_cors_exposed_header():
assert header_name in preflight_response.headers assert header_name in preflight_response.headers
assert preflight_response.headers[header_name] == header_value assert preflight_response.headers[header_name] == header_value
# Test GET key response
# A regular GET should not receive any CORS headers
resp = test_client.get("/test", "http://testcors.localhost:5000/")
assert "Access-Control-Allow-Methods" not in resp.headers
assert "Access-Control-Expose-Headers" not in resp.headers
# A GET with mismatched Origin-header should not receive any CORS headers
resp = test_client.get(
"/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"}
)
assert "Access-Control-Allow-Methods" not in resp.headers
assert "Access-Control-Expose-Headers" not in resp.headers
# Only a GET with matching Origin-header should receive CORS headers
resp = test_client.get(
"/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin}
)
assert (
resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE"
)
assert resp.headers["Access-Control-Expose-Headers"] == "ETag"
# Test PUT key response
# A regular PUT should not receive any CORS headers
resp = test_client.put("/test", "http://testcors.localhost:5000/")
assert "Access-Control-Allow-Methods" not in resp.headers
assert "Access-Control-Expose-Headers" not in resp.headers
# A PUT with mismatched Origin-header should not receive any CORS headers
resp = test_client.put(
"/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"}
)
assert "Access-Control-Allow-Methods" not in resp.headers
assert "Access-Control-Expose-Headers" not in resp.headers
# Only a PUT with matching Origin-header should receive CORS headers
resp = test_client.put(
"/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin}
)
assert (
resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE"
)
assert resp.headers["Access-Control-Expose-Headers"] == "ETag"
def test_s3_server_post_cors_multiple_origins(): def test_s3_server_post_cors_multiple_origins():
"""Test that Moto only responds with the Origin that we that hosts the server""" """Test that Moto only responds with the Origin that we that hosts the server"""
@ -315,7 +361,7 @@ def test_s3_server_post_cors_multiple_origins():
# Test only our requested origin is returned # Test only our requested origin is returned
preflight_response = requests.options( preflight_response = requests.options(
"http://testcors.localhost:6789/test", "http://testcors.localhost:6789/test2",
headers={ headers={
"Access-Control-Request-Method": "POST", "Access-Control-Request-Method": "POST",
"Origin": "https://localhost:6789", "Origin": "https://localhost:6789",
@ -330,7 +376,7 @@ def test_s3_server_post_cors_multiple_origins():
# Verify a request with unknown origin fails # Verify a request with unknown origin fails
preflight_response = requests.options( preflight_response = requests.options(
"http://testcors.localhost:6789/test", "http://testcors.localhost:6789/test2",
headers={ headers={
"Access-Control-Request-Method": "POST", "Access-Control-Request-Method": "POST",
"Origin": "https://unknown.host", "Origin": "https://unknown.host",
@ -347,7 +393,7 @@ def test_s3_server_post_cors_multiple_origins():
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://sth.google.com", "https://a.google.com"]: for origin in ["https://sth.google.com", "https://a.google.com"]:
preflight_response = requests.options( preflight_response = requests.options(
"http://testcors.localhost:6789/test", "http://testcors.localhost:6789/test2",
headers={"Access-Control-Request-Method": "POST", "Origin": origin}, headers={"Access-Control-Request-Method": "POST", "Origin": origin},
) )
assert preflight_response.status_code == 200 assert preflight_response.status_code == 200
@ -355,7 +401,7 @@ def test_s3_server_post_cors_multiple_origins():
# Non-matching requests throw an error though - it does not act as a full wildcard # Non-matching requests throw an error though - it does not act as a full wildcard
preflight_response = requests.options( preflight_response = requests.options(
"http://testcors.localhost:6789/test", "http://testcors.localhost:6789/test2",
headers={ headers={
"Access-Control-Request-Method": "POST", "Access-Control-Request-Method": "POST",
"Origin": "sth.microsoft.com", "Origin": "sth.microsoft.com",
@ -372,7 +418,7 @@ def test_s3_server_post_cors_multiple_origins():
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]: for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]:
preflight_response = requests.options( preflight_response = requests.options(
"http://testcors.localhost:6789/test", "http://testcors.localhost:6789/test2",
headers={"Access-Control-Request-Method": "POST", "Origin": origin}, headers={"Access-Control-Request-Method": "POST", "Origin": origin},
) )
assert preflight_response.status_code == 200 assert preflight_response.status_code == 200