diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py index 1756e9a89..737b7d3dc 100644 --- a/moto/cloudfront/models.py +++ b/moto/cloudfront/models.py @@ -278,6 +278,13 @@ class CloudFrontBackend(BaseBackend): dist.advance() return dist, dist.etag + def get_distribution_config(self, distribution_id: str) -> Tuple[Distribution, str]: + if distribution_id not in self.distributions: + raise NoSuchDistribution + dist = self.distributions[distribution_id] + dist.advance() + return dist, dist.etag + def delete_distribution(self, distribution_id: str, if_match: bool) -> None: """ The IfMatch-value is ignored - any value is considered valid. diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py index c3aa45d22..89b223da0 100644 --- a/moto/cloudfront/responses.py +++ b/moto/cloudfront/responses.py @@ -78,24 +78,30 @@ class CloudFrontResponse(BaseResponse): response = template.render(distribution=dist, xmlns=XMLNS) return 200, {"ETag": etag}, response - def update_distribution( + def update_distribution( # type: ignore[return] self, request: Any, full_url: str, headers: Any ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) - params = self._get_xml_body() - distribution_config = params.get("DistributionConfig") dist_id = full_url.split("/")[-2] - if_match = headers["If-Match"] + if request.method == "GET": + distribution_config, etag = self.backend.get_distribution_config(dist_id) + template = self.response_template(GET_DISTRIBUTION_CONFIG_TEMPLATE) + response = template.render(distribution=distribution_config, xmlns=XMLNS) + return 200, {"ETag": etag}, response + if request.method == "PUT": + params = self._get_xml_body() + dist_config = params.get("DistributionConfig") + if_match = headers["If-Match"] - dist, location, e_tag = self.backend.update_distribution( - dist_config=distribution_config, # type: ignore[arg-type] - _id=dist_id, - if_match=if_match, - ) - template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE) - response = template.render(distribution=dist, xmlns=XMLNS) - headers = {"ETag": e_tag, "Location": location} - return 200, headers, response + dist, location, e_tag = self.backend.update_distribution( + dist_config=dist_config, # type: ignore[arg-type] + _id=dist_id, + if_match=if_match, + ) + template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE) + response = template.render(distribution=dist, xmlns=XMLNS) + headers = {"ETag": e_tag, "Location": location} + return 200, headers, response def create_invalidation(self) -> TYPE_RESPONSE: dist_id = self.path.split("/")[-2] @@ -551,6 +557,16 @@ GET_DISTRIBUTION_TEMPLATE = ( """ ) +GET_DISTRIBUTION_CONFIG_TEMPLATE = ( + """ + +""" + + DIST_CONFIG_TEMPLATE + + """ + +""" +) + LIST_TEMPLATE = ( """ diff --git a/tests/test_cloudfront/test_cloudfront_distributions.py b/tests/test_cloudfront/test_cloudfront_distributions.py index 61ac48673..d4017ed13 100644 --- a/tests/test_cloudfront/test_cloudfront_distributions.py +++ b/tests/test_cloudfront/test_cloudfront_distributions.py @@ -370,7 +370,21 @@ def test_create_distribution_needs_unique_caller_reference(): @mock_cloudfront -def test_create_distribution_with_mismatched_originid(): +def test_get_distribution_config_with_unknown_distribution_id(): + client = boto3.client("cloudfront", region_name="us-west-1") + + with pytest.raises(ClientError) as exc: + client.get_distribution_config(Id="unknown") + + metadata = exc.value.response["ResponseMetadata"] + metadata["HTTPStatusCode"].should.equal(404) + err = exc.value.response["Error"] + err["Code"].should.equal("NoSuchDistribution") + err["Message"].should.equal("The specified distribution does not exist.") + + +@mock_cloudfront +def test_get_distribution_config_with_mismatched_originid(): client = boto3.client("cloudfront", region_name="us-west-1") with pytest.raises(ClientError) as exc: @@ -615,3 +629,108 @@ def test_delete_distribution_random_etag(): client.get_distribution(Id=dist_id) err = exc.value.response["Error"] err["Code"].should.equal("NoSuchDistribution") + + +@mock_cloudfront +def test_get_distribution_config(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + dist_id = dist["Distribution"]["Id"] + + resp = client.get_distribution_config(Id=dist_id) + resp.should.have.key("DistributionConfig") + + config = resp["DistributionConfig"] + config.should.have.key("CallerReference").should.equal("ref") + + config.should.have.key("Aliases") + config["Aliases"].should.have.key("Quantity").equals(0) + + config.should.have.key("Origins") + origins = config["Origins"] + origins.should.have.key("Quantity").equals(1) + origins.should.have.key("Items").length_of(1) + origin = origins["Items"][0] + origin.should.have.key("Id").equals("origin1") + origin.should.have.key("DomainName").equals("asdf.s3.us-east-1.amazonaws.com") + origin.should.have.key("OriginPath").equals("") + + origin.should.have.key("CustomHeaders") + origin["CustomHeaders"].should.have.key("Quantity").equals(0) + + origin.should.have.key("ConnectionAttempts").equals(3) + origin.should.have.key("ConnectionTimeout").equals(10) + origin.should.have.key("OriginShield").equals({"Enabled": False}) + + config.should.have.key("OriginGroups").equals({"Quantity": 0}) + + config.should.have.key("DefaultCacheBehavior") + default_cache = config["DefaultCacheBehavior"] + default_cache.should.have.key("TargetOriginId").should.equal("origin1") + default_cache.should.have.key("TrustedSigners") + + signers = default_cache["TrustedSigners"] + signers.should.have.key("Enabled").equals(False) + signers.should.have.key("Quantity").equals(0) + + default_cache.should.have.key("TrustedKeyGroups") + groups = default_cache["TrustedKeyGroups"] + groups.should.have.key("Enabled").equals(False) + groups.should.have.key("Quantity").equals(0) + + default_cache.should.have.key("ViewerProtocolPolicy").equals("allow-all") + + default_cache.should.have.key("AllowedMethods") + methods = default_cache["AllowedMethods"] + methods.should.have.key("Quantity").equals(2) + methods.should.have.key("Items") + set(methods["Items"]).should.equal({"HEAD", "GET"}) + + methods.should.have.key("CachedMethods") + cached_methods = methods["CachedMethods"] + cached_methods.should.have.key("Quantity").equals(2) + set(cached_methods["Items"]).should.equal({"HEAD", "GET"}) + + default_cache.should.have.key("SmoothStreaming").equals(False) + default_cache.should.have.key("Compress").equals(True) + default_cache.should.have.key("LambdaFunctionAssociations").equals({"Quantity": 0}) + default_cache.should.have.key("FunctionAssociations").equals({"Quantity": 0}) + default_cache.should.have.key("FieldLevelEncryptionId").equals("") + default_cache.should.have.key("CachePolicyId") + + config.should.have.key("CacheBehaviors").equals({"Quantity": 0}) + config.should.have.key("CustomErrorResponses").equals({"Quantity": 0}) + config.should.have.key("Comment").equals( + "an optional comment that's not actually optional" + ) + + config.should.have.key("Logging") + logging = config["Logging"] + logging.should.have.key("Enabled").equals(False) + logging.should.have.key("IncludeCookies").equals(False) + logging.should.have.key("Bucket").equals("") + logging.should.have.key("Prefix").equals("") + + config.should.have.key("PriceClass").equals("PriceClass_All") + config.should.have.key("Enabled").equals(False) + config.should.have.key("WebACLId") + config.should.have.key("HttpVersion").equals("http2") + config.should.have.key("IsIPV6Enabled").equals(True) + + config.should.have.key("ViewerCertificate") + cert = config["ViewerCertificate"] + cert.should.have.key("CloudFrontDefaultCertificate").equals(True) + cert.should.have.key("MinimumProtocolVersion").equals("TLSv1") + cert.should.have.key("CertificateSource").equals("cloudfront") + + config.should.have.key("Restrictions") + config["Restrictions"].should.have.key("GeoRestriction") + restriction = config["Restrictions"]["GeoRestriction"] + restriction.should.have.key("RestrictionType").equals("none") + restriction.should.have.key("Quantity").equals(0) + + config.should.have.key("WebACLId") + config.should.have.key("WebACLId").equals("")