Fix requests mock for custom S3 endpoints (#7445)

This commit is contained in:
David H. Irving 2024-03-12 02:46:11 -07:00 committed by GitHub
parent 3d3f1c969e
commit ed3f77fd77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 27 deletions

View File

@ -1,15 +1,14 @@
import re import re
from io import BytesIO from io import BytesIO
from typing import Any, Optional, Union from typing import Any, Optional, Union
from urllib.parse import urlparse
from botocore.awsrequest import AWSResponse from botocore.awsrequest import AWSResponse
import moto.backend_index as backend_index import moto.backend_index as backend_index
from moto import settings
from moto.core.base_backend import BackendDict from moto.core.base_backend import BackendDict
from moto.core.common_types import TYPE_RESPONSE from moto.core.common_types import TYPE_RESPONSE
from moto.core.config import passthrough_service, passthrough_url from moto.core.config import passthrough_service, passthrough_url
from moto.core.utils import get_equivalent_url_in_aws_domain
class MockRawResponse(BytesIO): class MockRawResponse(BytesIO):
@ -43,30 +42,11 @@ class BotocoreStubber:
return response return response
def process_request(self, request: Any) -> Optional[TYPE_RESPONSE]: def process_request(self, request: Any) -> Optional[TYPE_RESPONSE]:
# Handle non-standard AWS endpoint hostnames from ISO regions or custom
# S3 endpoints.
parsed_url, _ = get_equivalent_url_in_aws_domain(request.url)
# Remove the querystring from the URL, as we'll never match on that # Remove the querystring from the URL, as we'll never match on that
x = urlparse(request.url) clean_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
host = x.netloc
# https://github.com/getmoto/moto/pull/6412
# Support ISO regions
iso_region_domains = [
"amazonaws.com.cn",
"c2s.ic.gov",
"sc2s.sgov.gov",
"cloud.adc-e.uk",
"csp.hci.ic.gov",
]
for domain in iso_region_domains:
if host.endswith(domain):
host = host.replace(domain, "amazonaws.com")
# https://github.com/getmoto/moto/issues/2993
# Support S3-compatible tools (Ceph, Digital Ocean, etc)
for custom_endpoint in settings.get_s3_custom_endpoints():
if host == custom_endpoint or host == custom_endpoint.split("://")[-1]:
host = "s3.amazonaws.com"
clean_url = f"{x.scheme}://{host}{x.path}"
if passthrough_url(clean_url): if passthrough_url(clean_url):
return None return None

View File

@ -1,10 +1,12 @@
# This will only exist in responses >= 0.17 # This will only exist in responses >= 0.17
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlunparse
import responses import responses
from .custom_responses_mock import CallbackResponse, not_implemented_callback from .custom_responses_mock import CallbackResponse, not_implemented_callback
from .utils import get_equivalent_url_in_aws_domain
class CustomRegistry(responses.registries.FirstMatchRegistry): class CustomRegistry(responses.registries.FirstMatchRegistry):
@ -41,8 +43,20 @@ class CustomRegistry(responses.registries.FirstMatchRegistry):
) )
found = [] found = []
match_failed_reasons = [] match_failed_reasons = []
# Handle non-standard AWS endpoint hostnames from ISO regions or custom S3 endpoints.
parsed_url, url_was_modified = get_equivalent_url_in_aws_domain(request.url)
if url_was_modified:
url_with_standard_aws_domain = urlunparse(parsed_url)
request_with_standard_aws_domain = request.copy()
request_with_standard_aws_domain.prepare_url(
url_with_standard_aws_domain, {}
)
else:
request_with_standard_aws_domain = request
for response in all_possibles: for response in all_possibles:
match_result, reason = response.matches(request) match_result, reason = response.matches(request_with_standard_aws_domain)
if match_result: if match_result:
found.append(response) found.append(response)
else: else:

View File

@ -3,10 +3,11 @@ import inspect
import re import re
from gzip import decompress from gzip import decompress
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import ParseResult, urlparse
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from ..settings import get_s3_custom_endpoints
from .common_types import TYPE_RESPONSE from .common_types import TYPE_RESPONSE
from .versions import PYTHON_311 from .versions import PYTHON_311
@ -398,3 +399,47 @@ def get_partition_from_region(region_name: str) -> str:
if region_name.startswith("cn-"): if region_name.startswith("cn-"):
return "aws-cn" return "aws-cn"
return "aws" return "aws"
def get_equivalent_url_in_aws_domain(url: str) -> Tuple[ParseResult, bool]:
"""Parses a URL and converts non-standard AWS endpoint hostnames (from ISO
regions or custom S3 endpoints) to the equivalent standard AWS domain.
Returns a tuple: (parsed URL, was URL modified).
"""
parsed = urlparse(url)
original_host = parsed.netloc
host = original_host
# https://github.com/getmoto/moto/pull/6412
# Support ISO regions
iso_region_domains = [
"amazonaws.com.cn",
"c2s.ic.gov",
"sc2s.sgov.gov",
"cloud.adc-e.uk",
"csp.hci.ic.gov",
]
for domain in iso_region_domains:
if host.endswith(domain):
host = host.replace(domain, "amazonaws.com")
# https://github.com/getmoto/moto/issues/2993
# Support S3-compatible tools (Ceph, Digital Ocean, etc)
for custom_endpoint in get_s3_custom_endpoints():
if host == custom_endpoint or host == custom_endpoint.split("://")[-1]:
host = "s3.amazonaws.com"
if host == original_host:
return (parsed, False)
else:
result = ParseResult(
scheme=parsed.scheme,
netloc=host,
path=parsed.path,
params=parsed.params,
query=parsed.query,
fragment=parsed.fragment,
)
return (result, True)

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import boto3 import boto3
import pytest import pytest
import requests
from moto import mock_aws, settings from moto import mock_aws, settings
@ -96,3 +97,27 @@ def test_put_and_list_objects(url):
contents = s3_client.list_objects(Bucket=bucket)["Contents"] contents = s3_client.list_objects(Bucket=bucket)["Contents"]
assert len(contents) == 3 assert len(contents) == 3
assert "two" in [c["Key"] for c in contents] assert "two" in [c["Key"] for c in contents]
@pytest.mark.parametrize("url", [CUSTOM_ENDPOINT, CUSTOM_ENDPOINT_2])
def test_get_presigned_url(url):
if not settings.TEST_DECORATOR_MODE:
raise SkipTest("Unable to set ENV VAR in ServerMode")
with patch.dict(os.environ, {"MOTO_S3_CUSTOM_ENDPOINTS": url}):
with mock_aws():
bucket = "mybucket"
key = "file.txt"
contents = b"file contents"
conn = boto3.resource(
"s3", endpoint_url=url, region_name=DEFAULT_REGION_NAME
)
conn.create_bucket(Bucket=bucket)
s3_client = boto3.client("s3", endpoint_url=url)
s3_client.put_object(Bucket=bucket, Key=key, Body=contents)
signed_url = s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=86400
)
response = requests.get(signed_url, stream=False)
assert contents == response.content