diff --git a/docs/docs/services/cloudfront.rst b/docs/docs/services/cloudfront.rst
index 3dce697b7..ff6ff703e 100644
--- a/docs/docs/services/cloudfront.rst
+++ b/docs/docs/services/cloudfront.rst
@@ -121,7 +121,15 @@ cloudfront
- [ ] untag_resource
- [ ] update_cache_policy
- [ ] update_cloud_front_origin_access_identity
-- [ ] update_distribution
+- [X] update_distribution
+
+ The IfMatch-value is ignored - any value is considered valid.
+ Calling this function without a value is invalid, per AWS' behaviour
+
+ This implementation is immature, and tests the basic
+ functionality of updating an exisint distribution with very
+ simple changes.
+
- [ ] update_field_level_encryption_config
- [ ] update_field_level_encryption_profile
- [ ] update_function
diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py
index 03861c901..9fc087d29 100644
--- a/moto/cloudfront/models.py
+++ b/moto/cloudfront/models.py
@@ -180,8 +180,10 @@ class CloudFrontBackend(BaseBackend):
def create_distribution(self, distribution_config):
"""
- This has been tested against an S3-distribution with the simplest possible configuration.
- Please raise an issue if we're not persisting/returning the correct attributes for your use-case.
+ This has been tested against an S3-distribution with the
+ simplest possible configuration. Please raise an issue if
+ we're not persisting/returning the correct attributes for your
+ use-case.
"""
dist = Distribution(distribution_config)
caller_reference = dist.distribution_config.caller_reference
@@ -224,5 +226,25 @@ class CloudFrontBackend(BaseBackend):
return dist
return False
+ def update_distribution(self, DistributionConfig, Id, IfMatch):
+ """
+ The IfMatch-value is ignored - any value is considered valid.
+ Calling this function without a value is invalid, per AWS' behaviour
+ """
+ if Id not in self.distributions or Id is None:
+ raise NoSuchDistribution
+ if not IfMatch:
+ raise InvalidIfMatchVersion
+ if not DistributionConfig:
+ raise NoSuchDistribution
+ dist = self.distributions[Id]
+
+ aliases = DistributionConfig["Aliases"]["Items"]["CNAME"]
+ dist.distribution_config.config = DistributionConfig
+ dist.distribution_config.aliases = aliases
+ self.distributions[Id] = dist
+ dist.advance()
+ return dist, dist.location, dist.etag
+
cloudfront_backend = CloudFrontBackend()
diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py
index f9e335ddb..4888e00ad 100644
--- a/moto/cloudfront/responses.py
+++ b/moto/cloudfront/responses.py
@@ -48,6 +48,23 @@ class CloudFrontResponse(BaseResponse):
response = template.render(distribution=dist, xmlns=XMLNS)
return 200, {"ETag": etag}, response
+ def update_distribution(self, request, full_url, headers):
+ self.setup_class(request, full_url, headers)
+ params = self._get_xml_body()
+ distribution_config = params.get("DistributionConfig")
+ dist_id = full_url.split("/")[-2]
+ if_match = headers["If-Match"]
+
+ dist, location, e_tag = cloudfront_backend.update_distribution(
+ DistributionConfig=distribution_config,
+ Id=dist_id,
+ IfMatch=if_match,
+ )
+ template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE)
+ response = template.render(distribution=dist, xmlns=XMLNS)
+ headers = {"ETag": e_tag, "Location": location}
+ return 200, headers, response
+
DIST_META_TEMPLATE = """
{{ distribution.distribution_id }}
@@ -497,3 +514,13 @@ LIST_TEMPLATE = (
{% endif %}
"""
)
+
+UPDATE_DISTRIBUTION_TEMPLATE = (
+ """
+
+"""
+ + DISTRIBUTION_TEMPLATE
+ + """
+
+"""
+)
diff --git a/moto/cloudfront/urls.py b/moto/cloudfront/urls.py
index 83613be70..9f165d39b 100644
--- a/moto/cloudfront/urls.py
+++ b/moto/cloudfront/urls.py
@@ -1,13 +1,14 @@
"""cloudfront base URL and path."""
from .responses import CloudFrontResponse
+
response = CloudFrontResponse()
url_bases = [
r"https?://cloudfront\.amazonaws\.com",
]
-
url_paths = {
"{0}/2020-05-31/distribution$": response.distributions,
"{0}/2020-05-31/distribution/(?P[^/]+)$": response.individual_distribution,
+ "{0}/2020-05-31/distribution/(?P[^/]+)/config$": response.update_distribution,
}
diff --git a/tests/test_cloudfront/cloudfront_test_scaffolding.py b/tests/test_cloudfront/cloudfront_test_scaffolding.py
new file mode 100644
index 000000000..690d121d0
--- /dev/null
+++ b/tests/test_cloudfront/cloudfront_test_scaffolding.py
@@ -0,0 +1,27 @@
+# Example distribution config used in tests in both test_cloudfront.py
+# as well as test_cloudfront_distributions.py.
+
+
+def example_distribution_config(ref):
+ """Return a basic example distribution config for use in tests."""
+ return {
+ "CallerReference": ref,
+ "Origins": {
+ "Quantity": 1,
+ "Items": [
+ {
+ "Id": "origin1",
+ "DomainName": "asdf.s3.us-east-1.amazonaws.com",
+ "S3OriginConfig": {"OriginAccessIdentity": ""},
+ }
+ ],
+ },
+ "DefaultCacheBehavior": {
+ "TargetOriginId": "origin1",
+ "ViewerProtocolPolicy": "allow-all",
+ "MinTTL": 10,
+ "ForwardedValues": {"QueryString": False, "Cookies": {"Forward": "none"}},
+ },
+ "Comment": "an optional comment that's not actually optional",
+ "Enabled": False,
+ }
diff --git a/tests/test_cloudfront/test_cloudfront.py b/tests/test_cloudfront/test_cloudfront.py
new file mode 100644
index 000000000..612c775f7
--- /dev/null
+++ b/tests/test_cloudfront/test_cloudfront.py
@@ -0,0 +1,245 @@
+"""Unit tests for cloudfront-supported APIs."""
+import pytest
+import boto3
+from botocore.exceptions import ClientError, ParamValidationError
+from moto import mock_cloudfront
+import sure # noqa # pylint: disable=unused-import
+from moto.core import ACCOUNT_ID
+from . import cloudfront_test_scaffolding as scaffold
+
+# See our Development Tips on writing tests for hints on how to write good tests:
+# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
+
+
+@mock_cloudfront
+def test_update_distribution():
+ client = boto3.client("cloudfront", region_name="us-east-1")
+
+ # Create standard distribution
+ config = scaffold.example_distribution_config(ref="ref")
+ dist = client.create_distribution(DistributionConfig=config)
+ dist_id = dist["Distribution"]["Id"]
+ dist_etag = dist["ETag"]
+
+ dist_config = dist["Distribution"]["DistributionConfig"]
+ aliases = ["alias1", "alias2"]
+ dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases}
+
+ resp = client.update_distribution(
+ DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag
+ )
+
+ resp.should.have.key("Distribution")
+ distribution = resp["Distribution"]
+ distribution.should.have.key("Id")
+ distribution.should.have.key("ARN").equals(
+ f"arn:aws:cloudfront:{ACCOUNT_ID}:distribution/{distribution['Id']}"
+ )
+ distribution.should.have.key("Status").equals("Deployed")
+ distribution.should.have.key("LastModifiedTime")
+ distribution.should.have.key("InProgressInvalidationBatches").equals(0)
+ distribution.should.have.key("DomainName").should.contain(".cloudfront.net")
+
+ distribution.should.have.key("ActiveTrustedSigners")
+ signers = distribution["ActiveTrustedSigners"]
+ signers.should.have.key("Enabled").equals(False)
+ signers.should.have.key("Quantity").equals(0)
+
+ distribution.should.have.key("ActiveTrustedKeyGroups")
+ key_groups = distribution["ActiveTrustedKeyGroups"]
+ key_groups.should.have.key("Enabled").equals(False)
+ key_groups.should.have.key("Quantity").equals(0)
+
+ distribution.should.have.key("DistributionConfig")
+ config = distribution["DistributionConfig"]
+ config.should.have.key("CallerReference").should.equal("ref")
+
+ config.should.have.key("Aliases")
+ config["Aliases"].should.equal(dist_config["Aliases"])
+
+ config.should.have.key("Origins")
+ origins = config["Origins"]
+ origins.should.have.key("Quantity").equals(1)
+ origins.should.have.key("Items").length_of(1)
+ origin = origins["Items"][0]
+ origin.should.have.key("Id").equals("origin1")
+ origin.should.have.key("DomainName").equals("asdf.s3.us-east-1.amazonaws.com")
+ origin.should.have.key("OriginPath").equals("")
+
+ origin.should.have.key("CustomHeaders")
+ origin["CustomHeaders"].should.have.key("Quantity").equals(0)
+
+ origin.should.have.key("ConnectionAttempts").equals(3)
+ origin.should.have.key("ConnectionTimeout").equals(10)
+ origin.should.have.key("OriginShield").equals({"Enabled": False})
+
+ config.should.have.key("OriginGroups").equals({"Quantity": 0})
+
+ config.should.have.key("DefaultCacheBehavior")
+ default_cache = config["DefaultCacheBehavior"]
+ default_cache.should.have.key("TargetOriginId").should.equal("origin1")
+ default_cache.should.have.key("TrustedSigners")
+
+ signers = default_cache["TrustedSigners"]
+ signers.should.have.key("Enabled").equals(False)
+ signers.should.have.key("Quantity").equals(0)
+
+ default_cache.should.have.key("TrustedKeyGroups")
+ groups = default_cache["TrustedKeyGroups"]
+ groups.should.have.key("Enabled").equals(False)
+ groups.should.have.key("Quantity").equals(0)
+
+ default_cache.should.have.key("ViewerProtocolPolicy").equals("allow-all")
+
+ default_cache.should.have.key("AllowedMethods")
+ methods = default_cache["AllowedMethods"]
+ methods.should.have.key("Quantity").equals(2)
+ methods.should.have.key("Items")
+ set(methods["Items"]).should.equal({"HEAD", "GET"})
+
+ methods.should.have.key("CachedMethods")
+ cached_methods = methods["CachedMethods"]
+ cached_methods.should.have.key("Quantity").equals(2)
+ set(cached_methods["Items"]).should.equal({"HEAD", "GET"})
+
+ default_cache.should.have.key("SmoothStreaming").equals(False)
+ default_cache.should.have.key("Compress").equals(True)
+ default_cache.should.have.key("LambdaFunctionAssociations").equals({"Quantity": 0})
+ default_cache.should.have.key("FunctionAssociations").equals({"Quantity": 0})
+ default_cache.should.have.key("FieldLevelEncryptionId").equals("")
+ default_cache.should.have.key("CachePolicyId")
+
+ config.should.have.key("CacheBehaviors").equals({"Quantity": 0})
+ config.should.have.key("CustomErrorResponses").equals({"Quantity": 0})
+ config.should.have.key("Comment").equals(
+ "an optional comment that's not actually optional"
+ )
+
+ config.should.have.key("Logging")
+ logging = config["Logging"]
+ logging.should.have.key("Enabled").equals(False)
+ logging.should.have.key("IncludeCookies").equals(False)
+ logging.should.have.key("Bucket").equals("")
+ logging.should.have.key("Prefix").equals("")
+
+ config.should.have.key("PriceClass").equals("PriceClass_All")
+ config.should.have.key("Enabled").equals(False)
+ config.should.have.key("WebACLId")
+ config.should.have.key("HttpVersion").equals("http2")
+ config.should.have.key("IsIPV6Enabled").equals(True)
+
+ config.should.have.key("ViewerCertificate")
+ cert = config["ViewerCertificate"]
+ cert.should.have.key("CloudFrontDefaultCertificate").equals(True)
+ cert.should.have.key("MinimumProtocolVersion").equals("TLSv1")
+ cert.should.have.key("CertificateSource").equals("cloudfront")
+
+ config.should.have.key("Restrictions")
+ config["Restrictions"].should.have.key("GeoRestriction")
+ restriction = config["Restrictions"]["GeoRestriction"]
+ restriction.should.have.key("RestrictionType").equals("none")
+ restriction.should.have.key("Quantity").equals(0)
+
+
+@mock_cloudfront
+def test_update_distribution_no_such_distId():
+ client = boto3.client("cloudfront", region_name="us-east-1")
+
+ # Create standard distribution
+ config = scaffold.example_distribution_config(ref="ref")
+ dist = client.create_distribution(DistributionConfig=config)
+
+ # Make up a fake dist ID by reversing the actual ID
+ dist_id = dist["Distribution"]["Id"][::-1]
+ dist_etag = dist["ETag"]
+
+ dist_config = dist["Distribution"]["DistributionConfig"]
+ aliases = ["alias1", "alias2"]
+ dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases}
+
+ with pytest.raises(ClientError) as error:
+ client.update_distribution(
+ DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag
+ )
+
+ metadata = error.value.response["ResponseMetadata"]
+ metadata["HTTPStatusCode"].should.equal(404)
+ err = error.value.response["Error"]
+ err["Code"].should.equal("NoSuchDistribution")
+ err["Message"].should.equal("The specified distribution does not exist.")
+
+
+@mock_cloudfront
+def test_update_distribution_distId_is_None():
+ client = boto3.client("cloudfront", region_name="us-east-1")
+
+ # Create standard distribution
+ config = scaffold.example_distribution_config(ref="ref")
+ dist = client.create_distribution(DistributionConfig=config)
+
+ # Make up a fake dist ID by reversing the actual ID
+ dist_id = None
+ dist_etag = dist["ETag"]
+
+ dist_config = dist["Distribution"]["DistributionConfig"]
+ aliases = ["alias1", "alias2"]
+ dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases}
+
+ with pytest.raises(ParamValidationError) as error:
+ client.update_distribution(
+ DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag
+ )
+
+ typename = error.typename
+ typename.should.equal("ParamValidationError")
+ error_str = "botocore.exceptions.ParamValidationError: Parameter validation failed:\nInvalid type for parameter Id, value: None, type: , valid types: "
+ error.exconly().should.equal(error_str)
+
+
+@mock_cloudfront
+def test_update_distribution_IfMatch_not_set():
+ client = boto3.client("cloudfront", region_name="us-east-1")
+
+ # Create standard distribution
+ config = scaffold.example_distribution_config(ref="ref")
+ dist = client.create_distribution(DistributionConfig=config)
+
+ # Make up a fake dist ID by reversing the actual ID
+ dist_id = dist["Distribution"]["Id"]
+
+ dist_config = dist["Distribution"]["DistributionConfig"]
+ aliases = ["alias1", "alias2"]
+ dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases}
+
+ with pytest.raises(ClientError) as error:
+ client.update_distribution(
+ DistributionConfig=dist_config, Id=dist_id, IfMatch=""
+ )
+
+ metadata = error.value.response["ResponseMetadata"]
+ metadata["HTTPStatusCode"].should.equal(400)
+ err = error.value.response["Error"]
+ err["Code"].should.equal("InvalidIfMatchVersion")
+ msg = "The If-Match version is missing or not valid for the resource."
+ err["Message"].should.equal(msg)
+
+
+@mock_cloudfront
+def test_update_distribution_dist_config_not_set():
+ client = boto3.client("cloudfront", region_name="us-east-1")
+
+ # Create standard distribution
+ config = scaffold.example_distribution_config(ref="ref")
+ dist = client.create_distribution(DistributionConfig=config)
+
+ # Make up a fake dist ID by reversing the actual ID
+ dist_id = dist["Distribution"]["Id"]
+ dist_etag = dist["ETag"]
+
+ with pytest.raises(ParamValidationError) as error:
+ client.update_distribution(Id=dist_id, IfMatch=dist_etag)
+
+ typename = error.typename
+ typename.should.equal("ParamValidationError")
+ error_str = 'botocore.exceptions.ParamValidationError: Parameter validation failed:\nMissing required parameter in input: "DistributionConfig"'
+ error.exconly().should.equal(error_str)
diff --git a/tests/test_cloudfront/test_cloudfront_distributions.py b/tests/test_cloudfront/test_cloudfront_distributions.py
index 2986ed2ea..5986a8a9f 100644
--- a/tests/test_cloudfront/test_cloudfront_distributions.py
+++ b/tests/test_cloudfront/test_cloudfront_distributions.py
@@ -1,42 +1,17 @@
import boto3
-
-import pytest
-import sure # noqa # pylint: disable=unused-import
-
from botocore.exceptions import ClientError
from moto import mock_cloudfront
from moto.core import ACCOUNT_ID
-
-
-def example_distribution_config(ref):
- return {
- "CallerReference": ref,
- "Origins": {
- "Quantity": 1,
- "Items": [
- {
- "Id": "origin1",
- "DomainName": "asdf.s3.us-east-1.amazonaws.com",
- "S3OriginConfig": {"OriginAccessIdentity": ""},
- }
- ],
- },
- "DefaultCacheBehavior": {
- "TargetOriginId": "origin1",
- "ViewerProtocolPolicy": "allow-all",
- "MinTTL": 10,
- "ForwardedValues": {"QueryString": False, "Cookies": {"Forward": "none"}},
- },
- "Comment": "an optional comment that's not actually optional",
- "Enabled": False,
- }
+from . import cloudfront_test_scaffolding as scaffold
+import pytest
+import sure # noqa # pylint: disable=unused-import
@mock_cloudfront
def test_create_distribution_s3_minimum():
client = boto3.client("cloudfront", region_name="us-west-1")
+ config = scaffold.example_distribution_config("ref")
- config = example_distribution_config("ref")
resp = client.create_distribution(DistributionConfig=config)
resp.should.have.key("Distribution")
@@ -155,7 +130,7 @@ def test_create_distribution_s3_minimum():
def test_create_distribution_with_additional_fields():
client = boto3.client("cloudfront", region_name="us-west-1")
- config = example_distribution_config("ref")
+ config = scaffold.example_distribution_config("ref")
config["Aliases"] = {"Quantity": 2, "Items": ["alias1", "alias2"]}
resp = client.create_distribution(DistributionConfig=config)
distribution = resp["Distribution"]
@@ -170,7 +145,7 @@ def test_create_distribution_with_additional_fields():
def test_create_distribution_returns_etag():
client = boto3.client("cloudfront", region_name="us-east-1")
- config = example_distribution_config("ref")
+ config = scaffold.example_distribution_config("ref")
resp = client.create_distribution(DistributionConfig=config)
dist_id = resp["Distribution"]["Id"]
@@ -186,7 +161,7 @@ def test_create_distribution_needs_unique_caller_reference():
client = boto3.client("cloudfront", region_name="us-east-1")
# Create standard distribution
- config = example_distribution_config(ref="ref")
+ config = scaffold.example_distribution_config(ref="ref")
dist1 = client.create_distribution(DistributionConfig=config)
dist1_id = dist1["Distribution"]["Id"]
@@ -200,7 +175,7 @@ def test_create_distribution_needs_unique_caller_reference():
)
# Creating another distribution with a different reference
- config = example_distribution_config(ref="ref2")
+ config = scaffold.example_distribution_config(ref="ref2")
dist2 = client.create_distribution(DistributionConfig=config)
dist1_id.shouldnt.equal(dist2["Distribution"]["Id"])
@@ -320,9 +295,9 @@ def test_list_distributions_without_any():
def test_list_distributions():
client = boto3.client("cloudfront", region_name="us-east-1")
- config = example_distribution_config(ref="ref1")
+ config = scaffold.example_distribution_config(ref="ref1")
dist1 = client.create_distribution(DistributionConfig=config)["Distribution"]
- config = example_distribution_config(ref="ref2")
+ config = scaffold.example_distribution_config(ref="ref2")
dist2 = client.create_distribution(DistributionConfig=config)["Distribution"]
resp = client.list_distributions()
@@ -347,7 +322,7 @@ def test_get_distribution():
client = boto3.client("cloudfront", region_name="us-east-1")
# Create standard distribution
- config = example_distribution_config(ref="ref")
+ config = scaffold.example_distribution_config(ref="ref")
dist = client.create_distribution(DistributionConfig=config)
dist_id = dist["Distribution"]["Id"]
@@ -420,7 +395,7 @@ def test_delete_distribution_random_etag():
client = boto3.client("cloudfront", region_name="us-east-1")
# Create standard distribution
- config = example_distribution_config(ref="ref")
+ config = scaffold.example_distribution_config(ref="ref")
dist1 = client.create_distribution(DistributionConfig=config)
dist_id = dist1["Distribution"]["Id"]