S3: Improve Cors AllowedOrigin behaviour (#6007)
This commit is contained in:
		
							parent
							
								
									96b8e12d45
								
							
						
					
					
						commit
						8b058d9177
					
				| @ -52,6 +52,13 @@ class InvalidArgumentError(S3ClientError): | |||||||
|         super().__init__("InvalidArgument", message, *args, **kwargs) |         super().__init__("InvalidArgument", message, *args, **kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class AccessForbidden(S3ClientError): | ||||||
|  |     code = 403 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, msg): | ||||||
|  |         super().__init__("AccessForbidden", msg) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class BucketError(S3ClientError): | class BucketError(S3ClientError): | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         kwargs.setdefault("template", "bucket_error") |         kwargs.setdefault("template", "bucket_error") | ||||||
|  | |||||||
| @ -50,6 +50,7 @@ from .exceptions import ( | |||||||
|     PreconditionFailed, |     PreconditionFailed, | ||||||
|     InvalidRange, |     InvalidRange, | ||||||
|     LockNotEnabled, |     LockNotEnabled, | ||||||
|  |     AccessForbidden, | ||||||
| ) | ) | ||||||
| from .models import s3_backends | from .models import s3_backends | ||||||
| from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey | from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey | ||||||
| @ -59,6 +60,7 @@ from .utils import ( | |||||||
|     parse_region_from_url, |     parse_region_from_url, | ||||||
|     compute_checksum, |     compute_checksum, | ||||||
|     ARCHIVE_STORAGE_CLASSES, |     ARCHIVE_STORAGE_CLASSES, | ||||||
|  |     cors_matches_origin, | ||||||
| ) | ) | ||||||
| from xml.dom import minidom | from xml.dom import minidom | ||||||
| 
 | 
 | ||||||
