S3: Improve Cors AllowedOrigin behaviour (#6007)

This commit is contained in:
Bert Blommers 2023-03-03 21:40:55 -01:00 committed by GitHub
parent 96b8e12d45
commit 8b058d9177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 10 deletions

View File

@ -52,6 +52,13 @@ class InvalidArgumentError(S3ClientError):
super().__init__("InvalidArgument", message, *args, **kwargs) super().__init__("InvalidArgument", message, *args, **kwargs)
class AccessForbidden(S3ClientError):
code = 403
def __init__(self, msg):
super().__init__("AccessForbidden", msg)
class BucketError(S3ClientError): class BucketError(S3ClientError):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "bucket_error") kwargs.setdefault("template", "bucket_error")

View File

@ -50,6 +50,7 @@ from .exceptions import (
PreconditionFailed, PreconditionFailed,
InvalidRange, InvalidRange,
LockNotEnabled, LockNotEnabled,
AccessForbidden,
) )
from .models import s3_backends from .models import s3_backends
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
@ -59,6 +60,7 @@ from .utils import (
parse_region_from_url, parse_region_from_url,
compute_checksum, compute_checksum,
ARCHIVE_STORAGE_CLASSES, ARCHIVE_STORAGE_CLASSES,
cors_matches_origin,
) )
from xml.dom import minidom from xml.dom import minidom
@ -298,7 +300,7 @@ class S3Response(BaseResponse):
elif method == "POST": elif method == "POST":
return self._bucket_response_post(request, bucket_name) return self._bucket_response_post(request, bucket_name)
elif method == "OPTIONS": elif method == "OPTIONS":
return self._response_options(bucket_name) return self._response_options(request.headers, bucket_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Method {method} has not been implemented in the S3 backend yet" f"Method {method} has not been implemented in the S3 backend yet"
@ -343,7 +345,7 @@ class S3Response(BaseResponse):
return 404, {}, "" return 404, {}, ""
return 200, {}, "" return 200, {}, ""
def _set_cors_headers(self, bucket): def _set_cors_headers(self, headers, bucket):
""" """
TODO: smarter way of matching the right CORS rule: TODO: smarter way of matching the right CORS rule:
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
@ -367,8 +369,12 @@ class S3Response(BaseResponse):
cors_rule.allowed_methods cors_rule.allowed_methods
) )
if cors_rule.allowed_origins is not None: if cors_rule.allowed_origins is not None:
self.response_headers["Access-Control-Allow-Origin"] = _to_string( origin = headers.get("Origin")
cors_rule.allowed_origins if cors_matches_origin(origin, cors_rule.allowed_origins):
self.response_headers["Access-Control-Allow-Origin"] = origin
else:
raise AccessForbidden(
"CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec."
) )
if cors_rule.allowed_headers is not None: if cors_rule.allowed_headers is not None:
self.response_headers["Access-Control-Allow-Headers"] = _to_string( self.response_headers["Access-Control-Allow-Headers"] = _to_string(
@ -383,7 +389,7 @@ class S3Response(BaseResponse):
cors_rule.max_age_seconds cors_rule.max_age_seconds
) )
def _response_options(self, bucket_name): def _response_options(self, headers, bucket_name):
# Return 200 with the headers from the bucket CORS configuration # Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
try: try:
@ -395,7 +401,7 @@ class S3Response(BaseResponse):
"", "",
) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD ) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
self._set_cors_headers(bucket) self._set_cors_headers(headers, bucket)
return 200, self.response_headers, "" return 200, self.response_headers, ""
@ -1241,7 +1247,7 @@ class S3Response(BaseResponse):
return self._key_response_post(request, body, bucket_name, query, key_name) return self._key_response_post(request, body, bucket_name, query, key_name)
elif method == "OPTIONS": elif method == "OPTIONS":
# OPTIONS response doesn't depend on the key_name: always return 200 with CORS headers # OPTIONS response doesn't depend on the key_name: always return 200 with CORS headers
return self._response_options(bucket_name) return self._response_options(request.headers, bucket_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Method {method} has not been implemented in the S3 backend yet" f"Method {method} has not been implemented in the S3 backend yet"

View File

@ -5,7 +5,7 @@ import re
import hashlib import hashlib
from urllib.parse import urlparse, unquote, quote from urllib.parse import urlparse, unquote, quote
from requests.structures import CaseInsensitiveDict from requests.structures import CaseInsensitiveDict
from typing import Union, Tuple from typing import List, Union, Tuple
import sys import sys
from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME
@ -212,3 +212,14 @@ def _hash(fn, args) -> bytes:
except TypeError: except TypeError:
# The usedforsecurity-parameter is only available as of Python 3.9 # The usedforsecurity-parameter is only available as of Python 3.9
return fn(*args).hexdigest().encode("utf-8") return fn(*args).hexdigest().encode("utf-8")
def cors_matches_origin(origin_header: str, allowed_origins: List[str]) -> bool:
if "*" in allowed_origins:
return True
if origin_header in allowed_origins:
return True
for allowed in allowed_origins:
if re.match(allowed.replace(".", "\\.").replace("*", ".*"), origin_header):
return True
return False

View File

@ -7,6 +7,7 @@ from moto.s3.utils import (
clean_key_name, clean_key_name,
undo_clean_key_name, undo_clean_key_name,
compute_checksum, compute_checksum,
cors_matches_origin,
) )
from unittest.mock import patch from unittest.mock import patch
@ -141,3 +142,19 @@ def test_checksum_crc32():
def test_checksum_crc32c(): def test_checksum_crc32c():
compute_checksum(b"somedata", "CRC32C").should.equal(b"MTM5MzM0Mzk1Mg==") compute_checksum(b"somedata", "CRC32C").should.equal(b"MTM5MzM0Mzk1Mg==")
def test_cors_utils():
"Fancy string matching"
assert cors_matches_origin("a", ["a"])
assert cors_matches_origin("b", ["a", "b"])
assert not cors_matches_origin("c", [])
assert not cors_matches_origin("c", ["a", "b"])
assert cors_matches_origin("http://www.google.com", ["http://*.google.com"])
assert cors_matches_origin("http://www.google.com", ["http://www.*.com"])
assert cors_matches_origin("http://www.google.com", ["http://*"])
assert cors_matches_origin("http://www.google.com", ["*"])
assert not cors_matches_origin("http://www.google.com", ["http://www.*.org"])
assert not cors_matches_origin("http://www.google.com", ["https://*"])

