Techdebt: MyPy R (#6229)

This commit is contained in:
Bert Blommers 2023-04-19 10:25:48 +00:00 committed by GitHub
parent 3e21ddd606
commit f1286506be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 434 additions and 340 deletions

View File

@ -1,11 +1,12 @@
"""Exceptions raised by the Route53 service.""" """Exceptions raised by the Route53 service."""
from typing import Any
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
class Route53ClientError(RESTError): class Route53ClientError(RESTError):
"""Base class for Route53 errors.""" """Base class for Route53 errors."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "wrapped_single_error") kwargs.setdefault("template", "wrapped_single_error")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -20,9 +21,7 @@ class InvalidInput(Route53ClientError):
class InvalidCloudWatchArn(InvalidInput): class InvalidCloudWatchArn(InvalidInput):
def __init__( def __init__(self) -> None:
self,
):
message = "The ARN for the CloudWatch Logs log group is invalid" message = "The ARN for the CloudWatch Logs log group is invalid"
super().__init__(message) super().__init__(message)
@ -41,7 +40,7 @@ class InvalidPaginationToken(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = ( message = (
"Route 53 can't get the next page of query logging configurations " "Route 53 can't get the next page of query logging configurations "
"because the specified value for NextToken is invalid." "because the specified value for NextToken is invalid."
@ -54,7 +53,7 @@ class InvalidVPCId(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = "Invalid or missing VPC Id." message = "Invalid or missing VPC Id."
super().__init__("InvalidVPCId", message) super().__init__("InvalidVPCId", message)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -65,7 +64,7 @@ class NoSuchCloudWatchLogsLogGroup(Route53ClientError):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
message = "The specified CloudWatch Logs log group doesn't exist." message = "The specified CloudWatch Logs log group doesn't exist."
super().__init__("NoSuchCloudWatchLogsLogGroup", message) super().__init__("NoSuchCloudWatchLogsLogGroup", message)
@ -75,7 +74,7 @@ class NoSuchHostedZone(Route53ClientError):
code = 404 code = 404
def __init__(self, host_zone_id): def __init__(self, host_zone_id: str):
message = f"No hosted zone found with ID: {host_zone_id}" message = f"No hosted zone found with ID: {host_zone_id}"
super().__init__("NoSuchHostedZone", message) super().__init__("NoSuchHostedZone", message)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -86,7 +85,7 @@ class NoSuchHealthCheck(Route53ClientError):
code = 404 code = 404
def __init__(self, health_check_id): def __init__(self, health_check_id: str):
message = f"A health check with id {health_check_id} does not exist." message = f"A health check with id {health_check_id} does not exist."
super().__init__("NoSuchHealthCheck", message) super().__init__("NoSuchHealthCheck", message)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -97,7 +96,7 @@ class HostedZoneNotEmpty(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = ( message = (
"The hosted zone contains resource records that are not SOA or NS records." "The hosted zone contains resource records that are not SOA or NS records."
) )
@ -110,7 +109,7 @@ class PublicZoneVPCAssociation(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = "You're trying to associate a VPC with a public hosted zone. Amazon Route 53 doesn't support associating a VPC with a public hosted zone." message = "You're trying to associate a VPC with a public hosted zone. Amazon Route 53 doesn't support associating a VPC with a public hosted zone."
super().__init__("PublicZoneVPCAssociation", message) super().__init__("PublicZoneVPCAssociation", message)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -121,7 +120,7 @@ class LastVPCAssociation(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = "The VPC that you're trying to disassociate from the private hosted zone is the last VPC that is associated with the hosted zone. Amazon Route 53 doesn't support disassociating the last VPC from a hosted zone." message = "The VPC that you're trying to disassociate from the private hosted zone is the last VPC that is associated with the hosted zone. Amazon Route 53 doesn't support disassociating the last VPC from a hosted zone."
super().__init__("LastVPCAssociation", message) super().__init__("LastVPCAssociation", message)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -132,7 +131,7 @@ class NoSuchQueryLoggingConfig(Route53ClientError):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
message = "The query logging configuration does not exist" message = "The query logging configuration does not exist"
super().__init__("NoSuchQueryLoggingConfig", message) super().__init__("NoSuchQueryLoggingConfig", message)
@ -142,7 +141,7 @@ class QueryLoggingConfigAlreadyExists(Route53ClientError):
code = 409 code = 409
def __init__(self): def __init__(self) -> None:
message = "A query logging configuration already exists for this hosted zone" message = "A query logging configuration already exists for this hosted zone"
super().__init__("QueryLoggingConfigAlreadyExists", message) super().__init__("QueryLoggingConfigAlreadyExists", message)
@ -151,7 +150,7 @@ class InvalidChangeBatch(Route53ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
message = "Number of records limit of 1000 exceeded." message = "Number of records limit of 1000 exceeded."
super().__init__("InvalidChangeBatch", message) super().__init__("InvalidChangeBatch", message)
@ -159,7 +158,7 @@ class InvalidChangeBatch(Route53ClientError):
class NoSuchDelegationSet(Route53ClientError): class NoSuchDelegationSet(Route53ClientError):
code = 400 code = 400
def __init__(self, delegation_set_id): def __init__(self, delegation_set_id: str):
super().__init__("NoSuchDelegationSet", delegation_set_id) super().__init__("NoSuchDelegationSet", delegation_set_id)
self.content_type = "text/xml" self.content_type = "text/xml"
@ -167,7 +166,7 @@ class NoSuchDelegationSet(Route53ClientError):
class DnsNameInvalidForZone(Route53ClientError): class DnsNameInvalidForZone(Route53ClientError):
code = 400 code = 400
def __init__(self, name, zone_name): def __init__(self, name: str, zone_name: str):
error_msg = ( error_msg = (
f"""RRSet with DNS name {name} is not permitted in zone {zone_name}""" f"""RRSet with DNS name {name} is not permitted in zone {zone_name}"""
) )

View File

@ -5,6 +5,7 @@ import re
import string import string
from collections import defaultdict from collections import defaultdict
from jinja2 import Template from jinja2 import Template
from typing import Any, Dict, List, Optional, Tuple
from moto.route53.exceptions import ( from moto.route53.exceptions import (
HostedZoneNotEmpty, HostedZoneNotEmpty,
@ -30,13 +31,18 @@ from .utils import PAGINATION_MODEL
ROUTE53_ID_CHOICE = string.ascii_uppercase + string.digits ROUTE53_ID_CHOICE = string.ascii_uppercase + string.digits
def create_route53_zone_id(): def create_route53_zone_id() -> str:
# New ID's look like this Z1RWWTK7Y8UDDQ # New ID's look like this Z1RWWTK7Y8UDDQ
return "".join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)]) return "".join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)])
class DelegationSet(BaseModel): class DelegationSet(BaseModel):
def __init__(self, caller_reference, name_servers, delegation_set_id): def __init__(
self,
caller_reference: str,
name_servers: Optional[List[str]],
delegation_set_id: Optional[str],
):
self.caller_reference = caller_reference self.caller_reference = caller_reference
self.name_servers = name_servers or [ self.name_servers = name_servers or [
"ns-2048.awsdns-64.com", "ns-2048.awsdns-64.com",
@ -51,7 +57,12 @@ class DelegationSet(BaseModel):
class HealthCheck(CloudFormationModel): class HealthCheck(CloudFormationModel):
def __init__(self, health_check_id, caller_reference, health_check_args): def __init__(
self,
health_check_id: str,
caller_reference: str,
health_check_args: Dict[str, Any],
):
self.id = health_check_id self.id = health_check_id
self.ip_address = health_check_args.get("ip_address") self.ip_address = health_check_args.get("ip_address")
self.port = health_check_args.get("port") or 80 self.port = health_check_args.get("port") or 80
@ -70,35 +81,40 @@ class HealthCheck(CloudFormationModel):
self.children = None self.children = None
self.regions = None self.regions = None
def set_children(self, children): def set_children(self, children: Any) -> None:
if children and isinstance(children, list): if children and isinstance(children, list):
self.children = children self.children = children # type: ignore
elif children and isinstance(children, str): elif children and isinstance(children, str):
self.children = [children] self.children = [children] # type: ignore
def set_regions(self, regions): def set_regions(self, regions: Any) -> None:
if regions and isinstance(regions, list): if regions and isinstance(regions, list):
self.regions = regions self.regions = regions # type: ignore
elif regions and isinstance(regions, str): elif regions and isinstance(regions, str):
self.regions = [regions] self.regions = [regions] # type: ignore
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-healthcheck.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-healthcheck.html
return "AWS::Route53::HealthCheck" return "AWS::Route53::HealthCheck"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "HealthCheck":
properties = cloudformation_json["Properties"]["HealthCheckConfig"] properties = cloudformation_json["Properties"]["HealthCheckConfig"]
health_check_args = { health_check_args = {
"ip_address": properties.get("IPAddress"), "ip_address": properties.get("IPAddress"),
@ -116,7 +132,7 @@ class HealthCheck(CloudFormationModel):
) )
return health_check return health_check
def to_xml(self): def to_xml(self) -> str:
template = Template( template = Template(
"""<HealthCheck> """<HealthCheck>
<Id>{{ health_check.id }}</Id> <Id>{{ health_check.id }}</Id>
@ -169,8 +185,8 @@ class HealthCheck(CloudFormationModel):
class RecordSet(CloudFormationModel): class RecordSet(CloudFormationModel):
def __init__(self, kwargs): def __init__(self, kwargs: Dict[str, Any]):
self.name = kwargs.get("Name") self.name = kwargs.get("Name", "")
self.type_ = kwargs.get("Type") self.type_ = kwargs.get("Type")
self.ttl = kwargs.get("TTL", 0) self.ttl = kwargs.get("TTL", 0)
self.records = kwargs.get("ResourceRecords", []) self.records = kwargs.get("ResourceRecords", [])
@ -185,18 +201,23 @@ class RecordSet(CloudFormationModel):
self.geo_location = kwargs.get("GeoLocation", []) self.geo_location = kwargs.get("GeoLocation", [])
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "Name" return "Name"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-recordset.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-recordset.html
return "AWS::Route53::RecordSet" return "AWS::Route53::RecordSet"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "RecordSet":
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
zone_name = properties.get("HostedZoneName") zone_name = properties.get("HostedZoneName")
@ -209,14 +230,14 @@ class RecordSet(CloudFormationModel):
return record_set return record_set
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource, original_resource: Any,
new_resource_name, new_resource_name: str,
cloudformation_json, cloudformation_json: Any,
account_id, account_id: str,
region_name, region_name: str,
): ) -> "RecordSet":
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, account_id, region_name original_resource.name, cloudformation_json, account_id, region_name
) )
@ -225,9 +246,13 @@ class RecordSet(CloudFormationModel):
) )
@classmethod @classmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
# this will break if you changed the zone the record is in, # this will break if you changed the zone the record is in,
# unfortunately # unfortunately
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -245,10 +270,12 @@ class RecordSet(CloudFormationModel):
pass pass
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.name return self.name
def delete(self, account_id, region): # pylint: disable=unused-argument def delete(
self, account_id: str, region: str # pylint: disable=unused-argument
) -> None:
"""Not exposed as part of the Route 53 API - used for CloudFormation""" """Not exposed as part of the Route 53 API - used for CloudFormation"""
backend = route53_backends[account_id]["global"] backend = route53_backends[account_id]["global"]
hosted_zone = backend.get_hosted_zone_by_name(self.hosted_zone_name) hosted_zone = backend.get_hosted_zone_by_name(self.hosted_zone_name)
@ -257,22 +284,22 @@ class RecordSet(CloudFormationModel):
hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_}) hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_})
def reverse_domain_name(domain_name): def reverse_domain_name(domain_name: str) -> str:
if domain_name.endswith("."): # normalize without trailing dot if domain_name.endswith("."): # normalize without trailing dot
domain_name = domain_name[:-1] domain_name = domain_name[:-1]
return ".".join(reversed(domain_name.split("."))) return ".".join(reversed(domain_name.split(".")))
class ChangeList(list): class ChangeList(List[Dict[str, Any]]):
""" """
Contains a 'clean' list of ResourceRecordChangeSets Contains a 'clean' list of ResourceRecordChangeSets
""" """
def append(self, item) -> None: def append(self, item: Any) -> None:
item["ResourceRecordSet"]["Name"] = item["ResourceRecordSet"]["Name"].strip(".") item["ResourceRecordSet"]["Name"] = item["ResourceRecordSet"]["Name"].strip(".")
super().append(item) super().append(item)
def __contains__(self, item): def __contains__(self, item: Any) -> bool:
item["ResourceRecordSet"]["Name"] = item["ResourceRecordSet"]["Name"].strip(".") item["ResourceRecordSet"]["Name"] = item["ResourceRecordSet"]["Name"].strip(".")
return super().__contains__(item) return super().__contains__(item)
@ -280,28 +307,28 @@ class ChangeList(list):
class FakeZone(CloudFormationModel): class FakeZone(CloudFormationModel):
def __init__( def __init__(
self, self,
name, name: str,
id_, id_: str,
private_zone, private_zone: bool,
comment=None, comment: Optional[str] = None,
delegation_set=None, delegation_set: Optional[DelegationSet] = None,
): ):
self.name = name self.name = name
self.id = id_ self.id = id_
self.vpcs = [] self.vpcs: List[Dict[str, Any]] = []
if comment is not None: if comment is not None:
self.comment = comment self.comment = comment
self.private_zone = private_zone self.private_zone = private_zone
self.rrsets = [] self.rrsets: List[RecordSet] = []
self.delegation_set = delegation_set self.delegation_set = delegation_set
self.rr_changes = ChangeList() self.rr_changes = ChangeList()
def add_rrset(self, record_set): def add_rrset(self, record_set: Dict[str, Any]) -> RecordSet:
record_set = RecordSet(record_set) record_set_obj = RecordSet(record_set)
self.rrsets.append(record_set) self.rrsets.append(record_set_obj)
return record_set return record_set_obj
def upsert_rrset(self, record_set): def upsert_rrset(self, record_set: Dict[str, Any]) -> RecordSet:
new_rrset = RecordSet(record_set) new_rrset = RecordSet(record_set)
for i, rrset in enumerate(self.rrsets): for i, rrset in enumerate(self.rrsets):
if ( if (
@ -315,7 +342,7 @@ class FakeZone(CloudFormationModel):
self.rrsets.append(new_rrset) self.rrsets.append(new_rrset)
return new_rrset return new_rrset
def delete_rrset(self, rrset): def delete_rrset(self, rrset: Dict[str, Any]) -> None:
self.rrsets = [ self.rrsets = [
record_set record_set
for record_set in self.rrsets for record_set in self.rrsets
@ -323,14 +350,16 @@ class FakeZone(CloudFormationModel):
or (rrset.get("Type") is not None and record_set.type_ != rrset["Type"]) or (rrset.get("Type") is not None and record_set.type_ != rrset["Type"])
] ]
def delete_rrset_by_id(self, set_identifier): def delete_rrset_by_id(self, set_identifier: str) -> None:
self.rrsets = [ self.rrsets = [
record_set record_set
for record_set in self.rrsets for record_set in self.rrsets
if record_set.set_identifier != set_identifier if record_set.set_identifier != set_identifier
] ]
def add_vpc(self, vpc_id, vpc_region): def add_vpc(
self, vpc_id: Optional[str], vpc_region: Optional[str]
) -> Dict[str, Any]:
vpc = {} vpc = {}
if vpc_id is not None: if vpc_id is not None:
vpc["vpc_id"] = vpc_id vpc["vpc_id"] = vpc_id
@ -340,15 +369,15 @@ class FakeZone(CloudFormationModel):
self.vpcs.append(vpc) self.vpcs.append(vpc)
return vpc return vpc
def delete_vpc(self, vpc_id): def delete_vpc(self, vpc_id: str) -> None:
self.vpcs = [vpc for vpc in self.vpcs if vpc["vpc_id"] != vpc_id] self.vpcs = [vpc for vpc in self.vpcs if vpc["vpc_id"] != vpc_id]
def get_record_sets(self, start_type, start_name): def get_record_sets(self, start_type: str, start_name: str) -> List[RecordSet]:
def predicate(rrset): def predicate(rrset: RecordSet) -> bool:
rrset_name_reversed = reverse_domain_name(rrset.name) rrset_name_reversed = reverse_domain_name(rrset.name)
start_name_reversed = reverse_domain_name(start_name) start_name_reversed = reverse_domain_name(start_name)
return rrset_name_reversed < start_name_reversed or ( return rrset_name_reversed < start_name_reversed or (
rrset_name_reversed == start_name_reversed and rrset.type_ < start_type rrset_name_reversed == start_name_reversed and rrset.type_ < start_type # type: ignore
) )
record_sets = sorted( record_sets = sorted(
@ -358,27 +387,32 @@ class FakeZone(CloudFormationModel):
if start_name: if start_name:
start_type = start_type or "" start_type = start_type or ""
record_sets = itertools.dropwhile(predicate, record_sets) record_sets = itertools.dropwhile(predicate, record_sets) # type: ignore
return record_sets return record_sets
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "Name" return "Name"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-hostedzone.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-hostedzone.html
return "AWS::Route53::HostedZone" return "AWS::Route53::HostedZone"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "FakeZone":
hosted_zone = route53_backends[account_id]["global"].create_hosted_zone( hosted_zone = route53_backends[account_id]["global"].create_hosted_zone(
resource_name, private_zone=False resource_name, private_zone=False
) )
@ -386,27 +420,32 @@ class FakeZone(CloudFormationModel):
class RecordSetGroup(CloudFormationModel): class RecordSetGroup(CloudFormationModel):
def __init__(self, hosted_zone_id, record_sets): def __init__(self, hosted_zone_id: str, record_sets: List[str]):
self.hosted_zone_id = hosted_zone_id self.hosted_zone_id = hosted_zone_id
self.record_sets = record_sets self.record_sets = record_sets
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return f"arn:aws:route53:::hostedzone/{self.hosted_zone_id}" return f"arn:aws:route53:::hostedzone/{self.hosted_zone_id}"
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-recordsetgroup.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-route53-recordsetgroup.html
return "AWS::Route53::RecordSetGroup" return "AWS::Route53::RecordSetGroup"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "RecordSetGroup":
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
zone_name = properties.get("HostedZoneName") zone_name = properties.get("HostedZoneName")
@ -428,14 +467,17 @@ class QueryLoggingConfig(BaseModel):
"""QueryLoggingConfig class; this object isn't part of Cloudformation.""" """QueryLoggingConfig class; this object isn't part of Cloudformation."""
def __init__( def __init__(
self, query_logging_config_id, hosted_zone_id, cloudwatch_logs_log_group_arn self,
query_logging_config_id: str,
hosted_zone_id: str,
cloudwatch_logs_log_group_arn: str,
): ):
self.hosted_zone_id = hosted_zone_id self.hosted_zone_id = hosted_zone_id
self.cloudwatch_logs_log_group_arn = cloudwatch_logs_log_group_arn self.cloudwatch_logs_log_group_arn = cloudwatch_logs_log_group_arn
self.query_logging_config_id = query_logging_config_id self.query_logging_config_id = query_logging_config_id
self.location = f"https://route53.amazonaws.com/2013-04-01/queryloggingconfig/{self.query_logging_config_id}" self.location = f"https://route53.amazonaws.com/2013-04-01/queryloggingconfig/{self.query_logging_config_id}"
def to_xml(self): def to_xml(self) -> str:
template = Template( template = Template(
"""<QueryLoggingConfig> """<QueryLoggingConfig>
<CloudWatchLogsLogGroupArn>{{ query_logging_config.cloudwatch_logs_log_group_arn }}</CloudWatchLogsLogGroupArn> <CloudWatchLogsLogGroupArn>{{ query_logging_config.cloudwatch_logs_log_group_arn }}</CloudWatchLogsLogGroupArn>
@ -449,23 +491,23 @@ class QueryLoggingConfig(BaseModel):
class Route53Backend(BaseBackend): class Route53Backend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.zones = {} self.zones: Dict[str, FakeZone] = {}
self.health_checks = {} self.health_checks: Dict[str, HealthCheck] = {}
self.resource_tags = defaultdict(dict) self.resource_tags: Dict[str, Any] = defaultdict(dict)
self.query_logging_configs = {} self.query_logging_configs: Dict[str, QueryLoggingConfig] = {}
self.delegation_sets = dict() self.delegation_sets: Dict[str, DelegationSet] = dict()
def create_hosted_zone( def create_hosted_zone(
self, self,
name, name: str,
private_zone, private_zone: bool,
vpcid=None, vpcid: Optional[str] = None,
vpcregion=None, vpcregion: Optional[str] = None,
comment=None, comment: Optional[str] = None,
delegation_set_id=None, delegation_set_id: Optional[str] = None,
): ) -> FakeZone:
new_id = create_route53_zone_id() new_id = create_route53_zone_id()
delegation_set = self.create_reusable_delegation_set( delegation_set = self.create_reusable_delegation_set(
caller_reference=f"DelSet_{name}", delegation_set_id=delegation_set_id caller_reference=f"DelSet_{name}", delegation_set_id=delegation_set_id
@ -506,25 +548,27 @@ class Route53Backend(BaseBackend):
self.zones[new_id] = new_zone self.zones[new_id] = new_zone
return new_zone return new_zone
def get_dnssec(self, zone_id): def get_dnssec(self, zone_id: str) -> None:
# check if hosted zone exists # check if hosted zone exists
self.get_hosted_zone(zone_id) self.get_hosted_zone(zone_id)
def associate_vpc_with_hosted_zone(self, zone_id, vpcid, vpcregion): def associate_vpc_with_hosted_zone(
self, zone_id: str, vpcid: str, vpcregion: str
) -> FakeZone:
zone = self.get_hosted_zone(zone_id) zone = self.get_hosted_zone(zone_id)
if not zone.private_zone: if not zone.private_zone:
raise PublicZoneVPCAssociation() raise PublicZoneVPCAssociation()
zone.add_vpc(vpcid, vpcregion) zone.add_vpc(vpcid, vpcregion)
return zone return zone
def disassociate_vpc_from_hosted_zone(self, zone_id, vpcid): def disassociate_vpc_from_hosted_zone(self, zone_id: str, vpcid: str) -> FakeZone:
zone = self.get_hosted_zone(zone_id) zone = self.get_hosted_zone(zone_id)
if len(zone.vpcs) <= 1: if len(zone.vpcs) <= 1:
raise LastVPCAssociation() raise LastVPCAssociation()
zone.delete_vpc(vpcid) zone.delete_vpc(vpcid)
return zone return zone
def change_tags_for_resource(self, resource_id, tags): def change_tags_for_resource(self, resource_id: str, tags: Any) -> None:
if "Tag" in tags: if "Tag" in tags:
if isinstance(tags["Tag"], list): if isinstance(tags["Tag"], list):
for tag in tags["Tag"]: for tag in tags["Tag"]:
@ -540,12 +584,14 @@ class Route53Backend(BaseBackend):
else: else:
del self.resource_tags[resource_id][tags["Key"]] del self.resource_tags[resource_id][tags["Key"]]
def list_tags_for_resource(self, resource_id): def list_tags_for_resource(self, resource_id: str) -> Dict[str, str]:
if resource_id in self.resource_tags: if resource_id in self.resource_tags:
return self.resource_tags[resource_id] return self.resource_tags[resource_id]
return {} return {}
def list_resource_record_sets(self, zone_id, start_type, start_name, max_items): def list_resource_record_sets(
self, zone_id: str, start_type: str, start_name: str, max_items: int
) -> Tuple[List[RecordSet], Optional[str], Optional[str], bool]:
""" """
The StartRecordIdentifier-parameter is not yet implemented The StartRecordIdentifier-parameter is not yet implemented
""" """
@ -558,7 +604,9 @@ class Route53Backend(BaseBackend):
is_truncated = next_record is not None is_truncated = next_record is not None
return records, next_start_name, next_start_type, is_truncated return records, next_start_name, next_start_type, is_truncated
def change_resource_record_sets(self, zoneid, change_list) -> None: def change_resource_record_sets(
self, zoneid: str, change_list: List[Dict[str, Any]]
) -> None:
the_zone = self.get_hosted_zone(zoneid) the_zone = self.get_hosted_zone(zoneid)
for value in change_list: for value in change_list:
@ -622,20 +670,23 @@ class Route53Backend(BaseBackend):
the_zone.delete_rrset(record_set) the_zone.delete_rrset(record_set)
the_zone.rr_changes.append(original_change) the_zone.rr_changes.append(original_change)
def list_hosted_zones(self): def list_hosted_zones(self) -> List[FakeZone]:
return self.zones.values() return list(self.zones.values())
def list_hosted_zones_by_name(self, dnsname): def list_hosted_zones_by_name(
if dnsname: self, dnsnames: Optional[List[str]]
dnsname = dnsname[0] ) -> Tuple[Optional[str], List[FakeZone]]:
if dnsnames:
dnsname = dnsnames[0] # type: ignore
if dnsname[-1] != ".": if dnsname[-1] != ".":
dnsname += "." dnsname += "."
zones = [zone for zone in self.list_hosted_zones() if zone.name == dnsname] zones = [zone for zone in self.list_hosted_zones() if zone.name == dnsname] # type: ignore
else: else:
dnsname = None
# sort by names, but with domain components reversed # sort by names, but with domain components reversed
# see http://boto3.readthedocs.io/en/latest/reference/services/route53.html#Route53.Client.list_hosted_zones_by_name # see http://boto3.readthedocs.io/en/latest/reference/services/route53.html#Route53.Client.list_hosted_zones_by_name
def sort_key(zone): def sort_key(zone: FakeZone) -> str:
domains = zone.name.split(".") domains = zone.name.split(".")
if domains[-1] == "": if domains[-1] == "":
domains = domains[-1:] + domains[:-1] domains = domains[-1:] + domains[:-1]
@ -643,9 +694,9 @@ class Route53Backend(BaseBackend):
zones = self.list_hosted_zones() zones = self.list_hosted_zones()
zones = sorted(zones, key=sort_key) zones = sorted(zones, key=sort_key)
return dnsname, zones return dnsname, zones # type: ignore
def list_hosted_zones_by_vpc(self, vpc_id): def list_hosted_zones_by_vpc(self, vpc_id: str) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -665,22 +716,22 @@ class Route53Backend(BaseBackend):
return zone_list return zone_list
def get_hosted_zone(self, id_) -> FakeZone: def get_hosted_zone(self, id_: str) -> FakeZone:
the_zone = self.zones.get(id_.replace("/hostedzone/", "")) the_zone = self.zones.get(id_.replace("/hostedzone/", ""))
if not the_zone: if not the_zone:
raise NoSuchHostedZone(id_) raise NoSuchHostedZone(id_)
return the_zone return the_zone
def get_hosted_zone_count(self): def get_hosted_zone_count(self) -> int:
return len(self.list_hosted_zones()) return len(self.list_hosted_zones())
def get_hosted_zone_by_name(self, name): def get_hosted_zone_by_name(self, name: str) -> Optional[FakeZone]:
for zone in self.list_hosted_zones(): for zone in self.list_hosted_zones():
if zone.name == name: if zone.name == name:
return zone return zone
return None return None
def delete_hosted_zone(self, id_): def delete_hosted_zone(self, id_: str) -> Optional[FakeZone]:
# Verify it exists # Verify it exists
zone = self.get_hosted_zone(id_) zone = self.get_hosted_zone(id_)
if len(zone.rrsets) > 0: if len(zone.rrsets) > 0:
@ -689,12 +740,14 @@ class Route53Backend(BaseBackend):
raise HostedZoneNotEmpty() raise HostedZoneNotEmpty()
return self.zones.pop(id_.replace("/hostedzone/", ""), None) return self.zones.pop(id_.replace("/hostedzone/", ""), None)
def update_hosted_zone_comment(self, id_, comment): def update_hosted_zone_comment(self, id_: str, comment: str) -> FakeZone:
zone = self.get_hosted_zone(id_) zone = self.get_hosted_zone(id_)
zone.comment = comment zone.comment = comment
return zone return zone
def create_health_check(self, caller_reference, health_check_args): def create_health_check(
self, caller_reference: str, health_check_args: Dict[str, Any]
) -> HealthCheck:
health_check_id = str(random.uuid4()) health_check_id = str(random.uuid4())
health_check = HealthCheck(health_check_id, caller_reference, health_check_args) health_check = HealthCheck(health_check_id, caller_reference, health_check_args)
health_check.set_children(health_check_args.get("children")) health_check.set_children(health_check_args.get("children"))
@ -702,10 +755,12 @@ class Route53Backend(BaseBackend):
self.health_checks[health_check_id] = health_check self.health_checks[health_check_id] = health_check
return health_check return health_check
def update_health_check(self, health_check_id, health_check_args): def update_health_check(
self, health_check_id: str, health_check_args: Dict[str, Any]
) -> HealthCheck:
health_check = self.health_checks.get(health_check_id) health_check = self.health_checks.get(health_check_id)
if not health_check: if not health_check:
raise NoSuchHealthCheck() raise NoSuchHealthCheck(health_check_id)
if health_check_args.get("ip_address"): if health_check_args.get("ip_address"):
health_check.ip_address = health_check_args.get("ip_address") health_check.ip_address = health_check_args.get("ip_address")
@ -736,30 +791,32 @@ class Route53Backend(BaseBackend):
return health_check return health_check
def list_health_checks(self): def list_health_checks(self) -> List[HealthCheck]:
return self.health_checks.values() return list(self.health_checks.values())
def delete_health_check(self, health_check_id): def delete_health_check(self, health_check_id: str) -> None:
return self.health_checks.pop(health_check_id, None) self.health_checks.pop(health_check_id, None)
def get_health_check(self, health_check_id): def get_health_check(self, health_check_id: str) -> HealthCheck:
health_check = self.health_checks.get(health_check_id) health_check = self.health_checks.get(health_check_id)
if not health_check: if not health_check:
raise NoSuchHealthCheck(health_check_id) raise NoSuchHealthCheck(health_check_id)
return health_check return health_check
@staticmethod @staticmethod
def _validate_arn(region, arn): def _validate_arn(region: str, arn: str) -> None:
match = re.match(rf"arn:aws:logs:{region}:\d{{12}}:log-group:.+", arn) match = re.match(rf"arn:aws:logs:{region}:\d{{12}}:log-group:.+", arn)
if not arn or not match: if not arn or not match:
raise InvalidCloudWatchArn() raise InvalidCloudWatchArn()
# The CloudWatch Logs log group must be in the "us-east-1" region. # The CloudWatch Logs log group must be in the "us-east-1" region.
match = re.match(r"^(?:[^:]+:){3}(?P<region>[^:]+).*", arn) match = re.match(r"^(?:[^:]+:){3}(?P<region>[^:]+).*", arn)
if match.group("region") != "us-east-1": if not match or match.group("region") != "us-east-1":
raise InvalidCloudWatchArn() raise InvalidCloudWatchArn()
def create_query_logging_config(self, region, hosted_zone_id, log_group_arn): def create_query_logging_config(
self, region: str, hosted_zone_id: str, log_group_arn: str
) -> QueryLoggingConfig:
"""Process the create_query_logging_config request.""" """Process the create_query_logging_config request."""
# Does the hosted_zone_id exist? # Does the hosted_zone_id exist?
response = self.list_hosted_zones() response = self.list_hosted_zones()
@ -785,7 +842,7 @@ class Route53Backend(BaseBackend):
response = logs_backends[self.account_id][region].describe_log_groups() response = logs_backends[self.account_id][region].describe_log_groups()
log_groups = response[0] if response else [] log_groups = response[0] if response else []
for entry in log_groups: for entry in log_groups: # type: ignore
if log_group_arn == entry["arn"]: if log_group_arn == entry["arn"]:
break break
else: else:
@ -806,20 +863,22 @@ class Route53Backend(BaseBackend):
self.query_logging_configs[query_logging_config_id] = query_logging_config self.query_logging_configs[query_logging_config_id] = query_logging_config
return query_logging_config return query_logging_config
def delete_query_logging_config(self, query_logging_config_id): def delete_query_logging_config(self, query_logging_config_id: str) -> None:
"""Delete query logging config, if it exists.""" """Delete query logging config, if it exists."""
if query_logging_config_id not in self.query_logging_configs: if query_logging_config_id not in self.query_logging_configs:
raise NoSuchQueryLoggingConfig() raise NoSuchQueryLoggingConfig()
self.query_logging_configs.pop(query_logging_config_id) self.query_logging_configs.pop(query_logging_config_id)
def get_query_logging_config(self, query_logging_config_id): def get_query_logging_config(
self, query_logging_config_id: str
) -> QueryLoggingConfig:
"""Return query logging config, if it exists.""" """Return query logging config, if it exists."""
if query_logging_config_id not in self.query_logging_configs: if query_logging_config_id not in self.query_logging_configs:
raise NoSuchQueryLoggingConfig() raise NoSuchQueryLoggingConfig()
return self.query_logging_configs[query_logging_config_id] return self.query_logging_configs[query_logging_config_id]
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_query_logging_configs(self, hosted_zone_id=None): def list_query_logging_configs(self, hosted_zone_id: Optional[str] = None) -> List[QueryLoggingConfig]: # type: ignore
"""Return a list of query logging configs.""" """Return a list of query logging configs."""
if hosted_zone_id: if hosted_zone_id:
# Does the hosted_zone_id exist? # Does the hosted_zone_id exist?
@ -834,28 +893,31 @@ class Route53Backend(BaseBackend):
return list(self.query_logging_configs.values()) return list(self.query_logging_configs.values())
def create_reusable_delegation_set( def create_reusable_delegation_set(
self, caller_reference, delegation_set_id=None, hosted_zone_id=None self,
): caller_reference: str,
name_servers = None delegation_set_id: Optional[str] = None,
hosted_zone_id: Optional[str] = None,
) -> DelegationSet:
name_servers: Optional[List[str]] = None
if hosted_zone_id: if hosted_zone_id:
hosted_zone = self.get_hosted_zone(hosted_zone_id) hosted_zone = self.get_hosted_zone(hosted_zone_id)
name_servers = hosted_zone.delegation_set.name_servers name_servers = hosted_zone.delegation_set.name_servers # type: ignore
delegation_set = DelegationSet( delegation_set = DelegationSet(
caller_reference, name_servers, delegation_set_id caller_reference, name_servers, delegation_set_id
) )
self.delegation_sets[delegation_set.id] = delegation_set self.delegation_sets[delegation_set.id] = delegation_set
return delegation_set return delegation_set
def list_reusable_delegation_sets(self): def list_reusable_delegation_sets(self) -> List[DelegationSet]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
return self.delegation_sets.values() return list(self.delegation_sets.values())
def delete_reusable_delegation_set(self, delegation_set_id): def delete_reusable_delegation_set(self, delegation_set_id: str) -> None:
self.delegation_sets.pop(delegation_set_id, None) self.delegation_sets.pop(delegation_set_id, None)
def get_reusable_delegation_set(self, delegation_set_id): def get_reusable_delegation_set(self, delegation_set_id: str) -> DelegationSet:
if delegation_set_id not in self.delegation_sets: if delegation_set_id not in self.delegation_sets:
raise NoSuchDelegationSet(delegation_set_id) raise NoSuchDelegationSet(delegation_set_id)
return self.delegation_sets[delegation_set_id] return self.delegation_sets[delegation_set_id]

View File

@ -2,11 +2,13 @@
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from jinja2 import Template from jinja2 import Template
from typing import Any
import xmltodict import xmltodict
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.route53.exceptions import InvalidChangeBatch from moto.route53.exceptions import InvalidChangeBatch
from moto.route53.models import route53_backends from moto.route53.models import route53_backends, Route53Backend
XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/" XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
@ -14,11 +16,11 @@ XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
class Route53(BaseResponse): class Route53(BaseResponse):
"""Handler for Route53 requests and responses.""" """Handler for Route53 requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="route53") super().__init__(service_name="route53")
@staticmethod @staticmethod
def _convert_to_bool(bool_str): def _convert_to_bool(bool_str: Any) -> bool: # type: ignore[misc]
if isinstance(bool_str, bool): if isinstance(bool_str, bool):
return bool_str return bool_str
@ -28,10 +30,10 @@ class Route53(BaseResponse):
return False return False
@property @property
def backend(self): def backend(self) -> Route53Backend:
return route53_backends[self.current_account]["global"] return route53_backends[self.current_account]["global"]
def list_or_create_hostzone_response(self, request, full_url, headers): def list_or_create_hostzone_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
# Set these here outside the scope of the try/except # Set these here outside the scope of the try/except
@ -83,33 +85,39 @@ class Route53(BaseResponse):
template = Template(LIST_HOSTED_ZONES_RESPONSE) template = Template(LIST_HOSTED_ZONES_RESPONSE)
return 200, headers, template.render(zones=all_zones) return 200, headers, template.render(zones=all_zones)
def list_hosted_zones_by_name_response(self, request, full_url, headers): def list_hosted_zones_by_name_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query_params = parse_qs(parsed_url.query) query_params = parse_qs(parsed_url.query)
dnsname = query_params.get("dnsname") dnsnames = query_params.get("dnsname")
dnsname, zones = self.backend.list_hosted_zones_by_name(dnsname) dnsname, zones = self.backend.list_hosted_zones_by_name(dnsnames)
template = Template(LIST_HOSTED_ZONES_BY_NAME_RESPONSE) template = Template(LIST_HOSTED_ZONES_BY_NAME_RESPONSE)
return 200, headers, template.render(zones=zones, dnsname=dnsname, xmlns=XMLNS) return 200, headers, template.render(zones=zones, dnsname=dnsname, xmlns=XMLNS)
def list_hosted_zones_by_vpc_response(self, request, full_url, headers): def list_hosted_zones_by_vpc_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query_params = parse_qs(parsed_url.query) query_params = parse_qs(parsed_url.query)
vpc_id = query_params.get("vpcid")[0] vpc_id = query_params.get("vpcid")[0] # type: ignore
zones = self.backend.list_hosted_zones_by_vpc(vpc_id) zones = self.backend.list_hosted_zones_by_vpc(vpc_id)
template = Template(LIST_HOSTED_ZONES_BY_VPC_RESPONSE) template = Template(LIST_HOSTED_ZONES_BY_VPC_RESPONSE)
return 200, headers, template.render(zones=zones, xmlns=XMLNS) return 200, headers, template.render(zones=zones, xmlns=XMLNS)
def get_hosted_zone_count_response(self, request, full_url, headers): def get_hosted_zone_count_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
num_zones = self.backend.get_hosted_zone_count() num_zones = self.backend.get_hosted_zone_count()
template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE) template = Template(GET_HOSTED_ZONE_COUNT_RESPONSE)
return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS) return 200, headers, template.render(zone_count=num_zones, xmlns=XMLNS)
def get_or_delete_hostzone_response(self, request, full_url, headers): def get_or_delete_hostzone_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1] zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
@ -130,7 +138,7 @@ class Route53(BaseResponse):
template = Template(UPDATE_HOSTED_ZONE_COMMENT_RESPONSE) template = Template(UPDATE_HOSTED_ZONE_COMMENT_RESPONSE)
return 200, headers, template.render(zone=zone) return 200, headers, template.render(zone=zone)
def get_dnssec_response(self, request, full_url, headers): def get_dnssec_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
# returns static response # returns static response
# TODO: implement enable/disable dnssec apis # TODO: implement enable/disable dnssec apis
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -144,7 +152,9 @@ class Route53(BaseResponse):
self.backend.get_dnssec(zoneid) self.backend.get_dnssec(zoneid)
return 200, headers, GET_DNSSEC return 200, headers, GET_DNSSEC
def associate_vpc_response(self, request, full_url, headers): def associate_vpc_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -163,7 +173,9 @@ class Route53(BaseResponse):
template = Template(ASSOCIATE_VPC_RESPONSE) template = Template(ASSOCIATE_VPC_RESPONSE)
return 200, headers, template.render(comment=comment) return 200, headers, template.render(comment=comment)
def disassociate_vpc_response(self, request, full_url, headers): def disassociate_vpc_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -181,7 +193,7 @@ class Route53(BaseResponse):
template = Template(DISASSOCIATE_VPC_RESPONSE) template = Template(DISASSOCIATE_VPC_RESPONSE)
return 200, headers, template.render(comment=comment) return 200, headers, template.render(comment=comment)
def rrset_response(self, request, full_url, headers): def rrset_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -226,8 +238,8 @@ class Route53(BaseResponse):
elif method == "GET": elif method == "GET":
querystring = parse_qs(parsed_url.query) querystring = parse_qs(parsed_url.query)
template = Template(LIST_RRSET_RESPONSE) template = Template(LIST_RRSET_RESPONSE)
start_type = querystring.get("type", [None])[0] start_type = querystring.get("type", [None])[0] # type: ignore
start_name = querystring.get("name", [None])[0] start_name = querystring.get("name", [None])[0] # type: ignore
max_items = int(querystring.get("maxitems", ["300"])[0]) max_items = int(querystring.get("maxitems", ["300"])[0])
if start_type and not start_name: if start_type and not start_name:
@ -244,19 +256,18 @@ class Route53(BaseResponse):
start_name=start_name, start_name=start_name,
max_items=max_items, max_items=max_items,
) )
template = template.render( r_template = template.render(
record_sets=record_sets, record_sets=record_sets,
next_name=next_name, next_name=next_name,
next_type=next_type, next_type=next_type,
max_items=max_items, max_items=max_items,
is_truncated=is_truncated, is_truncated=is_truncated,
) )
return 200, headers, template return 200, headers, r_template
def health_check_response1(self, request, full_url, headers): def health_check_response1(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url)
method = request.method method = request.method
if method == "POST": if method == "POST":
@ -285,11 +296,6 @@ class Route53(BaseResponse):
) )
template = Template(CREATE_HEALTH_CHECK_RESPONSE) template = Template(CREATE_HEALTH_CHECK_RESPONSE)
return 201, headers, template.render(health_check=health_check, xmlns=XMLNS) return 201, headers, template.render(health_check=health_check, xmlns=XMLNS)
elif method == "DELETE":
health_check_id = parsed_url.path.split("/")[-1]
self.backend.delete_health_check(health_check_id)
template = Template(DELETE_HEALTH_CHECK_RESPONSE)
return 200, headers, template.render(xmlns=XMLNS)
elif method == "GET": elif method == "GET":
template = Template(LIST_HEALTH_CHECKS_RESPONSE) template = Template(LIST_HEALTH_CHECKS_RESPONSE)
health_checks = self.backend.list_health_checks() health_checks = self.backend.list_health_checks()
@ -299,7 +305,7 @@ class Route53(BaseResponse):
template.render(health_checks=health_checks, xmlns=XMLNS), template.render(health_checks=health_checks, xmlns=XMLNS),
) )
def health_check_response2(self, request, full_url, headers): def health_check_response2(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -336,7 +342,9 @@ class Route53(BaseResponse):
template = Template(UPDATE_HEALTH_CHECK_RESPONSE) template = Template(UPDATE_HEALTH_CHECK_RESPONSE)
return 200, headers, template.render(health_check=health_check) return 200, headers, template.render(health_check=health_check)
def not_implemented_response(self, request, full_url, headers): def not_implemented_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
action = "" action = ""
@ -348,7 +356,7 @@ class Route53(BaseResponse):
f"The action for {action} has not been implemented for route 53" f"The action for {action} has not been implemented for route 53"
) )
def list_or_change_tags_for_resource_request(self, request, full_url, headers): def list_or_change_tags_for_resource_request(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
@ -368,15 +376,15 @@ class Route53(BaseResponse):
tags = xmltodict.parse(self.body)["ChangeTagsForResourceRequest"] tags = xmltodict.parse(self.body)["ChangeTagsForResourceRequest"]
if "AddTags" in tags: if "AddTags" in tags:
tags = tags["AddTags"] tags = tags["AddTags"] # type: ignore
elif "RemoveTagKeys" in tags: elif "RemoveTagKeys" in tags:
tags = tags["RemoveTagKeys"] tags = tags["RemoveTagKeys"] # type: ignore
self.backend.change_tags_for_resource(id_, tags) self.backend.change_tags_for_resource(id_, tags)
template = Template(CHANGE_TAGS_FOR_RESOURCE_RESPONSE) template = Template(CHANGE_TAGS_FOR_RESOURCE_RESPONSE)
return 200, headers, template.render() return 200, headers, template.render()
def get_change(self, request, full_url, headers): def get_change(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
@ -385,7 +393,7 @@ class Route53(BaseResponse):
template = Template(GET_CHANGE_RESPONSE) template = Template(GET_CHANGE_RESPONSE)
return 200, headers, template.render(change_id=change_id, xmlns=XMLNS) return 200, headers, template.render(change_id=change_id, xmlns=XMLNS)
def list_or_create_query_logging_config_response(self, request, full_url, headers): def list_or_create_query_logging_config_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "POST": if request.method == "POST":
@ -429,7 +437,7 @@ class Route53(BaseResponse):
), ),
) )
def get_or_delete_query_logging_config_response(self, request, full_url, headers): def get_or_delete_query_logging_config_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query_logging_config_id = parsed_url.path.rstrip("/").rsplit("/", 1)[1] query_logging_config_id = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
@ -449,7 +457,7 @@ class Route53(BaseResponse):
self.backend.delete_query_logging_config(query_logging_config_id) self.backend.delete_query_logging_config(query_logging_config_id)
return 200, headers, "" return 200, headers, ""
def reusable_delegation_sets(self, request, full_url, headers): def reusable_delegation_sets(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
delegation_sets = self.backend.list_reusable_delegation_sets() delegation_sets = self.backend.list_reusable_delegation_sets()
@ -479,7 +487,7 @@ class Route53(BaseResponse):
template.render(delegation_set=delegation_set), template.render(delegation_set=delegation_set),
) )
def reusable_delegation_set(self, request, full_url, headers): def reusable_delegation_set(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
ds_id = parsed_url.path.rstrip("/").rsplit("/")[-1] ds_id = parsed_url.path.rstrip("/").rsplit("/")[-1]

View File

@ -1,15 +1,21 @@
"""Route53 base URL and path.""" """Route53 base URL and path."""
from typing import Any
from .responses import Route53 from .responses import Route53
from moto.core.common_types import TYPE_RESPONSE
url_bases = [r"https?://route53(\..+)?\.amazonaws.com"] url_bases = [r"https?://route53(\..+)?\.amazonaws.com"]
def tag_response1(*args, **kwargs): def tag_response1(request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
return Route53().list_or_change_tags_for_resource_request(*args, **kwargs) return Route53().list_or_change_tags_for_resource_request(
request, full_url, headers
)
def tag_response2(*args, **kwargs): def tag_response2(request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
return Route53().list_or_change_tags_for_resource_request(*args, **kwargs) return Route53().list_or_change_tags_for_resource_request(
request, full_url, headers
)
url_paths = { url_paths = {

View File

@ -1,13 +1,11 @@
"""Exceptions raised by the route53resolver service.""" from typing import List, Tuple
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
class RRValidationException(JsonRESTError): class RRValidationException(JsonRESTError):
"""Report one of more parameter validation errors."""
code = 400 code = 400
def __init__(self, error_tuples): def __init__(self, error_tuples: List[Tuple[str, str, str]]):
"""Validation errors are concatenated into one exception message. """Validation errors are concatenated into one exception message.
error_tuples is a list of tuples. Each tuple contains: error_tuples is a list of tuples. Each tuple contains:
@ -30,11 +28,10 @@ class RRValidationException(JsonRESTError):
class InvalidNextTokenException(JsonRESTError): class InvalidNextTokenException(JsonRESTError):
"""Invalid next token parameter used to return a list of entities."""
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidNextTokenException", "InvalidNextTokenException",
"Invalid value passed for the NextToken parameter", "Invalid value passed for the NextToken parameter",
@ -42,63 +39,56 @@ class InvalidNextTokenException(JsonRESTError):
class InvalidParameterException(JsonRESTError): class InvalidParameterException(JsonRESTError):
"""One or more parameters in request are not valid."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterException", message) super().__init__("InvalidParameterException", message)
class InvalidRequestException(JsonRESTError): class InvalidRequestException(JsonRESTError):
"""The request is invalid."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidRequestException", message) super().__init__("InvalidRequestException", message)
class LimitExceededException(JsonRESTError): class LimitExceededException(JsonRESTError):
"""The request caused one or more limits to be exceeded."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("LimitExceededException", message) super().__init__("LimitExceededException", message)
class ResourceExistsException(JsonRESTError): class ResourceExistsException(JsonRESTError):
"""The resource already exists."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceExistsException", message) super().__init__("ResourceExistsException", message)
class ResourceInUseException(JsonRESTError): class ResourceInUseException(JsonRESTError):
"""The resource has other resources associated with it."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceInUseException", message) super().__init__("ResourceInUseException", message)
class ResourceNotFoundException(JsonRESTError): class ResourceNotFoundException(JsonRESTError):
"""The specified resource doesn't exist."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceNotFoundException", message) super().__init__("ResourceNotFoundException", message)
class TagValidationException(JsonRESTError): class TagValidationException(JsonRESTError):
"""Tag validation failed."""
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ValidationException", message) super().__init__("ValidationException", message)

View File

@ -2,6 +2,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from ipaddress import ip_address, ip_network, IPv4Address from ipaddress import ip_address, ip_network, IPv4Address
from typing import Any, Dict, List, Optional, Set
import re import re
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
@ -43,7 +44,12 @@ class ResolverRuleAssociation(BaseModel): # pylint: disable=too-few-public-meth
] ]
def __init__( def __init__(
self, region, resolver_rule_association_id, resolver_rule_id, vpc_id, name=None self,
region: str,
resolver_rule_association_id: str,
resolver_rule_id: str,
vpc_id: str,
name: str,
): # pylint: disable=too-many-arguments ): # pylint: disable=too-many-arguments
self.region = region self.region = region
self.resolver_rule_id = resolver_rule_id self.resolver_rule_id = resolver_rule_id
@ -55,7 +61,7 @@ class ResolverRuleAssociation(BaseModel): # pylint: disable=too-few-public-meth
self.status = "COMPLETE" self.status = "COMPLETE"
self.status_message = "" self.status_message = ""
def description(self): def description(self) -> Dict[str, Any]:
"""Return dictionary of relevant info for resolver rule association.""" """Return dictionary of relevant info for resolver rule association."""
return { return {
"Id": self.id, "Id": self.id,
@ -86,15 +92,15 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
rule_id, rule_id: str,
creator_request_id, creator_request_id: str,
rule_type, rule_type: str,
domain_name, domain_name: str,
target_ips=None, target_ips: Optional[List[Dict[str, Any]]],
resolver_endpoint_id=None, resolver_endpoint_id: Optional[str],
name=None, name: str,
): # pylint: disable=too-many-arguments ): # pylint: disable=too-many-arguments
self.account_id = account_id self.account_id = account_id
self.region = region self.region = region
@ -122,11 +128,11 @@ class ResolverRule(BaseModel): # pylint: disable=too-many-instance-attributes
self.modification_time = datetime.now(timezone.utc).isoformat() self.modification_time = datetime.now(timezone.utc).isoformat()
@property @property
def arn(self): def arn(self) -> str:
"""Return ARN for this resolver rule.""" """Return ARN for this resolver rule."""
return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-rule/{self.id}" return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-rule/{self.id}"
def description(self): def description(self) -> Dict[str, Any]:
"""Return a dictionary of relevant info for this resolver rule.""" """Return a dictionary of relevant info for this resolver rule."""
return { return {
"Id": self.id, "Id": self.id,
@ -166,14 +172,14 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
endpoint_id, endpoint_id: str,
creator_request_id, creator_request_id: str,
security_group_ids, security_group_ids: List[str],
direction, direction: str,
ip_addresses, ip_addresses: List[Dict[str, Any]],
name=None, name: str,
): # pylint: disable=too-many-arguments ): # pylint: disable=too-many-arguments
self.account_id = account_id self.account_id = account_id
self.region = region self.region = region
@ -206,11 +212,11 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
self.modification_time = datetime.now(timezone.utc).isoformat() self.modification_time = datetime.now(timezone.utc).isoformat()
@property @property
def arn(self): def arn(self) -> str:
"""Return ARN for this resolver endpoint.""" """Return ARN for this resolver endpoint."""
return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-endpoint/{self.id}" return f"arn:aws:route53resolver:{self.region}:{self.account_id}:resolver-endpoint/{self.id}"
def _vpc_id_from_subnet(self): def _vpc_id_from_subnet(self) -> str:
"""Return VPC Id associated with the subnet. """Return VPC Id associated with the subnet.
The assumption is that all of the subnets are associated with the The assumption is that all of the subnets are associated with the
@ -221,19 +227,19 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
subnet_info = self.ec2_backend.describe_subnets(subnet_ids=[first_subnet_id])[0] subnet_info = self.ec2_backend.describe_subnets(subnet_ids=[first_subnet_id])[0]
return subnet_info.vpc_id return subnet_info.vpc_id
def _build_subnet_info(self): def _build_subnet_info(self) -> Dict[str, Any]:
"""Create a dict of subnet info, including ip addrs and ENI ids. """Create a dict of subnet info, including ip addrs and ENI ids.
self.subnets[subnet_id][ip_addr1] = eni-id1 ... self.subnets[subnet_id][ip_addr1] = eni-id1 ...
""" """
subnets = defaultdict(dict) subnets: Dict[str, Any] = defaultdict(dict)
for entry in self.ip_addresses: for entry in self.ip_addresses:
subnets[entry["SubnetId"]][ subnets[entry["SubnetId"]][
entry["Ip"] entry["Ip"]
] = f"rni-{mock_random.get_random_hex(17)}" ] = f"rni-{mock_random.get_random_hex(17)}"
return subnets return subnets
def create_eni(self): def create_eni(self) -> List[str]:
"""Create a VPC ENI for each combo of AZ, subnet and IP.""" """Create a VPC ENI for each combo of AZ, subnet and IP."""
eni_ids = [] eni_ids = []
for subnet, ip_info in self.subnets.items(): for subnet, ip_info in self.subnets.items():
@ -251,12 +257,12 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
eni_ids.append(eni_info.id) eni_ids.append(eni_info.id)
return eni_ids return eni_ids
def delete_eni(self): def delete_eni(self) -> None:
"""Delete the VPC ENI created for the subnet and IP combos.""" """Delete the VPC ENI created for the subnet and IP combos."""
for eni_id in self.eni_ids: for eni_id in self.eni_ids:
self.ec2_backend.delete_network_interface(eni_id) self.ec2_backend.delete_network_interface(eni_id)
def description(self): def description(self) -> Dict[str, Any]:
"""Return a dictionary of relevant info for this resolver endpoint.""" """Return a dictionary of relevant info for this resolver endpoint."""
return { return {
"Id": self.id, "Id": self.id,
@ -273,7 +279,7 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
"ModificationTime": self.modification_time, "ModificationTime": self.modification_time,
} }
def ip_descriptions(self): def ip_descriptions(self) -> List[Dict[str, Any]]:
"""Return a list of dicts describing resolver endpoint IP addresses.""" """Return a list of dicts describing resolver endpoint IP addresses."""
description = [] description = []
for subnet_id, ip_info in self.subnets.items(): for subnet_id, ip_info in self.subnets.items():
@ -291,12 +297,12 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
) )
return description return description
def update_name(self, name): def update_name(self, name: str) -> None:
"""Replace existing name with new name.""" """Replace existing name with new name."""
self.name = name self.name = name
self.modification_time = datetime.now(timezone.utc).isoformat() self.modification_time = datetime.now(timezone.utc).isoformat()
def associate_ip_address(self, value): def associate_ip_address(self, value: Dict[str, Any]) -> None:
self.ip_addresses.append(value) self.ip_addresses.append(value)
self.ip_address_count = len(self.ip_addresses) self.ip_address_count = len(self.ip_addresses)
@ -315,9 +321,9 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
) )
self.eni_ids.append(eni_info.id) self.eni_ids.append(eni_info.id)
def disassociate_ip_address(self, value): def disassociate_ip_address(self, value: Dict[str, Any]) -> None:
if not value.get("Ip") and value.get("IpId"): if not value.get("Ip") and value.get("IpId"):
for ip_addr, eni_id in self.subnets[value.get("SubnetId")].items(): for ip_addr, eni_id in self.subnets[value.get("SubnetId")].items(): # type: ignore
if value.get("IpId") == eni_id: if value.get("IpId") == eni_id:
value["Ip"] = ip_addr value["Ip"] = ip_addr
if value.get("Ip"): if value.get("Ip"):
@ -340,23 +346,33 @@ class ResolverEndpoint(BaseModel): # pylint: disable=too-many-instance-attribut
class Route53ResolverBackend(BaseBackend): class Route53ResolverBackend(BaseBackend):
"""Implementation of Route53Resolver APIs.""" """Implementation of Route53Resolver APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.resolver_endpoints = {} # Key is self-generated ID (endpoint_id) self.resolver_endpoints: Dict[
self.resolver_rules = {} # Key is self-generated ID (rule_id) str, ResolverEndpoint
self.resolver_rule_associations = {} # Key is resolver_rule_association_id) ] = {} # Key is self-generated ID (endpoint_id)
self.resolver_rules: Dict[
str, ResolverRule
] = {} # Key is self-generated ID (rule_id)
self.resolver_rule_associations: Dict[
str, ResolverRuleAssociation
] = {} # Key is resolver_rule_association_id)
self.tagger = TaggingService() self.tagger = TaggingService()
self.ec2_backend = ec2_backends[self.account_id][self.region_name] self.ec2_backend = ec2_backends[self.account_id][self.region_name]
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""List of dicts representing default VPC endpoints for this service.""" """List of dicts representing default VPC endpoints for this service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "route53resolver" service_region, zones, "route53resolver"
) )
def associate_resolver_rule(self, resolver_rule_id, name, vpc_id): def associate_resolver_rule(
self, resolver_rule_id: str, name: str, vpc_id: str
) -> ResolverRuleAssociation:
validate_args( validate_args(
[("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)] [("resolverRuleId", resolver_rule_id), ("name", name), ("vPCId", vpc_id)]
) )
@ -399,7 +415,9 @@ class Route53ResolverBackend(BaseBackend):
self.resolver_rule_associations[rule_association_id] = rule_association self.resolver_rule_associations[rule_association_id] = rule_association
return rule_association return rule_association
def _verify_subnet_ips(self, ip_addresses, initial=True): def _verify_subnet_ips(
self, ip_addresses: List[Dict[str, Any]], initial: bool = True
) -> None:
""" """
Perform additional checks on the IPAddresses. Perform additional checks on the IPAddresses.
@ -412,7 +430,7 @@ class Route53ResolverBackend(BaseBackend):
"Resolver endpoint needs to have at least 2 IP addresses" "Resolver endpoint needs to have at least 2 IP addresses"
) )
subnets = defaultdict(set) subnets: Dict[str, Set[str]] = defaultdict(set)
for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]: for subnet_id, ip_addr in [(x["SubnetId"], x["Ip"]) for x in ip_addresses]:
try: try:
subnet_info = self.ec2_backend.describe_subnets(subnet_ids=[subnet_id])[ subnet_info = self.ec2_backend.describe_subnets(subnet_ids=[subnet_id])[
@ -438,7 +456,7 @@ class Route53ResolverBackend(BaseBackend):
) )
subnets[subnet_id].add(ip_addr) subnets[subnet_id].add(ip_addr)
def _verify_security_group_ids(self, security_group_ids): def _verify_security_group_ids(self, security_group_ids: List[str]) -> None:
"""Perform additional checks on the security groups.""" """Perform additional checks on the security groups."""
if len(security_group_ids) > 10: if len(security_group_ids) > 10:
raise InvalidParameterException("Maximum of 10 security groups are allowed") raise InvalidParameterException("Maximum of 10 security groups are allowed")
@ -458,14 +476,14 @@ class Route53ResolverBackend(BaseBackend):
def create_resolver_endpoint( def create_resolver_endpoint(
self, self,
region, region: str,
creator_request_id, creator_request_id: str,
name, name: str,
security_group_ids, security_group_ids: List[str],
direction, direction: str,
ip_addresses, ip_addresses: List[Dict[str, Any]],
tags, tags: List[Dict[str, str]],
): # pylint: disable=too-many-arguments ) -> ResolverEndpoint: # pylint: disable=too-many-arguments
""" """
Return description for a newly created resolver endpoint. Return description for a newly created resolver endpoint.
@ -529,15 +547,15 @@ class Route53ResolverBackend(BaseBackend):
def create_resolver_rule( def create_resolver_rule(
self, self,
region, region: str,
creator_request_id, creator_request_id: str,
name, name: str,
rule_type, rule_type: str,
domain_name, domain_name: str,
target_ips, target_ips: List[Dict[str, Any]],
resolver_endpoint_id, resolver_endpoint_id: str,
tags, tags: List[Dict[str, str]],
): # pylint: disable=too-many-arguments ) -> ResolverRule: # pylint: disable=too-many-arguments
"""Return description for a newly created resolver rule.""" """Return description for a newly created resolver rule."""
validate_args( validate_args(
[ [
@ -607,22 +625,22 @@ class Route53ResolverBackend(BaseBackend):
rule_id = f"rslvr-rr-{mock_random.get_random_hex(17)}" rule_id = f"rslvr-rr-{mock_random.get_random_hex(17)}"
resolver_rule = ResolverRule( resolver_rule = ResolverRule(
self.account_id, account_id=self.account_id,
region, region=region,
rule_id, rule_id=rule_id,
creator_request_id, creator_request_id=creator_request_id,
rule_type, rule_type=rule_type,
domain_name, domain_name=domain_name,
target_ips, target_ips=target_ips,
resolver_endpoint_id, resolver_endpoint_id=resolver_endpoint_id,
name, name=name,
) )
self.resolver_rules[rule_id] = resolver_rule self.resolver_rules[rule_id] = resolver_rule
self.tagger.tag_resource(resolver_rule.arn, tags or []) self.tagger.tag_resource(resolver_rule.arn, tags or [])
return resolver_rule return resolver_rule
def _validate_resolver_endpoint_id(self, resolver_endpoint_id): def _validate_resolver_endpoint_id(self, resolver_endpoint_id: str) -> None:
"""Raise an exception if the id is invalid or unknown.""" """Raise an exception if the id is invalid or unknown."""
validate_args([("resolverEndpointId", resolver_endpoint_id)]) validate_args([("resolverEndpointId", resolver_endpoint_id)])
if resolver_endpoint_id not in self.resolver_endpoints: if resolver_endpoint_id not in self.resolver_endpoints:
@ -630,7 +648,7 @@ class Route53ResolverBackend(BaseBackend):
f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist" f"Resolver endpoint with ID '{resolver_endpoint_id}' does not exist"
) )
def delete_resolver_endpoint(self, resolver_endpoint_id): def delete_resolver_endpoint(self, resolver_endpoint_id: str) -> ResolverEndpoint:
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
# Can't delete an endpoint if there are rules associated with it. # Can't delete an endpoint if there are rules associated with it.
@ -655,7 +673,7 @@ class Route53ResolverBackend(BaseBackend):
) )
return resolver_endpoint return resolver_endpoint
def _validate_resolver_rule_id(self, resolver_rule_id): def _validate_resolver_rule_id(self, resolver_rule_id: str) -> None:
"""Raise an exception if the id is invalid or unknown.""" """Raise an exception if the id is invalid or unknown."""
validate_args([("resolverRuleId", resolver_rule_id)]) validate_args([("resolverRuleId", resolver_rule_id)])
if resolver_rule_id not in self.resolver_rules: if resolver_rule_id not in self.resolver_rules:
@ -663,7 +681,7 @@ class Route53ResolverBackend(BaseBackend):
f"Resolver rule with ID '{resolver_rule_id}' does not exist" f"Resolver rule with ID '{resolver_rule_id}' does not exist"
) )
def delete_resolver_rule(self, resolver_rule_id): def delete_resolver_rule(self, resolver_rule_id: str) -> ResolverRule:
self._validate_resolver_rule_id(resolver_rule_id) self._validate_resolver_rule_id(resolver_rule_id)
# Can't delete an rule unless VPC's are disassociated. # Can't delete an rule unless VPC's are disassociated.
@ -686,7 +704,9 @@ class Route53ResolverBackend(BaseBackend):
) )
return resolver_rule return resolver_rule
def disassociate_resolver_rule(self, resolver_rule_id, vpc_id): def disassociate_resolver_rule(
self, resolver_rule_id: str, vpc_id: str
) -> ResolverRuleAssociation:
validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)]) validate_args([("resolverRuleId", resolver_rule_id), ("vPCId", vpc_id)])
# Non-existent rule or vpc ids? # Non-existent rule or vpc ids?
@ -715,16 +735,18 @@ class Route53ResolverBackend(BaseBackend):
rule_association.status_message = "Deleting Association" rule_association.status_message = "Deleting Association"
return rule_association return rule_association
def get_resolver_endpoint(self, resolver_endpoint_id): def get_resolver_endpoint(self, resolver_endpoint_id: str) -> ResolverEndpoint:
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
return self.resolver_endpoints[resolver_endpoint_id] return self.resolver_endpoints[resolver_endpoint_id]
def get_resolver_rule(self, resolver_rule_id): def get_resolver_rule(self, resolver_rule_id: str) -> ResolverRule:
"""Return info for specified resolver rule.""" """Return info for specified resolver rule."""
self._validate_resolver_rule_id(resolver_rule_id) self._validate_resolver_rule_id(resolver_rule_id)
return self.resolver_rules[resolver_rule_id] return self.resolver_rules[resolver_rule_id]
def get_resolver_rule_association(self, resolver_rule_association_id): def get_resolver_rule_association(
self, resolver_rule_association_id: str
) -> ResolverRuleAssociation:
validate_args([("resolverRuleAssociationId", resolver_rule_association_id)]) validate_args([("resolverRuleAssociationId", resolver_rule_association_id)])
if resolver_rule_association_id not in self.resolver_rule_associations: if resolver_rule_association_id not in self.resolver_rule_associations:
raise ResourceNotFoundException( raise ResourceNotFoundException(
@ -733,13 +755,13 @@ class Route53ResolverBackend(BaseBackend):
return self.resolver_rule_associations[resolver_rule_association_id] return self.resolver_rule_associations[resolver_rule_association_id]
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id): def list_resolver_endpoint_ip_addresses(self, resolver_endpoint_id: str) -> List[Dict[str, Any]]: # type: ignore[misc]
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
endpoint = self.resolver_endpoints[resolver_endpoint_id] endpoint = self.resolver_endpoints[resolver_endpoint_id]
return endpoint.ip_descriptions() return endpoint.ip_descriptions()
@staticmethod @staticmethod
def _add_field_name_to_filter(filters): def _add_field_name_to_filter(filters: List[Dict[str, Any]]) -> None: # type: ignore[misc]
"""Convert both styles of filter names to lowercase snake format. """Convert both styles of filter names to lowercase snake format.
"IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count". "IP_ADDRESS_COUNT" or "IpAddressCount" will become "ip_address_count".
@ -762,7 +784,7 @@ class Route53ResolverBackend(BaseBackend):
rr_filter["Field"] = filter_name.lower() rr_filter["Field"] = filter_name.lower()
@staticmethod @staticmethod
def _validate_filters(filters, allowed_filter_names): def _validate_filters(filters: Any, allowed_filter_names: List[str]) -> None: # type: ignore[misc]
"""Raise exception if filter names are not as expected.""" """Raise exception if filter names are not as expected."""
for rr_filter in filters: for rr_filter in filters:
if rr_filter["Field"] not in allowed_filter_names: if rr_filter["Field"] not in allowed_filter_names:
@ -775,7 +797,7 @@ class Route53ResolverBackend(BaseBackend):
) )
@staticmethod @staticmethod
def _matches_all_filters(entity, filters): def _matches_all_filters(entity: Any, filters: Any) -> bool: # type: ignore[misc]
"""Return True if this entity has fields matching all the filters.""" """Return True if this entity has fields matching all the filters."""
for rr_filter in filters: for rr_filter in filters:
field_value = getattr(entity, rr_filter["Field"]) field_value = getattr(entity, rr_filter["Field"])
@ -792,7 +814,7 @@ class Route53ResolverBackend(BaseBackend):
return True return True
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_endpoints(self, filters): def list_resolver_endpoints(self, filters: Any) -> List[ResolverEndpoint]: # type: ignore[misc]
if not filters: if not filters:
filters = [] filters = []
@ -806,7 +828,7 @@ class Route53ResolverBackend(BaseBackend):
return endpoints return endpoints
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rules(self, filters): def list_resolver_rules(self, filters: Any) -> List[ResolverRule]: # type: ignore[misc]
if not filters: if not filters:
filters = [] filters = []
@ -820,7 +842,7 @@ class Route53ResolverBackend(BaseBackend):
return rules return rules
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_resolver_rule_associations(self, filters): def list_resolver_rule_associations(self, filters: Any) -> List[ResolverRuleAssociation]: # type: ignore[misc]
if not filters: if not filters:
filters = [] filters = []
@ -835,7 +857,7 @@ class Route53ResolverBackend(BaseBackend):
rules.append(rule) rules.append(rule)
return rules return rules
def _matched_arn(self, resource_arn): def _matched_arn(self, resource_arn: str) -> None:
"""Given ARN, raise exception if there is no corresponding resource.""" """Given ARN, raise exception if there is no corresponding resource."""
for resolver_endpoint in self.resolver_endpoints.values(): for resolver_endpoint in self.resolver_endpoints.values():
if resolver_endpoint.arn == resource_arn: if resolver_endpoint.arn == resource_arn:
@ -848,11 +870,11 @@ class Route53ResolverBackend(BaseBackend):
) )
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(self, resource_arn: str) -> Optional[List[Dict[str, str]]]: # type: ignore[misc]
self._matched_arn(resource_arn) self._matched_arn(resource_arn)
return self.tagger.list_tags_for_resource(resource_arn).get("Tags") return self.tagger.list_tags_for_resource(resource_arn).get("Tags")
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self._matched_arn(resource_arn) self._matched_arn(resource_arn)
errmsg = self.tagger.validate_tags( errmsg = self.tagger.validate_tags(
tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT tags, limit=ResolverEndpoint.MAX_TAGS_PER_RESOLVER_ENDPOINT
@ -861,18 +883,22 @@ class Route53ResolverBackend(BaseBackend):
raise TagValidationException(errmsg) raise TagValidationException(errmsg)
self.tagger.tag_resource(resource_arn, tags) self.tagger.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self._matched_arn(resource_arn) self._matched_arn(resource_arn)
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def update_resolver_endpoint(self, resolver_endpoint_id, name): def update_resolver_endpoint(
self, resolver_endpoint_id: str, name: str
) -> ResolverEndpoint:
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
validate_args([("name", name)]) validate_args([("name", name)])
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
resolver_endpoint.update_name(name) resolver_endpoint.update_name(name)
return resolver_endpoint return resolver_endpoint
def associate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value): def associate_resolver_endpoint_ip_address(
self, resolver_endpoint_id: str, value: Dict[str, Any]
) -> ResolverEndpoint:
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]
@ -886,7 +912,9 @@ class Route53ResolverBackend(BaseBackend):
resolver_endpoint.associate_ip_address(value) resolver_endpoint.associate_ip_address(value)
return resolver_endpoint return resolver_endpoint
def disassociate_resolver_endpoint_ip_address(self, resolver_endpoint_id, value): def disassociate_resolver_endpoint_ip_address(
self, resolver_endpoint_id: str, value: Dict[str, Any]
) -> ResolverEndpoint:
self._validate_resolver_endpoint_id(resolver_endpoint_id) self._validate_resolver_endpoint_id(resolver_endpoint_id)
resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id] resolver_endpoint = self.resolver_endpoints[resolver_endpoint_id]

View File

@ -4,22 +4,22 @@ import json
from moto.core.exceptions import InvalidToken from moto.core.exceptions import InvalidToken
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.route53resolver.exceptions import InvalidNextTokenException from moto.route53resolver.exceptions import InvalidNextTokenException
from moto.route53resolver.models import route53resolver_backends from moto.route53resolver.models import route53resolver_backends, Route53ResolverBackend
from moto.route53resolver.validations import validate_args from moto.route53resolver.validations import validate_args
class Route53ResolverResponse(BaseResponse): class Route53ResolverResponse(BaseResponse):
"""Handler for Route53Resolver requests and responses.""" """Handler for Route53Resolver requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="route53-resolver") super().__init__(service_name="route53-resolver")
@property @property
def route53resolver_backend(self): def route53resolver_backend(self) -> Route53ResolverBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return route53resolver_backends[self.current_account][self.region] return route53resolver_backends[self.current_account][self.region]
def associate_resolver_rule(self): def associate_resolver_rule(self) -> str:
"""Associate a Resolver rule with a VPC.""" """Associate a Resolver rule with a VPC."""
resolver_rule_id = self._get_param("ResolverRuleId") resolver_rule_id = self._get_param("ResolverRuleId")
name = self._get_param("Name") name = self._get_param("Name")
@ -35,7 +35,7 @@ class Route53ResolverResponse(BaseResponse):
{"ResolverRuleAssociation": resolver_rule_association.description()} {"ResolverRuleAssociation": resolver_rule_association.description()}
) )
def create_resolver_endpoint(self): def create_resolver_endpoint(self) -> str:
"""Create an inbound or outbound Resolver endpoint.""" """Create an inbound or outbound Resolver endpoint."""
creator_request_id = self._get_param("CreatorRequestId") creator_request_id = self._get_param("CreatorRequestId")
name = self._get_param("Name") name = self._get_param("Name")
@ -54,7 +54,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
def create_resolver_rule(self): def create_resolver_rule(self) -> str:
"""Specify which Resolver enpoint the queries will pass through.""" """Specify which Resolver enpoint the queries will pass through."""
creator_request_id = self._get_param("CreatorRequestId") creator_request_id = self._get_param("CreatorRequestId")
name = self._get_param("Name") name = self._get_param("Name")
@ -75,7 +75,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverRule": resolver_rule.description()}) return json.dumps({"ResolverRule": resolver_rule.description()})
def delete_resolver_endpoint(self): def delete_resolver_endpoint(self) -> str:
"""Delete a Resolver endpoint.""" """Delete a Resolver endpoint."""
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")
resolver_endpoint = self.route53resolver_backend.delete_resolver_endpoint( resolver_endpoint = self.route53resolver_backend.delete_resolver_endpoint(
@ -83,7 +83,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
def delete_resolver_rule(self): def delete_resolver_rule(self) -> str:
"""Delete a Resolver rule.""" """Delete a Resolver rule."""
resolver_rule_id = self._get_param("ResolverRuleId") resolver_rule_id = self._get_param("ResolverRuleId")
resolver_rule = self.route53resolver_backend.delete_resolver_rule( resolver_rule = self.route53resolver_backend.delete_resolver_rule(
@ -91,7 +91,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverRule": resolver_rule.description()}) return json.dumps({"ResolverRule": resolver_rule.description()})
def disassociate_resolver_rule(self): def disassociate_resolver_rule(self) -> str:
"""Remove the association between a Resolver rule and a VPC.""" """Remove the association between a Resolver rule and a VPC."""
vpc_id = self._get_param("VPCId") vpc_id = self._get_param("VPCId")
resolver_rule_id = self._get_param("ResolverRuleId") resolver_rule_id = self._get_param("ResolverRuleId")
@ -104,7 +104,7 @@ class Route53ResolverResponse(BaseResponse):
{"ResolverRuleAssociation": resolver_rule_association.description()} {"ResolverRuleAssociation": resolver_rule_association.description()}
) )
def get_resolver_endpoint(self): def get_resolver_endpoint(self) -> str:
"""Return info about a specific Resolver endpoint.""" """Return info about a specific Resolver endpoint."""
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")
resolver_endpoint = self.route53resolver_backend.get_resolver_endpoint( resolver_endpoint = self.route53resolver_backend.get_resolver_endpoint(
@ -112,7 +112,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
def get_resolver_rule(self): def get_resolver_rule(self) -> str:
"""Return info about a specific Resolver rule.""" """Return info about a specific Resolver rule."""
resolver_rule_id = self._get_param("ResolverRuleId") resolver_rule_id = self._get_param("ResolverRuleId")
resolver_rule = self.route53resolver_backend.get_resolver_rule( resolver_rule = self.route53resolver_backend.get_resolver_rule(
@ -120,7 +120,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverRule": resolver_rule.description()}) return json.dumps({"ResolverRule": resolver_rule.description()})
def get_resolver_rule_association(self): def get_resolver_rule_association(self) -> str:
"""Return info about association between a Resolver rule and a VPC.""" """Return info about association between a Resolver rule and a VPC."""
resolver_rule_association_id = self._get_param("ResolverRuleAssociationId") resolver_rule_association_id = self._get_param("ResolverRuleAssociationId")
resolver_rule_association = ( resolver_rule_association = (
@ -132,7 +132,7 @@ class Route53ResolverResponse(BaseResponse):
{"ResolverRuleAssociation": resolver_rule_association.description()} {"ResolverRuleAssociation": resolver_rule_association.description()}
) )
def list_resolver_endpoint_ip_addresses(self): def list_resolver_endpoint_ip_addresses(self) -> str:
"""Returns list of IP addresses for specified Resolver endpoint.""" """Returns list of IP addresses for specified Resolver endpoint."""
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -158,7 +158,7 @@ class Route53ResolverResponse(BaseResponse):
response["NextToken"] = next_token response["NextToken"] = next_token
return json.dumps(response) return json.dumps(response)
def list_resolver_endpoints(self): def list_resolver_endpoints(self) -> str:
"""Returns list of all Resolver endpoints, filtered if specified.""" """Returns list of all Resolver endpoints, filtered if specified."""
filters = self._get_param("Filters") filters = self._get_param("Filters")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -176,7 +176,7 @@ class Route53ResolverResponse(BaseResponse):
response["NextToken"] = next_token response["NextToken"] = next_token
return json.dumps(response) return json.dumps(response)
def list_resolver_rules(self): def list_resolver_rules(self) -> str:
"""Returns list of all Resolver rules, filtered if specified.""" """Returns list of all Resolver rules, filtered if specified."""
filters = self._get_param("Filters") filters = self._get_param("Filters")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -197,7 +197,7 @@ class Route53ResolverResponse(BaseResponse):
response["NextToken"] = next_token response["NextToken"] = next_token
return json.dumps(response) return json.dumps(response)
def list_resolver_rule_associations(self): def list_resolver_rule_associations(self) -> str:
"""Returns list of all Resolver associations, filtered if specified.""" """Returns list of all Resolver associations, filtered if specified."""
filters = self._get_param("Filters") filters = self._get_param("Filters")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -221,7 +221,7 @@ class Route53ResolverResponse(BaseResponse):
response["NextToken"] = next_token response["NextToken"] = next_token
return json.dumps(response) return json.dumps(response)
def list_tags_for_resource(self): def list_tags_for_resource(self) -> str:
"""Lists all tags for the given resource.""" """Lists all tags for the given resource."""
resource_arn = self._get_param("ResourceArn") resource_arn = self._get_param("ResourceArn")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
@ -235,14 +235,14 @@ class Route53ResolverResponse(BaseResponse):
response["NextToken"] = next_token response["NextToken"] = next_token
return json.dumps(response) return json.dumps(response)
def tag_resource(self): def tag_resource(self) -> str:
"""Add one or more tags to a specified resource.""" """Add one or more tags to a specified resource."""
resource_arn = self._get_param("ResourceArn") resource_arn = self._get_param("ResourceArn")
tags = self._get_param("Tags") tags = self._get_param("Tags")
self.route53resolver_backend.tag_resource(resource_arn=resource_arn, tags=tags) self.route53resolver_backend.tag_resource(resource_arn=resource_arn, tags=tags)
return "" return ""
def untag_resource(self): def untag_resource(self) -> str:
"""Removes one or more tags from the specified resource.""" """Removes one or more tags from the specified resource."""
resource_arn = self._get_param("ResourceArn") resource_arn = self._get_param("ResourceArn")
tag_keys = self._get_param("TagKeys") tag_keys = self._get_param("TagKeys")
@ -251,7 +251,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return "" return ""
def update_resolver_endpoint(self): def update_resolver_endpoint(self) -> str:
"""Update name of Resolver endpoint.""" """Update name of Resolver endpoint."""
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")
name = self._get_param("Name") name = self._get_param("Name")
@ -260,7 +260,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
def associate_resolver_endpoint_ip_address(self): def associate_resolver_endpoint_ip_address(self) -> str:
ip_address = self._get_param("IpAddress") ip_address = self._get_param("IpAddress")
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")
resolver_endpoint = ( resolver_endpoint = (
@ -271,7 +271,7 @@ class Route53ResolverResponse(BaseResponse):
) )
return json.dumps({"ResolverEndpoint": resolver_endpoint.description()}) return json.dumps({"ResolverEndpoint": resolver_endpoint.description()})
def disassociate_resolver_endpoint_ip_address(self): def disassociate_resolver_endpoint_ip_address(self) -> str:
ip_address = self._get_param("IpAddress") ip_address = self._get_param("IpAddress")
resolver_endpoint_id = self._get_param("ResolverEndpointId") resolver_endpoint_id = self._get_param("ResolverEndpointId")

View File

@ -3,11 +3,12 @@
Note that ValidationExceptions are accumulative. Note that ValidationExceptions are accumulative.
""" """
import re import re
from typing import Any, Dict, List, Tuple, Optional
from moto.route53resolver.exceptions import RRValidationException from moto.route53resolver.exceptions import RRValidationException
def validate_args(validators): def validate_args(validators: List[Tuple[str, Any]]) -> None:
"""Raise exception if any of the validations fails. """Raise exception if any of the validations fails.
validators is a list of tuples each containing the following: validators is a list of tuples each containing the following:
@ -36,56 +37,56 @@ def validate_args(validators):
# This eventually could be a switch (python 3.10), eliminating the need # This eventually could be a switch (python 3.10), eliminating the need
# for the above map and individual functions. # for the above map and individual functions.
for (fieldname, value) in validators: for (fieldname, value) in validators:
msg = validation_map[fieldname](value) msg = validation_map[fieldname](value) # type: ignore
if msg: if msg:
err_msgs.append((fieldname, value, msg)) err_msgs.append((fieldname, value, msg))
if err_msgs: if err_msgs:
raise RRValidationException(err_msgs) raise RRValidationException(err_msgs)
def validate_creator_request_id(value): def validate_creator_request_id(value: Optional[str]) -> str:
"""Raise exception if the creator_request_id has invalid length.""" """Raise exception if the creator_request_id has invalid length."""
if value and len(value) > 255: if value and len(value) > 255:
return "have length less than or equal to 255" return "have length less than or equal to 255"
return "" return ""
def validate_direction(value): def validate_direction(value: Optional[str]) -> str:
"""Raise exception if direction not one of the allowed values.""" """Raise exception if direction not one of the allowed values."""
if value and value not in ["INBOUND", "OUTBOUND"]: if value and value not in ["INBOUND", "OUTBOUND"]:
return "satisfy enum value set: [INBOUND, OUTBOUND]" return "satisfy enum value set: [INBOUND, OUTBOUND]"
return "" return ""
def validate_domain_name(value): def validate_domain_name(value: str) -> str:
"""Raise exception if the domain_name has invalid length.""" """Raise exception if the domain_name has invalid length."""
if len(value) > 256: if len(value) > 256:
return "have length less than or equal to 256" return "have length less than or equal to 256"
return "" return ""
def validate_endpoint_id(value): def validate_endpoint_id(value: Optional[str]) -> str:
"""Raise exception if resolver endpoint id has invalid length.""" """Raise exception if resolver endpoint id has invalid length."""
if value and len(value) > 64: if value and len(value) > 64:
return "have length less than or equal to 64" return "have length less than or equal to 64"
return "" return ""
def validate_ip_addresses(value): def validate_ip_addresses(value: str) -> str:
"""Raise exception if IPs fail to match length constraint.""" """Raise exception if IPs fail to match length constraint."""
if len(value) > 10: if len(value) > 10:
return "have length less than or equal to 10" return "have length less than or equal to 10"
return "" return ""
def validate_max_results(value): def validate_max_results(value: Optional[int]) -> str:
"""Raise exception if number of endpoints or IPs is too large.""" """Raise exception if number of endpoints or IPs is too large."""
if value and value > 100: if value and value > 100:
return "have length less than or equal to 100" return "have length less than or equal to 100"
return "" return ""
def validate_name(value): def validate_name(value: Optional[str]) -> str:
"""Raise exception if name fails to match constraints.""" """Raise exception if name fails to match constraints."""
if value: if value:
if len(value) > 64: if len(value) > 64:
@ -96,28 +97,28 @@ def validate_name(value):
return "" return ""
def validate_rule_association_id(value): def validate_rule_association_id(value: Optional[str]) -> str:
"""Raise exception if resolver rule association id has invalid length.""" """Raise exception if resolver rule association id has invalid length."""
if value and len(value) > 64: if value and len(value) > 64:
return "have length less than or equal to 64" return "have length less than or equal to 64"
return "" return ""
def validate_rule_id(value): def validate_rule_id(value: Optional[str]) -> str:
"""Raise exception if resolver rule id has invalid length.""" """Raise exception if resolver rule id has invalid length."""
if value and len(value) > 64: if value and len(value) > 64:
return "have length less than or equal to 64" return "have length less than or equal to 64"
return "" return ""
def validate_rule_type(value): def validate_rule_type(value: Optional[str]) -> str:
"""Raise exception if rule_type not one of the allowed values.""" """Raise exception if rule_type not one of the allowed values."""
if value and value not in ["FORWARD", "SYSTEM", "RECURSIVE"]: if value and value not in ["FORWARD", "SYSTEM", "RECURSIVE"]:
return "satisfy enum value set: [FORWARD, SYSTEM, RECURSIVE]" return "satisfy enum value set: [FORWARD, SYSTEM, RECURSIVE]"
return "" return ""
def validate_security_group_ids(value): def validate_security_group_ids(value: List[str]) -> str:
"""Raise exception if IPs fail to match length constraint.""" """Raise exception if IPs fail to match length constraint."""
# Too many security group IDs is an InvalidParameterException. # Too many security group IDs is an InvalidParameterException.
for group_id in value: for group_id in value:
@ -129,7 +130,7 @@ def validate_security_group_ids(value):
return "" return ""
def validate_subnets(value): def validate_subnets(value: List[Dict[str, Any]]) -> str:
"""Raise exception if subnets fail to match length constraint.""" """Raise exception if subnets fail to match length constraint."""
for subnet_id in [x["SubnetId"] for x in value]: for subnet_id in [x["SubnetId"] for x in value]:
if len(subnet_id) > 32: if len(subnet_id) > 32:
@ -137,14 +138,14 @@ def validate_subnets(value):
return "" return ""
def validate_target_port(value): def validate_target_port(value: Optional[Dict[str, int]]) -> str:
"""Raise exception if target port fails to match length constraint.""" """Raise exception if target port fails to match length constraint."""
if value and value["Port"] > 65535: if value and value["Port"] > 65535:
return "have value less than or equal to 65535" return "have value less than or equal to 65535"
return "" return ""
def validate_vpc_id(value): def validate_vpc_id(value: str) -> str:
"""Raise exception if VPC id has invalid length.""" """Raise exception if VPC id has invalid length."""
if len(value) > 64: if len(value) > 64:
return "have length less than or equal to 64" return "have length less than or equal to 64"

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds*,moto/redshift*,moto/rekognition,moto/resourcegroups*,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract