CloudFront: create_distribution() now supports CustomHeaders (#7371)

This commit is contained in:
Bert Blommers 2024-02-20 21:29:21 +00:00 committed by GitHub
parent ce074824a4
commit 56d11d841c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 3 deletions

View File

@ -113,7 +113,6 @@ class Origin:
self.id = origin["Id"]
self.domain_name = origin["DomainName"]
self.origin_path = origin.get("OriginPath") or ""
self.custom_headers: List[Any] = []
self.s3_access_identity = ""
self.custom_origin = None
self.origin_shield = origin.get("OriginShield")
@ -129,6 +128,14 @@ class Origin:
if "CustomOriginConfig" in origin:
self.custom_origin = CustomOriginConfig(origin["CustomOriginConfig"])
custom_headers = origin.get("CustomHeaders") or {}
custom_headers = custom_headers.get("Items") or {}
custom_headers = custom_headers.get("OriginCustomHeader") or []
if isinstance(custom_headers, dict):
# Happens if user only sends a single header
custom_headers = [custom_headers]
self.custom_headers = custom_headers
class GeoRestrictions:
def __init__(self, config: Dict[str, Any]):

View File

@ -212,8 +212,10 @@ DIST_CONFIG_TEMPLATE = """
<Quantity>{{ origin.custom_headers|length }}</Quantity>
<Items>
{% for header in origin.custom_headers %}
<HeaderName>{{ header.header_name }}</HeaderName>
<HeaderValue>{{ header.header_value }}</HeaderValue>
<OriginCustomHeader>
<HeaderName>{{ header['HeaderName'] }}</HeaderName>
<HeaderValue>{{ header['HeaderValue'] }}</HeaderValue>
</OriginCustomHeader>
{% endfor %}
</Items>
</CustomHeaders>

View File

@ -238,6 +238,26 @@ def test_create_distribution_with_origins():
assert origin["OriginShield"] == {"Enabled": True, "OriginShieldRegion": "east"}
@mock_aws
@pytest.mark.parametrize("nr_of_headers", [1, 2])
def test_create_distribution_with_custom_headers(nr_of_headers):
client = boto3.client("cloudfront", region_name="us-west-1")
config = scaffold.example_distribution_config("ref")
headers = [
{"HeaderName": f"X-Custom-Header{i}", "HeaderValue": f"v{i}"}
for i in range(nr_of_headers)
]
config["Origins"]["Items"][0]["CustomHeaders"] = {
"Quantity": nr_of_headers,
"Items": headers,
}
dist = client.create_distribution(DistributionConfig=config)["Distribution"]
origin = dist["DistributionConfig"]["Origins"]["Items"][0]
assert origin["CustomHeaders"] == {"Quantity": nr_of_headers, "Items": headers}
@mock_aws
@pytest.mark.parametrize("compress", [True, False])
@pytest.mark.parametrize("qs", [True, False])