feat(CloudFront): Include get_distribution_config (#5660)

This commit is contained in:
Pepe Fagoaga 2022-11-12 00:32:00 +01:00 committed by GitHub
parent 4b946ce208
commit df64b7b777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 14 deletions

View File

@ -278,6 +278,13 @@ class CloudFrontBackend(BaseBackend):
dist.advance()
return dist, dist.etag
def get_distribution_config(self, distribution_id: str) -> Tuple[Distribution, str]:
if distribution_id not in self.distributions:
raise NoSuchDistribution
dist = self.distributions[distribution_id]
dist.advance()
return dist, dist.etag
def delete_distribution(self, distribution_id: str, if_match: bool) -> None:
"""
The IfMatch-value is ignored - any value is considered valid.

View File

@ -78,24 +78,30 @@ class CloudFrontResponse(BaseResponse):
response = template.render(distribution=dist, xmlns=XMLNS)
return 200, {"ETag": etag}, response
def update_distribution(
def update_distribution( # type: ignore[return]
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
params = self._get_xml_body()
distribution_config = params.get("DistributionConfig")
dist_id = full_url.split("/")[-2]
if_match = headers["If-Match"]
if request.method == "GET":
distribution_config, etag = self.backend.get_distribution_config(dist_id)
template = self.response_template(GET_DISTRIBUTION_CONFIG_TEMPLATE)
response = template.render(distribution=distribution_config, xmlns=XMLNS)
return 200, {"ETag": etag}, response
if request.method == "PUT":
params = self._get_xml_body()
dist_config = params.get("DistributionConfig")
if_match = headers["If-Match"]
dist, location, e_tag = self.backend.update_distribution(
dist_config=distribution_config, # type: ignore[arg-type]
_id=dist_id,
if_match=if_match,
)
template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS)
headers = {"ETag": e_tag, "Location": location}
return 200, headers, response
dist, location, e_tag = self.backend.update_distribution(
dist_config=dist_config, # type: ignore[arg-type]
_id=dist_id,
if_match=if_match,
)
template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
response = template.render(distribution=dist, xmlns=XMLNS)
headers = {"ETag": e_tag, "Location": location}
return 200, headers, response
def create_invalidation(self) -> TYPE_RESPONSE:
dist_id = self.path.split("/")[-2]
@ -551,6 +557,16 @@ GET_DISTRIBUTION_TEMPLATE = (
"""
)
GET_DISTRIBUTION_CONFIG_TEMPLATE = (
"""<?xml version="1.0"?>
<DistributionConfig>
"""
+ DIST_CONFIG_TEMPLATE
+ """
</DistributionConfig>
"""
)
LIST_TEMPLATE = (
"""<?xml version="1.0"?>

View File

@ -370,7 +370,21 @@ def test_create_distribution_needs_unique_caller_reference():
@mock_cloudfront
def test_create_distribution_with_mismatched_originid():
def test_get_distribution_config_with_unknown_distribution_id():
client = boto3.client("cloudfront", region_name="us-west-1")
with pytest.raises(ClientError) as exc:
client.get_distribution_config(Id="unknown")
metadata = exc.value.response["ResponseMetadata"]
metadata["HTTPStatusCode"].should.equal(404)
err = exc.value.response["Error"]
err["Code"].should.equal("NoSuchDistribution")
err["Message"].should.equal("The specified distribution does not exist.")
@mock_cloudfront
def test_get_distribution_config_with_mismatched_originid():
client = boto3.client("cloudfront", region_name="us-west-1")
with pytest.raises(ClientError) as exc:
@ -615,3 +629,108 @@ def test_delete_distribution_random_etag():
client.get_distribution(Id=dist_id)
err = exc.value.response["Error"]
err["Code"].should.equal("NoSuchDistribution")
@mock_cloudfront
def test_get_distribution_config():
client = boto3.client("cloudfront", region_name="us-east-1")
# Create standard distribution
config = scaffold.example_distribution_config(ref="ref")
dist = client.create_distribution(DistributionConfig=config)
dist_id = dist["Distribution"]["Id"]
resp = client.get_distribution_config(Id=dist_id)
resp.should.have.key("DistributionConfig")
config = resp["DistributionConfig"]
config.should.have.key("CallerReference").should.equal("ref")
config.should.have.key("Aliases")
config["Aliases"].should.have.key("Quantity").equals(0)
config.should.have.key("Origins")
origins = config["Origins"]
origins.should.have.key("Quantity").equals(1)
origins.should.have.key("Items").length_of(1)
origin = origins["Items"][0]
origin.should.have.key("Id").equals("origin1")
origin.should.have.key("DomainName").equals("asdf.s3.us-east-1.amazonaws.com")
origin.should.have.key("OriginPath").equals("")
origin.should.have.key("CustomHeaders")
origin["CustomHeaders"].should.have.key("Quantity").equals(0)
origin.should.have.key("ConnectionAttempts").equals(3)
origin.should.have.key("ConnectionTimeout").equals(10)
origin.should.have.key("OriginShield").equals({"Enabled": False})
config.should.have.key("OriginGroups").equals({"Quantity": 0})
config.should.have.key("DefaultCacheBehavior")
default_cache = config["DefaultCacheBehavior"]
default_cache.should.have.key("TargetOriginId").should.equal("origin1")
default_cache.should.have.key("TrustedSigners")
signers = default_cache["TrustedSigners"]
signers.should.have.key("Enabled").equals(False)
signers.should.have.key("Quantity").equals(0)
default_cache.should.have.key("TrustedKeyGroups")
groups = default_cache["TrustedKeyGroups"]
groups.should.have.key("Enabled").equals(False)
groups.should.have.key("Quantity").equals(0)
default_cache.should.have.key("ViewerProtocolPolicy").equals("allow-all")
default_cache.should.have.key("AllowedMethods")
methods = default_cache["AllowedMethods"]
methods.should.have.key("Quantity").equals(2)
methods.should.have.key("Items")
set(methods["Items"]).should.equal({"HEAD", "GET"})
methods.should.have.key("CachedMethods")
cached_methods = methods["CachedMethods"]
cached_methods.should.have.key("Quantity").equals(2)
set(cached_methods["Items"]).should.equal({"HEAD", "GET"})
default_cache.should.have.key("SmoothStreaming").equals(False)
default_cache.should.have.key("Compress").equals(True)
default_cache.should.have.key("LambdaFunctionAssociations").equals({"Quantity": 0})
default_cache.should.have.key("FunctionAssociations").equals({"Quantity": 0})
default_cache.should.have.key("FieldLevelEncryptionId").equals("")
default_cache.should.have.key("CachePolicyId")
config.should.have.key("CacheBehaviors").equals({"Quantity": 0})
config.should.have.key("CustomErrorResponses").equals({"Quantity": 0})
config.should.have.key("Comment").equals(
"an optional comment that's not actually optional"
)
config.should.have.key("Logging")
logging = config["Logging"]
logging.should.have.key("Enabled").equals(False)
logging.should.have.key("IncludeCookies").equals(False)
logging.should.have.key("Bucket").equals("")
logging.should.have.key("Prefix").equals("")
config.should.have.key("PriceClass").equals("PriceClass_All")
config.should.have.key("Enabled").equals(False)
config.should.have.key("WebACLId")
config.should.have.key("HttpVersion").equals("http2")
config.should.have.key("IsIPV6Enabled").equals(True)
config.should.have.key("ViewerCertificate")
cert = config["ViewerCertificate"]
cert.should.have.key("CloudFrontDefaultCertificate").equals(True)
cert.should.have.key("MinimumProtocolVersion").equals("TLSv1")
cert.should.have.key("CertificateSource").equals("cloudfront")
config.should.have.key("Restrictions")
config["Restrictions"].should.have.key("GeoRestriction")
restriction = config["Restrictions"]["GeoRestriction"]
restriction.should.have.key("RestrictionType").equals("none")
restriction.should.have.key("Quantity").equals(0)
config.should.have.key("WebACLId")
config.should.have.key("WebACLId").equals("")