TechDebt: MyPy CloudFront (#5612)

This commit is contained in:
Bert Blommers 2022-10-29 13:26:49 +00:00 committed by GitHub
parent a370362143
commit b17a792f1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 82 deletions

View File

@ -1,4 +1,5 @@
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from typing import Any
EXCEPTION_RESPONSE = """<?xml version="1.0"?> EXCEPTION_RESPONSE = """<?xml version="1.0"?>
<ErrorResponse xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/"> <ErrorResponse xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/">
@ -15,57 +16,52 @@ class CloudFrontException(RESTError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self, error_type: str, message: str, **kwargs: Any):
kwargs.setdefault("template", "cferror") kwargs.setdefault("template", "cferror")
self.templates["cferror"] = EXCEPTION_RESPONSE self.templates["cferror"] = EXCEPTION_RESPONSE
super().__init__(*args, **kwargs) super().__init__(error_type, message, **kwargs)
class OriginDoesNotExist(CloudFrontException): class OriginDoesNotExist(CloudFrontException):
code = 404 code = 404
def __init__(self, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"NoSuchOrigin", "NoSuchOrigin",
message="One or more of your origins or origin groups do not exist.", message="One or more of your origins or origin groups do not exist.",
**kwargs,
) )
class InvalidOriginServer(CloudFrontException): class InvalidOriginServer(CloudFrontException):
def __init__(self, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidOrigin", "InvalidOrigin",
message="The specified origin server does not exist or is not valid.", message="The specified origin server does not exist or is not valid.",
**kwargs,
) )
class DomainNameNotAnS3Bucket(CloudFrontException): class DomainNameNotAnS3Bucket(CloudFrontException):
def __init__(self, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidArgument", "InvalidArgument",
message="The parameter Origin DomainName does not refer to a valid S3 bucket.", message="The parameter Origin DomainName does not refer to a valid S3 bucket.",
**kwargs,
) )
class DistributionAlreadyExists(CloudFrontException): class DistributionAlreadyExists(CloudFrontException):
def __init__(self, dist_id, **kwargs): def __init__(self, dist_id: str):
super().__init__( super().__init__(
"DistributionAlreadyExists", "DistributionAlreadyExists",
message=f"The caller reference that you are using to create a distribution is associated with another distribution. Already exists: {dist_id}", message=f"The caller reference that you are using to create a distribution is associated with another distribution. Already exists: {dist_id}",
**kwargs,
) )
class InvalidIfMatchVersion(CloudFrontException): class InvalidIfMatchVersion(CloudFrontException):
def __init__(self, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidIfMatchVersion", "InvalidIfMatchVersion",
message="The If-Match version is missing or not valid for the resource.", message="The If-Match version is missing or not valid for the resource.",
**kwargs,
) )
@ -73,9 +69,7 @@ class NoSuchDistribution(CloudFrontException):
code = 404 code = 404
def __init__(self, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"NoSuchDistribution", "NoSuchDistribution", message="The specified distribution does not exist."
message="The specified distribution does not exist.",
**kwargs,
) )

View File

