From 7c702ee33fb12b3a89208ca224263160a61d66f0 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 7 Jun 2023 22:28:40 +0000 Subject: [PATCH] S3: Return CORS headers for GET/PUT requests (#6376) --- moto/s3/responses.py | 57 +++++++++++++++++++++++++++++++--- tests/test_s3/test_server.py | 60 +++++++++++++++++++++++++++++++----- 2 files changed, 106 insertions(+), 11 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index c263039e7..3d8a5a5ec 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -353,7 +353,9 @@ class S3Response(BaseResponse): return 404, {}, "" return 200, {"x-amz-bucket-region": bucket.region_name}, "" - def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None: + def _set_cors_headers_options( + self, headers: Dict[str, str], bucket: FakeBucket + ) -> None: """ TODO: smarter way of matching the right CORS rule: See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html @@ -408,10 +410,57 @@ class S3Response(BaseResponse): # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD return 403, {}, "" - self._set_cors_headers(headers, bucket) + self._set_cors_headers_options(headers, bucket) return 200, self.response_headers, "" + def _get_cors_headers_other( + self, headers: Dict[str, str], bucket_name: str + ) -> Dict[str, Any]: + """ + Returns a dictionary with the appropriate CORS headers + Should be used for non-OPTIONS requests only + Applicable if the 'Origin' header matches one of a CORS-rules - returns an empty dictionary otherwise + """ + response_headers: Dict[str, Any] = dict() + try: + origin = headers.get("Origin") + if not origin: + return response_headers + bucket = self.backend.get_bucket(bucket_name) + + def _to_string(header: Union[List[str], str]) -> str: + # We allow list and strs in header values. Transform lists in comma-separated strings + if isinstance(header, list): + return ", ".join(header) + return header + + for cors_rule in bucket.cors: + if cors_rule.allowed_origins is not None: + if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore + response_headers["Access-Control-Allow-Origin"] = origin # type: ignore + if cors_rule.allowed_methods is not None: + response_headers[ + "Access-Control-Allow-Methods" + ] = _to_string(cors_rule.allowed_methods) + if cors_rule.allowed_headers is not None: + response_headers[ + "Access-Control-Allow-Headers" + ] = _to_string(cors_rule.allowed_headers) + if cors_rule.exposed_headers is not None: + response_headers[ + "Access-Control-Expose-Headers" + ] = _to_string(cors_rule.exposed_headers) + if cors_rule.max_age_seconds is not None: + response_headers["Access-Control-Max-Age"] = _to_string( + cors_rule.max_age_seconds + ) + + return response_headers + except S3ClientError: + pass + return response_headers + def _bucket_response_get( self, bucket_name: str, querystring: Dict[str, Any] ) -> Union[str, TYPE_RESPONSE]: @@ -1294,7 +1343,7 @@ class S3Response(BaseResponse): self._set_action("KEY", "GET", query) self._authenticate_and_authorize_s3_action() - response_headers: Dict[str, Any] = {} + response_headers = self._get_cors_headers_other(headers, bucket_name) if query.get("uploadId"): upload_id = query["uploadId"][0] @@ -1411,7 +1460,7 @@ class S3Response(BaseResponse): self._set_action("KEY", "PUT", query) self._authenticate_and_authorize_s3_action() - response_headers: Dict[str, Any] = {} + response_headers = self._get_cors_headers_other(request.headers, bucket_name) if query.get("uploadId") and query.get("partNumber"): upload_id = query["uploadId"][0] part_number = int(query["partNumber"][0]) diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index 73d02f546..8017b86f2 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -254,10 +254,11 @@ def test_s3_server_post_cors_exposed_header(): """ test_client = authenticated_client() + valid_origin = "https://example.org" preflight_headers = { "Access-Control-Request-Method": "POST", "Access-Control-Request-Headers": "origin, x-requested-with", - "Origin": "https://example.org", + "Origin": valid_origin, } # Returns 403 on non existing bucket preflight_response = test_client.options( @@ -265,8 +266,9 @@ def test_s3_server_post_cors_exposed_header(): ) assert preflight_response.status_code == 403 - # Create the bucket + # Create the bucket & file test_client.put("/", "http://testcors.localhost:5000/") + test_client.put("/test", "http://testcors.localhost:5000/") res = test_client.put( "/?cors", "http://testcors.localhost:5000", data=cors_config_payload ) @@ -292,6 +294,50 @@ def test_s3_server_post_cors_exposed_header(): assert header_name in preflight_response.headers assert preflight_response.headers[header_name] == header_value + # Test GET key response + # A regular GET should not receive any CORS headers + resp = test_client.get("/test", "http://testcors.localhost:5000/") + assert "Access-Control-Allow-Methods" not in resp.headers + assert "Access-Control-Expose-Headers" not in resp.headers + + # A GET with mismatched Origin-header should not receive any CORS headers + resp = test_client.get( + "/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"} + ) + assert "Access-Control-Allow-Methods" not in resp.headers + assert "Access-Control-Expose-Headers" not in resp.headers + + # Only a GET with matching Origin-header should receive CORS headers + resp = test_client.get( + "/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin} + ) + assert ( + resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE" + ) + assert resp.headers["Access-Control-Expose-Headers"] == "ETag" + + # Test PUT key response + # A regular PUT should not receive any CORS headers + resp = test_client.put("/test", "http://testcors.localhost:5000/") + assert "Access-Control-Allow-Methods" not in resp.headers + assert "Access-Control-Expose-Headers" not in resp.headers + + # A PUT with mismatched Origin-header should not receive any CORS headers + resp = test_client.put( + "/test", "http://testcors.localhost:5000/", headers={"Origin": "something.com"} + ) + assert "Access-Control-Allow-Methods" not in resp.headers + assert "Access-Control-Expose-Headers" not in resp.headers + + # Only a PUT with matching Origin-header should receive CORS headers + resp = test_client.put( + "/test", "http://testcors.localhost:5000/", headers={"Origin": valid_origin} + ) + assert ( + resp.headers["Access-Control-Allow-Methods"] == "HEAD, GET, PUT, POST, DELETE" + ) + assert resp.headers["Access-Control-Expose-Headers"] == "ETag" + def test_s3_server_post_cors_multiple_origins(): """Test that Moto only responds with the Origin that we that hosts the server""" @@ -315,7 +361,7 @@ def test_s3_server_post_cors_multiple_origins(): # Test only our requested origin is returned preflight_response = requests.options( - "http://testcors.localhost:6789/test", + "http://testcors.localhost:6789/test2", headers={ "Access-Control-Request-Method": "POST", "Origin": "https://localhost:6789", @@ -330,7 +376,7 @@ def test_s3_server_post_cors_multiple_origins(): # Verify a request with unknown origin fails preflight_response = requests.options( - "http://testcors.localhost:6789/test", + "http://testcors.localhost:6789/test2", headers={ "Access-Control-Request-Method": "POST", "Origin": "https://unknown.host", @@ -347,7 +393,7 @@ def test_s3_server_post_cors_multiple_origins(): requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) for origin in ["https://sth.google.com", "https://a.google.com"]: preflight_response = requests.options( - "http://testcors.localhost:6789/test", + "http://testcors.localhost:6789/test2", headers={"Access-Control-Request-Method": "POST", "Origin": origin}, ) assert preflight_response.status_code == 200 @@ -355,7 +401,7 @@ def test_s3_server_post_cors_multiple_origins(): # Non-matching requests throw an error though - it does not act as a full wildcard preflight_response = requests.options( - "http://testcors.localhost:6789/test", + "http://testcors.localhost:6789/test2", headers={ "Access-Control-Request-Method": "POST", "Origin": "sth.microsoft.com", @@ -372,7 +418,7 @@ def test_s3_server_post_cors_multiple_origins(): requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]: preflight_response = requests.options( - "http://testcors.localhost:6789/test", + "http://testcors.localhost:6789/test2", headers={"Access-Control-Request-Method": "POST", "Origin": origin}, ) assert preflight_response.status_code == 200