diff --git a/docs/docs/services/cloudfront.rst b/docs/docs/services/cloudfront.rst index 3dce697b7..ff6ff703e 100644 --- a/docs/docs/services/cloudfront.rst +++ b/docs/docs/services/cloudfront.rst @@ -121,7 +121,15 @@ cloudfront - [ ] untag_resource - [ ] update_cache_policy - [ ] update_cloud_front_origin_access_identity -- [ ] update_distribution +- [X] update_distribution + + The IfMatch-value is ignored - any value is considered valid. + Calling this function without a value is invalid, per AWS' behaviour + + This implementation is immature, and tests the basic + functionality of updating an exisint distribution with very + simple changes. + - [ ] update_field_level_encryption_config - [ ] update_field_level_encryption_profile - [ ] update_function diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py index 03861c901..9fc087d29 100644 --- a/moto/cloudfront/models.py +++ b/moto/cloudfront/models.py @@ -180,8 +180,10 @@ class CloudFrontBackend(BaseBackend): def create_distribution(self, distribution_config): """ - This has been tested against an S3-distribution with the simplest possible configuration. - Please raise an issue if we're not persisting/returning the correct attributes for your use-case. + This has been tested against an S3-distribution with the + simplest possible configuration. Please raise an issue if + we're not persisting/returning the correct attributes for your + use-case. """ dist = Distribution(distribution_config) caller_reference = dist.distribution_config.caller_reference @@ -224,5 +226,25 @@ class CloudFrontBackend(BaseBackend): return dist return False + def update_distribution(self, DistributionConfig, Id, IfMatch): + """ + The IfMatch-value is ignored - any value is considered valid. + Calling this function without a value is invalid, per AWS' behaviour + """ + if Id not in self.distributions or Id is None: + raise NoSuchDistribution + if not IfMatch: + raise InvalidIfMatchVersion + if not DistributionConfig: + raise NoSuchDistribution + dist = self.distributions[Id] + + aliases = DistributionConfig["Aliases"]["Items"]["CNAME"] + dist.distribution_config.config = DistributionConfig + dist.distribution_config.aliases = aliases + self.distributions[Id] = dist + dist.advance() + return dist, dist.location, dist.etag + cloudfront_backend = CloudFrontBackend() diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py index f9e335ddb..4888e00ad 100644 --- a/moto/cloudfront/responses.py +++ b/moto/cloudfront/responses.py @@ -48,6 +48,23 @@ class CloudFrontResponse(BaseResponse): response = template.render(distribution=dist, xmlns=XMLNS) return 200, {"ETag": etag}, response + def update_distribution(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + params = self._get_xml_body() + distribution_config = params.get("DistributionConfig") + dist_id = full_url.split("/")[-2] + if_match = headers["If-Match"] + + dist, location, e_tag = cloudfront_backend.update_distribution( + DistributionConfig=distribution_config, + Id=dist_id, + IfMatch=if_match, + ) + template = self.response_template(UPDATE_DISTRIBUTION_TEMPLATE) + response = template.render(distribution=dist, xmlns=XMLNS) + headers = {"ETag": e_tag, "Location": location} + return 200, headers, response + DIST_META_TEMPLATE = """ {{ distribution.distribution_id }} @@ -497,3 +514,13 @@ LIST_TEMPLATE = ( {% endif %} """ ) + +UPDATE_DISTRIBUTION_TEMPLATE = ( + """ + +""" + + DISTRIBUTION_TEMPLATE + + """ + +""" +) diff --git a/moto/cloudfront/urls.py b/moto/cloudfront/urls.py index 83613be70..9f165d39b 100644 --- a/moto/cloudfront/urls.py +++ b/moto/cloudfront/urls.py @@ -1,13 +1,14 @@ """cloudfront base URL and path.""" from .responses import CloudFrontResponse + response = CloudFrontResponse() url_bases = [ r"https?://cloudfront\.amazonaws\.com", ] - url_paths = { "{0}/2020-05-31/distribution$": response.distributions, "{0}/2020-05-31/distribution/(?P[^/]+)$": response.individual_distribution, + "{0}/2020-05-31/distribution/(?P[^/]+)/config$": response.update_distribution, } diff --git a/tests/test_cloudfront/cloudfront_test_scaffolding.py b/tests/test_cloudfront/cloudfront_test_scaffolding.py new file mode 100644 index 000000000..690d121d0 --- /dev/null +++ b/tests/test_cloudfront/cloudfront_test_scaffolding.py @@ -0,0 +1,27 @@ +# Example distribution config used in tests in both test_cloudfront.py +# as well as test_cloudfront_distributions.py. + + +def example_distribution_config(ref): + """Return a basic example distribution config for use in tests.""" + return { + "CallerReference": ref, + "Origins": { + "Quantity": 1, + "Items": [ + { + "Id": "origin1", + "DomainName": "asdf.s3.us-east-1.amazonaws.com", + "S3OriginConfig": {"OriginAccessIdentity": ""}, + } + ], + }, + "DefaultCacheBehavior": { + "TargetOriginId": "origin1", + "ViewerProtocolPolicy": "allow-all", + "MinTTL": 10, + "ForwardedValues": {"QueryString": False, "Cookies": {"Forward": "none"}}, + }, + "Comment": "an optional comment that's not actually optional", + "Enabled": False, + } diff --git a/tests/test_cloudfront/test_cloudfront.py b/tests/test_cloudfront/test_cloudfront.py new file mode 100644 index 000000000..612c775f7 --- /dev/null +++ b/tests/test_cloudfront/test_cloudfront.py @@ -0,0 +1,245 @@ +"""Unit tests for cloudfront-supported APIs.""" +import pytest +import boto3 +from botocore.exceptions import ClientError, ParamValidationError +from moto import mock_cloudfront +import sure # noqa # pylint: disable=unused-import +from moto.core import ACCOUNT_ID +from . import cloudfront_test_scaffolding as scaffold + +# See our Development Tips on writing tests for hints on how to write good tests: +# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html + + +@mock_cloudfront +def test_update_distribution(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + dist_id = dist["Distribution"]["Id"] + dist_etag = dist["ETag"] + + dist_config = dist["Distribution"]["DistributionConfig"] + aliases = ["alias1", "alias2"] + dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases} + + resp = client.update_distribution( + DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag + ) + + resp.should.have.key("Distribution") + distribution = resp["Distribution"] + distribution.should.have.key("Id") + distribution.should.have.key("ARN").equals( + f"arn:aws:cloudfront:{ACCOUNT_ID}:distribution/{distribution['Id']}" + ) + distribution.should.have.key("Status").equals("Deployed") + distribution.should.have.key("LastModifiedTime") + distribution.should.have.key("InProgressInvalidationBatches").equals(0) + distribution.should.have.key("DomainName").should.contain(".cloudfront.net") + + distribution.should.have.key("ActiveTrustedSigners") + signers = distribution["ActiveTrustedSigners"] + signers.should.have.key("Enabled").equals(False) + signers.should.have.key("Quantity").equals(0) + + distribution.should.have.key("ActiveTrustedKeyGroups") + key_groups = distribution["ActiveTrustedKeyGroups"] + key_groups.should.have.key("Enabled").equals(False) + key_groups.should.have.key("Quantity").equals(0) + + distribution.should.have.key("DistributionConfig") + config = distribution["DistributionConfig"] + config.should.have.key("CallerReference").should.equal("ref") + + config.should.have.key("Aliases") + config["Aliases"].should.equal(dist_config["Aliases"]) + + config.should.have.key("Origins") + origins = config["Origins"] + origins.should.have.key("Quantity").equals(1) + origins.should.have.key("Items").length_of(1) + origin = origins["Items"][0] + origin.should.have.key("Id").equals("origin1") + origin.should.have.key("DomainName").equals("asdf.s3.us-east-1.amazonaws.com") + origin.should.have.key("OriginPath").equals("") + + origin.should.have.key("CustomHeaders") + origin["CustomHeaders"].should.have.key("Quantity").equals(0) + + origin.should.have.key("ConnectionAttempts").equals(3) + origin.should.have.key("ConnectionTimeout").equals(10) + origin.should.have.key("OriginShield").equals({"Enabled": False}) + + config.should.have.key("OriginGroups").equals({"Quantity": 0}) + + config.should.have.key("DefaultCacheBehavior") + default_cache = config["DefaultCacheBehavior"] + default_cache.should.have.key("TargetOriginId").should.equal("origin1") + default_cache.should.have.key("TrustedSigners") + + signers = default_cache["TrustedSigners"] + signers.should.have.key("Enabled").equals(False) + signers.should.have.key("Quantity").equals(0) + + default_cache.should.have.key("TrustedKeyGroups") + groups = default_cache["TrustedKeyGroups"] + groups.should.have.key("Enabled").equals(False) + groups.should.have.key("Quantity").equals(0) + + default_cache.should.have.key("ViewerProtocolPolicy").equals("allow-all") + + default_cache.should.have.key("AllowedMethods") + methods = default_cache["AllowedMethods"] + methods.should.have.key("Quantity").equals(2) + methods.should.have.key("Items") + set(methods["Items"]).should.equal({"HEAD", "GET"}) + + methods.should.have.key("CachedMethods") + cached_methods = methods["CachedMethods"] + cached_methods.should.have.key("Quantity").equals(2) + set(cached_methods["Items"]).should.equal({"HEAD", "GET"}) + + default_cache.should.have.key("SmoothStreaming").equals(False) + default_cache.should.have.key("Compress").equals(True) + default_cache.should.have.key("LambdaFunctionAssociations").equals({"Quantity": 0}) + default_cache.should.have.key("FunctionAssociations").equals({"Quantity": 0}) + default_cache.should.have.key("FieldLevelEncryptionId").equals("") + default_cache.should.have.key("CachePolicyId") + + config.should.have.key("CacheBehaviors").equals({"Quantity": 0}) + config.should.have.key("CustomErrorResponses").equals({"Quantity": 0}) + config.should.have.key("Comment").equals( + "an optional comment that's not actually optional" + ) + + config.should.have.key("Logging") + logging = config["Logging"] + logging.should.have.key("Enabled").equals(False) + logging.should.have.key("IncludeCookies").equals(False) + logging.should.have.key("Bucket").equals("") + logging.should.have.key("Prefix").equals("") + + config.should.have.key("PriceClass").equals("PriceClass_All") + config.should.have.key("Enabled").equals(False) + config.should.have.key("WebACLId") + config.should.have.key("HttpVersion").equals("http2") + config.should.have.key("IsIPV6Enabled").equals(True) + + config.should.have.key("ViewerCertificate") + cert = config["ViewerCertificate"] + cert.should.have.key("CloudFrontDefaultCertificate").equals(True) + cert.should.have.key("MinimumProtocolVersion").equals("TLSv1") + cert.should.have.key("CertificateSource").equals("cloudfront") + + config.should.have.key("Restrictions") + config["Restrictions"].should.have.key("GeoRestriction") + restriction = config["Restrictions"]["GeoRestriction"] + restriction.should.have.key("RestrictionType").equals("none") + restriction.should.have.key("Quantity").equals(0) + + +@mock_cloudfront +def test_update_distribution_no_such_distId(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + + # Make up a fake dist ID by reversing the actual ID + dist_id = dist["Distribution"]["Id"][::-1] + dist_etag = dist["ETag"] + + dist_config = dist["Distribution"]["DistributionConfig"] + aliases = ["alias1", "alias2"] + dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases} + + with pytest.raises(ClientError) as error: + client.update_distribution( + DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag + ) + + metadata = error.value.response["ResponseMetadata"] + metadata["HTTPStatusCode"].should.equal(404) + err = error.value.response["Error"] + err["Code"].should.equal("NoSuchDistribution") + err["Message"].should.equal("The specified distribution does not exist.") + + +@mock_cloudfront +def test_update_distribution_distId_is_None(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + + # Make up a fake dist ID by reversing the actual ID + dist_id = None + dist_etag = dist["ETag"] + + dist_config = dist["Distribution"]["DistributionConfig"] + aliases = ["alias1", "alias2"] + dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases} + + with pytest.raises(ParamValidationError) as error: + client.update_distribution( + DistributionConfig=dist_config, Id=dist_id, IfMatch=dist_etag + ) + + typename = error.typename + typename.should.equal("ParamValidationError") + error_str = "botocore.exceptions.ParamValidationError: Parameter validation failed:\nInvalid type for parameter Id, value: None, type: , valid types: " + error.exconly().should.equal(error_str) + + +@mock_cloudfront +def test_update_distribution_IfMatch_not_set(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + + # Make up a fake dist ID by reversing the actual ID + dist_id = dist["Distribution"]["Id"] + + dist_config = dist["Distribution"]["DistributionConfig"] + aliases = ["alias1", "alias2"] + dist_config["Aliases"] = {"Quantity": len(aliases), "Items": aliases} + + with pytest.raises(ClientError) as error: + client.update_distribution( + DistributionConfig=dist_config, Id=dist_id, IfMatch="" + ) + + metadata = error.value.response["ResponseMetadata"] + metadata["HTTPStatusCode"].should.equal(400) + err = error.value.response["Error"] + err["Code"].should.equal("InvalidIfMatchVersion") + msg = "The If-Match version is missing or not valid for the resource." + err["Message"].should.equal(msg) + + +@mock_cloudfront +def test_update_distribution_dist_config_not_set(): + client = boto3.client("cloudfront", region_name="us-east-1") + + # Create standard distribution + config = scaffold.example_distribution_config(ref="ref") + dist = client.create_distribution(DistributionConfig=config) + + # Make up a fake dist ID by reversing the actual ID + dist_id = dist["Distribution"]["Id"] + dist_etag = dist["ETag"] + + with pytest.raises(ParamValidationError) as error: + client.update_distribution(Id=dist_id, IfMatch=dist_etag) + + typename = error.typename + typename.should.equal("ParamValidationError") + error_str = 'botocore.exceptions.ParamValidationError: Parameter validation failed:\nMissing required parameter in input: "DistributionConfig"' + error.exconly().should.equal(error_str) diff --git a/tests/test_cloudfront/test_cloudfront_distributions.py b/tests/test_cloudfront/test_cloudfront_distributions.py index 2986ed2ea..5986a8a9f 100644 --- a/tests/test_cloudfront/test_cloudfront_distributions.py +++ b/tests/test_cloudfront/test_cloudfront_distributions.py @@ -1,42 +1,17 @@ import boto3 - -import pytest -import sure # noqa # pylint: disable=unused-import - from botocore.exceptions import ClientError from moto import mock_cloudfront from moto.core import ACCOUNT_ID - - -def example_distribution_config(ref): - return { - "CallerReference": ref, - "Origins": { - "Quantity": 1, - "Items": [ - { - "Id": "origin1", - "DomainName": "asdf.s3.us-east-1.amazonaws.com", - "S3OriginConfig": {"OriginAccessIdentity": ""}, - } - ], - }, - "DefaultCacheBehavior": { - "TargetOriginId": "origin1", - "ViewerProtocolPolicy": "allow-all", - "MinTTL": 10, - "ForwardedValues": {"QueryString": False, "Cookies": {"Forward": "none"}}, - }, - "Comment": "an optional comment that's not actually optional", - "Enabled": False, - } +from . import cloudfront_test_scaffolding as scaffold +import pytest +import sure # noqa # pylint: disable=unused-import @mock_cloudfront def test_create_distribution_s3_minimum(): client = boto3.client("cloudfront", region_name="us-west-1") + config = scaffold.example_distribution_config("ref") - config = example_distribution_config("ref") resp = client.create_distribution(DistributionConfig=config) resp.should.have.key("Distribution") @@ -155,7 +130,7 @@ def test_create_distribution_s3_minimum(): def test_create_distribution_with_additional_fields(): client = boto3.client("cloudfront", region_name="us-west-1") - config = example_distribution_config("ref") + config = scaffold.example_distribution_config("ref") config["Aliases"] = {"Quantity": 2, "Items": ["alias1", "alias2"]} resp = client.create_distribution(DistributionConfig=config) distribution = resp["Distribution"] @@ -170,7 +145,7 @@ def test_create_distribution_with_additional_fields(): def test_create_distribution_returns_etag(): client = boto3.client("cloudfront", region_name="us-east-1") - config = example_distribution_config("ref") + config = scaffold.example_distribution_config("ref") resp = client.create_distribution(DistributionConfig=config) dist_id = resp["Distribution"]["Id"] @@ -186,7 +161,7 @@ def test_create_distribution_needs_unique_caller_reference(): client = boto3.client("cloudfront", region_name="us-east-1") # Create standard distribution - config = example_distribution_config(ref="ref") + config = scaffold.example_distribution_config(ref="ref") dist1 = client.create_distribution(DistributionConfig=config) dist1_id = dist1["Distribution"]["Id"] @@ -200,7 +175,7 @@ def test_create_distribution_needs_unique_caller_reference(): ) # Creating another distribution with a different reference - config = example_distribution_config(ref="ref2") + config = scaffold.example_distribution_config(ref="ref2") dist2 = client.create_distribution(DistributionConfig=config) dist1_id.shouldnt.equal(dist2["Distribution"]["Id"]) @@ -320,9 +295,9 @@ def test_list_distributions_without_any(): def test_list_distributions(): client = boto3.client("cloudfront", region_name="us-east-1") - config = example_distribution_config(ref="ref1") + config = scaffold.example_distribution_config(ref="ref1") dist1 = client.create_distribution(DistributionConfig=config)["Distribution"] - config = example_distribution_config(ref="ref2") + config = scaffold.example_distribution_config(ref="ref2") dist2 = client.create_distribution(DistributionConfig=config)["Distribution"] resp = client.list_distributions() @@ -347,7 +322,7 @@ def test_get_distribution(): client = boto3.client("cloudfront", region_name="us-east-1") # Create standard distribution - config = example_distribution_config(ref="ref") + config = scaffold.example_distribution_config(ref="ref") dist = client.create_distribution(DistributionConfig=config) dist_id = dist["Distribution"]["Id"] @@ -420,7 +395,7 @@ def test_delete_distribution_random_etag(): client = boto3.client("cloudfront", region_name="us-east-1") # Create standard distribution - config = example_distribution_config(ref="ref") + config = scaffold.example_distribution_config(ref="ref") dist1 = client.create_distribution(DistributionConfig=config) dist_id = dist1["Distribution"]["Id"]