@ -1,6 +1,7 @@
import string import string
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Iterable, List, Tuple, Optional
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import BackendDict, iso_8601_datetime_with_milliseconds from moto.core.utils import BackendDict, iso_8601_datetime_with_milliseconds
from moto.moto_api import state_manager from moto.moto_api import state_manager
@ -19,28 +20,28 @@ from .exceptions import (
class ActiveTrustedSigners: class ActiveTrustedSigners:
def __init__(self): def __init__(self) -> None:
self.enabled = False self.enabled = False
self.quantity = 0 self.quantity = 0
self.signers = [] self.signers: List[Any] = []
class ActiveTrustedKeyGroups: class ActiveTrustedKeyGroups:
def __init__(self): def __init__(self) -> None:
self.enabled = False self.enabled = False
self.quantity = 0 self.quantity = 0
self.kg_key_pair_ids = [] self.kg_key_pair_ids: List[Any] = []
class LambdaFunctionAssociation: class LambdaFunctionAssociation:
def __init__(self): def __init__(self) -> None:
self.arn = "" self.arn = ""
self.event_type = "" self.event_type = ""
self.include_body = False self.include_body = False
class ForwardedValues: class ForwardedValues:
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
self.query_string = config.get("QueryString", "false") self.query_string = config.get("QueryString", "false")
self.cookie_forward = config.get("Cookies", {}).get("Forward") or "none" self.cookie_forward = config.get("Cookies", {}).get("Forward") or "none"
self.whitelisted_names = ( self.whitelisted_names = (
@ -49,17 +50,17 @@ class ForwardedValues:
self.whitelisted_names = self.whitelisted_names.get("Name") or [] self.whitelisted_names = self.whitelisted_names.get("Name") or []
if isinstance(self.whitelisted_names, str): if isinstance(self.whitelisted_names, str):
self.whitelisted_names = [self.whitelisted_names] self.whitelisted_names = [self.whitelisted_names]
self.headers = [] self.headers: List[Any] = []
self.query_string_cache_keys = [] self.query_string_cache_keys: List[Any] = []
class DefaultCacheBehaviour: class DefaultCacheBehaviour:
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
self.target_origin_id = config["TargetOriginId"] self.target_origin_id = config["TargetOriginId"]
self.trusted_signers_enabled = False self.trusted_signers_enabled = False
self.trusted_signers = [] self.trusted_signers: List[Any] = []
self.trusted_key_groups_enabled = False self.trusted_key_groups_enabled = False
self.trusted_key_groups = [] self.trusted_key_groups: List[Any] = []
self.viewer_protocol_policy = config["ViewerProtocolPolicy"] self.viewer_protocol_policy = config["ViewerProtocolPolicy"]
methods = config.get("AllowedMethods", {}) methods = config.get("AllowedMethods", {})
self.allowed_methods = methods.get("Items", {}).get("Method", ["HEAD", "GET"]) self.allowed_methods = methods.get("Items", {}).get("Method", ["HEAD", "GET"])
@ -70,8 +71,8 @@ class DefaultCacheBehaviour:
) )
self.smooth_streaming = config.get("SmoothStreaming") or True self.smooth_streaming = config.get("SmoothStreaming") or True
self.compress = config.get("Compress", "true").lower() == "true" self.compress = config.get("Compress", "true").lower() == "true"
self.lambda_function_associations = [] self.lambda_function_associations: List[Any] = []
self.function_associations = [] self.function_associations: List[Any] = []
self.field_level_encryption_id = "" self.field_level_encryption_id = ""
self.forwarded_values = ForwardedValues(config.get("ForwardedValues", {})) self.forwarded_values = ForwardedValues(config.get("ForwardedValues", {}))
self.min_ttl = config.get("MinTTL") or 0 self.min_ttl = config.get("MinTTL") or 0
@ -80,20 +81,20 @@ class DefaultCacheBehaviour:
class Logging: class Logging:
def __init__(self): def __init__(self) -> None:
self.enabled = False self.enabled = False
self.include_cookies = False self.include_cookies = False
class ViewerCertificate: class ViewerCertificate:
def __init__(self): def __init__(self) -> None:
self.cloud_front_default_certificate = True self.cloud_front_default_certificate = True
self.min_protocol_version = "TLSv1" self.min_protocol_version = "TLSv1"
self.certificate_source = "cloudfront" self.certificate_source = "cloudfront"
class CustomOriginConfig: class CustomOriginConfig:
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
self.http_port = config.get("HTTPPort") self.http_port = config.get("HTTPPort")
self.https_port = config.get("HTTPSPort") self.https_port = config.get("HTTPSPort")
self.keep_alive = config.get("OriginKeepaliveTimeout") self.keep_alive = config.get("OriginKeepaliveTimeout")
@ -106,10 +107,10 @@ class CustomOriginConfig:
class Origin: class Origin:
def __init__(self, origin): def __init__(self, origin: Dict[str, Any]):
self.id = origin["Id"] self.id = origin["Id"]
self.domain_name = origin["DomainName"] self.domain_name = origin["DomainName"]
self.custom_headers = [] self.custom_headers: List[Any] = []
self.s3_access_identity = "" self.s3_access_identity = ""
self.custom_origin = None self.custom_origin = None
self.origin_shield = origin.get("OriginShield") self.origin_shield = origin.get("OriginShield")
@ -130,14 +131,14 @@ class Origin:
class GeoRestrictions: class GeoRestrictions:
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
config = config.get("GeoRestriction") or {} config = config.get("GeoRestriction") or {}
self._type = config.get("RestrictionType", "none") self._type = config.get("RestrictionType", "none")
self.restrictions = (config.get("Items") or {}).get("Location") or [] self.restrictions = (config.get("Items") or {}).get("Location") or []
class DistributionConfig: class DistributionConfig:
def __init__(self, config): def __init__(self, config: Dict[str, Any]):
self.config = config self.config = config
self.aliases = ((config.get("Aliases") or {}).get("Items") or {}).get( self.aliases = ((config.get("Aliases") or {}).get("Items") or {}).get(
"CNAME" "CNAME"
@ -146,8 +147,8 @@ class DistributionConfig:
self.default_cache_behavior = DefaultCacheBehaviour( self.default_cache_behavior = DefaultCacheBehaviour(
config["DefaultCacheBehavior"] config["DefaultCacheBehavior"]
) )
self.cache_behaviors = [] self.cache_behaviors: List[Any] = []
self.custom_error_responses = [] self.custom_error_responses: List[Any] = []
self.logging = Logging() self.logging = Logging()
self.enabled = config.get("Enabled") or False self.enabled = config.get("Enabled") or False
self.viewer_certificate = ViewerCertificate() self.viewer_certificate = ViewerCertificate()
@ -172,7 +173,7 @@ class DistributionConfig:
class Distribution(BaseModel, ManagedState): class Distribution(BaseModel, ManagedState):
@staticmethod @staticmethod
def random_id(uppercase=True): def random_id(uppercase: bool = True) -> str:
ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase
chars = list(range(10)) + list(ascii_set) chars = list(range(10)) + list(ascii_set)
resource_id = random.choice(ascii_set) + "".join( resource_id = random.choice(ascii_set) + "".join(
@ -180,7 +181,7 @@ class Distribution(BaseModel, ManagedState):
) )
return resource_id return resource_id
def __init__(self, account_id, config): def __init__(self, account_id: str, config: Dict[str, Any]):
# Configured ManagedState # Configured ManagedState
super().__init__( super().__init__(
"cloudfront::distribution", transitions=[("InProgress", "Deployed")] "cloudfront::distribution", transitions=[("InProgress", "Deployed")]
@ -193,8 +194,8 @@ class Distribution(BaseModel, ManagedState):
self.distribution_config = DistributionConfig(config) self.distribution_config = DistributionConfig(config)
self.active_trusted_signers = ActiveTrustedSigners() self.active_trusted_signers = ActiveTrustedSigners()
self.active_trusted_key_groups = ActiveTrustedKeyGroups() self.active_trusted_key_groups = ActiveTrustedKeyGroups()
self.origin_groups = [] self.origin_groups: List[Any] = []
self.alias_icp_recordals = [] self.alias_icp_recordals: List[Any] = []
self.last_modified_time = "2021-11-27T10:34:26.802Z" self.last_modified_time = "2021-11-27T10:34:26.802Z"
self.in_progress_invalidation_batches = 0 self.in_progress_invalidation_batches = 0
self.has_active_trusted_key_groups = False self.has_active_trusted_key_groups = False
@ -202,13 +203,13 @@ class Distribution(BaseModel, ManagedState):
self.etag = Distribution.random_id() self.etag = Distribution.random_id()
@property @property
def location(self): def location(self) -> str:
return f"https://cloudfront.amazonaws.com/2020-05-31/distribution/{self.distribution_id}" return f"https://cloudfront.amazonaws.com/2020-05-31/distribution/{self.distribution_id}"
class Invalidation(BaseModel): class Invalidation(BaseModel):
@staticmethod @staticmethod
def random_id(uppercase=True): def random_id(uppercase: bool = True) -> str:
ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase
chars = list(range(10)) + list(ascii_set) chars = list(range(10)) + list(ascii_set)
resource_id = random.choice(ascii_set) + "".join( resource_id = random.choice(ascii_set) + "".join(
@ -216,7 +217,9 @@ class Invalidation(BaseModel):
) )
return resource_id return resource_id
def __init__(self, distribution, paths, caller_ref): def __init__(
self, distribution: Distribution, paths: Dict[str, Any], caller_ref: str
):
self.invalidation_id = Invalidation.random_id() self.invalidation_id = Invalidation.random_id()
self.create_time = iso_8601_datetime_with_milliseconds(datetime.now()) self.create_time = iso_8601_datetime_with_milliseconds(datetime.now())
self.distribution = distribution self.distribution = distribution
@ -226,22 +229,24 @@ class Invalidation(BaseModel):
self.caller_ref = caller_ref self.caller_ref = caller_ref
@property @property
def location(self): def location(self) -> str:
return self.distribution.location + f"/invalidation/{self.invalidation_id}" return self.distribution.location + f"/invalidation/{self.invalidation_id}"
class CloudFrontBackend(BaseBackend): class CloudFrontBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.distributions = dict() self.distributions: Dict[str, Distribution] = dict()
self.invalidations = dict() self.invalidations: Dict[str, List[Invalidation]] = dict()
self.tagger = TaggingService() self.tagger = TaggingService()
state_manager.register_default_transition( state_manager.register_default_transition(
"cloudfront::distribution", transition={"progression": "manual", "times": 1} "cloudfront::distribution", transition={"progression": "manual", "times": 1}
) )
def create_distribution(self, distribution_config, tags): def create_distribution(
self, distribution_config: Dict[str, Any], tags: List[Dict[str, str]]
) -> Tuple[Distribution, str, str]:
""" """
Not all configuration options are supported yet. Please raise an issue if Not all configuration options are supported yet. Please raise an issue if
we're not persisting/returning the correct attributes for your we're not persisting/returning the correct attributes for your
@ -250,24 +255,26 @@ class CloudFrontBackend(BaseBackend):
# We'll always call dist_with_tags, as the incoming request is the same # We'll always call dist_with_tags, as the incoming request is the same
return self.create_distribution_with_tags(distribution_config, tags) return self.create_distribution_with_tags(distribution_config, tags)
def create_distribution_with_tags(self, distribution_config, tags): def create_distribution_with_tags(
self, distribution_config: Dict[str, Any], tags: List[Dict[str, str]]
) -> Tuple[Distribution, str, str]:
dist = Distribution(self.account_id, distribution_config) dist = Distribution(self.account_id, distribution_config)
caller_reference = dist.distribution_config.caller_reference caller_reference = dist.distribution_config.caller_reference
existing_dist = self._distribution_with_caller_reference(caller_reference) existing_dist = self._distribution_with_caller_reference(caller_reference)
if existing_dist: if existing_dist is not None:
raise DistributionAlreadyExists(existing_dist.distribution_id) raise DistributionAlreadyExists(existing_dist.distribution_id)
self.distributions[dist.distribution_id] = dist self.distributions[dist.distribution_id] = dist
self.tagger.tag_resource(dist.arn, tags) self.tagger.tag_resource(dist.arn, tags)
return dist, dist.location, dist.etag return dist, dist.location, dist.etag
def get_distribution(self, distribution_id): def get_distribution(self, distribution_id: str) -> Tuple[Distribution, str]:
if distribution_id not in self.distributions: if distribution_id not in self.distributions:
raise NoSuchDistribution raise NoSuchDistribution
dist = self.distributions[distribution_id] dist = self.distributions[distribution_id]
dist.advance() dist.advance()
return dist, dist.etag return dist, dist.etag
def delete_distribution(self, distribution_id, if_match): def delete_distribution(self, distribution_id: str, if_match: bool) -> None:
""" """
The IfMatch-value is ignored - any value is considered valid. The IfMatch-value is ignored - any value is considered valid.
Calling this function without a value is invalid, per AWS' behaviour Calling this function without a value is invalid, per AWS' behaviour
@ -278,7 +285,7 @@ class CloudFrontBackend(BaseBackend):
raise NoSuchDistribution raise NoSuchDistribution
del self.distributions[distribution_id] del self.distributions[distribution_id]
def list_distributions(self): def list_distributions(self) -> Iterable[Distribution]:
""" """
Pagination is not supported yet. Pagination is not supported yet.
""" """
@ -286,14 +293,18 @@ class CloudFrontBackend(BaseBackend):
dist.advance() dist.advance()
return self.distributions.values() return self.distributions.values()
def _distribution_with_caller_reference(self, reference): def _distribution_with_caller_reference(
self, reference: str
) -> Optional[Distribution]:
for dist in self.distributions.values(): for dist in self.distributions.values():
config = dist.distribution_config config = dist.distribution_config
if config.caller_reference == reference: if config.caller_reference == reference:
return dist return dist
return False return None
def update_distribution(self, dist_config, _id, if_match): def update_distribution(
self, dist_config: Dict[str, Any], _id: str, if_match: bool
) -> Tuple[Distribution, str, str]:
""" """
The IfMatch-value is ignored - any value is considered valid. The IfMatch-value is ignored - any value is considered valid.
Calling this function without a value is invalid, per AWS' behaviour Calling this function without a value is invalid, per AWS' behaviour
@ -313,7 +324,9 @@ class CloudFrontBackend(BaseBackend):
dist.advance() dist.advance()
return dist, dist.location, dist.etag return dist, dist.location, dist.etag
def create_invalidation(self, dist_id, paths, caller_ref): def create_invalidation(
self, dist_id: str, paths: Dict[str, Any], caller_ref: str
) -> Invalidation:
dist, _ = self.get_distribution(dist_id) dist, _ = self.get_distribution(dist_id)
invalidation = Invalidation(dist, paths, caller_ref) invalidation = Invalidation(dist, paths, caller_ref)
try: try:
@ -323,13 +336,13 @@ class CloudFrontBackend(BaseBackend):
return invalidation return invalidation
def list_invalidations(self, dist_id): def list_invalidations(self, dist_id: str) -> Iterable[Invalidation]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
return self.invalidations.get(dist_id) or {} return self.invalidations.get(dist_id) or []
def list_tags_for_resource(self, resource): def list_tags_for_resource(self, resource: str) -> Dict[str, List[Dict[str, str]]]:
return self.tagger.list_tags_for_resource(resource) return self.tagger.list_tags_for_resource(resource)

View File

@ -1,54 +1,55 @@
import xmltodict import xmltodict
from typing import Any, Dict
from urllib.parse import unquote from urllib.parse import unquote
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse, TYPE_RESPONSE
from .models import cloudfront_backends from .models import cloudfront_backends, CloudFrontBackend
XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/"
class CloudFrontResponse(BaseResponse): class CloudFrontResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="cloudfront") super().__init__(service_name="cloudfront")
def _get_xml_body(self): def _get_xml_body(self) -> Dict[str, Any]:
return xmltodict.parse(self.body, dict_constructor=dict) return xmltodict.parse(self.body, dict_constructor=dict)
@property @property
def backend(self): def backend(self) -> CloudFrontBackend:
return cloudfront_backends[self.current_account]["global"] return cloudfront_backends[self.current_account]["global"]
def distributions(self, request, full_url, headers): def distributions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_distribution() return self.create_distribution()
if request.method == "GET": if request.method == "GET":
return self.list_distributions() return self.list_distributions()
def invalidation(self, request, full_url, headers): def invalidation(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
return self.create_invalidation() return self.create_invalidation()
if request.method == "GET": if request.method == "GET":
return self.list_invalidations() return self.list_invalidations()
def tags(self, request, full_url, headers): def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self.list_tags_for_resource() return self.list_tags_for_resource()
def create_distribution(self): def create_distribution(self) -> TYPE_RESPONSE:
params = self._get_xml_body() params = self._get_xml_body()
if "DistributionConfigWithTags" in params: if "DistributionConfigWithTags" in params:
config = params.get("DistributionConfigWithTags") config = params.get("DistributionConfigWithTags")
tags = (config.get("Tags", {}).get("Items") or {}).get("Tag", []) tags = (config.get("Tags", {}).get("Items") or {}).get("Tag", []) # type: ignore[union-attr]
if not isinstance(tags, list): if not isinstance(tags, list):
tags = [tags] tags = [tags]
else: else:
config = params config = params
tags = [] tags = []
distribution_config = config.get("DistributionConfig") distribution_config = config.get("DistributionConfig") # type: ignore[union-attr]
distribution, location, e_tag = self.backend.create_distribution( distribution, location, e_tag = self.backend.create_distribution(
distribution_config=distribution_config, distribution_config=distribution_config,
tags=tags, tags=tags,
@ -58,13 +59,13 @@ class CloudFrontResponse(BaseResponse):
headers = {"ETag": e_tag, "Location": location} headers = {"ETag": e_tag, "Location": location}
return 200, headers, response return 200, headers, response
def list_distributions(self): def list_distributions(self) -> TYPE_RESPONSE:
distributions = self.backend.list_distributions() distributions = self.backend.list_distributions()
template = self.response_template(LIST_TEMPLATE) template = self.response_template(LIST_TEMPLATE)
response = template.render(distributions=distributions) response = template.render(distributions=distributions)
return 200, {}, response return 200, {}, response
def individual_distribution(self, request, full_url, headers): def individual_distribution(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
distribution_id = full_url.split("/")[-1] distribution_id = full_url.split("/")[-1]
if request.method == "DELETE": if request.method == "DELETE":
@ -77,7 +78,9 @@ class CloudFrontResponse(BaseResponse):
response = template.render(distribution=dist, xmlns=XMLNS) response = template.render(distribution=dist, xmlns=XMLNS)
return 200, {"ETag": etag}, response return 200, {"ETag": etag}, response
def update_distribution(self, request, full_url, headers): def update_distribution(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
params = self._get_xml_body() params = self._get_xml_body()
distribution_config = params.get("DistributionConfig") distribution_config = params.get("DistributionConfig")
@ -85,7 +88,7 @@ class CloudFrontResponse(BaseResponse):
if_match = headers["If-Match"] if_match = headers["If-Match"]
dist, location, e_tag = self.backend.update_distribution( dist, location, e_tag = self.backend.update_distribution(
dist_config=distribution_config, dist_config=distribution_config, # type: ignore[arg-type]
_id=dist_id, _id=dist_id,
if_match=if_match, if_match=if_match,
) )
@ -94,19 +97,19 @@ class CloudFrontResponse(BaseResponse):
headers = {"ETag": e_tag, "Location": location} headers = {"ETag": e_tag, "Location": location}
return 200, headers, response return 200, headers, response
def create_invalidation(self): def create_invalidation(self) -> TYPE_RESPONSE:
dist_id = self.path.split("/")[-2] dist_id = self.path.split("/")[-2]
params = self._get_xml_body()["InvalidationBatch"] params = self._get_xml_body()["InvalidationBatch"]
paths = ((params.get("Paths") or {}).get("Items") or {}).get("Path") or [] paths = ((params.get("Paths") or {}).get("Items") or {}).get("Path") or []
caller_ref = params.get("CallerReference") caller_ref = params.get("CallerReference")
invalidation = self.backend.create_invalidation(dist_id, paths, caller_ref) invalidation = self.backend.create_invalidation(dist_id, paths, caller_ref) # type: ignore[arg-type]
template = self.response_template(CREATE_INVALIDATION_TEMPLATE) template = self.response_template(CREATE_INVALIDATION_TEMPLATE)
response = template.render(invalidation=invalidation, xmlns=XMLNS) response = template.render(invalidation=invalidation, xmlns=XMLNS)
return 200, {"Location": invalidation.location}, response return 200, {"Location": invalidation.location}, response
def list_invalidations(self): def list_invalidations(self) -> TYPE_RESPONSE:
dist_id = self.path.split("/")[-2] dist_id = self.path.split("/")[-2]
invalidations = self.backend.list_invalidations(dist_id) invalidations = self.backend.list_invalidations(dist_id)
template = self.response_template(INVALIDATIONS_TEMPLATE) template = self.response_template(INVALIDATIONS_TEMPLATE)
@ -114,7 +117,7 @@ class CloudFrontResponse(BaseResponse):
return 200, {}, response return 200, {}, response
def list_tags_for_resource(self): def list_tags_for_resource(self) -> TYPE_RESPONSE:
resource = unquote(self._get_param("Resource")) resource = unquote(self._get_param("Resource"))
tags = self.backend.list_tags_for_resource(resource=resource)["Tags"] tags = self.backend.list_tags_for_resource(resource=resource)["Tags"]
template = self.response_template(TAGS_TEMPLATE) template = self.response_template(TAGS_TEMPLATE)

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/ce,moto/cloudformation files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract