From c50737d02786e87ec65537606fae3633c2bdcf81 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 13 Feb 2023 19:28:15 -0100 Subject: [PATCH] Techdebt: MyPy EC2 models (#5925) --- moto/ec2/exceptions.py | 44 +-- moto/ec2/models/__init__.py | 37 +- moto/ec2/models/vpc_peering_connections.py | 73 ++-- moto/ec2/models/vpc_service_configuration.py | 68 ++-- moto/ec2/models/vpcs.py | 344 ++++++++++--------- moto/ec2/models/vpn_connections.py | 56 ++- moto/ec2/models/vpn_gateway.py | 92 +++-- moto/ec2/models/windows.py | 2 +- moto/ec2/responses/vpcs.py | 13 +- moto/ec2/responses/vpn_connections.py | 2 +- moto/ec2/utils.py | 30 +- setup.cfg | 2 +- 12 files changed, 422 insertions(+), 341 deletions(-) diff --git a/moto/ec2/exceptions.py b/moto/ec2/exceptions.py index 4a441a085..aad3b6738 100644 --- a/moto/ec2/exceptions.py +++ b/moto/ec2/exceptions.py @@ -30,7 +30,7 @@ class EC2ClientError(RESTError): class DefaultVpcAlreadyExists(EC2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "DefaultVpcAlreadyExists", "A Default VPC already exists for this account in this region.", @@ -59,7 +59,7 @@ class InvalidDHCPOptionsIdError(EC2ClientError): class InvalidRequest(EC2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidRequest", "The request received was invalid") @@ -98,7 +98,7 @@ class InvalidKeyPairFormatError(EC2ClientError): class InvalidVPCIdError(EC2ClientError): - def __init__(self, vpc_id: str): + def __init__(self, vpc_id: Any): super().__init__("InvalidVpcID.NotFound", f"VpcID {vpc_id} does not exist.") @@ -134,7 +134,7 @@ class InvalidNetworkAclIdError(EC2ClientError): class InvalidVpnGatewayIdError(EC2ClientError): - def __init__(self, vpn_gw): + def __init__(self, vpn_gw: str): super().__init__( "InvalidVpnGatewayID.NotFound", f"The virtual private gateway ID '{vpn_gw}' does not exist", @@ -142,7 +142,7 @@ class InvalidVpnGatewayIdError(EC2ClientError): class InvalidVpnGatewayAttachmentError(EC2ClientError): - def __init__(self, vpn_gw, vpc_id): + def __init__(self, vpn_gw: str, vpc_id: str): super().__init__( "InvalidVpnGatewayAttachment.NotFound", f"The attachment with vpn gateway ID '{vpn_gw}' and vpc ID '{vpc_id}' does not exist", @@ -150,7 +150,7 @@ class InvalidVpnGatewayAttachmentError(EC2ClientError): class InvalidVpnConnectionIdError(EC2ClientError): - def __init__(self, network_acl_id): + def __init__(self, network_acl_id: str): super().__init__( "InvalidVpnConnectionID.NotFound", f"The vpnConnection ID '{network_acl_id}' does not exist", @@ -365,7 +365,7 @@ class InvalidAssociationIdError(EC2ClientError): class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): - def __init__(self, association_id): + def __init__(self, association_id: str): super().__init__( "InvalidVpcCidrBlockAssociationIdError.NotFound", f"The vpc CIDR block association ID '{association_id}' does not exist", @@ -373,7 +373,7 @@ class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): class InvalidVPCPeeringConnectionIdError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): + def __init__(self, vpc_peering_connection_id: str): super().__init__( "InvalidVpcPeeringConnectionId.NotFound", f"VpcPeeringConnectionID {vpc_peering_connection_id} does not exist.", @@ -381,7 +381,7 @@ class InvalidVPCPeeringConnectionIdError(EC2ClientError): class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): + def __init__(self, vpc_peering_connection_id: str): super().__init__( "InvalidStateTransition", f"VpcPeeringConnectionID {vpc_peering_connection_id} is not in the correct state for the request.", @@ -389,7 +389,7 @@ class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): class InvalidServiceName(EC2ClientError): - def __init__(self, service_name): + def __init__(self, service_name: str): super().__init__( "InvalidServiceName", f"The Vpc Endpoint Service '{service_name}' does not exist", @@ -402,7 +402,7 @@ class InvalidFilter(EC2ClientError): class InvalidNextToken(EC2ClientError): - def __init__(self, next_token): + def __init__(self, next_token: str): super().__init__("InvalidNextToken", f"The token '{next_token}' is invalid") @@ -436,7 +436,7 @@ class InvalidParameterValueError(EC2ClientError): class EmptyTagSpecError(EC2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidParameterValue", "Tag specification must have at least one tag" ) @@ -498,7 +498,7 @@ class TagLimitExceeded(EC2ClientError): class InvalidID(EC2ClientError): - def __init__(self, resource_id): + def __init__(self, resource_id: str): super().__init__("InvalidID", f"The ID '{resource_id}' is not valid") @@ -532,7 +532,7 @@ class FilterNotImplementedError(MotoNotImplementedError): class CidrLimitExceeded(EC2ClientError): - def __init__(self, vpc_id, max_cidr_limit): + def __init__(self, vpc_id: str, max_cidr_limit: int): super().__init__( "CidrLimitExceeded", f"This network '{vpc_id}' has met its maximum number of allowed CIDRs: {max_cidr_limit}", @@ -540,14 +540,14 @@ class CidrLimitExceeded(EC2ClientError): class UnsupportedTenancy(EC2ClientError): - def __init__(self, tenancy): + def __init__(self, tenancy: str): super().__init__( "UnsupportedTenancy", f"The tenancy value {tenancy} is not supported." ) class OperationNotPermitted(EC2ClientError): - def __init__(self, association_id): + def __init__(self, association_id: str): super().__init__( "OperationNotPermitted", f"The vpc CIDR block with association ID {association_id} may not be disassociated. It is the primary IPv4 CIDR block of the VPC", @@ -611,13 +611,13 @@ class InvalidSubnetConflictError(EC2ClientError): class InvalidVPCRangeError(EC2ClientError): - def __init__(self, cidr_block): + def __init__(self, cidr_block: str): super().__init__("InvalidVpc.Range", f"The CIDR '{cidr_block}' is invalid.") # accept exception class OperationNotPermitted2(EC2ClientError): - def __init__(self, client_region, pcx_id, acceptor_region): + def __init__(self, client_region: str, pcx_id: str, acceptor_region: str): super().__init__( "OperationNotPermitted", f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted in region {acceptor_region}", @@ -626,7 +626,7 @@ class OperationNotPermitted2(EC2ClientError): # reject exception class OperationNotPermitted3(EC2ClientError): - def __init__(self, client_region, pcx_id, acceptor_region): + def __init__(self, client_region: str, pcx_id: str, acceptor_region: str): super().__init__( "OperationNotPermitted", f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted or rejected in region {acceptor_region}", @@ -690,7 +690,7 @@ class InvalidAssociationIDIamProfileAssociationError(EC2ClientError): class InvalidVpcEndPointIdError(EC2ClientError): - def __init__(self, vpc_end_point_id): + def __init__(self, vpc_end_point_id: str): super().__init__( "InvalidVpcEndpointId.NotFound", f"The VpcEndPoint ID '{vpc_end_point_id}' does not exist", @@ -730,7 +730,7 @@ class InvalidCarrierGatewayID(EC2ClientError): class NoLoadBalancersProvided(EC2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidParameter", "exactly one of network_load_balancer_arn or gateway_load_balancer_arn is a required member", @@ -738,7 +738,7 @@ class NoLoadBalancersProvided(EC2ClientError): class UnknownVpcEndpointService(EC2ClientError): - def __init__(self, service_id): + def __init__(self, service_id: str): super().__init__( "InvalidVpcEndpointServiceId.NotFound", f"The VpcEndpointService Id '{service_id}' does not exist", diff --git a/moto/ec2/models/__init__.py b/moto/ec2/models/__init__.py index c57182bc8..70bcd56f1 100644 --- a/moto/ec2/models/__init__.py +++ b/moto/ec2/models/__init__.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List from moto.core import BaseBackend, BackendDict from ..exceptions import ( EC2ClientError, @@ -49,7 +50,7 @@ from ..utils import ( ) -def validate_resource_ids(resource_ids): +def validate_resource_ids(resource_ids: List[str]) -> bool: if not resource_ids: raise MissingParameterError(parameter="resourceIdSet") for resource_id in resource_ids: @@ -59,19 +60,19 @@ def validate_resource_ids(resource_ids): class SettingsBackend: - def __init__(self): + def __init__(self) -> None: self.ebs_encryption_by_default = False - def disable_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.account_id][self.region_name] + def disable_ebs_encryption_by_default(self) -> None: + ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined] ec2_backend.ebs_encryption_by_default = False - def enable_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.account_id][self.region_name] + def enable_ebs_encryption_by_default(self) -> None: + ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined] ec2_backend.ebs_encryption_by_default = True - def get_ebs_encryption_by_default(self): - ec2_backend = ec2_backends[self.account_id][self.region_name] + def get_ebs_encryption_by_default(self) -> None: + ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined] return ec2_backend.ebs_encryption_by_default @@ -129,11 +130,11 @@ class EC2Backend( """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): BaseBackend.__init__(self, region_name, account_id) for backend in EC2Backend.__mro__: if backend not in [EC2Backend, BaseBackend, object]: - backend.__init__(self) + backend.__init__(self) # type: ignore # Default VPC exists by default, which is the current behavior # of EC2-VPC. See for detail: @@ -145,23 +146,23 @@ class EC2Backend( else: # For now this is included for potential # backward-compatibility issues - vpc = self.vpcs.values()[0] + vpc = list(self.vpcs.values())[0] self.default_vpc = vpc # Create default subnet for each availability zone ip, _ = vpc.cidr_block.split("/") - ip = ip.split(".") - ip[2] = 0 + ip = ip.split(".") # type: ignore + ip[2] = 0 # type: ignore for zone in self.describe_availability_zones(): az_name = zone.name cidr_block = ".".join(str(i) for i in ip) + "/20" self.create_subnet(vpc.id, cidr_block, availability_zone=az_name) - ip[2] += 16 + ip[2] += 16 # type: ignore @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc] """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "ec2" @@ -171,13 +172,13 @@ class EC2Backend( # Use this to generate a proper error template response when in a response # handler. - def raise_error(self, code, message): + def raise_error(self, code: str, message: str) -> None: raise EC2ClientError(code, message) - def raise_not_implemented_error(self, blurb: str): + def raise_not_implemented_error(self, blurb: str) -> None: raise MotoNotImplementedError(blurb) - def do_resources_exist(self, resource_ids): + def do_resources_exist(self, resource_ids: List[str]) -> bool: for resource_id in resource_ids: resource_prefix = get_prefix(resource_id) if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]: diff --git a/moto/ec2/models/vpc_peering_connections.py b/moto/ec2/models/vpc_peering_connections.py index dfb026328..0498080fd 100644 --- a/moto/ec2/models/vpc_peering_connections.py +++ b/moto/ec2/models/vpc_peering_connections.py @@ -1,5 +1,6 @@ import weakref from collections import defaultdict +from typing import Any, Dict, Iterator, List, Optional from moto.core import CloudFormationModel from ..exceptions import ( InvalidVPCPeeringConnectionIdError, @@ -8,6 +9,7 @@ from ..exceptions import ( OperationNotPermitted3, ) from .core import TaggedEC2Resource +from .vpcs import VPC from ..utils import random_vpc_peering_connection_id @@ -44,7 +46,14 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel): "AllowDnsResolutionFromRemoteVpc": "false", } - def __init__(self, backend, vpc_pcx_id, vpc, peer_vpc, tags=None): + def __init__( + self, + backend: Any, + vpc_pcx_id: str, + vpc: VPC, + peer_vpc: VPC, + tags: Optional[Dict[str, str]] = None, + ): self.id = vpc_pcx_id self.ec2_backend = backend self.vpc = vpc @@ -55,18 +64,23 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel): self._status = PeeringConnectionStatus() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcpeeringconnection.html return "AWS::EC2::VPCPeeringConnection" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any + ) -> "VPCPeeringConnection": from ..models import ec2_backends properties = cloudformation_json["Properties"] @@ -80,26 +94,28 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel): return vpc_pcx @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id class VPCPeeringConnectionBackend: # for cross region vpc reference - vpc_pcx_refs = defaultdict(set) + vpc_pcx_refs = defaultdict(set) # type: ignore - def __init__(self): - self.vpc_pcxs = {} + def __init__(self) -> None: + self.vpc_pcxs: Dict[str, VPCPeeringConnection] = {} self.vpc_pcx_refs[self.__class__].add(weakref.ref(self)) @classmethod - def get_vpc_pcx_refs(cls): + def get_vpc_pcx_refs(cls) -> Iterator[VPCPeeringConnection]: for inst_ref in cls.vpc_pcx_refs[cls]: inst = inst_ref() if inst is not None: yield inst - def create_vpc_peering_connection(self, vpc, peer_vpc, tags=None): + def create_vpc_peering_connection( + self, vpc: VPC, peer_vpc: VPC, tags: Optional[Dict[str, str]] = None + ) -> VPCPeeringConnection: vpc_pcx_id = random_vpc_peering_connection_id() vpc_pcx = VPCPeeringConnection(self, vpc_pcx_id, vpc, peer_vpc, tags) vpc_pcx._status.pending() @@ -111,49 +127,54 @@ class VPCPeeringConnectionBackend: vpc_pcx_cx.vpc_pcxs[vpc_pcx_id] = vpc_pcx return vpc_pcx - def describe_vpc_peering_connections(self, vpc_peering_ids=None): - all_pcxs = self.vpc_pcxs.copy().values() + def describe_vpc_peering_connections( + self, vpc_peering_ids: Optional[List[str]] = None + ) -> List[VPCPeeringConnection]: + all_pcxs = list(self.vpc_pcxs.values()) if vpc_peering_ids: return [pcx for pcx in all_pcxs if pcx.id in vpc_peering_ids] return all_pcxs - def get_vpc_peering_connection(self, vpc_pcx_id): + def get_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection: if vpc_pcx_id not in self.vpc_pcxs: raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id) - return self.vpc_pcxs.get(vpc_pcx_id) + return self.vpc_pcxs[vpc_pcx_id] - def delete_vpc_peering_connection(self, vpc_pcx_id): + def delete_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection: deleted = self.get_vpc_peering_connection(vpc_pcx_id) deleted._status.deleted() return deleted - def accept_vpc_peering_connection(self, vpc_pcx_id): + def accept_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection: vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) # if cross region need accepter from another region pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name - if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: - raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) + if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined] + raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined] if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.accept() return vpc_pcx - def reject_vpc_peering_connection(self, vpc_pcx_id): + def reject_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection: vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) # if cross region need accepter from another region pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name - if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: - raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) + if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined] + raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined] if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.reject() return vpc_pcx def modify_vpc_peering_connection_options( - self, vpc_pcx_id, accepter_options=None, requester_options=None - ): + self, + vpc_pcx_id: str, + accepter_options: Optional[Dict[str, Any]] = None, + requester_options: Optional[Dict[str, Any]] = None, + ) -> None: vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) if not vpc_pcx: raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id) diff --git a/moto/ec2/models/vpc_service_configuration.py b/moto/ec2/models/vpc_service_configuration.py index a7c760e70..cccad8d4d 100644 --- a/moto/ec2/models/vpc_service_configuration.py +++ b/moto/ec2/models/vpc_service_configuration.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Optional from moto.core import CloudFormationModel from moto.moto_api._internal import mock_random from .core import TaggedEC2Resource @@ -6,7 +7,12 @@ from ..exceptions import UnknownVpcEndpointService class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel): def __init__( - self, load_balancers, region, acceptance_required, private_dns_name, ec2_backend + self, + load_balancers: List[Any], + region: str, + acceptance_required: bool, + private_dns_name: str, + ec2_backend: Any, ): self.id = f"vpce-svc-{mock_random.get_random_hex(length=8)}" self.service_name = f"com.amazonaws.vpce.{region}.{self.id}" @@ -32,43 +38,49 @@ class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel): self.private_dns_name = private_dns_name self.endpoint_dns_name = f"{self.id}.{region}.vpce.amazonaws.com" - self.principals = [] + self.principals: List[str] = [] self.ec2_backend = ec2_backend class VPCServiceConfigurationBackend: - def __init__(self): - self.configurations = {} + def __init__(self) -> None: + self.configurations: Dict[str, VPCServiceConfiguration] = {} @property - def elbv2_backend(self): + def elbv2_backend(self) -> Any: # type: ignore[misc] from moto.elbv2.models import elbv2_backends - return elbv2_backends[self.account_id][self.region_name] + return elbv2_backends[self.account_id][self.region_name] # type: ignore[attr-defined] - def get_vpc_endpoint_service(self, resource_id): + def get_vpc_endpoint_service( + self, resource_id: str + ) -> Optional[VPCServiceConfiguration]: return self.configurations.get(resource_id) def create_vpc_endpoint_service_configuration( - self, lb_arns, acceptance_required, private_dns_name, tags - ): + self, + lb_arns: List[Any], + acceptance_required: bool, + private_dns_name: str, + tags: List[Dict[str, str]], + ) -> VPCServiceConfiguration: lbs = self.elbv2_backend.describe_load_balancers(arns=lb_arns, names=None) config = VPCServiceConfiguration( load_balancers=lbs, - region=self.region_name, + region=self.region_name, # type: ignore[attr-defined] acceptance_required=acceptance_required, private_dns_name=private_dns_name, ec2_backend=self, ) for tag in tags or []: - tag_key = tag.get("Key") - tag_value = tag.get("Value") - config.add_tag(tag_key, tag_value) + config.add_tag(tag["Key"], tag["Value"]) self.configurations[config.id] = config return config - def describe_vpc_endpoint_service_configurations(self, service_ids): + def describe_vpc_endpoint_service_configurations( + self, service_ids: Optional[List[str]] + ) -> List[VPCServiceConfiguration]: """ The Filters, MaxResults, NextToken parameters are not yet implemented """ @@ -80,15 +92,17 @@ class VPCServiceConfigurationBackend: else: raise UnknownVpcEndpointService(service_id) return found_configs - return self.configurations.copy().values() + return list(self.configurations.values()) - def delete_vpc_endpoint_service_configurations(self, service_ids): + def delete_vpc_endpoint_service_configurations( + self, service_ids: List[str] + ) -> List[str]: missing = [s for s in service_ids if s not in self.configurations] for s in service_ids: self.configurations.pop(s, None) return missing - def describe_vpc_endpoint_service_permissions(self, service_id): + def describe_vpc_endpoint_service_permissions(self, service_id: str) -> List[str]: """ The Filters, MaxResults, NextToken parameters are not yet implemented """ @@ -96,8 +110,8 @@ class VPCServiceConfigurationBackend: return config.principals def modify_vpc_endpoint_service_permissions( - self, service_id, add_principals, remove_principals - ): + self, service_id: str, add_principals: List[str], remove_principals: List[str] + ) -> None: config = self.describe_vpc_endpoint_service_configurations([service_id])[0] config.principals += add_principals config.principals = [p for p in config.principals if p not in remove_principals] @@ -105,14 +119,14 @@ class VPCServiceConfigurationBackend: def modify_vpc_endpoint_service_configuration( self, - service_id, - acceptance_required, - private_dns_name, - add_network_lbs, - remove_network_lbs, - add_gateway_lbs, - remove_gateway_lbs, - ): + service_id: str, + acceptance_required: Optional[str], + private_dns_name: Optional[str], + add_network_lbs: List[str], + remove_network_lbs: List[str], + add_gateway_lbs: List[str], + remove_gateway_lbs: List[str], + ) -> None: """ The following parameters are not yet implemented: RemovePrivateDnsName """ diff --git a/moto/ec2/models/vpcs.py b/moto/ec2/models/vpcs.py index 78c48591f..5ccdcbf02 100644 --- a/moto/ec2/models/vpcs.py +++ b/moto/ec2/models/vpcs.py @@ -2,6 +2,7 @@ import ipaddress import json import weakref from collections import defaultdict +from typing import Any, Dict, List, Optional from operator import itemgetter from moto.core import CloudFormationModel @@ -35,27 +36,27 @@ from ..utils import ( ) MAX_NUMBER_OF_ENDPOINT_SERVICES_RESULTS = 1000 -DEFAULT_VPC_ENDPOINT_SERVICES = [] +DEFAULT_VPC_ENDPOINT_SERVICES: List[Dict[str, str]] = [] class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): def __init__( self, - ec2_backend, - endpoint_id, - vpc_id, - service_name, - endpoint_type=None, - policy_document=False, - route_table_ids=None, - subnet_ids=None, - network_interface_ids=None, - dns_entries=None, - client_token=None, - security_group_ids=None, - tags=None, - private_dns_enabled=None, - destination_prefix_list_id=None, + ec2_backend: Any, + endpoint_id: str, + vpc_id: str, + service_name: str, + endpoint_type: Optional[str], + policy_document: Optional[str], + route_table_ids: List[str], + subnet_ids: Optional[List[str]] = None, + network_interface_ids: Optional[List[str]] = None, + dns_entries: Optional[List[Dict[str, str]]] = None, + client_token: Optional[str] = None, + security_group_ids: Optional[List[str]] = None, + tags: Optional[Dict[str, str]] = None, + private_dns_enabled: Optional[str] = None, + destination_prefix_list_id: Optional[str] = None, ): self.ec2_backend = ec2_backend self.id = endpoint_id @@ -76,11 +77,17 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): self.created_at = utc_date_and_time() - def modify(self, policy_doc, add_subnets, add_route_tables, remove_route_tables): + def modify( + self, + policy_doc: Optional[str], + add_subnets: Optional[List[str]], + add_route_tables: Optional[List[str]], + remove_route_tables: Optional[List[str]], + ) -> None: if policy_doc: self.policy_document = policy_doc if add_subnets: - self.subnet_ids.extend(add_subnets) + self.subnet_ids.extend(add_subnets) # type: ignore[union-attr] if add_route_tables: self.route_table_ids.extend(add_route_tables) if remove_route_tables: @@ -90,32 +97,39 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): if rt_id not in remove_route_tables ] - def get_filter_value(self, filter_name): + def get_filter_value( + self, filter_name: str, method_name: Optional[str] = None + ) -> Any: if filter_name in ("vpc-endpoint-type", "vpc_endpoint_type"): return self.endpoint_type else: return super().get_filter_value(filter_name, "DescribeVpcs") @property - def owner_id(self): + def owner_id(self) -> str: return self.ec2_backend.account_id @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::EC2::VPCEndpoint" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "VPCEndPoint": from ..models import ec2_backends properties = cloudformation_json["Properties"] @@ -146,19 +160,19 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): class VPC(TaggedEC2Resource, CloudFormationModel): def __init__( self, - ec2_backend, - vpc_id, - cidr_block, - is_default, - instance_tenancy="default", - amazon_provided_ipv6_cidr_block=False, - ipv6_cidr_block_network_border_group=None, + ec2_backend: Any, + vpc_id: str, + cidr_block: str, + is_default: bool, + instance_tenancy: str = "default", + amazon_provided_ipv6_cidr_block: bool = False, + ipv6_cidr_block_network_border_group: Optional[str] = None, ): self.ec2_backend = ec2_backend self.id = vpc_id self.cidr_block = cidr_block - self.cidr_block_association_set = {} + self.cidr_block_association_set: Dict[str, Any] = {} self.dhcp_options = None self.state = "available" self.instance_tenancy = instance_tenancy @@ -180,22 +194,27 @@ class VPC(TaggedEC2Resource, CloudFormationModel): ) @property - def owner_id(self): + def owner_id(self) -> str: return self.ec2_backend.account_id @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpc.html return "AWS::EC2::VPC" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "VPC": from ..models import ec2_backends properties = cloudformation_json["Properties"] @@ -213,10 +232,12 @@ class VPC(TaggedEC2Resource, CloudFormationModel): return vpc @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id - def get_filter_value(self, filter_name): + def get_filter_value( + self, filter_name: str, method_name: Optional[str] = None + ) -> Any: if filter_name in ("vpc-id", "vpcId"): return self.id elif filter_name in ("cidr", "cidr-block", "cidrBlock"): @@ -255,23 +276,22 @@ class VPC(TaggedEC2Resource, CloudFormationModel): else: return super().get_filter_value(filter_name, "DescribeVpcs") - def modify_vpc_tenancy(self, tenancy): + def modify_vpc_tenancy(self, tenancy: str) -> None: if tenancy != "default": raise UnsupportedTenancy(tenancy) self.instance_tenancy = tenancy - return True def associate_vpc_cidr_block( self, - cidr_block, - amazon_provided_ipv6_cidr_block=False, - ipv6_cidr_block_network_border_group=None, - ): + cidr_block: str, + amazon_provided_ipv6_cidr_block: bool = False, + ipv6_cidr_block_network_border_group: Optional[str] = None, + ) -> Dict[str, Any]: max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1 for cidr in self.cidr_block_association_set.copy(): if ( - self.cidr_block_association_set.get(cidr) + self.cidr_block_association_set.get(cidr) # type: ignore[union-attr] .get("cidr_block_state") .get("state") == "disassociated" @@ -285,7 +305,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel): association_id = random_vpc_cidr_association_id() - association_set = { + association_set: Dict[str, Any] = { "association_id": association_id, "cidr_block_state": {"state": "associated", "StatusMessage": ""}, } @@ -301,7 +321,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel): self.cidr_block_association_set[association_id] = association_set return association_set - def enable_vpc_classic_link(self): + def enable_vpc_classic_link(self) -> str: # Check if current cidr block doesn't fall within the 10.0.0.0/8 block, excluding 10.0.0.0/16 and 10.1.0.0/16. # Doesn't check any route tables, maybe something for in the future? # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/vpc-classiclink.html#classiclink-limitations @@ -315,19 +335,19 @@ class VPC(TaggedEC2Resource, CloudFormationModel): return self.classic_link_enabled - def disable_vpc_classic_link(self): + def disable_vpc_classic_link(self) -> str: self.classic_link_enabled = "false" return self.classic_link_enabled - def enable_vpc_classic_link_dns_support(self): + def enable_vpc_classic_link_dns_support(self) -> str: self.classic_link_dns_supported = "true" return self.classic_link_dns_supported - def disable_vpc_classic_link_dns_support(self): + def disable_vpc_classic_link_dns_support(self) -> str: self.classic_link_dns_supported = "false" return self.classic_link_dns_supported - def disassociate_vpc_cidr_block(self, association_id): + def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]: if self.cidr_block == self.cidr_block_association_set.get( association_id, {} ).get("cidr_block"): @@ -341,7 +361,9 @@ class VPC(TaggedEC2Resource, CloudFormationModel): entry["cidr_block_state"]["state"] = "disassociated" return response - def get_cidr_block_association_set(self, ipv6=False): + def get_cidr_block_association_set( + self, ipv6: bool = False + ) -> List[Dict[str, Any]]: return [ c for c in self.cidr_block_association_set.values() @@ -350,14 +372,14 @@ class VPC(TaggedEC2Resource, CloudFormationModel): class VPCBackend: - vpc_refs = defaultdict(set) + vpc_refs = defaultdict(set) # type: ignore - def __init__(self): - self.vpcs = {} - self.vpc_end_points = {} + def __init__(self) -> None: + self.vpcs: Dict[str, VPC] = {} + self.vpc_end_points: Dict[str, VPCEndPoint] = {} self.vpc_refs[self.__class__].add(weakref.ref(self)) - def create_default_vpc(self): + def create_default_vpc(self) -> VPC: default_vpc = self.describe_vpcs(filters={"is-default": "true"}) if default_vpc: raise DefaultVpcAlreadyExists @@ -366,13 +388,13 @@ class VPCBackend: def create_vpc( self, - cidr_block, - instance_tenancy="default", - amazon_provided_ipv6_cidr_block=False, - ipv6_cidr_block_network_border_group=None, - tags=None, - is_default=False, - ): + cidr_block: str, + instance_tenancy: str = "default", + amazon_provided_ipv6_cidr_block: bool = False, + ipv6_cidr_block_network_border_group: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + is_default: bool = False, + ) -> VPC: vpc_id = random_vpc_id() try: vpc_cidr_block = ipaddress.IPv4Network(str(cidr_block), strict=False) @@ -391,45 +413,45 @@ class VPCBackend: ) for tag in tags or []: - tag_key = tag.get("Key") - tag_value = tag.get("Value") - vpc.add_tag(tag_key, tag_value) + vpc.add_tag(tag["Key"], tag["Value"]) self.vpcs[vpc_id] = vpc # AWS creates a default main route table and security group. - self.create_route_table(vpc_id, main=True) + self.create_route_table(vpc_id, main=True) # type: ignore[attr-defined] # AWS creates a default Network ACL - self.create_network_acl(vpc_id, default=True) + self.create_network_acl(vpc_id, default=True) # type: ignore[attr-defined] - default = self.get_security_group_from_name("default", vpc_id=vpc_id) + default = self.get_security_group_from_name("default", vpc_id=vpc_id) # type: ignore[attr-defined] if not default: - self.create_security_group( + self.create_security_group( # type: ignore[attr-defined] "default", "default VPC security group", vpc_id=vpc_id, is_default=True ) return vpc - def get_vpc(self, vpc_id): + def get_vpc(self, vpc_id: str) -> VPC: if vpc_id not in self.vpcs: raise InvalidVPCIdError(vpc_id) - return self.vpcs.get(vpc_id) + return self.vpcs[vpc_id] - def describe_vpcs(self, vpc_ids=None, filters=None): - matches = self.vpcs.copy().values() + def describe_vpcs( + self, vpc_ids: Optional[List[str]] = None, filters: Any = None + ) -> List[VPC]: + matches = list(self.vpcs.values()) if vpc_ids: matches = [vpc for vpc in matches if vpc.id in vpc_ids] if len(vpc_ids) > len(matches): - unknown_ids = set(vpc_ids) - set(matches) + unknown_ids = set(vpc_ids) - set(matches) # type: ignore[arg-type] raise InvalidVPCIdError(unknown_ids) if filters: matches = generic_filter(filters, matches) return matches - def delete_vpc(self, vpc_id): + def delete_vpc(self, vpc_id: str) -> VPC: # Do not delete if any VPN Gateway is attached - vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id}) + vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id}) # type: ignore[attr-defined] vpn_gateways = [ item for item in vpn_gateways @@ -441,18 +463,18 @@ class VPCBackend: ) # Delete route table if only main route table remains. - route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id}) + route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id}) # type: ignore[attr-defined] if len(route_tables) > 1: raise DependencyViolationError( f"The vpc {vpc_id} has dependencies and cannot be deleted." ) for route_table in route_tables: - self.delete_route_table(route_table.id) + self.delete_route_table(route_table.id) # type: ignore[attr-defined] # Delete default security group if exists. - default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) + default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) # type: ignore[attr-defined] if default: - self.delete_security_group(group_id=default.id) + self.delete_security_group(group_id=default.id) # type: ignore[attr-defined] # Now delete VPC. vpc = self.vpcs.pop(vpc_id, None) @@ -465,7 +487,7 @@ class VPCBackend: vpc.dhcp_options = None return vpc - def describe_vpc_attribute(self, vpc_id, attr_name): + def describe_vpc_attribute(self, vpc_id: str, attr_name: str) -> Any: vpc = self.get_vpc(vpc_id) if attr_name in ( "enable_dns_support", @@ -476,27 +498,29 @@ class VPCBackend: else: raise InvalidParameterValueError(attr_name) - def modify_vpc_tenancy(self, vpc_id, tenancy): + def modify_vpc_tenancy(self, vpc_id: str, tenancy: str) -> None: vpc = self.get_vpc(vpc_id) - return vpc.modify_vpc_tenancy(tenancy) + vpc.modify_vpc_tenancy(tenancy) - def enable_vpc_classic_link(self, vpc_id): + def enable_vpc_classic_link(self, vpc_id: str) -> str: vpc = self.get_vpc(vpc_id) return vpc.enable_vpc_classic_link() - def disable_vpc_classic_link(self, vpc_id): + def disable_vpc_classic_link(self, vpc_id: str) -> str: vpc = self.get_vpc(vpc_id) return vpc.disable_vpc_classic_link() - def enable_vpc_classic_link_dns_support(self, vpc_id): + def enable_vpc_classic_link_dns_support(self, vpc_id: str) -> str: vpc = self.get_vpc(vpc_id) return vpc.enable_vpc_classic_link_dns_support() - def disable_vpc_classic_link_dns_support(self, vpc_id): + def disable_vpc_classic_link_dns_support(self, vpc_id: str) -> str: vpc = self.get_vpc(vpc_id) return vpc.disable_vpc_classic_link_dns_support() - def modify_vpc_attribute(self, vpc_id, attr_name, attr_value): + def modify_vpc_attribute( + self, vpc_id: str, attr_name: str, attr_value: str + ) -> None: vpc = self.get_vpc(vpc_id) if attr_name in ( "enable_dns_support", @@ -507,58 +531,58 @@ class VPCBackend: else: raise InvalidParameterValueError(attr_name) - def disassociate_vpc_cidr_block(self, association_id): + def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]: for vpc in self.vpcs.copy().values(): response = vpc.disassociate_vpc_cidr_block(association_id) - for route_table in self.route_tables.copy().values(): + for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined] if route_table.vpc_id == response.get("vpc_id"): - if "::/" in response.get("cidr_block"): - self.delete_route( + if "::/" in response.get("cidr_block"): # type: ignore[operator] + self.delete_route( # type: ignore[attr-defined] route_table.id, None, response.get("cidr_block") ) else: - self.delete_route(route_table.id, response.get("cidr_block")) + self.delete_route(route_table.id, response.get("cidr_block")) # type: ignore[attr-defined] if response: return response raise InvalidVpcCidrBlockAssociationIdError(association_id) def associate_vpc_cidr_block( - self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block - ): + self, vpc_id: str, cidr_block: str, amazon_provided_ipv6_cidr_block: bool + ) -> Dict[str, Any]: vpc = self.get_vpc(vpc_id) association_set = vpc.associate_vpc_cidr_block( cidr_block, amazon_provided_ipv6_cidr_block ) - for route_table in self.route_tables.copy().values(): + for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined] if route_table.vpc_id == vpc_id: if amazon_provided_ipv6_cidr_block: - self.create_route( + self.create_route( # type: ignore[attr-defined] route_table.id, None, destination_ipv6_cidr_block=association_set["cidr_block"], local=True, ) else: - self.create_route( + self.create_route( # type: ignore[attr-defined] route_table.id, association_set["cidr_block"], local=True ) return association_set def create_vpc_endpoint( self, - vpc_id, - service_name, - endpoint_type=None, - policy_document=False, - route_table_ids=None, - subnet_ids=None, - network_interface_ids=None, - dns_entries=None, - client_token=None, - security_group_ids=None, - tags=None, - private_dns_enabled=None, - ): + vpc_id: str, + service_name: str, + endpoint_type: Optional[str], + policy_document: Optional[str], + route_table_ids: List[str], + subnet_ids: Optional[List[str]] = None, + network_interface_ids: Optional[List[str]] = None, + dns_entries: Optional[Dict[str, str]] = None, + client_token: Optional[str] = None, + security_group_ids: Optional[List[str]] = None, + tags: Optional[Dict[str, str]] = None, + private_dns_enabled: Optional[str] = None, + ) -> VPCEndPoint: vpc_endpoint_id = random_vpc_ep_id() @@ -570,21 +594,18 @@ class VPCBackend: network_interface_ids = [] for subnet_id in subnet_ids or []: - self.get_subnet(subnet_id) - eni = self.create_network_interface(subnet_id, random_private_ip()) + self.get_subnet(subnet_id) # type: ignore[attr-defined] + eni = self.create_network_interface(subnet_id, random_private_ip()) # type: ignore[attr-defined] network_interface_ids.append(eni.id) dns_entries = create_dns_entries(service_name, vpc_endpoint_id) else: # considering gateway if type is not mentioned. - for prefix_list in self.managed_prefix_lists.values(): + for prefix_list in self.managed_prefix_lists.values(): # type: ignore[attr-defined] if prefix_list.prefix_list_name == service_name: destination_prefix_list_id = prefix_list.id - if dns_entries: - dns_entries = [dns_entries] - vpc_end_point = VPCEndPoint( self, vpc_endpoint_id, @@ -595,19 +616,19 @@ class VPCBackend: route_table_ids, subnet_ids, network_interface_ids, - dns_entries, - client_token, - security_group_ids, - tags, - private_dns_enabled, - destination_prefix_list_id, + dns_entries=[dns_entries] if dns_entries else None, + client_token=client_token, + security_group_ids=security_group_ids, + tags=tags, + private_dns_enabled=private_dns_enabled, + destination_prefix_list_id=destination_prefix_list_id, ) self.vpc_end_points[vpc_endpoint_id] = vpc_end_point if destination_prefix_list_id: for route_table_id in route_table_ids: - self.create_route( + self.create_route( # type: ignore[attr-defined] route_table_id, None, gateway_id=vpc_endpoint_id, @@ -617,28 +638,34 @@ class VPCBackend: return vpc_end_point def modify_vpc_endpoint( - self, vpc_id, policy_doc, add_subnets, remove_route_tables, add_route_tables - ): + self, + vpc_id: str, + policy_doc: str, + add_subnets: Optional[List[str]], + remove_route_tables: Optional[List[str]], + add_route_tables: Optional[List[str]], + ) -> None: endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0] endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables) - def delete_vpc_endpoints(self, vpce_ids=None): + def delete_vpc_endpoints(self, vpce_ids: Optional[List[str]] = None) -> None: for vpce_id in vpce_ids or []: vpc_endpoint = self.vpc_end_points.get(vpce_id, None) if vpc_endpoint: - if vpc_endpoint.endpoint_type.lower() == "interface": + if vpc_endpoint.endpoint_type.lower() == "interface": # type: ignore[union-attr] for eni_id in vpc_endpoint.network_interface_ids: - self.enis.pop(eni_id, None) + self.enis.pop(eni_id, None) # type: ignore[attr-defined] else: for route_table_id in vpc_endpoint.route_table_ids: - self.delete_route( + self.delete_route( # type: ignore[attr-defined] route_table_id, vpc_endpoint.destination_prefix_list_id ) vpc_endpoint.state = "deleted" - return True - def describe_vpc_endpoints(self, vpc_end_point_ids, filters=None): - vpc_end_points = self.vpc_end_points.values() + def describe_vpc_endpoints( + self, vpc_end_point_ids: Optional[List[str]], filters: Any = None + ) -> List[VPCEndPoint]: + vpc_end_points = list(self.vpc_end_points.values()) if vpc_end_point_ids: vpc_end_points = [ @@ -657,7 +684,9 @@ class VPCBackend: return generic_filter(filters, vpc_end_points) @staticmethod - def _collect_default_endpoint_services(account_id, region): + def _collect_default_endpoint_services( + account_id: str, region: str + ) -> List[Dict[str, str]]: """Return list of default services using list of backends.""" if DEFAULT_VPC_ENDPOINT_SERVICES: return DEFAULT_VPC_ENDPOINT_SERVICES @@ -672,14 +701,16 @@ class VPCBackend: from moto import backends # pylint: disable=import-outside-toplevel for _backends in backends.service_backends(): - _backends = _backends[account_id] - if region in _backends: - service = _backends[region].default_vpc_endpoint_service(region, zones) + account_backend = _backends[account_id] + if region in account_backend: + service = account_backend[region].default_vpc_endpoint_service( + region, zones + ) if service: DEFAULT_VPC_ENDPOINT_SERVICES.extend(service) - if "global" in _backends: - service = _backends["global"].default_vpc_endpoint_service( + if "global" in account_backend: + service = account_backend["global"].default_vpc_endpoint_service( region, zones ) if service: @@ -687,7 +718,7 @@ class VPCBackend: return DEFAULT_VPC_ENDPOINT_SERVICES @staticmethod - def _matches_service_by_tags(service, filter_item): + def _matches_service_by_tags(service: Dict[str, Any], filter_item: Dict[str, Any]) -> bool: # type: ignore[misc] """Return True if service tags are not filtered by their tags. Note that the API specifies a key of "Values" for a filter, but @@ -719,7 +750,7 @@ class VPCBackend: return matched @staticmethod - def _filter_endpoint_services(service_names_filters, filters, services): + def _filter_endpoint_services(service_names_filters: List[str], filters: List[Dict[str, Any]], services: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # type: ignore[misc] """Return filtered list of VPC endpoint services.""" if not service_names_filters and not filters: return services @@ -774,11 +805,16 @@ class VPCBackend: return filtered_services def describe_vpc_endpoint_services( - self, dry_run, service_names, filters, max_results, next_token, region - ): # pylint: disable=unused-argument,too-many-arguments + self, + service_names: List[str], + filters: Any, + max_results: int, + next_token: Optional[str], + region: str, + ) -> Dict[str, Any]: # pylint: disable=too-many-arguments """Return info on services to which you can create a VPC endpoint. - Currently only the default endpoing services are returned. When + Currently only the default endpoint services are returned. When create_vpc_endpoint_service_configuration() is implemented, a list of those private endpoints would be kept and when this API is invoked, those private endpoints would be added to the list of @@ -787,7 +823,7 @@ class VPCBackend: The DryRun parameter is ignored. """ default_services = self._collect_default_endpoint_services( - self.account_id, region + self.account_id, region # type: ignore[attr-defined] ) for service_name in service_names: if service_name not in [x["ServiceName"] for x in default_services]: @@ -827,7 +863,7 @@ class VPCBackend: "nextToken": next_token, } - def get_vpc_end_point(self, vpc_end_point_id): + def get_vpc_end_point(self, vpc_end_point_id: str) -> VPCEndPoint: vpc_end_point = self.vpc_end_points.get(vpc_end_point_id) if not vpc_end_point: raise InvalidVpcEndPointIdError(vpc_end_point_id) diff --git a/moto/ec2/models/vpn_connections.py b/moto/ec2/models/vpn_connections.py index 18f8d7b85..82c4ca7dc 100644 --- a/moto/ec2/models/vpn_connections.py +++ b/moto/ec2/models/vpn_connections.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, List, Optional from .core import TaggedEC2Resource from ..exceptions import InvalidVpnConnectionIdError from ..utils import generic_filter, random_vpn_connection_id @@ -6,18 +7,18 @@ from ..utils import generic_filter, random_vpn_connection_id class VPNConnection(TaggedEC2Resource): def __init__( self, - ec2_backend, - vpn_connection_id, - vpn_conn_type, - customer_gateway_id, - vpn_gateway_id=None, - transit_gateway_id=None, - tags=None, + ec2_backend: Any, + vpn_connection_id: str, + vpn_conn_type: str, + customer_gateway_id: str, + vpn_gateway_id: Optional[str] = None, + transit_gateway_id: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, ): self.ec2_backend = ec2_backend self.id = vpn_connection_id self.state = "available" - self.customer_gateway_configuration = {} + self.customer_gateway_configuration: Dict[str, str] = {} self.type = vpn_conn_type self.customer_gateway_id = customer_gateway_id self.vpn_gateway_id = vpn_gateway_id @@ -27,23 +28,25 @@ class VPNConnection(TaggedEC2Resource): self.static_routes = None self.add_tags(tags or {}) - def get_filter_value(self, filter_name): + def get_filter_value( + self, filter_name: str, method_name: Optional[str] = None + ) -> Any: return super().get_filter_value(filter_name, "DescribeVpnConnections") class VPNConnectionBackend: - def __init__(self): - self.vpn_connections = {} + def __init__(self) -> None: + self.vpn_connections: Dict[str, VPNConnection] = {} def create_vpn_connection( self, - vpn_conn_type, - customer_gateway_id, - vpn_gateway_id=None, - transit_gateway_id=None, - static_routes_only=None, - tags=None, - ): + vpn_conn_type: str, + customer_gateway_id: str, + vpn_gateway_id: Optional[str] = None, + transit_gateway_id: Optional[str] = None, + static_routes_only: Optional[bool] = None, + tags: Optional[Dict[str, str]] = None, + ) -> VPNConnection: vpn_connection_id = random_vpn_connection_id() if static_routes_only: pass @@ -59,7 +62,7 @@ class VPNConnectionBackend: self.vpn_connections[vpn_connection.id] = vpn_connection return vpn_connection - def delete_vpn_connection(self, vpn_connection_id): + def delete_vpn_connection(self, vpn_connection_id: str) -> VPNConnection: if vpn_connection_id in self.vpn_connections: self.vpn_connections[vpn_connection_id].state = "deleted" @@ -67,17 +70,10 @@ class VPNConnectionBackend: raise InvalidVpnConnectionIdError(vpn_connection_id) return self.vpn_connections[vpn_connection_id] - def describe_vpn_connections(self, vpn_connection_ids=None): - vpn_connections = [] - for vpn_connection_id in vpn_connection_ids or []: - if vpn_connection_id in self.vpn_connections: - vpn_connections.append(self.vpn_connections[vpn_connection_id]) - else: - raise InvalidVpnConnectionIdError(vpn_connection_id) - return vpn_connections or self.vpn_connections.values() - - def get_all_vpn_connections(self, vpn_connection_ids=None, filters=None): - vpn_connections = self.vpn_connections.values() + def describe_vpn_connections( + self, vpn_connection_ids: Optional[List[str]] = None, filters: Any = None + ) -> List[VPNConnection]: + vpn_connections = list(self.vpn_connections.values()) if vpn_connection_ids: vpn_connections = [ diff --git a/moto/ec2/models/vpn_gateway.py b/moto/ec2/models/vpn_gateway.py index 47040829b..18068648c 100644 --- a/moto/ec2/models/vpn_gateway.py +++ b/moto/ec2/models/vpn_gateway.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, List, Optional from moto.core import CloudFormationModel from .core import TaggedEC2Resource from ..exceptions import InvalidVpnGatewayIdError, InvalidVpnGatewayAttachmentError @@ -15,18 +15,23 @@ class VPCGatewayAttachment(CloudFormationModel): self.state = state @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html return "AWS::EC2::VPCGatewayAttachment" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any + ) -> "VPCGatewayAttachment": from ..models import ec2_backends properties = cloudformation_json["Properties"] @@ -45,20 +50,20 @@ class VPCGatewayAttachment(CloudFormationModel): return attachment @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.vpc_id class VpnGateway(CloudFormationModel, TaggedEC2Resource): def __init__( self, - ec2_backend, - gateway_id, - gateway_type, - amazon_side_asn, - availability_zone, - tags=None, - state="available", + ec2_backend: Any, + gateway_id: str, + gateway_type: str, + amazon_side_asn: Optional[str], + availability_zone: Optional[str], + tags: Optional[Dict[str, str]] = None, + state: str = "available", ): self.ec2_backend = ec2_backend self.id = gateway_id @@ -67,22 +72,27 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource): self.availability_zone = availability_zone self.state = state self.add_tags(tags or {}) - self.attachments = {} + self.attachments: Dict[str, VPCGatewayAttachment] = {} super().__init__() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html return "AWS::EC2::VPNGateway" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any + ) -> "VpnGateway": from ..models import ec2_backends properties = cloudformation_json["Properties"] @@ -93,10 +103,12 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource): return ec2_backend.create_vpn_gateway(gateway_type=_type, amazon_side_asn=asn) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id - def get_filter_value(self, filter_name): + def get_filter_value( + self, filter_name: str, method_name: Optional[str] = None + ) -> Any: if filter_name == "attachment.vpc-id": return self.attachments.keys() elif filter_name == "attachment.state": @@ -109,16 +121,16 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource): class VpnGatewayBackend: - def __init__(self): - self.vpn_gateways = {} + def __init__(self) -> None: + self.vpn_gateways: Dict[str, VpnGateway] = {} def create_vpn_gateway( self, - gateway_type="ipsec.1", - amazon_side_asn=None, - availability_zone=None, - tags=None, - ): + gateway_type: str = "ipsec.1", + amazon_side_asn: Optional[str] = None, + availability_zone: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + ) -> VpnGateway: vpn_gateway_id = random_vpn_gateway_id() vpn_gateway = VpnGateway( self, vpn_gateway_id, gateway_type, amazon_side_asn, availability_zone, tags @@ -126,21 +138,25 @@ class VpnGatewayBackend: self.vpn_gateways[vpn_gateway_id] = vpn_gateway return vpn_gateway - def describe_vpn_gateways(self, filters=None, vpn_gw_ids=None): + def describe_vpn_gateways( + self, filters: Any = None, vpn_gw_ids: Optional[List[str]] = None + ) -> List[VpnGateway]: vpn_gateways = list(self.vpn_gateways.values() or []) if vpn_gw_ids: vpn_gateways = [item for item in vpn_gateways if item.id in vpn_gw_ids] return generic_filter(filters, vpn_gateways) - def get_vpn_gateway(self, vpn_gateway_id): + def get_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway: vpn_gateway = self.vpn_gateways.get(vpn_gateway_id, None) if not vpn_gateway: raise InvalidVpnGatewayIdError(vpn_gateway_id) return vpn_gateway - def attach_vpn_gateway(self, vpn_gateway_id, vpc_id): + def attach_vpn_gateway( + self, vpn_gateway_id: str, vpc_id: str + ) -> VPCGatewayAttachment: vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) - self.get_vpc(vpc_id) + self.get_vpc(vpc_id) # type: ignore[attr-defined] attachment = VPCGatewayAttachment(vpc_id, state="attached") for key in vpn_gateway.attachments.copy(): if key.startswith("vpc-"): @@ -148,14 +164,16 @@ class VpnGatewayBackend: vpn_gateway.attachments[vpc_id] = attachment return attachment - def delete_vpn_gateway(self, vpn_gateway_id): + def delete_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway: deleted = self.vpn_gateways.get(vpn_gateway_id, None) if not deleted: raise InvalidVpnGatewayIdError(vpn_gateway_id) deleted.state = "deleted" return deleted - def detach_vpn_gateway(self, vpn_gateway_id, vpc_id): + def detach_vpn_gateway( + self, vpn_gateway_id: str, vpc_id: str + ) -> VPCGatewayAttachment: vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) detached = vpn_gateway.attachments.get(vpc_id, None) if not detached: diff --git a/moto/ec2/models/windows.py b/moto/ec2/models/windows.py index 5021bf49b..0951a09c8 100644 --- a/moto/ec2/models/windows.py +++ b/moto/ec2/models/windows.py @@ -5,7 +5,7 @@ from moto.core import BaseModel class WindowsBackend(BaseModel): def get_password_data(self, instance_id: str) -> str: - instance = self.get_instance(instance_id) + instance = self.get_instance(instance_id) # type: ignore[attr-defined] if instance.platform == "windows": return random.get_random_string(length=128) return "" diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 87eb46974..5e4bc5d55 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -65,9 +65,8 @@ class VPCs(EC2BaseResponse): def modify_vpc_tenancy(self): vpc_id = self._get_param("VpcId") tenancy = self._get_param("InstanceTenancy") - value = self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy) - template = self.response_template(MODIFY_VPC_TENANCY_RESPONSE) - return template.render(value=value) + self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy) + return self.response_template(MODIFY_VPC_TENANCY_RESPONSE).render() def describe_vpc_attribute(self): vpc_id = self._get_param("VpcId") @@ -237,7 +236,6 @@ class VPCs(EC2BaseResponse): def describe_vpc_endpoint_services(self): vpc_end_point_services = self.ec2_backend.describe_vpc_endpoint_services( - dry_run=self._get_bool_param("DryRun"), service_names=self._get_multi_param("ServiceName"), filters=self._get_multi_param("Filter"), max_results=self._get_int_param("MaxResults"), @@ -260,9 +258,8 @@ class VPCs(EC2BaseResponse): def delete_vpc_endpoints(self): vpc_end_points_ids = self._get_multi_param("VpcEndpointId") - response = self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids) - template = self.response_template(DELETE_VPC_ENDPOINT_RESPONSE) - return template.render(response=response) + self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids) + return self.response_template(DELETE_VPC_ENDPOINT_RESPONSE).render() def create_managed_prefix_list(self): address_family = self._get_param("AddressFamily") @@ -767,7 +764,7 @@ DESCRIBE_VPC_ENDPOINT_RESPONSE = """ 19a9ff46-7df6-49b8-9726-3df27527089d - {{ 'Error' if not response else '' }} + """ diff --git a/moto/ec2/responses/vpn_connections.py b/moto/ec2/responses/vpn_connections.py index edf268c77..e2857b02b 100644 --- a/moto/ec2/responses/vpn_connections.py +++ b/moto/ec2/responses/vpn_connections.py @@ -42,7 +42,7 @@ class VPNConnections(EC2BaseResponse): def describe_vpn_connections(self): vpn_connection_ids = self._get_multi_param("VpnConnectionId") filters = self._filters_from_querystring() - vpn_connections = self.ec2_backend.get_all_vpn_connections( + vpn_connections = self.ec2_backend.describe_vpn_connections( vpn_connection_ids=vpn_connection_ids, filters=filters ) template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE) diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index 3c724732a..dc4c95eb7 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -134,11 +134,11 @@ def random_network_acl_subnet_association_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"]) -def random_vpn_gateway_id(): +def random_vpn_gateway_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"]) -def random_vpn_connection_id(): +def random_vpn_connection_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"]) @@ -154,19 +154,19 @@ def random_key_pair_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["key-pair"]) -def random_vpc_id(): +def random_vpc_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"]) -def random_vpc_ep_id(): +def random_vpc_ep_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-endpoint"], size=8) -def random_vpc_cidr_association_id(): +def random_vpc_cidr_association_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"]) -def random_vpc_peering_connection_id(): +def random_vpc_peering_connection_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]) @@ -272,11 +272,11 @@ def random_mac_address() -> str: return f"02:00:00:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x" -def randor_ipv4_cidr(): +def randor_ipv4_cidr() -> str: return f"10.0.{random.randint(0, 255)}.{random.randint(0, 255)}/16" -def random_ipv6_cidr(): +def random_ipv6_cidr() -> str: return f"2400:6500:{random_resource_id(4)}:{random_resource_id(2)}00::/56" @@ -297,13 +297,11 @@ def random_managed_prefix_list_id() -> str: return random_id(prefix=EC2_RESOURCE_TO_PREFIX["managed-prefix-list"], size=8) -def create_dns_entries(service_name, vpc_endpoint_id): - dns_entries = {} - dns_entries[ - "dns_name" - ] = f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}" - dns_entries["hosted_zone_id"] = random_resource_id(13).upper() - return dns_entries +def create_dns_entries(service_name: str, vpc_endpoint_id: str) -> Dict[str, str]: + return { + "dns_name": f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}", + "hosted_zone_id": random_resource_id(13).upper(), + } def utc_date_and_time() -> str: @@ -589,7 +587,7 @@ def get_prefix(resource_id: str) -> str: return resource_id_prefix -def is_valid_resource_id(resource_id): +def is_valid_resource_id(resource_id: str) -> bool: valid_prefixes = EC2_RESOURCE_TO_PREFIX.values() resource_id_prefix = get_prefix(resource_id) if resource_id_prefix not in valid_prefixes: diff --git a/setup.cfg b/setup.cfg index ebedb8289..80aaa293b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -229,7 +229,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/a*,moto/ec2/models/c*,moto/ec2/models/d*,moto/ec2/models/e*,moto/ec2/models/f*,moto/ec2/models/h*,moto/ec2/models/i*,moto/ec2/models/k*,moto/ec2/models/l*,moto/ec2/models/m*,moto/ec2/models/n*,moto/ec2/models/r*,moto/ec2/models/s*,moto/ec2/models/t*,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/**/*.py,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract