Techdebt: MyPy EC2 models (#5925)

This commit is contained in:
Bert Blommers 2023-02-13 19:28:15 -01:00 committed by GitHub
parent 859114114c
commit c50737d027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 422 additions and 341 deletions

View File

@ -30,7 +30,7 @@ class EC2ClientError(RESTError):
class DefaultVpcAlreadyExists(EC2ClientError): class DefaultVpcAlreadyExists(EC2ClientError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"DefaultVpcAlreadyExists", "DefaultVpcAlreadyExists",
"A Default VPC already exists for this account in this region.", "A Default VPC already exists for this account in this region.",
@ -59,7 +59,7 @@ class InvalidDHCPOptionsIdError(EC2ClientError):
class InvalidRequest(EC2ClientError): class InvalidRequest(EC2ClientError):
def __init__(self): def __init__(self) -> None:
super().__init__("InvalidRequest", "The request received was invalid") super().__init__("InvalidRequest", "The request received was invalid")
@ -98,7 +98,7 @@ class InvalidKeyPairFormatError(EC2ClientError):
class InvalidVPCIdError(EC2ClientError): class InvalidVPCIdError(EC2ClientError):
def __init__(self, vpc_id: str): def __init__(self, vpc_id: Any):
super().__init__("InvalidVpcID.NotFound", f"VpcID {vpc_id} does not exist.") super().__init__("InvalidVpcID.NotFound", f"VpcID {vpc_id} does not exist.")
@ -134,7 +134,7 @@ class InvalidNetworkAclIdError(EC2ClientError):
class InvalidVpnGatewayIdError(EC2ClientError): class InvalidVpnGatewayIdError(EC2ClientError):
def __init__(self, vpn_gw): def __init__(self, vpn_gw: str):
super().__init__( super().__init__(
"InvalidVpnGatewayID.NotFound", "InvalidVpnGatewayID.NotFound",
f"The virtual private gateway ID '{vpn_gw}' does not exist", f"The virtual private gateway ID '{vpn_gw}' does not exist",
@ -142,7 +142,7 @@ class InvalidVpnGatewayIdError(EC2ClientError):
class InvalidVpnGatewayAttachmentError(EC2ClientError): class InvalidVpnGatewayAttachmentError(EC2ClientError):
def __init__(self, vpn_gw, vpc_id): def __init__(self, vpn_gw: str, vpc_id: str):
super().__init__( super().__init__(
"InvalidVpnGatewayAttachment.NotFound", "InvalidVpnGatewayAttachment.NotFound",
f"The attachment with vpn gateway ID '{vpn_gw}' and vpc ID '{vpc_id}' does not exist", f"The attachment with vpn gateway ID '{vpn_gw}' and vpc ID '{vpc_id}' does not exist",
@ -150,7 +150,7 @@ class InvalidVpnGatewayAttachmentError(EC2ClientError):
class InvalidVpnConnectionIdError(EC2ClientError): class InvalidVpnConnectionIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id: str):
super().__init__( super().__init__(
"InvalidVpnConnectionID.NotFound", "InvalidVpnConnectionID.NotFound",
f"The vpnConnection ID '{network_acl_id}' does not exist", f"The vpnConnection ID '{network_acl_id}' does not exist",
@ -365,7 +365,7 @@ class InvalidAssociationIdError(EC2ClientError):
class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id: str):
super().__init__( super().__init__(
"InvalidVpcCidrBlockAssociationIdError.NotFound", "InvalidVpcCidrBlockAssociationIdError.NotFound",
f"The vpc CIDR block association ID '{association_id}' does not exist", f"The vpc CIDR block association ID '{association_id}' does not exist",
@ -373,7 +373,7 @@ class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
class InvalidVPCPeeringConnectionIdError(EC2ClientError): class InvalidVPCPeeringConnectionIdError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id: str):
super().__init__( super().__init__(
"InvalidVpcPeeringConnectionId.NotFound", "InvalidVpcPeeringConnectionId.NotFound",
f"VpcPeeringConnectionID {vpc_peering_connection_id} does not exist.", f"VpcPeeringConnectionID {vpc_peering_connection_id} does not exist.",
@ -381,7 +381,7 @@ class InvalidVPCPeeringConnectionIdError(EC2ClientError):
class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id: str):
super().__init__( super().__init__(
"InvalidStateTransition", "InvalidStateTransition",
f"VpcPeeringConnectionID {vpc_peering_connection_id} is not in the correct state for the request.", f"VpcPeeringConnectionID {vpc_peering_connection_id} is not in the correct state for the request.",
@ -389,7 +389,7 @@ class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
class InvalidServiceName(EC2ClientError): class InvalidServiceName(EC2ClientError):
def __init__(self, service_name): def __init__(self, service_name: str):
super().__init__( super().__init__(
"InvalidServiceName", "InvalidServiceName",
f"The Vpc Endpoint Service '{service_name}' does not exist", f"The Vpc Endpoint Service '{service_name}' does not exist",
@ -402,7 +402,7 @@ class InvalidFilter(EC2ClientError):
class InvalidNextToken(EC2ClientError): class InvalidNextToken(EC2ClientError):
def __init__(self, next_token): def __init__(self, next_token: str):
super().__init__("InvalidNextToken", f"The token '{next_token}' is invalid") super().__init__("InvalidNextToken", f"The token '{next_token}' is invalid")
@ -436,7 +436,7 @@ class InvalidParameterValueError(EC2ClientError):
class EmptyTagSpecError(EC2ClientError): class EmptyTagSpecError(EC2ClientError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidParameterValue", "Tag specification must have at least one tag" "InvalidParameterValue", "Tag specification must have at least one tag"
) )
@ -498,7 +498,7 @@ class TagLimitExceeded(EC2ClientError):
class InvalidID(EC2ClientError): class InvalidID(EC2ClientError):
def __init__(self, resource_id): def __init__(self, resource_id: str):
super().__init__("InvalidID", f"The ID '{resource_id}' is not valid") super().__init__("InvalidID", f"The ID '{resource_id}' is not valid")
@ -532,7 +532,7 @@ class FilterNotImplementedError(MotoNotImplementedError):
class CidrLimitExceeded(EC2ClientError): class CidrLimitExceeded(EC2ClientError):
def __init__(self, vpc_id, max_cidr_limit): def __init__(self, vpc_id: str, max_cidr_limit: int):
super().__init__( super().__init__(
"CidrLimitExceeded", "CidrLimitExceeded",
f"This network '{vpc_id}' has met its maximum number of allowed CIDRs: {max_cidr_limit}", f"This network '{vpc_id}' has met its maximum number of allowed CIDRs: {max_cidr_limit}",
@ -540,14 +540,14 @@ class CidrLimitExceeded(EC2ClientError):
class UnsupportedTenancy(EC2ClientError): class UnsupportedTenancy(EC2ClientError):
def __init__(self, tenancy): def __init__(self, tenancy: str):
super().__init__( super().__init__(
"UnsupportedTenancy", f"The tenancy value {tenancy} is not supported." "UnsupportedTenancy", f"The tenancy value {tenancy} is not supported."
) )
class OperationNotPermitted(EC2ClientError): class OperationNotPermitted(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id: str):
super().__init__( super().__init__(
"OperationNotPermitted", "OperationNotPermitted",
f"The vpc CIDR block with association ID {association_id} may not be disassociated. It is the primary IPv4 CIDR block of the VPC", f"The vpc CIDR block with association ID {association_id} may not be disassociated. It is the primary IPv4 CIDR block of the VPC",
@ -611,13 +611,13 @@ class InvalidSubnetConflictError(EC2ClientError):
class InvalidVPCRangeError(EC2ClientError): class InvalidVPCRangeError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block: str):
super().__init__("InvalidVpc.Range", f"The CIDR '{cidr_block}' is invalid.") super().__init__("InvalidVpc.Range", f"The CIDR '{cidr_block}' is invalid.")
# accept exception # accept exception
class OperationNotPermitted2(EC2ClientError): class OperationNotPermitted2(EC2ClientError):
def __init__(self, client_region, pcx_id, acceptor_region): def __init__(self, client_region: str, pcx_id: str, acceptor_region: str):
super().__init__( super().__init__(
"OperationNotPermitted", "OperationNotPermitted",
f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted in region {acceptor_region}", f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted in region {acceptor_region}",
@ -626,7 +626,7 @@ class OperationNotPermitted2(EC2ClientError):
# reject exception # reject exception
class OperationNotPermitted3(EC2ClientError): class OperationNotPermitted3(EC2ClientError):
def __init__(self, client_region, pcx_id, acceptor_region): def __init__(self, client_region: str, pcx_id: str, acceptor_region: str):
super().__init__( super().__init__(
"OperationNotPermitted", "OperationNotPermitted",
f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted or rejected in region {acceptor_region}", f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted or rejected in region {acceptor_region}",
@ -690,7 +690,7 @@ class InvalidAssociationIDIamProfileAssociationError(EC2ClientError):
class InvalidVpcEndPointIdError(EC2ClientError): class InvalidVpcEndPointIdError(EC2ClientError):
def __init__(self, vpc_end_point_id): def __init__(self, vpc_end_point_id: str):
super().__init__( super().__init__(
"InvalidVpcEndpointId.NotFound", "InvalidVpcEndpointId.NotFound",
f"The VpcEndPoint ID '{vpc_end_point_id}' does not exist", f"The VpcEndPoint ID '{vpc_end_point_id}' does not exist",
@ -730,7 +730,7 @@ class InvalidCarrierGatewayID(EC2ClientError):
class NoLoadBalancersProvided(EC2ClientError): class NoLoadBalancersProvided(EC2ClientError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidParameter", "InvalidParameter",
"exactly one of network_load_balancer_arn or gateway_load_balancer_arn is a required member", "exactly one of network_load_balancer_arn or gateway_load_balancer_arn is a required member",
@ -738,7 +738,7 @@ class NoLoadBalancersProvided(EC2ClientError):
class UnknownVpcEndpointService(EC2ClientError): class UnknownVpcEndpointService(EC2ClientError):
def __init__(self, service_id): def __init__(self, service_id: str):
super().__init__( super().__init__(
"InvalidVpcEndpointServiceId.NotFound", "InvalidVpcEndpointServiceId.NotFound",
f"The VpcEndpointService Id '{service_id}' does not exist", f"The VpcEndpointService Id '{service_id}' does not exist",

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict from moto.core import BaseBackend, BackendDict
from ..exceptions import ( from ..exceptions import (
EC2ClientError, EC2ClientError,
@ -49,7 +50,7 @@ from ..utils import (
) )
def validate_resource_ids(resource_ids): def validate_resource_ids(resource_ids: List[str]) -> bool:
if not resource_ids: if not resource_ids:
raise MissingParameterError(parameter="resourceIdSet") raise MissingParameterError(parameter="resourceIdSet")
for resource_id in resource_ids: for resource_id in resource_ids:
@ -59,19 +60,19 @@ def validate_resource_ids(resource_ids):
class SettingsBackend: class SettingsBackend:
def __init__(self): def __init__(self) -> None:
self.ebs_encryption_by_default = False self.ebs_encryption_by_default = False
def disable_ebs_encryption_by_default(self): def disable_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
ec2_backend.ebs_encryption_by_default = False ec2_backend.ebs_encryption_by_default = False
def enable_ebs_encryption_by_default(self): def enable_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
ec2_backend.ebs_encryption_by_default = True ec2_backend.ebs_encryption_by_default = True
def get_ebs_encryption_by_default(self): def get_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
return ec2_backend.ebs_encryption_by_default return ec2_backend.ebs_encryption_by_default
@ -129,11 +130,11 @@ class EC2Backend(
""" """
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
BaseBackend.__init__(self, region_name, account_id) BaseBackend.__init__(self, region_name, account_id)
for backend in EC2Backend.__mro__: for backend in EC2Backend.__mro__:
if backend not in [EC2Backend, BaseBackend, object]: if backend not in [EC2Backend, BaseBackend, object]:
backend.__init__(self) backend.__init__(self) # type: ignore
# Default VPC exists by default, which is the current behavior # Default VPC exists by default, which is the current behavior
# of EC2-VPC. See for detail: # of EC2-VPC. See for detail:
@ -145,23 +146,23 @@ class EC2Backend(
else: else:
# For now this is included for potential # For now this is included for potential
# backward-compatibility issues # backward-compatibility issues
vpc = self.vpcs.values()[0] vpc = list(self.vpcs.values())[0]
self.default_vpc = vpc self.default_vpc = vpc
# Create default subnet for each availability zone # Create default subnet for each availability zone
ip, _ = vpc.cidr_block.split("/") ip, _ = vpc.cidr_block.split("/")
ip = ip.split(".") ip = ip.split(".") # type: ignore
ip[2] = 0 ip[2] = 0 # type: ignore
for zone in self.describe_availability_zones(): for zone in self.describe_availability_zones():
az_name = zone.name az_name = zone.name
cidr_block = ".".join(str(i) for i in ip) + "/20" cidr_block = ".".join(str(i) for i in ip) + "/20"
self.create_subnet(vpc.id, cidr_block, availability_zone=az_name) self.create_subnet(vpc.id, cidr_block, availability_zone=az_name)
ip[2] += 16 ip[2] += 16 # type: ignore
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "ec2" service_region, zones, "ec2"
@ -171,13 +172,13 @@ class EC2Backend(
# Use this to generate a proper error template response when in a response # Use this to generate a proper error template response when in a response
# handler. # handler.
def raise_error(self, code, message): def raise_error(self, code: str, message: str) -> None:
raise EC2ClientError(code, message) raise EC2ClientError(code, message)
def raise_not_implemented_error(self, blurb: str): def raise_not_implemented_error(self, blurb: str) -> None:
raise MotoNotImplementedError(blurb) raise MotoNotImplementedError(blurb)
def do_resources_exist(self, resource_ids): def do_resources_exist(self, resource_ids: List[str]) -> bool:
for resource_id in resource_ids: for resource_id in resource_ids:
resource_prefix = get_prefix(resource_id) resource_prefix = get_prefix(resource_id)
if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]: if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]:

View File

@ -1,5 +1,6 @@
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from ..exceptions import ( from ..exceptions import (
InvalidVPCPeeringConnectionIdError, InvalidVPCPeeringConnectionIdError,
@ -8,6 +9,7 @@ from ..exceptions import (
OperationNotPermitted3, OperationNotPermitted3,
) )
from .core import TaggedEC2Resource from .core import TaggedEC2Resource
from .vpcs import VPC
from ..utils import random_vpc_peering_connection_id from ..utils import random_vpc_peering_connection_id
@ -44,7 +46,14 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
"AllowDnsResolutionFromRemoteVpc": "false", "AllowDnsResolutionFromRemoteVpc": "false",
} }
def __init__(self, backend, vpc_pcx_id, vpc, peer_vpc, tags=None): def __init__(
self,
backend: Any,
vpc_pcx_id: str,
vpc: VPC,
peer_vpc: VPC,
tags: Optional[Dict[str, str]] = None,
):
self.id = vpc_pcx_id self.id = vpc_pcx_id
self.ec2_backend = backend self.ec2_backend = backend
self.vpc = vpc self.vpc = vpc
@ -55,18 +64,23 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
self._status = PeeringConnectionStatus() self._status = PeeringConnectionStatus()
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcpeeringconnection.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcpeeringconnection.html
return "AWS::EC2::VPCPeeringConnection" return "AWS::EC2::VPCPeeringConnection"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VPCPeeringConnection":
from ..models import ec2_backends from ..models import ec2_backends
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -80,26 +94,28 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
return vpc_pcx return vpc_pcx
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
class VPCPeeringConnectionBackend: class VPCPeeringConnectionBackend:
# for cross region vpc reference # for cross region vpc reference
vpc_pcx_refs = defaultdict(set) vpc_pcx_refs = defaultdict(set) # type: ignore
def __init__(self): def __init__(self) -> None:
self.vpc_pcxs = {} self.vpc_pcxs: Dict[str, VPCPeeringConnection] = {}
self.vpc_pcx_refs[self.__class__].add(weakref.ref(self)) self.vpc_pcx_refs[self.__class__].add(weakref.ref(self))
@classmethod @classmethod
def get_vpc_pcx_refs(cls): def get_vpc_pcx_refs(cls) -> Iterator[VPCPeeringConnection]:
for inst_ref in cls.vpc_pcx_refs[cls]: for inst_ref in cls.vpc_pcx_refs[cls]:
inst = inst_ref() inst = inst_ref()
if inst is not None: if inst is not None:
yield inst yield inst
def create_vpc_peering_connection(self, vpc, peer_vpc, tags=None): def create_vpc_peering_connection(
self, vpc: VPC, peer_vpc: VPC, tags: Optional[Dict[str, str]] = None
) -> VPCPeeringConnection:
vpc_pcx_id = random_vpc_peering_connection_id() vpc_pcx_id = random_vpc_peering_connection_id()
vpc_pcx = VPCPeeringConnection(self, vpc_pcx_id, vpc, peer_vpc, tags) vpc_pcx = VPCPeeringConnection(self, vpc_pcx_id, vpc, peer_vpc, tags)
vpc_pcx._status.pending() vpc_pcx._status.pending()
@ -111,49 +127,54 @@ class VPCPeeringConnectionBackend:
vpc_pcx_cx.vpc_pcxs[vpc_pcx_id] = vpc_pcx vpc_pcx_cx.vpc_pcxs[vpc_pcx_id] = vpc_pcx
return vpc_pcx return vpc_pcx
def describe_vpc_peering_connections(self, vpc_peering_ids=None): def describe_vpc_peering_connections(
all_pcxs = self.vpc_pcxs.copy().values() self, vpc_peering_ids: Optional[List[str]] = None
) -> List[VPCPeeringConnection]:
all_pcxs = list(self.vpc_pcxs.values())
if vpc_peering_ids: if vpc_peering_ids:
return [pcx for pcx in all_pcxs if pcx.id in vpc_peering_ids] return [pcx for pcx in all_pcxs if pcx.id in vpc_peering_ids]
return all_pcxs return all_pcxs
def get_vpc_peering_connection(self, vpc_pcx_id): def get_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
if vpc_pcx_id not in self.vpc_pcxs: if vpc_pcx_id not in self.vpc_pcxs:
raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id) raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id)
return self.vpc_pcxs.get(vpc_pcx_id) return self.vpc_pcxs[vpc_pcx_id]
def delete_vpc_peering_connection(self, vpc_pcx_id): def delete_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
deleted = self.get_vpc_peering_connection(vpc_pcx_id) deleted = self.get_vpc_peering_connection(vpc_pcx_id)
deleted._status.deleted() deleted._status.deleted()
return deleted return deleted
def accept_vpc_peering_connection(self, vpc_pcx_id): def accept_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
# if cross region need accepter from another region # if cross region need accepter from another region
pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name
pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined]
raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined]
if vpc_pcx._status.code != "pending-acceptance": if vpc_pcx._status.code != "pending-acceptance":
raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id)
vpc_pcx._status.accept() vpc_pcx._status.accept()
return vpc_pcx return vpc_pcx
def reject_vpc_peering_connection(self, vpc_pcx_id): def reject_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
# if cross region need accepter from another region # if cross region need accepter from another region
pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name
pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined]
raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined]
if vpc_pcx._status.code != "pending-acceptance": if vpc_pcx._status.code != "pending-acceptance":
raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id)
vpc_pcx._status.reject() vpc_pcx._status.reject()
return vpc_pcx return vpc_pcx
def modify_vpc_peering_connection_options( def modify_vpc_peering_connection_options(
self, vpc_pcx_id, accepter_options=None, requester_options=None self,
): vpc_pcx_id: str,
accepter_options: Optional[Dict[str, Any]] = None,
requester_options: Optional[Dict[str, Any]] = None,
) -> None:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id) vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
if not vpc_pcx: if not vpc_pcx:
raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id) raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from .core import TaggedEC2Resource from .core import TaggedEC2Resource
@ -6,7 +7,12 @@ from ..exceptions import UnknownVpcEndpointService
class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel): class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel):
def __init__( def __init__(
self, load_balancers, region, acceptance_required, private_dns_name, ec2_backend self,
load_balancers: List[Any],
region: str,
acceptance_required: bool,
private_dns_name: str,
ec2_backend: Any,
): ):
self.id = f"vpce-svc-{mock_random.get_random_hex(length=8)}" self.id = f"vpce-svc-{mock_random.get_random_hex(length=8)}"
self.service_name = f"com.amazonaws.vpce.{region}.{self.id}" self.service_name = f"com.amazonaws.vpce.{region}.{self.id}"
@ -32,43 +38,49 @@ class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel):
self.private_dns_name = private_dns_name self.private_dns_name = private_dns_name
self.endpoint_dns_name = f"{self.id}.{region}.vpce.amazonaws.com" self.endpoint_dns_name = f"{self.id}.{region}.vpce.amazonaws.com"
self.principals = [] self.principals: List[str] = []
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
class VPCServiceConfigurationBackend: class VPCServiceConfigurationBackend:
def __init__(self): def __init__(self) -> None:
self.configurations = {} self.configurations: Dict[str, VPCServiceConfiguration] = {}
@property @property
def elbv2_backend(self): def elbv2_backend(self) -> Any: # type: ignore[misc]
from moto.elbv2.models import elbv2_backends from moto.elbv2.models import elbv2_backends
return elbv2_backends[self.account_id][self.region_name] return elbv2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
def get_vpc_endpoint_service(self, resource_id): def get_vpc_endpoint_service(
self, resource_id: str
) -> Optional[VPCServiceConfiguration]:
return self.configurations.get(resource_id) return self.configurations.get(resource_id)
def create_vpc_endpoint_service_configuration( def create_vpc_endpoint_service_configuration(
self, lb_arns, acceptance_required, private_dns_name, tags self,
): lb_arns: List[Any],
acceptance_required: bool,
private_dns_name: str,
tags: List[Dict[str, str]],
) -> VPCServiceConfiguration:
lbs = self.elbv2_backend.describe_load_balancers(arns=lb_arns, names=None) lbs = self.elbv2_backend.describe_load_balancers(arns=lb_arns, names=None)
config = VPCServiceConfiguration( config = VPCServiceConfiguration(
load_balancers=lbs, load_balancers=lbs,
region=self.region_name, region=self.region_name, # type: ignore[attr-defined]
acceptance_required=acceptance_required, acceptance_required=acceptance_required,
private_dns_name=private_dns_name, private_dns_name=private_dns_name,
ec2_backend=self, ec2_backend=self,
) )
for tag in tags or []: for tag in tags or []:
tag_key = tag.get("Key") config.add_tag(tag["Key"], tag["Value"])
tag_value = tag.get("Value")
config.add_tag(tag_key, tag_value)
self.configurations[config.id] = config self.configurations[config.id] = config
return config return config
def describe_vpc_endpoint_service_configurations(self, service_ids): def describe_vpc_endpoint_service_configurations(
self, service_ids: Optional[List[str]]
) -> List[VPCServiceConfiguration]:
""" """
The Filters, MaxResults, NextToken parameters are not yet implemented The Filters, MaxResults, NextToken parameters are not yet implemented
""" """
@ -80,15 +92,17 @@ class VPCServiceConfigurationBackend:
else: else:
raise UnknownVpcEndpointService(service_id) raise UnknownVpcEndpointService(service_id)
return found_configs return found_configs
return self.configurations.copy().values() return list(self.configurations.values())
def delete_vpc_endpoint_service_configurations(self, service_ids): def delete_vpc_endpoint_service_configurations(
self, service_ids: List[str]
) -> List[str]:
missing = [s for s in service_ids if s not in self.configurations] missing = [s for s in service_ids if s not in self.configurations]
for s in service_ids: for s in service_ids:
self.configurations.pop(s, None) self.configurations.pop(s, None)
return missing return missing
def describe_vpc_endpoint_service_permissions(self, service_id): def describe_vpc_endpoint_service_permissions(self, service_id: str) -> List[str]:
""" """
The Filters, MaxResults, NextToken parameters are not yet implemented The Filters, MaxResults, NextToken parameters are not yet implemented
""" """
@ -96,8 +110,8 @@ class VPCServiceConfigurationBackend:
return config.principals return config.principals
def modify_vpc_endpoint_service_permissions( def modify_vpc_endpoint_service_permissions(
self, service_id, add_principals, remove_principals self, service_id: str, add_principals: List[str], remove_principals: List[str]
): ) -> None:
config = self.describe_vpc_endpoint_service_configurations([service_id])[0] config = self.describe_vpc_endpoint_service_configurations([service_id])[0]
config.principals += add_principals config.principals += add_principals
config.principals = [p for p in config.principals if p not in remove_principals] config.principals = [p for p in config.principals if p not in remove_principals]
@ -105,14 +119,14 @@ class VPCServiceConfigurationBackend:
def modify_vpc_endpoint_service_configuration( def modify_vpc_endpoint_service_configuration(
self, self,
service_id, service_id: str,
acceptance_required, acceptance_required: Optional[str],
private_dns_name, private_dns_name: Optional[str],
add_network_lbs, add_network_lbs: List[str],
remove_network_lbs, remove_network_lbs: List[str],
add_gateway_lbs, add_gateway_lbs: List[str],
remove_gateway_lbs, remove_gateway_lbs: List[str],
): ) -> None:
""" """
The following parameters are not yet implemented: RemovePrivateDnsName The following parameters are not yet implemented: RemovePrivateDnsName
""" """

View File

@ -2,6 +2,7 @@ import ipaddress
import json import json
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional
from operator import itemgetter from operator import itemgetter
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
@ -35,27 +36,27 @@ from ..utils import (
) )
MAX_NUMBER_OF_ENDPOINT_SERVICES_RESULTS = 1000 MAX_NUMBER_OF_ENDPOINT_SERVICES_RESULTS = 1000
DEFAULT_VPC_ENDPOINT_SERVICES = [] DEFAULT_VPC_ENDPOINT_SERVICES: List[Dict[str, str]] = []
class VPCEndPoint(TaggedEC2Resource, CloudFormationModel): class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
def __init__( def __init__(
self, self,
ec2_backend, ec2_backend: Any,
endpoint_id, endpoint_id: str,
vpc_id, vpc_id: str,
service_name, service_name: str,
endpoint_type=None, endpoint_type: Optional[str],
policy_document=False, policy_document: Optional[str],
route_table_ids=None, route_table_ids: List[str],
subnet_ids=None, subnet_ids: Optional[List[str]] = None,
network_interface_ids=None, network_interface_ids: Optional[List[str]] = None,
dns_entries=None, dns_entries: Optional[List[Dict[str, str]]] = None,
client_token=None, client_token: Optional[str] = None,
security_group_ids=None, security_group_ids: Optional[List[str]] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
private_dns_enabled=None, private_dns_enabled: Optional[str] = None,
destination_prefix_list_id=None, destination_prefix_list_id: Optional[str] = None,
): ):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = endpoint_id self.id = endpoint_id
@ -76,11 +77,17 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
self.created_at = utc_date_and_time() self.created_at = utc_date_and_time()
def modify(self, policy_doc, add_subnets, add_route_tables, remove_route_tables): def modify(
self,
policy_doc: Optional[str],
add_subnets: Optional[List[str]],
add_route_tables: Optional[List[str]],
remove_route_tables: Optional[List[str]],
) -> None:
if policy_doc: if policy_doc:
self.policy_document = policy_doc self.policy_document = policy_doc
if add_subnets: if add_subnets:
self.subnet_ids.extend(add_subnets) self.subnet_ids.extend(add_subnets) # type: ignore[union-attr]
if add_route_tables: if add_route_tables:
self.route_table_ids.extend(add_route_tables) self.route_table_ids.extend(add_route_tables)
if remove_route_tables: if remove_route_tables:
@ -90,32 +97,39 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
if rt_id not in remove_route_tables if rt_id not in remove_route_tables
] ]
def get_filter_value(self, filter_name): def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name in ("vpc-endpoint-type", "vpc_endpoint_type"): if filter_name in ("vpc-endpoint-type", "vpc_endpoint_type"):
return self.endpoint_type return self.endpoint_type
else: else:
return super().get_filter_value(filter_name, "DescribeVpcs") return super().get_filter_value(filter_name, "DescribeVpcs")
@property @property
def owner_id(self): def owner_id(self) -> str:
return self.ec2_backend.account_id return self.ec2_backend.account_id
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
return "AWS::EC2::VPCEndpoint" return "AWS::EC2::VPCEndpoint"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "VPCEndPoint":
from ..models import ec2_backends from ..models import ec2_backends
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -146,19 +160,19 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
class VPC(TaggedEC2Resource, CloudFormationModel): class VPC(TaggedEC2Resource, CloudFormationModel):
def __init__( def __init__(
self, self,
ec2_backend, ec2_backend: Any,
vpc_id, vpc_id: str,
cidr_block, cidr_block: str,
is_default, is_default: bool,
instance_tenancy="default", instance_tenancy: str = "default",
amazon_provided_ipv6_cidr_block=False, amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group=None, ipv6_cidr_block_network_border_group: Optional[str] = None,
): ):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = vpc_id self.id = vpc_id
self.cidr_block = cidr_block self.cidr_block = cidr_block
self.cidr_block_association_set = {} self.cidr_block_association_set: Dict[str, Any] = {}
self.dhcp_options = None self.dhcp_options = None
self.state = "available" self.state = "available"
self.instance_tenancy = instance_tenancy self.instance_tenancy = instance_tenancy
@ -180,22 +194,27 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
) )
@property @property
def owner_id(self): def owner_id(self) -> str:
return self.ec2_backend.account_id return self.ec2_backend.account_id
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpc.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpc.html
return "AWS::EC2::VPC" return "AWS::EC2::VPC"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "VPC":
from ..models import ec2_backends from ..models import ec2_backends
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -213,10 +232,12 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
return vpc return vpc
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
def get_filter_value(self, filter_name): def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name in ("vpc-id", "vpcId"): if filter_name in ("vpc-id", "vpcId"):
return self.id return self.id
elif filter_name in ("cidr", "cidr-block", "cidrBlock"): elif filter_name in ("cidr", "cidr-block", "cidrBlock"):
@ -255,23 +276,22 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
else: else:
return super().get_filter_value(filter_name, "DescribeVpcs") return super().get_filter_value(filter_name, "DescribeVpcs")
def modify_vpc_tenancy(self, tenancy): def modify_vpc_tenancy(self, tenancy: str) -> None:
if tenancy != "default": if tenancy != "default":
raise UnsupportedTenancy(tenancy) raise UnsupportedTenancy(tenancy)
self.instance_tenancy = tenancy self.instance_tenancy = tenancy
return True
def associate_vpc_cidr_block( def associate_vpc_cidr_block(
self, self,
cidr_block, cidr_block: str,
amazon_provided_ipv6_cidr_block=False, amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group=None, ipv6_cidr_block_network_border_group: Optional[str] = None,
): ) -> Dict[str, Any]:
max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1 max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1
for cidr in self.cidr_block_association_set.copy(): for cidr in self.cidr_block_association_set.copy():
if ( if (
self.cidr_block_association_set.get(cidr) self.cidr_block_association_set.get(cidr) # type: ignore[union-attr]
.get("cidr_block_state") .get("cidr_block_state")
.get("state") .get("state")
== "disassociated" == "disassociated"
@ -285,7 +305,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
association_id = random_vpc_cidr_association_id() association_id = random_vpc_cidr_association_id()
association_set = { association_set: Dict[str, Any] = {
"association_id": association_id, "association_id": association_id,
"cidr_block_state": {"state": "associated", "StatusMessage": ""}, "cidr_block_state": {"state": "associated", "StatusMessage": ""},
} }
@ -301,7 +321,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
self.cidr_block_association_set[association_id] = association_set self.cidr_block_association_set[association_id] = association_set
return association_set return association_set
def enable_vpc_classic_link(self): def enable_vpc_classic_link(self) -> str:
# Check if current cidr block doesn't fall within the 10.0.0.0/8 block, excluding 10.0.0.0/16 and 10.1.0.0/16. # Check if current cidr block doesn't fall within the 10.0.0.0/8 block, excluding 10.0.0.0/16 and 10.1.0.0/16.
# Doesn't check any route tables, maybe something for in the future? # Doesn't check any route tables, maybe something for in the future?
# See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/vpc-classiclink.html#classiclink-limitations # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/vpc-classiclink.html#classiclink-limitations
@ -315,19 +335,19 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
return self.classic_link_enabled return self.classic_link_enabled
def disable_vpc_classic_link(self): def disable_vpc_classic_link(self) -> str:
self.classic_link_enabled = "false" self.classic_link_enabled = "false"
return self.classic_link_enabled return self.classic_link_enabled
def enable_vpc_classic_link_dns_support(self): def enable_vpc_classic_link_dns_support(self) -> str:
self.classic_link_dns_supported = "true" self.classic_link_dns_supported = "true"
return self.classic_link_dns_supported return self.classic_link_dns_supported
def disable_vpc_classic_link_dns_support(self): def disable_vpc_classic_link_dns_support(self) -> str:
self.classic_link_dns_supported = "false" self.classic_link_dns_supported = "false"
return self.classic_link_dns_supported return self.classic_link_dns_supported
def disassociate_vpc_cidr_block(self, association_id): def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]:
if self.cidr_block == self.cidr_block_association_set.get( if self.cidr_block == self.cidr_block_association_set.get(
association_id, {} association_id, {}
).get("cidr_block"): ).get("cidr_block"):
@ -341,7 +361,9 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
entry["cidr_block_state"]["state"] = "disassociated" entry["cidr_block_state"]["state"] = "disassociated"
return response return response
def get_cidr_block_association_set(self, ipv6=False): def get_cidr_block_association_set(
self, ipv6: bool = False
) -> List[Dict[str, Any]]:
return [ return [
c c
for c in self.cidr_block_association_set.values() for c in self.cidr_block_association_set.values()
@ -350,14 +372,14 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
class VPCBackend: class VPCBackend:
vpc_refs = defaultdict(set) vpc_refs = defaultdict(set) # type: ignore
def __init__(self): def __init__(self) -> None:
self.vpcs = {} self.vpcs: Dict[str, VPC] = {}
self.vpc_end_points = {} self.vpc_end_points: Dict[str, VPCEndPoint] = {}
self.vpc_refs[self.__class__].add(weakref.ref(self)) self.vpc_refs[self.__class__].add(weakref.ref(self))
def create_default_vpc(self): def create_default_vpc(self) -> VPC:
default_vpc = self.describe_vpcs(filters={"is-default": "true"}) default_vpc = self.describe_vpcs(filters={"is-default": "true"})
if default_vpc: if default_vpc:
raise DefaultVpcAlreadyExists raise DefaultVpcAlreadyExists
@ -366,13 +388,13 @@ class VPCBackend:
def create_vpc( def create_vpc(
self, self,
cidr_block, cidr_block: str,
instance_tenancy="default", instance_tenancy: str = "default",
amazon_provided_ipv6_cidr_block=False, amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group=None, ipv6_cidr_block_network_border_group: Optional[str] = None,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
is_default=False, is_default: bool = False,
): ) -> VPC:
vpc_id = random_vpc_id() vpc_id = random_vpc_id()
try: try:
vpc_cidr_block = ipaddress.IPv4Network(str(cidr_block), strict=False) vpc_cidr_block = ipaddress.IPv4Network(str(cidr_block), strict=False)
@ -391,45 +413,45 @@ class VPCBackend:
) )
for tag in tags or []: for tag in tags or []:
tag_key = tag.get("Key") vpc.add_tag(tag["Key"], tag["Value"])
tag_value = tag.get("Value")
vpc.add_tag(tag_key, tag_value)
self.vpcs[vpc_id] = vpc self.vpcs[vpc_id] = vpc
# AWS creates a default main route table and security group. # AWS creates a default main route table and security group.
self.create_route_table(vpc_id, main=True) self.create_route_table(vpc_id, main=True) # type: ignore[attr-defined]
# AWS creates a default Network ACL # AWS creates a default Network ACL
self.create_network_acl(vpc_id, default=True) self.create_network_acl(vpc_id, default=True) # type: ignore[attr-defined]
default = self.get_security_group_from_name("default", vpc_id=vpc_id) default = self.get_security_group_from_name("default", vpc_id=vpc_id) # type: ignore[attr-defined]
if not default: if not default:
self.create_security_group( self.create_security_group( # type: ignore[attr-defined]
"default", "default VPC security group", vpc_id=vpc_id, is_default=True "default", "default VPC security group", vpc_id=vpc_id, is_default=True
) )
return vpc return vpc
def get_vpc(self, vpc_id): def get_vpc(self, vpc_id: str) -> VPC:
if vpc_id not in self.vpcs: if vpc_id not in self.vpcs:
raise InvalidVPCIdError(vpc_id) raise InvalidVPCIdError(vpc_id)
return self.vpcs.get(vpc_id) return self.vpcs[vpc_id]
def describe_vpcs(self, vpc_ids=None, filters=None): def describe_vpcs(
matches = self.vpcs.copy().values() self, vpc_ids: Optional[List[str]] = None, filters: Any = None
) -> List[VPC]:
matches = list(self.vpcs.values())
if vpc_ids: if vpc_ids:
matches = [vpc for vpc in matches if vpc.id in vpc_ids] matches = [vpc for vpc in matches if vpc.id in vpc_ids]
if len(vpc_ids) > len(matches): if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches) unknown_ids = set(vpc_ids) - set(matches) # type: ignore[arg-type]
raise InvalidVPCIdError(unknown_ids) raise InvalidVPCIdError(unknown_ids)
if filters: if filters:
matches = generic_filter(filters, matches) matches = generic_filter(filters, matches)
return matches return matches
def delete_vpc(self, vpc_id): def delete_vpc(self, vpc_id: str) -> VPC:
# Do not delete if any VPN Gateway is attached # Do not delete if any VPN Gateway is attached
vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id}) vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id}) # type: ignore[attr-defined]
vpn_gateways = [ vpn_gateways = [
item item
for item in vpn_gateways for item in vpn_gateways
@ -441,18 +463,18 @@ class VPCBackend:
) )
# Delete route table if only main route table remains. # Delete route table if only main route table remains.
route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id}) route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id}) # type: ignore[attr-defined]
if len(route_tables) > 1: if len(route_tables) > 1:
raise DependencyViolationError( raise DependencyViolationError(
f"The vpc {vpc_id} has dependencies and cannot be deleted." f"The vpc {vpc_id} has dependencies and cannot be deleted."
) )
for route_table in route_tables: for route_table in route_tables:
self.delete_route_table(route_table.id) self.delete_route_table(route_table.id) # type: ignore[attr-defined]
# Delete default security group if exists. # Delete default security group if exists.
default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) # type: ignore[attr-defined]
if default: if default:
self.delete_security_group(group_id=default.id) self.delete_security_group(group_id=default.id) # type: ignore[attr-defined]
# Now delete VPC. # Now delete VPC.
vpc = self.vpcs.pop(vpc_id, None) vpc = self.vpcs.pop(vpc_id, None)
@ -465,7 +487,7 @@ class VPCBackend:
vpc.dhcp_options = None vpc.dhcp_options = None
return vpc return vpc
def describe_vpc_attribute(self, vpc_id, attr_name): def describe_vpc_attribute(self, vpc_id: str, attr_name: str) -> Any:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
if attr_name in ( if attr_name in (
"enable_dns_support", "enable_dns_support",
@ -476,27 +498,29 @@ class VPCBackend:
else: else:
raise InvalidParameterValueError(attr_name) raise InvalidParameterValueError(attr_name)
def modify_vpc_tenancy(self, vpc_id, tenancy): def modify_vpc_tenancy(self, vpc_id: str, tenancy: str) -> None:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
return vpc.modify_vpc_tenancy(tenancy) vpc.modify_vpc_tenancy(tenancy)
def enable_vpc_classic_link(self, vpc_id): def enable_vpc_classic_link(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
return vpc.enable_vpc_classic_link() return vpc.enable_vpc_classic_link()
def disable_vpc_classic_link(self, vpc_id): def disable_vpc_classic_link(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
return vpc.disable_vpc_classic_link() return vpc.disable_vpc_classic_link()
def enable_vpc_classic_link_dns_support(self, vpc_id): def enable_vpc_classic_link_dns_support(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
return vpc.enable_vpc_classic_link_dns_support() return vpc.enable_vpc_classic_link_dns_support()
def disable_vpc_classic_link_dns_support(self, vpc_id): def disable_vpc_classic_link_dns_support(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
return vpc.disable_vpc_classic_link_dns_support() return vpc.disable_vpc_classic_link_dns_support()
def modify_vpc_attribute(self, vpc_id, attr_name, attr_value): def modify_vpc_attribute(
self, vpc_id: str, attr_name: str, attr_value: str
) -> None:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
if attr_name in ( if attr_name in (
"enable_dns_support", "enable_dns_support",
@ -507,58 +531,58 @@ class VPCBackend:
else: else:
raise InvalidParameterValueError(attr_name) raise InvalidParameterValueError(attr_name)
def disassociate_vpc_cidr_block(self, association_id): def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]:
for vpc in self.vpcs.copy().values(): for vpc in self.vpcs.copy().values():
response = vpc.disassociate_vpc_cidr_block(association_id) response = vpc.disassociate_vpc_cidr_block(association_id)
for route_table in self.route_tables.copy().values(): for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined]
if route_table.vpc_id == response.get("vpc_id"): if route_table.vpc_id == response.get("vpc_id"):
if "::/" in response.get("cidr_block"): if "::/" in response.get("cidr_block"): # type: ignore[operator]
self.delete_route( self.delete_route( # type: ignore[attr-defined]
route_table.id, None, response.get("cidr_block") route_table.id, None, response.get("cidr_block")
) )
else: else:
self.delete_route(route_table.id, response.get("cidr_block")) self.delete_route(route_table.id, response.get("cidr_block")) # type: ignore[attr-defined]
if response: if response:
return response return response
raise InvalidVpcCidrBlockAssociationIdError(association_id) raise InvalidVpcCidrBlockAssociationIdError(association_id)
def associate_vpc_cidr_block( def associate_vpc_cidr_block(
self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block self, vpc_id: str, cidr_block: str, amazon_provided_ipv6_cidr_block: bool
): ) -> Dict[str, Any]:
vpc = self.get_vpc(vpc_id) vpc = self.get_vpc(vpc_id)
association_set = vpc.associate_vpc_cidr_block( association_set = vpc.associate_vpc_cidr_block(
cidr_block, amazon_provided_ipv6_cidr_block cidr_block, amazon_provided_ipv6_cidr_block
) )
for route_table in self.route_tables.copy().values(): for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined]
if route_table.vpc_id == vpc_id: if route_table.vpc_id == vpc_id:
if amazon_provided_ipv6_cidr_block: if amazon_provided_ipv6_cidr_block:
self.create_route( self.create_route( # type: ignore[attr-defined]
route_table.id, route_table.id,
None, None,
destination_ipv6_cidr_block=association_set["cidr_block"], destination_ipv6_cidr_block=association_set["cidr_block"],
local=True, local=True,
) )
else: else:
self.create_route( self.create_route( # type: ignore[attr-defined]
route_table.id, association_set["cidr_block"], local=True route_table.id, association_set["cidr_block"], local=True
) )
return association_set return association_set
def create_vpc_endpoint( def create_vpc_endpoint(
self, self,
vpc_id, vpc_id: str,
service_name, service_name: str,
endpoint_type=None, endpoint_type: Optional[str],
policy_document=False, policy_document: Optional[str],
route_table_ids=None, route_table_ids: List[str],
subnet_ids=None, subnet_ids: Optional[List[str]] = None,
network_interface_ids=None, network_interface_ids: Optional[List[str]] = None,
dns_entries=None, dns_entries: Optional[Dict[str, str]] = None,
client_token=None, client_token: Optional[str] = None,
security_group_ids=None, security_group_ids: Optional[List[str]] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
private_dns_enabled=None, private_dns_enabled: Optional[str] = None,
): ) -> VPCEndPoint:
vpc_endpoint_id = random_vpc_ep_id() vpc_endpoint_id = random_vpc_ep_id()
@ -570,21 +594,18 @@ class VPCBackend:
network_interface_ids = [] network_interface_ids = []
for subnet_id in subnet_ids or []: for subnet_id in subnet_ids or []:
self.get_subnet(subnet_id) self.get_subnet(subnet_id) # type: ignore[attr-defined]
eni = self.create_network_interface(subnet_id, random_private_ip()) eni = self.create_network_interface(subnet_id, random_private_ip()) # type: ignore[attr-defined]
network_interface_ids.append(eni.id) network_interface_ids.append(eni.id)
dns_entries = create_dns_entries(service_name, vpc_endpoint_id) dns_entries = create_dns_entries(service_name, vpc_endpoint_id)
else: else:
# considering gateway if type is not mentioned. # considering gateway if type is not mentioned.
for prefix_list in self.managed_prefix_lists.values(): for prefix_list in self.managed_prefix_lists.values(): # type: ignore[attr-defined]
if prefix_list.prefix_list_name == service_name: if prefix_list.prefix_list_name == service_name:
destination_prefix_list_id = prefix_list.id destination_prefix_list_id = prefix_list.id
if dns_entries:
dns_entries = [dns_entries]
vpc_end_point = VPCEndPoint( vpc_end_point = VPCEndPoint(
self, self,
vpc_endpoint_id, vpc_endpoint_id,
@ -595,19 +616,19 @@ class VPCBackend:
route_table_ids, route_table_ids,
subnet_ids, subnet_ids,
network_interface_ids, network_interface_ids,
dns_entries, dns_entries=[dns_entries] if dns_entries else None,
client_token, client_token=client_token,
security_group_ids, security_group_ids=security_group_ids,
tags, tags=tags,
private_dns_enabled, private_dns_enabled=private_dns_enabled,
destination_prefix_list_id, destination_prefix_list_id=destination_prefix_list_id,
) )
self.vpc_end_points[vpc_endpoint_id] = vpc_end_point self.vpc_end_points[vpc_endpoint_id] = vpc_end_point
if destination_prefix_list_id: if destination_prefix_list_id:
for route_table_id in route_table_ids: for route_table_id in route_table_ids:
self.create_route( self.create_route( # type: ignore[attr-defined]
route_table_id, route_table_id,
None, None,
gateway_id=vpc_endpoint_id, gateway_id=vpc_endpoint_id,
@ -617,28 +638,34 @@ class VPCBackend:
return vpc_end_point return vpc_end_point
def modify_vpc_endpoint( def modify_vpc_endpoint(
self, vpc_id, policy_doc, add_subnets, remove_route_tables, add_route_tables self,
): vpc_id: str,
policy_doc: str,
add_subnets: Optional[List[str]],
remove_route_tables: Optional[List[str]],
add_route_tables: Optional[List[str]],
) -> None:
endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0] endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0]
endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables) endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables)
def delete_vpc_endpoints(self, vpce_ids=None): def delete_vpc_endpoints(self, vpce_ids: Optional[List[str]] = None) -> None:
for vpce_id in vpce_ids or []: for vpce_id in vpce_ids or []:
vpc_endpoint = self.vpc_end_points.get(vpce_id, None) vpc_endpoint = self.vpc_end_points.get(vpce_id, None)
if vpc_endpoint: if vpc_endpoint:
if vpc_endpoint.endpoint_type.lower() == "interface": if vpc_endpoint.endpoint_type.lower() == "interface": # type: ignore[union-attr]
for eni_id in vpc_endpoint.network_interface_ids: for eni_id in vpc_endpoint.network_interface_ids:
self.enis.pop(eni_id, None) self.enis.pop(eni_id, None) # type: ignore[attr-defined]
else: else:
for route_table_id in vpc_endpoint.route_table_ids: for route_table_id in vpc_endpoint.route_table_ids:
self.delete_route( self.delete_route( # type: ignore[attr-defined]
route_table_id, vpc_endpoint.destination_prefix_list_id route_table_id, vpc_endpoint.destination_prefix_list_id
) )
vpc_endpoint.state = "deleted" vpc_endpoint.state = "deleted"
return True
def describe_vpc_endpoints(self, vpc_end_point_ids, filters=None): def describe_vpc_endpoints(
vpc_end_points = self.vpc_end_points.values() self, vpc_end_point_ids: Optional[List[str]], filters: Any = None
) -> List[VPCEndPoint]:
vpc_end_points = list(self.vpc_end_points.values())
if vpc_end_point_ids: if vpc_end_point_ids:
vpc_end_points = [ vpc_end_points = [
@ -657,7 +684,9 @@ class VPCBackend:
return generic_filter(filters, vpc_end_points) return generic_filter(filters, vpc_end_points)
@staticmethod @staticmethod
def _collect_default_endpoint_services(account_id, region): def _collect_default_endpoint_services(
account_id: str, region: str
) -> List[Dict[str, str]]:
"""Return list of default services using list of backends.""" """Return list of default services using list of backends."""
if DEFAULT_VPC_ENDPOINT_SERVICES: if DEFAULT_VPC_ENDPOINT_SERVICES:
return DEFAULT_VPC_ENDPOINT_SERVICES return DEFAULT_VPC_ENDPOINT_SERVICES
@ -672,14 +701,16 @@ class VPCBackend:
from moto import backends # pylint: disable=import-outside-toplevel from moto import backends # pylint: disable=import-outside-toplevel
for _backends in backends.service_backends(): for _backends in backends.service_backends():
_backends = _backends[account_id] account_backend = _backends[account_id]
if region in _backends: if region in account_backend:
service = _backends[region].default_vpc_endpoint_service(region, zones) service = account_backend[region].default_vpc_endpoint_service(
region, zones
)
if service: if service:
DEFAULT_VPC_ENDPOINT_SERVICES.extend(service) DEFAULT_VPC_ENDPOINT_SERVICES.extend(service)
if "global" in _backends: if "global" in account_backend:
service = _backends["global"].default_vpc_endpoint_service( service = account_backend["global"].default_vpc_endpoint_service(
region, zones region, zones
) )
if service: if service:
@ -687,7 +718,7 @@ class VPCBackend:
return DEFAULT_VPC_ENDPOINT_SERVICES return DEFAULT_VPC_ENDPOINT_SERVICES
@staticmethod @staticmethod
def _matches_service_by_tags(service, filter_item): def _matches_service_by_tags(service: Dict[str, Any], filter_item: Dict[str, Any]) -> bool: # type: ignore[misc]
"""Return True if service tags are not filtered by their tags. """Return True if service tags are not filtered by their tags.
Note that the API specifies a key of "Values" for a filter, but Note that the API specifies a key of "Values" for a filter, but
@ -719,7 +750,7 @@ class VPCBackend:
return matched return matched
@staticmethod @staticmethod
def _filter_endpoint_services(service_names_filters, filters, services): def _filter_endpoint_services(service_names_filters: List[str], filters: List[Dict[str, Any]], services: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Return filtered list of VPC endpoint services.""" """Return filtered list of VPC endpoint services."""
if not service_names_filters and not filters: if not service_names_filters and not filters:
return services return services
@ -774,11 +805,16 @@ class VPCBackend:
return filtered_services return filtered_services
def describe_vpc_endpoint_services( def describe_vpc_endpoint_services(
self, dry_run, service_names, filters, max_results, next_token, region self,
): # pylint: disable=unused-argument,too-many-arguments service_names: List[str],
filters: Any,
max_results: int,
next_token: Optional[str],
region: str,
) -> Dict[str, Any]: # pylint: disable=too-many-arguments
"""Return info on services to which you can create a VPC endpoint. """Return info on services to which you can create a VPC endpoint.
Currently only the default endpoing services are returned. When Currently only the default endpoint services are returned. When
create_vpc_endpoint_service_configuration() is implemented, a create_vpc_endpoint_service_configuration() is implemented, a
list of those private endpoints would be kept and when this API list of those private endpoints would be kept and when this API
is invoked, those private endpoints would be added to the list of is invoked, those private endpoints would be added to the list of
@ -787,7 +823,7 @@ class VPCBackend:
The DryRun parameter is ignored. The DryRun parameter is ignored.
""" """
default_services = self._collect_default_endpoint_services( default_services = self._collect_default_endpoint_services(
self.account_id, region self.account_id, region # type: ignore[attr-defined]
) )
for service_name in service_names: for service_name in service_names:
if service_name not in [x["ServiceName"] for x in default_services]: if service_name not in [x["ServiceName"] for x in default_services]:
@ -827,7 +863,7 @@ class VPCBackend:
"nextToken": next_token, "nextToken": next_token,
} }
def get_vpc_end_point(self, vpc_end_point_id): def get_vpc_end_point(self, vpc_end_point_id: str) -> VPCEndPoint:
vpc_end_point = self.vpc_end_points.get(vpc_end_point_id) vpc_end_point = self.vpc_end_points.get(vpc_end_point_id)
if not vpc_end_point: if not vpc_end_point:
raise InvalidVpcEndPointIdError(vpc_end_point_id) raise InvalidVpcEndPointIdError(vpc_end_point_id)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from .core import TaggedEC2Resource from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnConnectionIdError from ..exceptions import InvalidVpnConnectionIdError
from ..utils import generic_filter, random_vpn_connection_id from ..utils import generic_filter, random_vpn_connection_id
@ -6,18 +7,18 @@ from ..utils import generic_filter, random_vpn_connection_id
class VPNConnection(TaggedEC2Resource): class VPNConnection(TaggedEC2Resource):
def __init__( def __init__(
self, self,
ec2_backend, ec2_backend: Any,
vpn_connection_id, vpn_connection_id: str,
vpn_conn_type, vpn_conn_type: str,
customer_gateway_id, customer_gateway_id: str,
vpn_gateway_id=None, vpn_gateway_id: Optional[str] = None,
transit_gateway_id=None, transit_gateway_id: Optional[str] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
): ):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = vpn_connection_id self.id = vpn_connection_id
self.state = "available" self.state = "available"
self.customer_gateway_configuration = {} self.customer_gateway_configuration: Dict[str, str] = {}
self.type = vpn_conn_type self.type = vpn_conn_type
self.customer_gateway_id = customer_gateway_id self.customer_gateway_id = customer_gateway_id
self.vpn_gateway_id = vpn_gateway_id self.vpn_gateway_id = vpn_gateway_id
@ -27,23 +28,25 @@ class VPNConnection(TaggedEC2Resource):
self.static_routes = None self.static_routes = None
self.add_tags(tags or {}) self.add_tags(tags or {})
def get_filter_value(self, filter_name): def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
return super().get_filter_value(filter_name, "DescribeVpnConnections") return super().get_filter_value(filter_name, "DescribeVpnConnections")
class VPNConnectionBackend: class VPNConnectionBackend:
def __init__(self): def __init__(self) -> None:
self.vpn_connections = {} self.vpn_connections: Dict[str, VPNConnection] = {}
def create_vpn_connection( def create_vpn_connection(
self, self,
vpn_conn_type, vpn_conn_type: str,
customer_gateway_id, customer_gateway_id: str,
vpn_gateway_id=None, vpn_gateway_id: Optional[str] = None,
transit_gateway_id=None, transit_gateway_id: Optional[str] = None,
static_routes_only=None, static_routes_only: Optional[bool] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
): ) -> VPNConnection:
vpn_connection_id = random_vpn_connection_id() vpn_connection_id = random_vpn_connection_id()
if static_routes_only: if static_routes_only:
pass pass
@ -59,7 +62,7 @@ class VPNConnectionBackend:
self.vpn_connections[vpn_connection.id] = vpn_connection self.vpn_connections[vpn_connection.id] = vpn_connection
return vpn_connection return vpn_connection
def delete_vpn_connection(self, vpn_connection_id): def delete_vpn_connection(self, vpn_connection_id: str) -> VPNConnection:
if vpn_connection_id in self.vpn_connections: if vpn_connection_id in self.vpn_connections:
self.vpn_connections[vpn_connection_id].state = "deleted" self.vpn_connections[vpn_connection_id].state = "deleted"
@ -67,17 +70,10 @@ class VPNConnectionBackend:
raise InvalidVpnConnectionIdError(vpn_connection_id) raise InvalidVpnConnectionIdError(vpn_connection_id)
return self.vpn_connections[vpn_connection_id] return self.vpn_connections[vpn_connection_id]
def describe_vpn_connections(self, vpn_connection_ids=None): def describe_vpn_connections(
vpn_connections = [] self, vpn_connection_ids: Optional[List[str]] = None, filters: Any = None
for vpn_connection_id in vpn_connection_ids or []: ) -> List[VPNConnection]:
if vpn_connection_id in self.vpn_connections: vpn_connections = list(self.vpn_connections.values())
vpn_connections.append(self.vpn_connections[vpn_connection_id])
else:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return vpn_connections or self.vpn_connections.values()
def get_all_vpn_connections(self, vpn_connection_ids=None, filters=None):
vpn_connections = self.vpn_connections.values()
if vpn_connection_ids: if vpn_connection_ids:
vpn_connections = [ vpn_connections = [

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Any, Dict, List, Optional
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from .core import TaggedEC2Resource from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnGatewayIdError, InvalidVpnGatewayAttachmentError from ..exceptions import InvalidVpnGatewayIdError, InvalidVpnGatewayAttachmentError
@ -15,18 +15,23 @@ class VPCGatewayAttachment(CloudFormationModel):
self.state = state self.state = state
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html
return "AWS::EC2::VPCGatewayAttachment" return "AWS::EC2::VPCGatewayAttachment"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VPCGatewayAttachment":
from ..models import ec2_backends from ..models import ec2_backends
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -45,20 +50,20 @@ class VPCGatewayAttachment(CloudFormationModel):
return attachment return attachment
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.vpc_id return self.vpc_id
class VpnGateway(CloudFormationModel, TaggedEC2Resource): class VpnGateway(CloudFormationModel, TaggedEC2Resource):
def __init__( def __init__(
self, self,
ec2_backend, ec2_backend: Any,
gateway_id, gateway_id: str,
gateway_type, gateway_type: str,
amazon_side_asn, amazon_side_asn: Optional[str],
availability_zone, availability_zone: Optional[str],
tags=None, tags: Optional[Dict[str, str]] = None,
state="available", state: str = "available",
): ):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = gateway_id self.id = gateway_id
@ -67,22 +72,27 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
self.availability_zone = availability_zone self.availability_zone = availability_zone
self.state = state self.state = state
self.add_tags(tags or {}) self.add_tags(tags or {})
self.attachments = {} self.attachments: Dict[str, VPCGatewayAttachment] = {}
super().__init__() super().__init__()
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html
return "AWS::EC2::VPNGateway" return "AWS::EC2::VPNGateway"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VpnGateway":
from ..models import ec2_backends from ..models import ec2_backends
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -93,10 +103,12 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
return ec2_backend.create_vpn_gateway(gateway_type=_type, amazon_side_asn=asn) return ec2_backend.create_vpn_gateway(gateway_type=_type, amazon_side_asn=asn)
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
def get_filter_value(self, filter_name): def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name == "attachment.vpc-id": if filter_name == "attachment.vpc-id":
return self.attachments.keys() return self.attachments.keys()
elif filter_name == "attachment.state": elif filter_name == "attachment.state":
@ -109,16 +121,16 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
class VpnGatewayBackend: class VpnGatewayBackend:
def __init__(self): def __init__(self) -> None:
self.vpn_gateways = {} self.vpn_gateways: Dict[str, VpnGateway] = {}
def create_vpn_gateway( def create_vpn_gateway(
self, self,
gateway_type="ipsec.1", gateway_type: str = "ipsec.1",
amazon_side_asn=None, amazon_side_asn: Optional[str] = None,
availability_zone=None, availability_zone: Optional[str] = None,
tags=None, tags: Optional[Dict[str, str]] = None,
): ) -> VpnGateway:
vpn_gateway_id = random_vpn_gateway_id() vpn_gateway_id = random_vpn_gateway_id()
vpn_gateway = VpnGateway( vpn_gateway = VpnGateway(
self, vpn_gateway_id, gateway_type, amazon_side_asn, availability_zone, tags self, vpn_gateway_id, gateway_type, amazon_side_asn, availability_zone, tags
@ -126,21 +138,25 @@ class VpnGatewayBackend:
self.vpn_gateways[vpn_gateway_id] = vpn_gateway self.vpn_gateways[vpn_gateway_id] = vpn_gateway
return vpn_gateway return vpn_gateway
def describe_vpn_gateways(self, filters=None, vpn_gw_ids=None): def describe_vpn_gateways(
self, filters: Any = None, vpn_gw_ids: Optional[List[str]] = None
) -> List[VpnGateway]:
vpn_gateways = list(self.vpn_gateways.values() or []) vpn_gateways = list(self.vpn_gateways.values() or [])
if vpn_gw_ids: if vpn_gw_ids:
vpn_gateways = [item for item in vpn_gateways if item.id in vpn_gw_ids] vpn_gateways = [item for item in vpn_gateways if item.id in vpn_gw_ids]
return generic_filter(filters, vpn_gateways) return generic_filter(filters, vpn_gateways)
def get_vpn_gateway(self, vpn_gateway_id): def get_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway:
vpn_gateway = self.vpn_gateways.get(vpn_gateway_id, None) vpn_gateway = self.vpn_gateways.get(vpn_gateway_id, None)
if not vpn_gateway: if not vpn_gateway:
raise InvalidVpnGatewayIdError(vpn_gateway_id) raise InvalidVpnGatewayIdError(vpn_gateway_id)
return vpn_gateway return vpn_gateway
def attach_vpn_gateway(self, vpn_gateway_id, vpc_id): def attach_vpn_gateway(
self, vpn_gateway_id: str, vpc_id: str
) -> VPCGatewayAttachment:
vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) vpn_gateway = self.get_vpn_gateway(vpn_gateway_id)
self.get_vpc(vpc_id) self.get_vpc(vpc_id) # type: ignore[attr-defined]
attachment = VPCGatewayAttachment(vpc_id, state="attached") attachment = VPCGatewayAttachment(vpc_id, state="attached")
for key in vpn_gateway.attachments.copy(): for key in vpn_gateway.attachments.copy():
if key.startswith("vpc-"): if key.startswith("vpc-"):
@ -148,14 +164,16 @@ class VpnGatewayBackend:
vpn_gateway.attachments[vpc_id] = attachment vpn_gateway.attachments[vpc_id] = attachment
return attachment return attachment
def delete_vpn_gateway(self, vpn_gateway_id): def delete_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway:
deleted = self.vpn_gateways.get(vpn_gateway_id, None) deleted = self.vpn_gateways.get(vpn_gateway_id, None)
if not deleted: if not deleted:
raise InvalidVpnGatewayIdError(vpn_gateway_id) raise InvalidVpnGatewayIdError(vpn_gateway_id)
deleted.state = "deleted" deleted.state = "deleted"
return deleted return deleted
def detach_vpn_gateway(self, vpn_gateway_id, vpc_id): def detach_vpn_gateway(
self, vpn_gateway_id: str, vpc_id: str
) -> VPCGatewayAttachment:
vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) vpn_gateway = self.get_vpn_gateway(vpn_gateway_id)
detached = vpn_gateway.attachments.get(vpc_id, None) detached = vpn_gateway.attachments.get(vpc_id, None)
if not detached: if not detached:

View File

@ -5,7 +5,7 @@ from moto.core import BaseModel
class WindowsBackend(BaseModel): class WindowsBackend(BaseModel):
def get_password_data(self, instance_id: str) -> str: def get_password_data(self, instance_id: str) -> str:
instance = self.get_instance(instance_id) instance = self.get_instance(instance_id) # type: ignore[attr-defined]
if instance.platform == "windows": if instance.platform == "windows":
return random.get_random_string(length=128) return random.get_random_string(length=128)
return "" return ""

View File

@ -65,9 +65,8 @@ class VPCs(EC2BaseResponse):
def modify_vpc_tenancy(self): def modify_vpc_tenancy(self):
vpc_id = self._get_param("VpcId") vpc_id = self._get_param("VpcId")
tenancy = self._get_param("InstanceTenancy") tenancy = self._get_param("InstanceTenancy")
value = self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy) self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy)
template = self.response_template(MODIFY_VPC_TENANCY_RESPONSE) return self.response_template(MODIFY_VPC_TENANCY_RESPONSE).render()
return template.render(value=value)
def describe_vpc_attribute(self): def describe_vpc_attribute(self):
vpc_id = self._get_param("VpcId") vpc_id = self._get_param("VpcId")
@ -237,7 +236,6 @@ class VPCs(EC2BaseResponse):
def describe_vpc_endpoint_services(self): def describe_vpc_endpoint_services(self):
vpc_end_point_services = self.ec2_backend.describe_vpc_endpoint_services( vpc_end_point_services = self.ec2_backend.describe_vpc_endpoint_services(
dry_run=self._get_bool_param("DryRun"),
service_names=self._get_multi_param("ServiceName"), service_names=self._get_multi_param("ServiceName"),
filters=self._get_multi_param("Filter"), filters=self._get_multi_param("Filter"),
max_results=self._get_int_param("MaxResults"), max_results=self._get_int_param("MaxResults"),
@ -260,9 +258,8 @@ class VPCs(EC2BaseResponse):
def delete_vpc_endpoints(self): def delete_vpc_endpoints(self):
vpc_end_points_ids = self._get_multi_param("VpcEndpointId") vpc_end_points_ids = self._get_multi_param("VpcEndpointId")
response = self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids) self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids)
template = self.response_template(DELETE_VPC_ENDPOINT_RESPONSE) return self.response_template(DELETE_VPC_ENDPOINT_RESPONSE).render()
return template.render(response=response)
def create_managed_prefix_list(self): def create_managed_prefix_list(self):
address_family = self._get_param("AddressFamily") address_family = self._get_param("AddressFamily")
@ -767,7 +764,7 @@ DESCRIBE_VPC_ENDPOINT_RESPONSE = """<DescribeVpcEndpointsResponse xmlns="http://
DELETE_VPC_ENDPOINT_RESPONSE = """<DeleteVpcEndpointsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/"> DELETE_VPC_ENDPOINT_RESPONSE = """<DeleteVpcEndpointsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>19a9ff46-7df6-49b8-9726-3df27527089d</requestId> <requestId>19a9ff46-7df6-49b8-9726-3df27527089d</requestId>
<unsuccessful>{{ 'Error' if not response else '' }}</unsuccessful> <unsuccessful></unsuccessful>
</DeleteVpcEndpointsResponse>""" </DeleteVpcEndpointsResponse>"""

View File

@ -42,7 +42,7 @@ class VPNConnections(EC2BaseResponse):
def describe_vpn_connections(self): def describe_vpn_connections(self):
vpn_connection_ids = self._get_multi_param("VpnConnectionId") vpn_connection_ids = self._get_multi_param("VpnConnectionId")
filters = self._filters_from_querystring() filters = self._filters_from_querystring()
vpn_connections = self.ec2_backend.get_all_vpn_connections( vpn_connections = self.ec2_backend.describe_vpn_connections(
vpn_connection_ids=vpn_connection_ids, filters=filters vpn_connection_ids=vpn_connection_ids, filters=filters
) )
template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE) template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE)

View File

@ -134,11 +134,11 @@ def random_network_acl_subnet_association_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"])
def random_vpn_gateway_id(): def random_vpn_gateway_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"])
def random_vpn_connection_id(): def random_vpn_connection_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"])
@ -154,19 +154,19 @@ def random_key_pair_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["key-pair"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["key-pair"])
def random_vpc_id(): def random_vpc_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"])
def random_vpc_ep_id(): def random_vpc_ep_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-endpoint"], size=8) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-endpoint"], size=8)
def random_vpc_cidr_association_id(): def random_vpc_cidr_association_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"])
def random_vpc_peering_connection_id(): def random_vpc_peering_connection_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"])
@ -272,11 +272,11 @@ def random_mac_address() -> str:
return f"02:00:00:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x" return f"02:00:00:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x"
def randor_ipv4_cidr(): def randor_ipv4_cidr() -> str:
return f"10.0.{random.randint(0, 255)}.{random.randint(0, 255)}/16" return f"10.0.{random.randint(0, 255)}.{random.randint(0, 255)}/16"
def random_ipv6_cidr(): def random_ipv6_cidr() -> str:
return f"2400:6500:{random_resource_id(4)}:{random_resource_id(2)}00::/56" return f"2400:6500:{random_resource_id(4)}:{random_resource_id(2)}00::/56"
@ -297,13 +297,11 @@ def random_managed_prefix_list_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["managed-prefix-list"], size=8) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["managed-prefix-list"], size=8)
def create_dns_entries(service_name, vpc_endpoint_id): def create_dns_entries(service_name: str, vpc_endpoint_id: str) -> Dict[str, str]:
dns_entries = {} return {
dns_entries[ "dns_name": f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}",
"dns_name" "hosted_zone_id": random_resource_id(13).upper(),
] = f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}" }
dns_entries["hosted_zone_id"] = random_resource_id(13).upper()
return dns_entries
def utc_date_and_time() -> str: def utc_date_and_time() -> str:
@ -589,7 +587,7 @@ def get_prefix(resource_id: str) -> str:
return resource_id_prefix return resource_id_prefix
def is_valid_resource_id(resource_id): def is_valid_resource_id(resource_id: str) -> bool:
valid_prefixes = EC2_RESOURCE_TO_PREFIX.values() valid_prefixes = EC2_RESOURCE_TO_PREFIX.values()
resource_id_prefix = get_prefix(resource_id) resource_id_prefix = get_prefix(resource_id)
if resource_id_prefix not in valid_prefixes: if resource_id_prefix not in valid_prefixes:

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/a*,moto/ec2/models/c*,moto/ec2/models/d*,moto/ec2/models/e*,moto/ec2/models/f*,moto/ec2/models/h*,moto/ec2/models/i*,moto/ec2/models/k*,moto/ec2/models/l*,moto/ec2/models/m*,moto/ec2/models/n*,moto/ec2/models/r*,moto/ec2/models/s*,moto/ec2/models/t*,moto/moto_api files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/**/*.py,moto/moto_api
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract