diff --git a/moto/cloudfront/exceptions.py b/moto/cloudfront/exceptions.py index 2f4dc203d..268faed53 100644 --- a/moto/cloudfront/exceptions.py +++ b/moto/cloudfront/exceptions.py @@ -1,4 +1,5 @@ from moto.core.exceptions import RESTError +from typing import Any EXCEPTION_RESPONSE = """ @@ -15,57 +16,52 @@ class CloudFrontException(RESTError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self, error_type: str, message: str, **kwargs: Any): kwargs.setdefault("template", "cferror") self.templates["cferror"] = EXCEPTION_RESPONSE - super().__init__(*args, **kwargs) + super().__init__(error_type, message, **kwargs) class OriginDoesNotExist(CloudFrontException): code = 404 - def __init__(self, **kwargs): + def __init__(self) -> None: super().__init__( "NoSuchOrigin", message="One or more of your origins or origin groups do not exist.", - **kwargs, ) class InvalidOriginServer(CloudFrontException): - def __init__(self, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidOrigin", message="The specified origin server does not exist or is not valid.", - **kwargs, ) class DomainNameNotAnS3Bucket(CloudFrontException): - def __init__(self, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidArgument", message="The parameter Origin DomainName does not refer to a valid S3 bucket.", - **kwargs, ) class DistributionAlreadyExists(CloudFrontException): - def __init__(self, dist_id, **kwargs): + def __init__(self, dist_id: str): super().__init__( "DistributionAlreadyExists", message=f"The caller reference that you are using to create a distribution is associated with another distribution. Already exists: {dist_id}", - **kwargs, ) class InvalidIfMatchVersion(CloudFrontException): - def __init__(self, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidIfMatchVersion", message="The If-Match version is missing or not valid for the resource.", - **kwargs, ) @@ -73,9 +69,7 @@ class NoSuchDistribution(CloudFrontException): code = 404 - def __init__(self, **kwargs): + def __init__(self) -> None: super().__init__( - "NoSuchDistribution", - message="The specified distribution does not exist.", - **kwargs, + "NoSuchDistribution", message="The specified distribution does not exist." ) diff --git a/moto/cloudfront/models.py b/moto/cloudfront/models.py index 9b8d13ab0..09efe1690 100644 --- a/moto/cloudfront/models.py +++ b/moto/cloudfront/models.py @@ -1,6 +1,7 @@ import string from datetime import datetime +from typing import Any, Dict, Iterable, List, Tuple, Optional from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict, iso_8601_datetime_with_milliseconds from moto.moto_api import state_manager @@ -19,28 +20,28 @@ from .exceptions import ( class ActiveTrustedSigners: - def __init__(self): + def __init__(self) -> None: self.enabled = False self.quantity = 0 - self.signers = [] + self.signers: List[Any] = [] class ActiveTrustedKeyGroups: - def __init__(self): + def __init__(self) -> None: self.enabled = False self.quantity = 0 - self.kg_key_pair_ids = [] + self.kg_key_pair_ids: List[Any] = [] class LambdaFunctionAssociation: - def __init__(self): + def __init__(self) -> None: self.arn = "" self.event_type = "" self.include_body = False class ForwardedValues: - def __init__(self, config): + def __init__(self, config: Dict[str, Any]): self.query_string = config.get("QueryString", "false") self.cookie_forward = config.get("Cookies", {}).get("Forward") or "none" self.whitelisted_names = ( @@ -49,17 +50,17 @@ class ForwardedValues: self.whitelisted_names = self.whitelisted_names.get("Name") or [] if isinstance(self.whitelisted_names, str): self.whitelisted_names = [self.whitelisted_names] - self.headers = [] - self.query_string_cache_keys = [] + self.headers: List[Any] = [] + self.query_string_cache_keys: List[Any] = [] class DefaultCacheBehaviour: - def __init__(self, config): + def __init__(self, config: Dict[str, Any]): self.target_origin_id = config["TargetOriginId"] self.trusted_signers_enabled = False - self.trusted_signers = [] + self.trusted_signers: List[Any] = [] self.trusted_key_groups_enabled = False - self.trusted_key_groups = [] + self.trusted_key_groups: List[Any] = [] self.viewer_protocol_policy = config["ViewerProtocolPolicy"] methods = config.get("AllowedMethods", {}) self.allowed_methods = methods.get("Items", {}).get("Method", ["HEAD", "GET"]) @@ -70,8 +71,8 @@ class DefaultCacheBehaviour: ) self.smooth_streaming = config.get("SmoothStreaming") or True self.compress = config.get("Compress", "true").lower() == "true" - self.lambda_function_associations = [] - self.function_associations = [] + self.lambda_function_associations: List[Any] = [] + self.function_associations: List[Any] = [] self.field_level_encryption_id = "" self.forwarded_values = ForwardedValues(config.get("ForwardedValues", {})) self.min_ttl = config.get("MinTTL") or 0 @@ -80,20 +81,20 @@ class DefaultCacheBehaviour: class Logging: - def __init__(self): + def __init__(self) -> None: self.enabled = False self.include_cookies = False class ViewerCertificate: - def __init__(self): + def __init__(self) -> None: self.cloud_front_default_certificate = True self.min_protocol_version = "TLSv1" self.certificate_source = "cloudfront" class CustomOriginConfig: - def __init__(self, config): + def __init__(self, config: Dict[str, Any]): self.http_port = config.get("HTTPPort") self.https_port = config.get("HTTPSPort") self.keep_alive = config.get("OriginKeepaliveTimeout") @@ -106,10 +107,10 @@ class CustomOriginConfig: class Origin: - def __init__(self, origin): + def __init__(self, origin: Dict[str, Any]): self.id = origin["Id"] self.domain_name = origin["DomainName"] - self.custom_headers = [] + self.custom_headers: List[Any] = [] self.s3_access_identity = "" self.custom_origin = None self.origin_shield = origin.get("OriginShield") @@ -130,14 +131,14 @@ class Origin: class GeoRestrictions: - def __init__(self, config): + def __init__(self, config: Dict[str, Any]): config = config.get("GeoRestriction") or {} self._type = config.get("RestrictionType", "none") self.restrictions = (config.get("Items") or {}).get("Location") or [] class DistributionConfig: - def __init__(self, config): + def __init__(self, config: Dict[str, Any]): self.config = config self.aliases = ((config.get("Aliases") or {}).get("Items") or {}).get( "CNAME" @@ -146,8 +147,8 @@ class DistributionConfig: self.default_cache_behavior = DefaultCacheBehaviour( config["DefaultCacheBehavior"] ) - self.cache_behaviors = [] - self.custom_error_responses = [] + self.cache_behaviors: List[Any] = [] + self.custom_error_responses: List[Any] = [] self.logging = Logging() self.enabled = config.get("Enabled") or False self.viewer_certificate = ViewerCertificate() @@ -172,7 +173,7 @@ class DistributionConfig: class Distribution(BaseModel, ManagedState): @staticmethod - def random_id(uppercase=True): + def random_id(uppercase: bool = True) -> str: ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase chars = list(range(10)) + list(ascii_set) resource_id = random.choice(ascii_set) + "".join( @@ -180,7 +181,7 @@ class Distribution(BaseModel, ManagedState): ) return resource_id - def __init__(self, account_id, config): + def __init__(self, account_id: str, config: Dict[str, Any]): # Configured ManagedState super().__init__( "cloudfront::distribution", transitions=[("InProgress", "Deployed")] @@ -193,8 +194,8 @@ class Distribution(BaseModel, ManagedState): self.distribution_config = DistributionConfig(config) self.active_trusted_signers = ActiveTrustedSigners() self.active_trusted_key_groups = ActiveTrustedKeyGroups() - self.origin_groups = [] - self.alias_icp_recordals = [] + self.origin_groups: List[Any] = [] + self.alias_icp_recordals: List[Any] = [] self.last_modified_time = "2021-11-27T10:34:26.802Z" self.in_progress_invalidation_batches = 0 self.has_active_trusted_key_groups = False @@ -202,13 +203,13 @@ class Distribution(BaseModel, ManagedState): self.etag = Distribution.random_id() @property - def location(self): + def location(self) -> str: return f"https://cloudfront.amazonaws.com/2020-05-31/distribution/{self.distribution_id}" class Invalidation(BaseModel): @staticmethod - def random_id(uppercase=True): + def random_id(uppercase: bool = True) -> str: ascii_set = string.ascii_uppercase if uppercase else string.ascii_lowercase chars = list(range(10)) + list(ascii_set) resource_id = random.choice(ascii_set) + "".join( @@ -216,7 +217,9 @@ class Invalidation(BaseModel): ) return resource_id - def __init__(self, distribution, paths, caller_ref): + def __init__( + self, distribution: Distribution, paths: Dict[str, Any], caller_ref: str + ): self.invalidation_id = Invalidation.random_id() self.create_time = iso_8601_datetime_with_milliseconds(datetime.now()) self.distribution = distribution @@ -226,22 +229,24 @@ class Invalidation(BaseModel): self.caller_ref = caller_ref @property - def location(self): + def location(self) -> str: return self.distribution.location + f"/invalidation/{self.invalidation_id}" class CloudFrontBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.distributions = dict() - self.invalidations = dict() + self.distributions: Dict[str, Distribution] = dict() + self.invalidations: Dict[str, List[Invalidation]] = dict() self.tagger = TaggingService() state_manager.register_default_transition( "cloudfront::distribution", transition={"progression": "manual", "times": 1} ) - def create_distribution(self, distribution_config, tags): + def create_distribution( + self, distribution_config: Dict[str, Any], tags: List[Dict[str, str]] + ) -> Tuple[Distribution, str, str]: """ Not all configuration options are supported yet. Please raise an issue if we're not persisting/returning the correct attributes for your @@ -250,24 +255,26 @@ class CloudFrontBackend(BaseBackend): # We'll always call dist_with_tags, as the incoming request is the same return self.create_distribution_with_tags(distribution_config, tags) - def create_distribution_with_tags(self, distribution_config, tags): + def create_distribution_with_tags( + self, distribution_config: Dict[str, Any], tags: List[Dict[str, str]] + ) -> Tuple[Distribution, str, str]: dist = Distribution(self.account_id, distribution_config) caller_reference = dist.distribution_config.caller_reference existing_dist = self._distribution_with_caller_reference(caller_reference) - if existing_dist: + if existing_dist is not None: raise DistributionAlreadyExists(existing_dist.distribution_id) self.distributions[dist.distribution_id] = dist self.tagger.tag_resource(dist.arn, tags) return dist, dist.location, dist.etag - def get_distribution(self, distribution_id): + def get_distribution(self, distribution_id: str) -> Tuple[Distribution, str]: if distribution_id not in self.distributions: raise NoSuchDistribution dist = self.distributions[distribution_id] dist.advance() return dist, dist.etag - def delete_distribution(self, distribution_id, if_match): + def delete_distribution(self, distribution_id: str, if_match: bool) -> None: """ The IfMatch-value is ignored - any value is considered valid. Calling this function without a value is invalid, per AWS' behaviour @@ -278,7 +285,7 @@ class CloudFrontBackend(BaseBackend): raise NoSuchDistribution del self.distributions[distribution_id] - def list_distributions(self): + def list_distributions(self) -> Iterable[Distribution]: """ Pagination is not supported yet. """ @@ -286,14 +293,18 @@ class CloudFrontBackend(BaseBackend): dist.advance() return self.distributions.values() - def _distribution_with_caller_reference(self, reference): + def _distribution_with_caller_reference( + self, reference: str + ) -> Optional[Distribution]: for dist in self.distributions.values(): config = dist.distribution_config if config.caller_reference == reference: return dist - return False + return None - def update_distribution(self, dist_config, _id, if_match): + def update_distribution( + self, dist_config: Dict[str, Any], _id: str, if_match: bool + ) -> Tuple[Distribution, str, str]: """ The IfMatch-value is ignored - any value is considered valid. Calling this function without a value is invalid, per AWS' behaviour @@ -313,7 +324,9 @@ class CloudFrontBackend(BaseBackend): dist.advance() return dist, dist.location, dist.etag - def create_invalidation(self, dist_id, paths, caller_ref): + def create_invalidation( + self, dist_id: str, paths: Dict[str, Any], caller_ref: str + ) -> Invalidation: dist, _ = self.get_distribution(dist_id) invalidation = Invalidation(dist, paths, caller_ref) try: @@ -323,13 +336,13 @@ class CloudFrontBackend(BaseBackend): return invalidation - def list_invalidations(self, dist_id): + def list_invalidations(self, dist_id: str) -> Iterable[Invalidation]: """ Pagination is not yet implemented """ - return self.invalidations.get(dist_id) or {} + return self.invalidations.get(dist_id) or [] - def list_tags_for_resource(self, resource): + def list_tags_for_resource(self, resource: str) -> Dict[str, List[Dict[str, str]]]: return self.tagger.list_tags_for_resource(resource) diff --git a/moto/cloudfront/responses.py b/moto/cloudfront/responses.py index 4038f1fa3..c3aa45d22 100644 --- a/moto/cloudfront/responses.py +++ b/moto/cloudfront/responses.py @@ -1,54 +1,55 @@ import xmltodict +from typing import Any, Dict from urllib.parse import unquote -from moto.core.responses import BaseResponse -from .models import cloudfront_backends +from moto.core.responses import BaseResponse, TYPE_RESPONSE +from .models import cloudfront_backends, CloudFrontBackend XMLNS = "http://cloudfront.amazonaws.com/doc/2020-05-31/" class CloudFrontResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="cloudfront") - def _get_xml_body(self): + def _get_xml_body(self) -> Dict[str, Any]: return xmltodict.parse(self.body, dict_constructor=dict) @property - def backend(self): + def backend(self) -> CloudFrontBackend: return cloudfront_backends[self.current_account]["global"] - def distributions(self, request, full_url, headers): + def distributions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_distribution() if request.method == "GET": return self.list_distributions() - def invalidation(self, request, full_url, headers): + def invalidation(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self.create_invalidation() if request.method == "GET": return self.list_invalidations() - def tags(self, request, full_url, headers): + def tags(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.list_tags_for_resource() - def create_distribution(self): + def create_distribution(self) -> TYPE_RESPONSE: params = self._get_xml_body() if "DistributionConfigWithTags" in params: config = params.get("DistributionConfigWithTags") - tags = (config.get("Tags", {}).get("Items") or {}).get("Tag", []) + tags = (config.get("Tags", {}).get("Items") or {}).get("Tag", []) # type: ignore[union-attr] if not isinstance(tags, list): tags = [tags] else: config = params tags = [] - distribution_config = config.get("DistributionConfig") + distribution_config = config.get("DistributionConfig") # type: ignore[union-attr] distribution, location, e_tag = self.backend.create_distribution( distribution_config=distribution_config, tags=tags, @@ -58,13 +59,13 @@ class CloudFrontResponse(BaseResponse): headers = {"ETag": e_tag, "Location": location} return 200, headers, response - def list_distributions(self): + def list_distributions(self) -> TYPE_RESPONSE: distributions = self.backend.list_distributions() template = self.response_template(LIST_TEMPLATE) response = template.render(distributions=distributions) return 200, {}, response - def individual_distribution(self, request, full_url, headers): + def individual_distribution(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) distribution_id = full_url.split("/")[-1] if request.method == "DELETE": @@ -77,7 +78,9 @@ class CloudFrontResponse(BaseResponse): response = template.render(distribution=dist, xmlns=XMLNS) return 200, {"ETag": etag}, response - def update_distribution(self, request, full_url, headers): + def update_distribution( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) params = self._get_xml_body() distribution_config = params.get("DistributionConfig") @@ -85,7 +88,7 @@ class CloudFrontResponse(BaseResponse): if_match = headers["If-Match"] dist, location, e_tag = self.backend.update_distribution( - dist_config=distribution_config, + dist_config=distribution_config, # type: ignore[arg-type] _id=dist_id, if_match=if_match, ) @@ -94,19 +97,19 @@ class CloudFrontResponse(BaseResponse): headers = {"ETag": e_tag, "Location": location} return 200, headers, response - def create_invalidation(self): + def create_invalidation(self) -> TYPE_RESPONSE: dist_id = self.path.split("/")[-2] params = self._get_xml_body()["InvalidationBatch"] paths = ((params.get("Paths") or {}).get("Items") or {}).get("Path") or [] caller_ref = params.get("CallerReference") - invalidation = self.backend.create_invalidation(dist_id, paths, caller_ref) + invalidation = self.backend.create_invalidation(dist_id, paths, caller_ref) # type: ignore[arg-type] template = self.response_template(CREATE_INVALIDATION_TEMPLATE) response = template.render(invalidation=invalidation, xmlns=XMLNS) return 200, {"Location": invalidation.location}, response - def list_invalidations(self): + def list_invalidations(self) -> TYPE_RESPONSE: dist_id = self.path.split("/")[-2] invalidations = self.backend.list_invalidations(dist_id) template = self.response_template(INVALIDATIONS_TEMPLATE) @@ -114,7 +117,7 @@ class CloudFrontResponse(BaseResponse): return 200, {}, response - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> TYPE_RESPONSE: resource = unquote(self._get_param("Resource")) tags = self.backend.list_tags_for_resource(resource=resource)["Tags"] template = self.response_template(TAGS_TEMPLATE) diff --git a/setup.cfg b/setup.cfg index 3391d9fdf..0085dd39e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/ce,moto/cloudformation +files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront show_column_numbers=True show_error_codes = True disable_error_code=abstract