| @ -298,7 +300,7 @@ class S3Response(BaseResponse): | |||||||
|         elif method == "POST": |         elif method == "POST": | ||||||
|             return self._bucket_response_post(request, bucket_name) |             return self._bucket_response_post(request, bucket_name) | ||||||
|         elif method == "OPTIONS": |         elif method == "OPTIONS": | ||||||
|             return self._response_options(bucket_name) |             return self._response_options(request.headers, bucket_name) | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError( |             raise NotImplementedError( | ||||||
|                 f"Method {method} has not been implemented in the S3 backend yet" |                 f"Method {method} has not been implemented in the S3 backend yet" | ||||||
| @ -343,7 +345,7 @@ class S3Response(BaseResponse): | |||||||
|             return 404, {}, "" |             return 404, {}, "" | ||||||
|         return 200, {}, "" |         return 200, {}, "" | ||||||
| 
 | 
 | ||||||
|     def _set_cors_headers(self, bucket): |     def _set_cors_headers(self, headers, bucket): | ||||||
|         """ |         """ | ||||||
|         TODO: smarter way of matching the right CORS rule: |         TODO: smarter way of matching the right CORS rule: | ||||||
|         See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html |         See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html | ||||||
| @ -367,8 +369,12 @@ class S3Response(BaseResponse): | |||||||
|                     cors_rule.allowed_methods |                     cors_rule.allowed_methods | ||||||
|                 ) |                 ) | ||||||
|             if cors_rule.allowed_origins is not None: |             if cors_rule.allowed_origins is not None: | ||||||
|                 self.response_headers["Access-Control-Allow-Origin"] = _to_string( |                 origin = headers.get("Origin") | ||||||
|                     cors_rule.allowed_origins |                 if cors_matches_origin(origin, cors_rule.allowed_origins): | ||||||
|  |                     self.response_headers["Access-Control-Allow-Origin"] = origin | ||||||
|  |                 else: | ||||||
|  |                     raise AccessForbidden( | ||||||
|  |                         "CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec." | ||||||
|                     ) |                     ) | ||||||
|             if cors_rule.allowed_headers is not None: |             if cors_rule.allowed_headers is not None: | ||||||
|                 self.response_headers["Access-Control-Allow-Headers"] = _to_string( |                 self.response_headers["Access-Control-Allow-Headers"] = _to_string( | ||||||
| @ -383,7 +389,7 @@ class S3Response(BaseResponse): | |||||||
|                     cors_rule.max_age_seconds |                     cors_rule.max_age_seconds | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|     def _response_options(self, bucket_name): |     def _response_options(self, headers, bucket_name): | ||||||
|         # Return 200 with the headers from the bucket CORS configuration |         # Return 200 with the headers from the bucket CORS configuration | ||||||
|         self._authenticate_and_authorize_s3_action() |         self._authenticate_and_authorize_s3_action() | ||||||
|         try: |         try: | ||||||
| @ -395,7 +401,7 @@ class S3Response(BaseResponse): | |||||||
|                 "", |                 "", | ||||||
|             )  # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD |             )  # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD | ||||||
| 
 | 
 | ||||||
|         self._set_cors_headers(bucket) |         self._set_cors_headers(headers, bucket) | ||||||
| 
 | 
 | ||||||
|         return 200, self.response_headers, "" |         return 200, self.response_headers, "" | ||||||
| 
 | 
 | ||||||
| @ -1241,7 +1247,7 @@ class S3Response(BaseResponse): | |||||||
|             return self._key_response_post(request, body, bucket_name, query, key_name) |             return self._key_response_post(request, body, bucket_name, query, key_name) | ||||||
|         elif method == "OPTIONS": |         elif method == "OPTIONS": | ||||||
|             # OPTIONS response doesn't depend on the key_name: always return 200 with CORS headers |             # OPTIONS response doesn't depend on the key_name: always return 200 with CORS headers | ||||||
|             return self._response_options(bucket_name) |             return self._response_options(request.headers, bucket_name) | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError( |             raise NotImplementedError( | ||||||
|                 f"Method {method} has not been implemented in the S3 backend yet" |                 f"Method {method} has not been implemented in the S3 backend yet" | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ import re | |||||||
| import hashlib | import hashlib | ||||||
| from urllib.parse import urlparse, unquote, quote | from urllib.parse import urlparse, unquote, quote | ||||||
| from requests.structures import CaseInsensitiveDict | from requests.structures import CaseInsensitiveDict | ||||||
| from typing import Union, Tuple | from typing import List, Union, Tuple | ||||||
| import sys | import sys | ||||||
| from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME | from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME | ||||||
| 
 | 
 | ||||||
| @ -212,3 +212,14 @@ def _hash(fn, args) -> bytes: | |||||||
|     except TypeError: |     except TypeError: | ||||||
|         # The usedforsecurity-parameter is only available as of Python 3.9 |         # The usedforsecurity-parameter is only available as of Python 3.9 | ||||||
|         return fn(*args).hexdigest().encode("utf-8") |         return fn(*args).hexdigest().encode("utf-8") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def cors_matches_origin(origin_header: str, allowed_origins: List[str]) -> bool: | ||||||
|  |     if "*" in allowed_origins: | ||||||
|  |         return True | ||||||
|  |     if origin_header in allowed_origins: | ||||||
|  |         return True | ||||||
|  |     for allowed in allowed_origins: | ||||||
|  |         if re.match(allowed.replace(".", "\\.").replace("*", ".*"), origin_header): | ||||||
|  |             return True | ||||||
|  |     return False | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ from moto.s3.utils import ( | |||||||
|     clean_key_name, |     clean_key_name, | ||||||
|     undo_clean_key_name, |     undo_clean_key_name, | ||||||
|     compute_checksum, |     compute_checksum, | ||||||
|  |     cors_matches_origin, | ||||||
| ) | ) | ||||||
| from unittest.mock import patch | from unittest.mock import patch | ||||||
| 
 | 
 | ||||||
| @ -141,3 +142,19 @@ def test_checksum_crc32(): | |||||||
| 
 | 
 | ||||||
| def test_checksum_crc32c(): | def test_checksum_crc32c(): | ||||||
|     compute_checksum(b"somedata", "CRC32C").should.equal(b"MTM5MzM0Mzk1Mg==") |     compute_checksum(b"somedata", "CRC32C").should.equal(b"MTM5MzM0Mzk1Mg==") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_cors_utils(): | ||||||
|  |     "Fancy string matching" | ||||||
|  |     assert cors_matches_origin("a", ["a"]) | ||||||
|  |     assert cors_matches_origin("b", ["a", "b"]) | ||||||
|  |     assert not cors_matches_origin("c", []) | ||||||
|  |     assert not cors_matches_origin("c", ["a", "b"]) | ||||||
|  | 
 | ||||||
|  |     assert cors_matches_origin("http://www.google.com", ["http://*.google.com"]) | ||||||
|  |     assert cors_matches_origin("http://www.google.com", ["http://www.*.com"]) | ||||||
|  |     assert cors_matches_origin("http://www.google.com", ["http://*"]) | ||||||
|  |     assert cors_matches_origin("http://www.google.com", ["*"]) | ||||||
|  | 
 | ||||||
|  |     assert not cors_matches_origin("http://www.google.com", ["http://www.*.org"]) | ||||||
|  |     assert not cors_matches_origin("http://www.google.com", ["https://*"]) | ||||||
|  | |||||||
| @ -1,11 +1,13 @@ | |||||||
| import io | import io | ||||||
| from urllib.parse import urlparse, parse_qs | from urllib.parse import urlparse, parse_qs | ||||||
| import sure  # noqa # pylint: disable=unused-import | import sure  # noqa # pylint: disable=unused-import | ||||||
|  | import requests | ||||||
| import pytest | import pytest | ||||||
| import xmltodict | import xmltodict | ||||||
| 
 | 
 | ||||||
| from flask.testing import FlaskClient | from flask.testing import FlaskClient | ||||||
| import moto.server as server | import moto.server as server | ||||||
|  | from moto.moto_server.threaded_moto_server import ThreadedMotoServer | ||||||
| from unittest.mock import patch | from unittest.mock import patch | ||||||
| 
 | 
 | ||||||
| """ | """ | ||||||
| @ -223,7 +225,7 @@ def test_s3_server_post_cors_exposed_header(): | |||||||
|     preflight_headers = { |     preflight_headers = { | ||||||
|         "Access-Control-Request-Method": "POST", |         "Access-Control-Request-Method": "POST", | ||||||
|         "Access-Control-Request-Headers": "origin, x-requested-with", |         "Access-Control-Request-Headers": "origin, x-requested-with", | ||||||
|         "Origin": "https://localhost:9000", |         "Origin": "https://example.org", | ||||||
|     } |     } | ||||||
|     # Returns 403 on non existing bucket |     # Returns 403 on non existing bucket | ||||||
|     preflight_response = test_client.options( |     preflight_response = test_client.options( | ||||||
| @ -257,3 +259,91 @@ def test_s3_server_post_cors_exposed_header(): | |||||||
|         for header_name, header_value in expected_cors_headers.items(): |         for header_name, header_value in expected_cors_headers.items(): | ||||||
|             assert header_name in preflight_response.headers |             assert header_name in preflight_response.headers | ||||||
|             assert preflight_response.headers[header_name] == header_value |             assert preflight_response.headers[header_name] == header_value | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_s3_server_post_cors_multiple_origins(): | ||||||
|  |     """Test that Moto only responds with the Origin that we that hosts the server""" | ||||||
|  |     # github.com/getmoto/moto/issues/6003 | ||||||
|  | 
 | ||||||
|  |     cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> | ||||||
|  |   <CORSRule> | ||||||
|  |     <AllowedOrigin>https://example.org</AllowedOrigin> | ||||||
|  |     <AllowedOrigin>https://localhost:6789</AllowedOrigin> | ||||||
|  |     <AllowedMethod>POST</AllowedMethod> | ||||||
|  |   </CORSRule> | ||||||
|  | </CORSConfiguration> | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     thread = ThreadedMotoServer(port="6789", verbose=False) | ||||||
|  |     thread.start() | ||||||
|  | 
 | ||||||
|  |     # Create the bucket | ||||||
|  |     requests.put("http://testcors.localhost:6789/") | ||||||
|  |     requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) | ||||||
|  | 
 | ||||||
|  |     # Test only our requested origin is returned | ||||||
|  |     preflight_response = requests.options( | ||||||
|  |         "http://testcors.localhost:6789/test", | ||||||
|  |         headers={ | ||||||
|  |             "Access-Control-Request-Method": "POST", | ||||||
|  |             "Origin": "https://localhost:6789", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     assert preflight_response.status_code == 200 | ||||||
|  |     assert ( | ||||||
|  |         preflight_response.headers["Access-Control-Allow-Origin"] | ||||||
|  |         == "https://localhost:6789" | ||||||
|  |     ) | ||||||
|  |     assert preflight_response.content == b"" | ||||||
|  | 
 | ||||||
|  |     # Verify a request with unknown origin fails | ||||||
|  |     preflight_response = requests.options( | ||||||
|  |         "http://testcors.localhost:6789/test", | ||||||
|  |         headers={ | ||||||
|  |             "Access-Control-Request-Method": "POST", | ||||||
|  |             "Origin": "https://unknown.host", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     assert preflight_response.status_code == 403 | ||||||
|  |     assert b"<Code>AccessForbidden</Code>" in preflight_response.content | ||||||
|  | 
 | ||||||
|  |     # Verify we can use a wildcard anywhere in the origin | ||||||
|  |     cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule> | ||||||
|  |             <AllowedOrigin>https://*.google.com</AllowedOrigin> | ||||||
|  |             <AllowedMethod>POST</AllowedMethod> | ||||||
|  |           </CORSRule></CORSConfiguration>""" | ||||||
|  |     requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) | ||||||
|  |     for origin in ["https://sth.google.com", "https://a.google.com"]: | ||||||
|  |         preflight_response = requests.options( | ||||||
|  |             "http://testcors.localhost:6789/test", | ||||||
|  |             headers={"Access-Control-Request-Method": "POST", "Origin": origin}, | ||||||
|  |         ) | ||||||
|  |         assert preflight_response.status_code == 200 | ||||||
|  |         assert preflight_response.headers["Access-Control-Allow-Origin"] == origin | ||||||
|  | 
 | ||||||
|  |     # Non-matching requests throw an error though - it does not act as a full wildcard | ||||||
|  |     preflight_response = requests.options( | ||||||
|  |         "http://testcors.localhost:6789/test", | ||||||
|  |         headers={ | ||||||
|  |             "Access-Control-Request-Method": "POST", | ||||||
|  |             "Origin": "sth.microsoft.com", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     assert preflight_response.status_code == 403 | ||||||
|  |     assert b"<Code>AccessForbidden</Code>" in preflight_response.content | ||||||
|  | 
 | ||||||
|  |     # Verify we can use a wildcard as the origin | ||||||
|  |     cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule> | ||||||
|  |                 <AllowedOrigin>*</AllowedOrigin> | ||||||
|  |                 <AllowedMethod>POST</AllowedMethod> | ||||||
|  |               </CORSRule></CORSConfiguration>""" | ||||||
|  |     requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload) | ||||||
|  |     for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]: | ||||||
|  |         preflight_response = requests.options( | ||||||
|  |             "http://testcors.localhost:6789/test", | ||||||
|  |             headers={"Access-Control-Request-Method": "POST", "Origin": origin}, | ||||||
|  |         ) | ||||||
|  |         assert preflight_response.status_code == 200 | ||||||
|  |         assert preflight_response.headers["Access-Control-Allow-Origin"] == origin | ||||||
|  | 
 | ||||||
|  |     thread.stop() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user