View File

@ -1,11 +1,13 @@
import io import io
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
import sure # noqa # pylint: disable=unused-import import sure # noqa # pylint: disable=unused-import
import requests
import pytest import pytest
import xmltodict import xmltodict
from flask.testing import FlaskClient from flask.testing import FlaskClient
import moto.server as server import moto.server as server
from moto.moto_server.threaded_moto_server import ThreadedMotoServer
from unittest.mock import patch from unittest.mock import patch
""" """
@ -223,7 +225,7 @@ def test_s3_server_post_cors_exposed_header():
preflight_headers = { preflight_headers = {
"Access-Control-Request-Method": "POST", "Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "origin, x-requested-with", "Access-Control-Request-Headers": "origin, x-requested-with",
"Origin": "https://localhost:9000", "Origin": "https://example.org",
} }
# Returns 403 on non existing bucket # Returns 403 on non existing bucket
preflight_response = test_client.options( preflight_response = test_client.options(
@ -257,3 +259,91 @@ def test_s3_server_post_cors_exposed_header():
for header_name, header_value in expected_cors_headers.items(): for header_name, header_value in expected_cors_headers.items():
assert header_name in preflight_response.headers assert header_name in preflight_response.headers
assert preflight_response.headers[header_name] == header_value assert preflight_response.headers[header_name] == header_value
def test_s3_server_post_cors_multiple_origins():
"""Test that Moto only responds with the Origin that we that hosts the server"""
# github.com/getmoto/moto/issues/6003
cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<CORSRule>
<AllowedOrigin>https://example.org</AllowedOrigin>
<AllowedOrigin>https://localhost:6789</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule>
</CORSConfiguration>
"""
thread = ThreadedMotoServer(port="6789", verbose=False)
thread.start()
# Create the bucket
requests.put("http://testcors.localhost:6789/")
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
# Test only our requested origin is returned
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "https://localhost:6789",
},
)
assert preflight_response.status_code == 200
assert (
preflight_response.headers["Access-Control-Allow-Origin"]
== "https://localhost:6789"
)
assert preflight_response.content == b""
# Verify a request with unknown origin fails
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "https://unknown.host",
},
)
assert preflight_response.status_code == 403
assert b"<Code>AccessForbidden</Code>" in preflight_response.content
# Verify we can use a wildcard anywhere in the origin
cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule>
<AllowedOrigin>https://*.google.com</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule></CORSConfiguration>"""
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://sth.google.com", "https://a.google.com"]:
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
)
assert preflight_response.status_code == 200
assert preflight_response.headers["Access-Control-Allow-Origin"] == origin
# Non-matching requests throw an error though - it does not act as a full wildcard
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "sth.microsoft.com",
},
)
assert preflight_response.status_code == 403
assert b"<Code>AccessForbidden</Code>" in preflight_response.content
# Verify we can use a wildcard as the origin
cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule>
<AllowedOrigin>*</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule></CORSConfiguration>"""
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]:
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
)
assert preflight_response.status_code == 200
assert preflight_response.headers["Access-Control-Allow-Origin"] == origin
thread.stop()