Cloudfront: Origin Access Control (#6453)

This commit is contained in:
Bert Blommers 2023-06-28 13:37:45 +00:00 committed by GitHub
parent 5c12416492
commit 3056ba95b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 247 additions and 14 deletions

View File

@ -738,7 +738,7 @@
## cloudfront
<details>
<summary>9% implemented</summary>
<summary>14% implemented</summary>
- [ ] associate_alias
- [ ] copy_distribution
@ -753,7 +753,7 @@
- [X] create_invalidation
- [ ] create_key_group
- [ ] create_monitoring_subscription
- [ ] create_origin_access_control
- [X] create_origin_access_control
- [ ] create_origin_request_policy
- [ ] create_public_key
- [ ] create_realtime_log_config
@ -769,7 +769,7 @@
- [ ] delete_function
- [ ] delete_key_group
- [ ] delete_monitoring_subscription
- [ ] delete_origin_access_control
- [X] delete_origin_access_control
- [ ] delete_origin_request_policy
- [ ] delete_public_key
- [ ] delete_realtime_log_config
@ -793,7 +793,7 @@
- [ ] get_key_group
- [ ] get_key_group_config
- [ ] get_monitoring_subscription
- [ ] get_origin_access_control
- [X] get_origin_access_control
- [ ] get_origin_access_control_config
- [ ] get_origin_request_policy
- [ ] get_origin_request_policy_config
@ -820,7 +820,7 @@
- [ ] list_functions
- [X] list_invalidations
- [ ] list_key_groups
- [ ] list_origin_access_controls
- [X] list_origin_access_controls
- [ ] list_origin_request_policies
- [ ] list_public_keys
- [ ] list_realtime_log_configs
@ -840,7 +840,7 @@
- [ ] update_field_level_encryption_profile
- [ ] update_function
- [ ] update_key_group
- [ ] update_origin_access_control
- [X] update_origin_access_control
- [ ] update_origin_request_policy
- [ ] update_public_key
- [ ] update_realtime_log_config

View File

@ -44,7 +44,7 @@ cloudfront
- [X] create_invalidation
- [ ] create_key_group
- [ ] create_monitoring_subscription
- [ ] create_origin_access_control
- [X] create_origin_access_control
- [ ] create_origin_request_policy
- [ ] create_public_key
- [ ] create_realtime_log_config
@ -65,7 +65,11 @@ cloudfront
- [ ] delete_function
- [ ] delete_key_group
- [ ] delete_monitoring_subscription
- [ ] delete_origin_access_control
- [X] delete_origin_access_control
The IfMatch-parameter is not yet implemented
- [ ] delete_origin_request_policy
- [ ] delete_public_key
- [ ] delete_realtime_log_config
@ -89,7 +93,7 @@ cloudfront
- [ ] get_key_group
- [ ] get_key_group_config
- [ ] get_monitoring_subscription
- [ ] get_origin_access_control
- [X] get_origin_access_control
- [ ] get_origin_access_control_config
- [ ] get_origin_request_policy
- [ ] get_origin_request_policy_config
@ -124,7 +128,11 @@ cloudfront
- [ ] list_key_groups
- [ ] list_origin_access_controls
- [X] list_origin_access_controls
Pagination is not yet implemented
- [ ] list_origin_request_policies
- [ ] list_public_keys
- [ ] list_realtime_log_configs
@ -149,7 +157,11 @@ cloudfront
- [ ] update_field_level_encryption_profile
- [ ] update_function
- [ ] update_key_group
- [ ] update_origin_access_control
- [X] update_origin_access_control
The IfMatch-parameter is not yet implemented
- [ ] update_origin_request_policy
- [ ] update_public_key
- [ ] update_realtime_log_config

View File

@ -13,7 +13,6 @@ EXCEPTION_RESPONSE = """<?xml version="1.0"?>
class CloudFrontException(RESTError):
code = 400
def __init__(self, error_type: str, message: str, **kwargs: Any):
@ -23,7 +22,6 @@ class CloudFrontException(RESTError):
class OriginDoesNotExist(CloudFrontException):
code = 404
def __init__(self) -> None:
@ -66,10 +64,19 @@ class InvalidIfMatchVersion(CloudFrontException):
class NoSuchDistribution(CloudFrontException):
code = 404
def __init__(self) -> None:
super().__init__(
"NoSuchDistribution", message="The specified distribution does not exist."
)
class NoSuchOriginAccessControl(CloudFrontException):
code = 404
def __init__(self) -> None:
super().__init__(
"NoSuchOriginAccessControl",
message="The specified origin access control does not exist.",
)

View File

@ -16,6 +16,7 @@ from .exceptions import (
DistributionAlreadyExists,
InvalidIfMatchVersion,
NoSuchDistribution,
NoSuchOriginAccessControl,
)
@ -212,6 +213,29 @@ class Distribution(BaseModel, ManagedState):
return f"https://cloudfront.amazonaws.com/2020-05-31/distribution/{self.distribution_id}"
class OriginAccessControl(BaseModel):
def __init__(self, config_dict: Dict[str, str]):
self.id = Invalidation.random_id()
self.name = config_dict.get("Name")
self.description = config_dict.get("Description")
self.signing_protocol = config_dict.get("SigningProtocol")
self.signing_behaviour = config_dict.get("SigningBehavior")
self.origin_type = config_dict.get("OriginAccessControlOriginType")
self.etag = Invalidation.random_id()
def update(self, config: Dict[str, str]) -> None:
if "Name" in config:
self.name = config["Name"]
if "Description" in config:
self.description = config["Description"]
if "SigningProtocol" in config:
self.signing_protocol = config["SigningProtocol"]
if "SigningBehavior" in config:
self.signing_behaviour = config["SigningBehavior"]
if "OriginAccessControlOriginType" in config:
self.origin_type = config["OriginAccessControlOriginType"]
class Invalidation(BaseModel):
@staticmethod
def random_id(uppercase: bool = True) -> str:
@ -243,6 +267,7 @@ class CloudFrontBackend(BaseBackend):
super().__init__(region_name, account_id)
self.distributions: Dict[str, Distribution] = dict()
self.invalidations: Dict[str, List[Invalidation]] = dict()
self.origin_access_controls: Dict[str, OriginAccessControl] = dict()
self.tagger = TaggingService()
state_manager.register_default_transition(
@ -363,6 +388,40 @@ class CloudFrontBackend(BaseBackend):
def list_tags_for_resource(self, resource: str) -> Dict[str, List[Dict[str, str]]]:
return self.tagger.list_tags_for_resource(resource)
def create_origin_access_control(
self, config_dict: Dict[str, str]
) -> OriginAccessControl:
control = OriginAccessControl(config_dict)
self.origin_access_controls[control.id] = control
return control
def get_origin_access_control(self, control_id: str) -> OriginAccessControl:
if control_id not in self.origin_access_controls:
raise NoSuchOriginAccessControl
return self.origin_access_controls[control_id]
def update_origin_access_control(
self, control_id: str, config: Dict[str, str]
) -> OriginAccessControl:
"""
The IfMatch-parameter is not yet implemented
"""
control = self.get_origin_access_control(control_id)
control.update(config)
return control
def list_origin_access_controls(self) -> Iterable[OriginAccessControl]:
"""
Pagination is not yet implemented
"""
return self.origin_access_controls.values()
def delete_origin_access_control(self, control_id: str) -> None:
"""
The IfMatch-parameter is not yet implemented
"""
self.origin_access_controls.pop(control_id)
cloudfront_backends = BackendDict(
CloudFrontBackend,

View File

@ -39,6 +39,22 @@ class CloudFrontResponse(BaseResponse):
if request.method == "GET":
return self.list_tags_for_resource()
def origin_access_controls(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self.create_origin_access_control()
if request.method == "GET":
return self.list_origin_access_controls()
def origin_access_control(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.get_origin_access_control()
if request.method == "PUT":
return self.update_origin_access_control()
if request.method == "DELETE":
return self.delete_origin_access_control()
def create_distribution(self) -> TYPE_RESPONSE:
params = self._get_xml_body()
if "DistributionConfigWithTags" in params:
@ -130,6 +146,37 @@ class CloudFrontResponse(BaseResponse):
response = template.render(tags=tags, xmlns=XMLNS)
return 200, {}, response
def create_origin_access_control(self) -> TYPE_RESPONSE:
config = self._get_xml_body().get("OriginAccessControlConfig", {})
config.pop("@xmlns", None)
control = self.backend.create_origin_access_control(config)
template = self.response_template(ORIGIN_ACCESS_CONTROl)
return 200, {}, template.render(control=control)
def get_origin_access_control(self) -> TYPE_RESPONSE:
control_id = self.path.split("/")[-1]
control = self.backend.get_origin_access_control(control_id)
template = self.response_template(ORIGIN_ACCESS_CONTROl)
return 200, {"ETag": control.etag}, template.render(control=control)
def list_origin_access_controls(self) -> TYPE_RESPONSE:
controls = self.backend.list_origin_access_controls()
template = self.response_template(LIST_ORIGIN_ACCESS_CONTROl)
return 200, {}, template.render(controls=controls)
def update_origin_access_control(self) -> TYPE_RESPONSE:
control_id = self.path.split("/")[-2]
config = self._get_xml_body().get("OriginAccessControlConfig", {})
config.pop("@xmlns", None)
control = self.backend.update_origin_access_control(control_id, config)
template = self.response_template(ORIGIN_ACCESS_CONTROl)
return 200, {"ETag": control.etag}, template.render(control=control)
def delete_origin_access_control(self) -> TYPE_RESPONSE:
control_id = self.path.split("/")[-1]
self.backend.delete_origin_access_control(control_id)
return 200, {}, "{}"
DIST_META_TEMPLATE = """
<Id>{{ distribution.distribution_id }}</Id>
@ -651,3 +698,39 @@ TAGS_TEMPLATE = """<?xml version="1.0"?>
</Items>
</Tags>
"""
ORIGIN_ACCESS_CONTROl = """<?xml version="1.0"?>
<OriginAccessControl>
<Id>{{ control.id }}</Id>
<OriginAccessControlConfig>
<Name>{{ control.name }}</Name>
{% if control.description %}
<Description>{{ control.description }}</Description>
{% endif %}
<SigningProtocol>{{ control.signing_protocol }}</SigningProtocol>
<SigningBehavior>{{ control.signing_behaviour }}</SigningBehavior>
<OriginAccessControlOriginType>{{ control.origin_type }}</OriginAccessControlOriginType>
</OriginAccessControlConfig>
</OriginAccessControl>
"""
LIST_ORIGIN_ACCESS_CONTROl = """<?xml version="1.0"?>
<OriginAccessControlList>
<Items>
{% for control in controls %}
<OriginAccessControlSummary>
<Id>{{ control.id }}</Id>
<Name>{{ control.name }}</Name>
{% if control.description %}
<Description>{{ control.description }}</Description>
{% endif %}
<SigningProtocol>{{ control.signing_protocol }}</SigningProtocol>
<SigningBehavior>{{ control.signing_behaviour }}</SigningBehavior>
<OriginAccessControlOriginType>{{ control.origin_type }}</OriginAccessControlOriginType>
</OriginAccessControlSummary>
{% endfor %}
</Items>
</OriginAccessControlList>
"""

View File

@ -13,4 +13,7 @@ url_paths = {
"{0}/2020-05-31/distribution/(?P<distribution_id>[^/]+)/config$": response.update_distribution,
"{0}/2020-05-31/distribution/(?P<distribution_id>[^/]+)/invalidation": response.invalidation,
"{0}/2020-05-31/tagging$": response.tags,
"{0}/2020-05-31/origin-access-control$": response.origin_access_controls,
"{0}/2020-05-31/origin-access-control/(?P<oac_id>[^/]+)$": response.origin_access_control,
"{0}/2020-05-31/origin-access-control/(?P<oac_id>[^/]+)/config$": response.origin_access_control,
}

View File

@ -85,6 +85,7 @@ cloudformation:
cloudfront:
- TestAccCloudFrontDistributionDataSource_basic
- TestAccCloudFrontDistribution_isIPV6Enabled
- TestAccCloudFrontOriginAccessControl_
cloudtrail:
- TestAccCloudTrailServiceAccount
cloudwatch:

View File

@ -0,0 +1,68 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_cloudfront
@mock_cloudfront
def test_create_origin_access_control():
cf = boto3.client("cloudfront", "us-east-1")
oac_list = cf.list_origin_access_controls()["OriginAccessControlList"]
assert oac_list["Items"] == []
oac_input = {
"Name": "my_oac",
"SigningProtocol": "sigv4",
"SigningBehavior": "always",
"OriginAccessControlOriginType": "s3",
}
resp = cf.create_origin_access_control(OriginAccessControlConfig=oac_input)[
"OriginAccessControl"
]
control_id = resp.pop("Id")
assert control_id is not None
assert resp["OriginAccessControlConfig"] == oac_input
resp = cf.get_origin_access_control(Id=control_id)["OriginAccessControl"]
assert resp.pop("Id") is not None
assert resp["OriginAccessControlConfig"] == oac_input
oac_list = cf.list_origin_access_controls()["OriginAccessControlList"]
assert oac_list["Items"][0].pop("Id") == control_id
assert oac_list["Items"][0] == oac_input
cf.delete_origin_access_control(Id=control_id)
oac_list = cf.list_origin_access_controls()["OriginAccessControlList"]
assert oac_list["Items"] == []
with pytest.raises(ClientError) as exc:
cf.get_origin_access_control(Id=control_id)
err = exc.value.response["Error"]
assert err["Code"] == "NoSuchOriginAccessControl"
assert err["Message"] == "The specified origin access control does not exist."
@mock_cloudfront
def test_update_origin_access_control():
# http://localhost:5000/2020-05-31/origin-access-control/DE53MREVCPIFL/config
cf = boto3.client("cloudfront", "us-east-1")
oac_input = {
"Name": "my_oac",
"SigningProtocol": "sigv4",
"SigningBehavior": "always",
"OriginAccessControlOriginType": "s3",
}
resp = cf.create_origin_access_control(OriginAccessControlConfig=oac_input)[
"OriginAccessControl"
]
control_id = resp.pop("Id")
oac_input["Description"] = "updated"
control = cf.update_origin_access_control(
Id=control_id, OriginAccessControlConfig=oac_input
)["OriginAccessControl"]
assert control["Id"] == control_id
assert control["OriginAccessControlConfig"] == oac_input