Techdebt: MyPy EC2 models (#5925)

This commit is contained in:
Bert Blommers 2023-02-13 19:28:15 -01:00 committed by GitHub
parent 859114114c
commit c50737d027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 422 additions and 341 deletions

View File

@ -30,7 +30,7 @@ class EC2ClientError(RESTError):
class DefaultVpcAlreadyExists(EC2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"DefaultVpcAlreadyExists",
"A Default VPC already exists for this account in this region.",
@ -59,7 +59,7 @@ class InvalidDHCPOptionsIdError(EC2ClientError):
class InvalidRequest(EC2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__("InvalidRequest", "The request received was invalid")
@ -98,7 +98,7 @@ class InvalidKeyPairFormatError(EC2ClientError):
class InvalidVPCIdError(EC2ClientError):
def __init__(self, vpc_id: str):
def __init__(self, vpc_id: Any):
super().__init__("InvalidVpcID.NotFound", f"VpcID {vpc_id} does not exist.")
@ -134,7 +134,7 @@ class InvalidNetworkAclIdError(EC2ClientError):
class InvalidVpnGatewayIdError(EC2ClientError):
def __init__(self, vpn_gw):
def __init__(self, vpn_gw: str):
super().__init__(
"InvalidVpnGatewayID.NotFound",
f"The virtual private gateway ID '{vpn_gw}' does not exist",
@ -142,7 +142,7 @@ class InvalidVpnGatewayIdError(EC2ClientError):
class InvalidVpnGatewayAttachmentError(EC2ClientError):
def __init__(self, vpn_gw, vpc_id):
def __init__(self, vpn_gw: str, vpc_id: str):
super().__init__(
"InvalidVpnGatewayAttachment.NotFound",
f"The attachment with vpn gateway ID '{vpn_gw}' and vpc ID '{vpc_id}' does not exist",
@ -150,7 +150,7 @@ class InvalidVpnGatewayAttachmentError(EC2ClientError):
class InvalidVpnConnectionIdError(EC2ClientError):
def __init__(self, network_acl_id):
def __init__(self, network_acl_id: str):
super().__init__(
"InvalidVpnConnectionID.NotFound",
f"The vpnConnection ID '{network_acl_id}' does not exist",
@ -365,7 +365,7 @@ class InvalidAssociationIdError(EC2ClientError):
class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
def __init__(self, association_id):
def __init__(self, association_id: str):
super().__init__(
"InvalidVpcCidrBlockAssociationIdError.NotFound",
f"The vpc CIDR block association ID '{association_id}' does not exist",
@ -373,7 +373,7 @@ class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
class InvalidVPCPeeringConnectionIdError(EC2ClientError):
def __init__(self, vpc_peering_connection_id):
def __init__(self, vpc_peering_connection_id: str):
super().__init__(
"InvalidVpcPeeringConnectionId.NotFound",
f"VpcPeeringConnectionID {vpc_peering_connection_id} does not exist.",
@ -381,7 +381,7 @@ class InvalidVPCPeeringConnectionIdError(EC2ClientError):
class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
def __init__(self, vpc_peering_connection_id):
def __init__(self, vpc_peering_connection_id: str):
super().__init__(
"InvalidStateTransition",
f"VpcPeeringConnectionID {vpc_peering_connection_id} is not in the correct state for the request.",
@ -389,7 +389,7 @@ class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
class InvalidServiceName(EC2ClientError):
def __init__(self, service_name):
def __init__(self, service_name: str):
super().__init__(
"InvalidServiceName",
f"The Vpc Endpoint Service '{service_name}' does not exist",
@ -402,7 +402,7 @@ class InvalidFilter(EC2ClientError):
class InvalidNextToken(EC2ClientError):
def __init__(self, next_token):
def __init__(self, next_token: str):
super().__init__("InvalidNextToken", f"The token '{next_token}' is invalid")
@ -436,7 +436,7 @@ class InvalidParameterValueError(EC2ClientError):
class EmptyTagSpecError(EC2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidParameterValue", "Tag specification must have at least one tag"
)
@ -498,7 +498,7 @@ class TagLimitExceeded(EC2ClientError):
class InvalidID(EC2ClientError):
def __init__(self, resource_id):
def __init__(self, resource_id: str):
super().__init__("InvalidID", f"The ID '{resource_id}' is not valid")
@ -532,7 +532,7 @@ class FilterNotImplementedError(MotoNotImplementedError):
class CidrLimitExceeded(EC2ClientError):
def __init__(self, vpc_id, max_cidr_limit):
def __init__(self, vpc_id: str, max_cidr_limit: int):
super().__init__(
"CidrLimitExceeded",
f"This network '{vpc_id}' has met its maximum number of allowed CIDRs: {max_cidr_limit}",
@ -540,14 +540,14 @@ class CidrLimitExceeded(EC2ClientError):
class UnsupportedTenancy(EC2ClientError):
def __init__(self, tenancy):
def __init__(self, tenancy: str):
super().__init__(
"UnsupportedTenancy", f"The tenancy value {tenancy} is not supported."
)
class OperationNotPermitted(EC2ClientError):
def __init__(self, association_id):
def __init__(self, association_id: str):
super().__init__(
"OperationNotPermitted",
f"The vpc CIDR block with association ID {association_id} may not be disassociated. It is the primary IPv4 CIDR block of the VPC",
@ -611,13 +611,13 @@ class InvalidSubnetConflictError(EC2ClientError):
class InvalidVPCRangeError(EC2ClientError):
def __init__(self, cidr_block):
def __init__(self, cidr_block: str):
super().__init__("InvalidVpc.Range", f"The CIDR '{cidr_block}' is invalid.")
# accept exception
class OperationNotPermitted2(EC2ClientError):
def __init__(self, client_region, pcx_id, acceptor_region):
def __init__(self, client_region: str, pcx_id: str, acceptor_region: str):
super().__init__(
"OperationNotPermitted",
f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted in region {acceptor_region}",
@ -626,7 +626,7 @@ class OperationNotPermitted2(EC2ClientError):
# reject exception
class OperationNotPermitted3(EC2ClientError):
def __init__(self, client_region, pcx_id, acceptor_region):
def __init__(self, client_region: str, pcx_id: str, acceptor_region: str):
super().__init__(
"OperationNotPermitted",
f"Incorrect region ({client_region}) specified for this request.VPC peering connection {pcx_id} must be accepted or rejected in region {acceptor_region}",
@ -690,7 +690,7 @@ class InvalidAssociationIDIamProfileAssociationError(EC2ClientError):
class InvalidVpcEndPointIdError(EC2ClientError):
def __init__(self, vpc_end_point_id):
def __init__(self, vpc_end_point_id: str):
super().__init__(
"InvalidVpcEndpointId.NotFound",
f"The VpcEndPoint ID '{vpc_end_point_id}' does not exist",
@ -730,7 +730,7 @@ class InvalidCarrierGatewayID(EC2ClientError):
class NoLoadBalancersProvided(EC2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidParameter",
"exactly one of network_load_balancer_arn or gateway_load_balancer_arn is a required member",
@ -738,7 +738,7 @@ class NoLoadBalancersProvided(EC2ClientError):
class UnknownVpcEndpointService(EC2ClientError):
def __init__(self, service_id):
def __init__(self, service_id: str):
super().__init__(
"InvalidVpcEndpointServiceId.NotFound",
f"The VpcEndpointService Id '{service_id}' does not exist",

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict
from ..exceptions import (
EC2ClientError,
@ -49,7 +50,7 @@ from ..utils import (
)
def validate_resource_ids(resource_ids):
def validate_resource_ids(resource_ids: List[str]) -> bool:
if not resource_ids:
raise MissingParameterError(parameter="resourceIdSet")
for resource_id in resource_ids:
@ -59,19 +60,19 @@ def validate_resource_ids(resource_ids):
class SettingsBackend:
def __init__(self):
def __init__(self) -> None:
self.ebs_encryption_by_default = False
def disable_ebs_encryption_by_default(self):
ec2_backend = ec2_backends[self.account_id][self.region_name]
def disable_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
ec2_backend.ebs_encryption_by_default = False
def enable_ebs_encryption_by_default(self):
ec2_backend = ec2_backends[self.account_id][self.region_name]
def enable_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
ec2_backend.ebs_encryption_by_default = True
def get_ebs_encryption_by_default(self):
ec2_backend = ec2_backends[self.account_id][self.region_name]
def get_ebs_encryption_by_default(self) -> None:
ec2_backend = ec2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
return ec2_backend.ebs_encryption_by_default
@ -129,11 +130,11 @@ class EC2Backend(
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
BaseBackend.__init__(self, region_name, account_id)
for backend in EC2Backend.__mro__:
if backend not in [EC2Backend, BaseBackend, object]:
backend.__init__(self)
backend.__init__(self) # type: ignore
# Default VPC exists by default, which is the current behavior
# of EC2-VPC. See for detail:
@ -145,23 +146,23 @@ class EC2Backend(
else:
# For now this is included for potential
# backward-compatibility issues
vpc = self.vpcs.values()[0]
vpc = list(self.vpcs.values())[0]
self.default_vpc = vpc
# Create default subnet for each availability zone
ip, _ = vpc.cidr_block.split("/")
ip = ip.split(".")
ip[2] = 0
ip = ip.split(".") # type: ignore
ip[2] = 0 # type: ignore
for zone in self.describe_availability_zones():
az_name = zone.name
cidr_block = ".".join(str(i) for i in ip) + "/20"
self.create_subnet(vpc.id, cidr_block, availability_zone=az_name)
ip[2] += 16
ip[2] += 16 # type: ignore
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "ec2"
@ -171,13 +172,13 @@ class EC2Backend(
# Use this to generate a proper error template response when in a response
# handler.
def raise_error(self, code, message):
def raise_error(self, code: str, message: str) -> None:
raise EC2ClientError(code, message)
def raise_not_implemented_error(self, blurb: str):
def raise_not_implemented_error(self, blurb: str) -> None:
raise MotoNotImplementedError(blurb)
def do_resources_exist(self, resource_ids):
def do_resources_exist(self, resource_ids: List[str]) -> bool:
for resource_id in resource_ids:
resource_prefix = get_prefix(resource_id)
if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]:

View File

@ -1,5 +1,6 @@
import weakref
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional
from moto.core import CloudFormationModel
from ..exceptions import (
InvalidVPCPeeringConnectionIdError,
@ -8,6 +9,7 @@ from ..exceptions import (
OperationNotPermitted3,
)
from .core import TaggedEC2Resource
from .vpcs import VPC
from ..utils import random_vpc_peering_connection_id
@ -44,7 +46,14 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
"AllowDnsResolutionFromRemoteVpc": "false",
}
def __init__(self, backend, vpc_pcx_id, vpc, peer_vpc, tags=None):
def __init__(
self,
backend: Any,
vpc_pcx_id: str,
vpc: VPC,
peer_vpc: VPC,
tags: Optional[Dict[str, str]] = None,
):
self.id = vpc_pcx_id
self.ec2_backend = backend
self.vpc = vpc
@ -55,18 +64,23 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
self._status = PeeringConnectionStatus()
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcpeeringconnection.html
return "AWS::EC2::VPCPeeringConnection"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VPCPeeringConnection":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
@ -80,26 +94,28 @@ class VPCPeeringConnection(TaggedEC2Resource, CloudFormationModel):
return vpc_pcx
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.id
class VPCPeeringConnectionBackend:
# for cross region vpc reference
vpc_pcx_refs = defaultdict(set)
vpc_pcx_refs = defaultdict(set) # type: ignore
def __init__(self):
self.vpc_pcxs = {}
def __init__(self) -> None:
self.vpc_pcxs: Dict[str, VPCPeeringConnection] = {}
self.vpc_pcx_refs[self.__class__].add(weakref.ref(self))
@classmethod
def get_vpc_pcx_refs(cls):
def get_vpc_pcx_refs(cls) -> Iterator[VPCPeeringConnection]:
for inst_ref in cls.vpc_pcx_refs[cls]:
inst = inst_ref()
if inst is not None:
yield inst
def create_vpc_peering_connection(self, vpc, peer_vpc, tags=None):
def create_vpc_peering_connection(
self, vpc: VPC, peer_vpc: VPC, tags: Optional[Dict[str, str]] = None
) -> VPCPeeringConnection:
vpc_pcx_id = random_vpc_peering_connection_id()
vpc_pcx = VPCPeeringConnection(self, vpc_pcx_id, vpc, peer_vpc, tags)
vpc_pcx._status.pending()
@ -111,49 +127,54 @@ class VPCPeeringConnectionBackend:
vpc_pcx_cx.vpc_pcxs[vpc_pcx_id] = vpc_pcx
return vpc_pcx
def describe_vpc_peering_connections(self, vpc_peering_ids=None):
all_pcxs = self.vpc_pcxs.copy().values()
def describe_vpc_peering_connections(
self, vpc_peering_ids: Optional[List[str]] = None
) -> List[VPCPeeringConnection]:
all_pcxs = list(self.vpc_pcxs.values())
if vpc_peering_ids:
return [pcx for pcx in all_pcxs if pcx.id in vpc_peering_ids]
return all_pcxs
def get_vpc_peering_connection(self, vpc_pcx_id):
def get_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
if vpc_pcx_id not in self.vpc_pcxs:
raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id)
return self.vpc_pcxs.get(vpc_pcx_id)
return self.vpc_pcxs[vpc_pcx_id]
def delete_vpc_peering_connection(self, vpc_pcx_id):
def delete_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
deleted = self.get_vpc_peering_connection(vpc_pcx_id)
deleted._status.deleted()
return deleted
def accept_vpc_peering_connection(self, vpc_pcx_id):
def accept_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
# if cross region need accepter from another region
pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name
pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region:
raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region)
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined]
raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined]
if vpc_pcx._status.code != "pending-acceptance":
raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id)
vpc_pcx._status.accept()
return vpc_pcx
def reject_vpc_peering_connection(self, vpc_pcx_id):
def reject_vpc_peering_connection(self, vpc_pcx_id: str) -> VPCPeeringConnection:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
# if cross region need accepter from another region
pcx_req_region = vpc_pcx.vpc.ec2_backend.region_name
pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region:
raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region)
if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: # type: ignore[attr-defined]
raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) # type: ignore[attr-defined]
if vpc_pcx._status.code != "pending-acceptance":
raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id)
vpc_pcx._status.reject()
return vpc_pcx
def modify_vpc_peering_connection_options(
self, vpc_pcx_id, accepter_options=None, requester_options=None
):
self,
vpc_pcx_id: str,
accepter_options: Optional[Dict[str, Any]] = None,
requester_options: Optional[Dict[str, Any]] = None,
) -> None:
vpc_pcx = self.get_vpc_peering_connection(vpc_pcx_id)
if not vpc_pcx:
raise InvalidVPCPeeringConnectionIdError(vpc_pcx_id)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from moto.core import CloudFormationModel
from moto.moto_api._internal import mock_random
from .core import TaggedEC2Resource
@ -6,7 +7,12 @@ from ..exceptions import UnknownVpcEndpointService
class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel):
def __init__(
self, load_balancers, region, acceptance_required, private_dns_name, ec2_backend
self,
load_balancers: List[Any],
region: str,
acceptance_required: bool,
private_dns_name: str,
ec2_backend: Any,
):
self.id = f"vpce-svc-{mock_random.get_random_hex(length=8)}"
self.service_name = f"com.amazonaws.vpce.{region}.{self.id}"
@ -32,43 +38,49 @@ class VPCServiceConfiguration(TaggedEC2Resource, CloudFormationModel):
self.private_dns_name = private_dns_name
self.endpoint_dns_name = f"{self.id}.{region}.vpce.amazonaws.com"
self.principals = []
self.principals: List[str] = []
self.ec2_backend = ec2_backend
class VPCServiceConfigurationBackend:
def __init__(self):
self.configurations = {}
def __init__(self) -> None:
self.configurations: Dict[str, VPCServiceConfiguration] = {}
@property
def elbv2_backend(self):
def elbv2_backend(self) -> Any: # type: ignore[misc]
from moto.elbv2.models import elbv2_backends
return elbv2_backends[self.account_id][self.region_name]
return elbv2_backends[self.account_id][self.region_name] # type: ignore[attr-defined]
def get_vpc_endpoint_service(self, resource_id):
def get_vpc_endpoint_service(
self, resource_id: str
) -> Optional[VPCServiceConfiguration]:
return self.configurations.get(resource_id)
def create_vpc_endpoint_service_configuration(
self, lb_arns, acceptance_required, private_dns_name, tags
):
self,
lb_arns: List[Any],
acceptance_required: bool,
private_dns_name: str,
tags: List[Dict[str, str]],
) -> VPCServiceConfiguration:
lbs = self.elbv2_backend.describe_load_balancers(arns=lb_arns, names=None)
config = VPCServiceConfiguration(
load_balancers=lbs,
region=self.region_name,
region=self.region_name, # type: ignore[attr-defined]
acceptance_required=acceptance_required,
private_dns_name=private_dns_name,
ec2_backend=self,
)
for tag in tags or []:
tag_key = tag.get("Key")
tag_value = tag.get("Value")
config.add_tag(tag_key, tag_value)
config.add_tag(tag["Key"], tag["Value"])
self.configurations[config.id] = config
return config
def describe_vpc_endpoint_service_configurations(self, service_ids):
def describe_vpc_endpoint_service_configurations(
self, service_ids: Optional[List[str]]
) -> List[VPCServiceConfiguration]:
"""
The Filters, MaxResults, NextToken parameters are not yet implemented
"""
@ -80,15 +92,17 @@ class VPCServiceConfigurationBackend:
else:
raise UnknownVpcEndpointService(service_id)
return found_configs
return self.configurations.copy().values()
return list(self.configurations.values())
def delete_vpc_endpoint_service_configurations(self, service_ids):
def delete_vpc_endpoint_service_configurations(
self, service_ids: List[str]
) -> List[str]:
missing = [s for s in service_ids if s not in self.configurations]
for s in service_ids:
self.configurations.pop(s, None)
return missing
def describe_vpc_endpoint_service_permissions(self, service_id):
def describe_vpc_endpoint_service_permissions(self, service_id: str) -> List[str]:
"""
The Filters, MaxResults, NextToken parameters are not yet implemented
"""
@ -96,8 +110,8 @@ class VPCServiceConfigurationBackend:
return config.principals
def modify_vpc_endpoint_service_permissions(
self, service_id, add_principals, remove_principals
):
self, service_id: str, add_principals: List[str], remove_principals: List[str]
) -> None:
config = self.describe_vpc_endpoint_service_configurations([service_id])[0]
config.principals += add_principals
config.principals = [p for p in config.principals if p not in remove_principals]
@ -105,14 +119,14 @@ class VPCServiceConfigurationBackend:
def modify_vpc_endpoint_service_configuration(
self,
service_id,
acceptance_required,
private_dns_name,
add_network_lbs,
remove_network_lbs,
add_gateway_lbs,
remove_gateway_lbs,
):
service_id: str,
acceptance_required: Optional[str],
private_dns_name: Optional[str],
add_network_lbs: List[str],
remove_network_lbs: List[str],
add_gateway_lbs: List[str],
remove_gateway_lbs: List[str],
) -> None:
"""
The following parameters are not yet implemented: RemovePrivateDnsName
"""

View File

@ -2,6 +2,7 @@ import ipaddress
import json
import weakref
from collections import defaultdict
from typing import Any, Dict, List, Optional
from operator import itemgetter
from moto.core import CloudFormationModel
@ -35,27 +36,27 @@ from ..utils import (
)
MAX_NUMBER_OF_ENDPOINT_SERVICES_RESULTS = 1000
DEFAULT_VPC_ENDPOINT_SERVICES = []
DEFAULT_VPC_ENDPOINT_SERVICES: List[Dict[str, str]] = []
class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
def __init__(
self,
ec2_backend,
endpoint_id,
vpc_id,
service_name,
endpoint_type=None,
policy_document=False,
route_table_ids=None,
subnet_ids=None,
network_interface_ids=None,
dns_entries=None,
client_token=None,
security_group_ids=None,
tags=None,
private_dns_enabled=None,
destination_prefix_list_id=None,
ec2_backend: Any,
endpoint_id: str,
vpc_id: str,
service_name: str,
endpoint_type: Optional[str],
policy_document: Optional[str],
route_table_ids: List[str],
subnet_ids: Optional[List[str]] = None,
network_interface_ids: Optional[List[str]] = None,
dns_entries: Optional[List[Dict[str, str]]] = None,
client_token: Optional[str] = None,
security_group_ids: Optional[List[str]] = None,
tags: Optional[Dict[str, str]] = None,
private_dns_enabled: Optional[str] = None,
destination_prefix_list_id: Optional[str] = None,
):
self.ec2_backend = ec2_backend
self.id = endpoint_id
@ -76,11 +77,17 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
self.created_at = utc_date_and_time()
def modify(self, policy_doc, add_subnets, add_route_tables, remove_route_tables):
def modify(
self,
policy_doc: Optional[str],
add_subnets: Optional[List[str]],
add_route_tables: Optional[List[str]],
remove_route_tables: Optional[List[str]],
) -> None:
if policy_doc:
self.policy_document = policy_doc
if add_subnets:
self.subnet_ids.extend(add_subnets)
self.subnet_ids.extend(add_subnets) # type: ignore[union-attr]
if add_route_tables:
self.route_table_ids.extend(add_route_tables)
if remove_route_tables:
@ -90,32 +97,39 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
if rt_id not in remove_route_tables
]
def get_filter_value(self, filter_name):
def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name in ("vpc-endpoint-type", "vpc_endpoint_type"):
return self.endpoint_type
else:
return super().get_filter_value(filter_name, "DescribeVpcs")
@property
def owner_id(self):
def owner_id(self) -> str:
return self.ec2_backend.account_id
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.id
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
return "AWS::EC2::VPCEndpoint"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "VPCEndPoint":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
@ -146,19 +160,19 @@ class VPCEndPoint(TaggedEC2Resource, CloudFormationModel):
class VPC(TaggedEC2Resource, CloudFormationModel):
def __init__(
self,
ec2_backend,
vpc_id,
cidr_block,
is_default,
instance_tenancy="default",
amazon_provided_ipv6_cidr_block=False,
ipv6_cidr_block_network_border_group=None,
ec2_backend: Any,
vpc_id: str,
cidr_block: str,
is_default: bool,
instance_tenancy: str = "default",
amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group: Optional[str] = None,
):
self.ec2_backend = ec2_backend
self.id = vpc_id
self.cidr_block = cidr_block
self.cidr_block_association_set = {}
self.cidr_block_association_set: Dict[str, Any] = {}
self.dhcp_options = None
self.state = "available"
self.instance_tenancy = instance_tenancy
@ -180,22 +194,27 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
)
@property
def owner_id(self):
def owner_id(self) -> str:
return self.ec2_backend.account_id
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpc.html
return "AWS::EC2::VPC"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "VPC":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
@ -213,10 +232,12 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
return vpc
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.id
def get_filter_value(self, filter_name):
def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name in ("vpc-id", "vpcId"):
return self.id
elif filter_name in ("cidr", "cidr-block", "cidrBlock"):
@ -255,23 +276,22 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
else:
return super().get_filter_value(filter_name, "DescribeVpcs")
def modify_vpc_tenancy(self, tenancy):
def modify_vpc_tenancy(self, tenancy: str) -> None:
if tenancy != "default":
raise UnsupportedTenancy(tenancy)
self.instance_tenancy = tenancy
return True
def associate_vpc_cidr_block(
self,
cidr_block,
amazon_provided_ipv6_cidr_block=False,
ipv6_cidr_block_network_border_group=None,
):
cidr_block: str,
amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group: Optional[str] = None,
) -> Dict[str, Any]:
max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1
for cidr in self.cidr_block_association_set.copy():
if (
self.cidr_block_association_set.get(cidr)
self.cidr_block_association_set.get(cidr) # type: ignore[union-attr]
.get("cidr_block_state")
.get("state")
== "disassociated"
@ -285,7 +305,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
association_id = random_vpc_cidr_association_id()
association_set = {
association_set: Dict[str, Any] = {
"association_id": association_id,
"cidr_block_state": {"state": "associated", "StatusMessage": ""},
}
@ -301,7 +321,7 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
self.cidr_block_association_set[association_id] = association_set
return association_set
def enable_vpc_classic_link(self):
def enable_vpc_classic_link(self) -> str:
# Check if current cidr block doesn't fall within the 10.0.0.0/8 block, excluding 10.0.0.0/16 and 10.1.0.0/16.
# Doesn't check any route tables, maybe something for in the future?
# See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/vpc-classiclink.html#classiclink-limitations
@ -315,19 +335,19 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
return self.classic_link_enabled
def disable_vpc_classic_link(self):
def disable_vpc_classic_link(self) -> str:
self.classic_link_enabled = "false"
return self.classic_link_enabled
def enable_vpc_classic_link_dns_support(self):
def enable_vpc_classic_link_dns_support(self) -> str:
self.classic_link_dns_supported = "true"
return self.classic_link_dns_supported
def disable_vpc_classic_link_dns_support(self):
def disable_vpc_classic_link_dns_support(self) -> str:
self.classic_link_dns_supported = "false"
return self.classic_link_dns_supported
def disassociate_vpc_cidr_block(self, association_id):
def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]:
if self.cidr_block == self.cidr_block_association_set.get(
association_id, {}
).get("cidr_block"):
@ -341,7 +361,9 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
entry["cidr_block_state"]["state"] = "disassociated"
return response
def get_cidr_block_association_set(self, ipv6=False):
def get_cidr_block_association_set(
self, ipv6: bool = False
) -> List[Dict[str, Any]]:
return [
c
for c in self.cidr_block_association_set.values()
@ -350,14 +372,14 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
class VPCBackend:
vpc_refs = defaultdict(set)
vpc_refs = defaultdict(set) # type: ignore
def __init__(self):
self.vpcs = {}
self.vpc_end_points = {}
def __init__(self) -> None:
self.vpcs: Dict[str, VPC] = {}
self.vpc_end_points: Dict[str, VPCEndPoint] = {}
self.vpc_refs[self.__class__].add(weakref.ref(self))
def create_default_vpc(self):
def create_default_vpc(self) -> VPC:
default_vpc = self.describe_vpcs(filters={"is-default": "true"})
if default_vpc:
raise DefaultVpcAlreadyExists
@ -366,13 +388,13 @@ class VPCBackend:
def create_vpc(
self,
cidr_block,
instance_tenancy="default",
amazon_provided_ipv6_cidr_block=False,
ipv6_cidr_block_network_border_group=None,
tags=None,
is_default=False,
):
cidr_block: str,
instance_tenancy: str = "default",
amazon_provided_ipv6_cidr_block: bool = False,
ipv6_cidr_block_network_border_group: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
is_default: bool = False,
) -> VPC:
vpc_id = random_vpc_id()
try:
vpc_cidr_block = ipaddress.IPv4Network(str(cidr_block), strict=False)
@ -391,45 +413,45 @@ class VPCBackend:
)
for tag in tags or []:
tag_key = tag.get("Key")
tag_value = tag.get("Value")
vpc.add_tag(tag_key, tag_value)
vpc.add_tag(tag["Key"], tag["Value"])
self.vpcs[vpc_id] = vpc
# AWS creates a default main route table and security group.
self.create_route_table(vpc_id, main=True)
self.create_route_table(vpc_id, main=True) # type: ignore[attr-defined]
# AWS creates a default Network ACL
self.create_network_acl(vpc_id, default=True)
self.create_network_acl(vpc_id, default=True) # type: ignore[attr-defined]
default = self.get_security_group_from_name("default", vpc_id=vpc_id)
default = self.get_security_group_from_name("default", vpc_id=vpc_id) # type: ignore[attr-defined]
if not default:
self.create_security_group(
self.create_security_group( # type: ignore[attr-defined]
"default", "default VPC security group", vpc_id=vpc_id, is_default=True
)
return vpc
def get_vpc(self, vpc_id):
def get_vpc(self, vpc_id: str) -> VPC:
if vpc_id not in self.vpcs:
raise InvalidVPCIdError(vpc_id)
return self.vpcs.get(vpc_id)
return self.vpcs[vpc_id]
def describe_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.copy().values()
def describe_vpcs(
self, vpc_ids: Optional[List[str]] = None, filters: Any = None
) -> List[VPC]:
matches = list(self.vpcs.values())
if vpc_ids:
matches = [vpc for vpc in matches if vpc.id in vpc_ids]
if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches)
unknown_ids = set(vpc_ids) - set(matches) # type: ignore[arg-type]
raise InvalidVPCIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_vpc(self, vpc_id):
def delete_vpc(self, vpc_id: str) -> VPC:
# Do not delete if any VPN Gateway is attached
vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id})
vpn_gateways = self.describe_vpn_gateways(filters={"attachment.vpc-id": vpc_id}) # type: ignore[attr-defined]
vpn_gateways = [
item
for item in vpn_gateways
@ -441,18 +463,18 @@ class VPCBackend:
)
# Delete route table if only main route table remains.
route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id})
route_tables = self.describe_route_tables(filters={"vpc-id": vpc_id}) # type: ignore[attr-defined]
if len(route_tables) > 1:
raise DependencyViolationError(
f"The vpc {vpc_id} has dependencies and cannot be deleted."
)
for route_table in route_tables:
self.delete_route_table(route_table.id)
self.delete_route_table(route_table.id) # type: ignore[attr-defined]
# Delete default security group if exists.
default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id)
default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) # type: ignore[attr-defined]
if default:
self.delete_security_group(group_id=default.id)
self.delete_security_group(group_id=default.id) # type: ignore[attr-defined]
# Now delete VPC.
vpc = self.vpcs.pop(vpc_id, None)
@ -465,7 +487,7 @@ class VPCBackend:
vpc.dhcp_options = None
return vpc
def describe_vpc_attribute(self, vpc_id, attr_name):
def describe_vpc_attribute(self, vpc_id: str, attr_name: str) -> Any:
vpc = self.get_vpc(vpc_id)
if attr_name in (
"enable_dns_support",
@ -476,27 +498,29 @@ class VPCBackend:
else:
raise InvalidParameterValueError(attr_name)
def modify_vpc_tenancy(self, vpc_id, tenancy):
def modify_vpc_tenancy(self, vpc_id: str, tenancy: str) -> None:
vpc = self.get_vpc(vpc_id)
return vpc.modify_vpc_tenancy(tenancy)
vpc.modify_vpc_tenancy(tenancy)
def enable_vpc_classic_link(self, vpc_id):
def enable_vpc_classic_link(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id)
return vpc.enable_vpc_classic_link()
def disable_vpc_classic_link(self, vpc_id):
def disable_vpc_classic_link(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id)
return vpc.disable_vpc_classic_link()
def enable_vpc_classic_link_dns_support(self, vpc_id):
def enable_vpc_classic_link_dns_support(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id)
return vpc.enable_vpc_classic_link_dns_support()
def disable_vpc_classic_link_dns_support(self, vpc_id):
def disable_vpc_classic_link_dns_support(self, vpc_id: str) -> str:
vpc = self.get_vpc(vpc_id)
return vpc.disable_vpc_classic_link_dns_support()
def modify_vpc_attribute(self, vpc_id, attr_name, attr_value):
def modify_vpc_attribute(
self, vpc_id: str, attr_name: str, attr_value: str
) -> None:
vpc = self.get_vpc(vpc_id)
if attr_name in (
"enable_dns_support",
@ -507,58 +531,58 @@ class VPCBackend:
else:
raise InvalidParameterValueError(attr_name)
def disassociate_vpc_cidr_block(self, association_id):
def disassociate_vpc_cidr_block(self, association_id: str) -> Dict[str, Any]:
for vpc in self.vpcs.copy().values():
response = vpc.disassociate_vpc_cidr_block(association_id)
for route_table in self.route_tables.copy().values():
for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined]
if route_table.vpc_id == response.get("vpc_id"):
if "::/" in response.get("cidr_block"):
self.delete_route(
if "::/" in response.get("cidr_block"): # type: ignore[operator]
self.delete_route( # type: ignore[attr-defined]
route_table.id, None, response.get("cidr_block")
)
else:
self.delete_route(route_table.id, response.get("cidr_block"))
self.delete_route(route_table.id, response.get("cidr_block")) # type: ignore[attr-defined]
if response:
return response
raise InvalidVpcCidrBlockAssociationIdError(association_id)
def associate_vpc_cidr_block(
self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block
):
self, vpc_id: str, cidr_block: str, amazon_provided_ipv6_cidr_block: bool
) -> Dict[str, Any]:
vpc = self.get_vpc(vpc_id)
association_set = vpc.associate_vpc_cidr_block(
cidr_block, amazon_provided_ipv6_cidr_block
)
for route_table in self.route_tables.copy().values():
for route_table in self.route_tables.copy().values(): # type: ignore[attr-defined]
if route_table.vpc_id == vpc_id:
if amazon_provided_ipv6_cidr_block:
self.create_route(
self.create_route( # type: ignore[attr-defined]
route_table.id,
None,
destination_ipv6_cidr_block=association_set["cidr_block"],
local=True,
)
else:
self.create_route(
self.create_route( # type: ignore[attr-defined]
route_table.id, association_set["cidr_block"], local=True
)
return association_set
def create_vpc_endpoint(
self,
vpc_id,
service_name,
endpoint_type=None,
policy_document=False,
route_table_ids=None,
subnet_ids=None,
network_interface_ids=None,
dns_entries=None,
client_token=None,
security_group_ids=None,
tags=None,
private_dns_enabled=None,
):
vpc_id: str,
service_name: str,
endpoint_type: Optional[str],
policy_document: Optional[str],
route_table_ids: List[str],
subnet_ids: Optional[List[str]] = None,
network_interface_ids: Optional[List[str]] = None,
dns_entries: Optional[Dict[str, str]] = None,
client_token: Optional[str] = None,
security_group_ids: Optional[List[str]] = None,
tags: Optional[Dict[str, str]] = None,
private_dns_enabled: Optional[str] = None,
) -> VPCEndPoint:
vpc_endpoint_id = random_vpc_ep_id()
@ -570,21 +594,18 @@ class VPCBackend:
network_interface_ids = []
for subnet_id in subnet_ids or []:
self.get_subnet(subnet_id)
eni = self.create_network_interface(subnet_id, random_private_ip())
self.get_subnet(subnet_id) # type: ignore[attr-defined]
eni = self.create_network_interface(subnet_id, random_private_ip()) # type: ignore[attr-defined]
network_interface_ids.append(eni.id)
dns_entries = create_dns_entries(service_name, vpc_endpoint_id)
else:
# considering gateway if type is not mentioned.
for prefix_list in self.managed_prefix_lists.values():
for prefix_list in self.managed_prefix_lists.values(): # type: ignore[attr-defined]
if prefix_list.prefix_list_name == service_name:
destination_prefix_list_id = prefix_list.id
if dns_entries:
dns_entries = [dns_entries]
vpc_end_point = VPCEndPoint(
self,
vpc_endpoint_id,
@ -595,19 +616,19 @@ class VPCBackend:
route_table_ids,
subnet_ids,
network_interface_ids,
dns_entries,
client_token,
security_group_ids,
tags,
private_dns_enabled,
destination_prefix_list_id,
dns_entries=[dns_entries] if dns_entries else None,
client_token=client_token,
security_group_ids=security_group_ids,
tags=tags,
private_dns_enabled=private_dns_enabled,
destination_prefix_list_id=destination_prefix_list_id,
)
self.vpc_end_points[vpc_endpoint_id] = vpc_end_point
if destination_prefix_list_id:
for route_table_id in route_table_ids:
self.create_route(
self.create_route( # type: ignore[attr-defined]
route_table_id,
None,
gateway_id=vpc_endpoint_id,
@ -617,28 +638,34 @@ class VPCBackend:
return vpc_end_point
def modify_vpc_endpoint(
self, vpc_id, policy_doc, add_subnets, remove_route_tables, add_route_tables
):
self,
vpc_id: str,
policy_doc: str,
add_subnets: Optional[List[str]],
remove_route_tables: Optional[List[str]],
add_route_tables: Optional[List[str]],
) -> None:
endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0]
endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables)
def delete_vpc_endpoints(self, vpce_ids=None):
def delete_vpc_endpoints(self, vpce_ids: Optional[List[str]] = None) -> None:
for vpce_id in vpce_ids or []:
vpc_endpoint = self.vpc_end_points.get(vpce_id, None)
if vpc_endpoint:
if vpc_endpoint.endpoint_type.lower() == "interface":
if vpc_endpoint.endpoint_type.lower() == "interface": # type: ignore[union-attr]
for eni_id in vpc_endpoint.network_interface_ids:
self.enis.pop(eni_id, None)
self.enis.pop(eni_id, None) # type: ignore[attr-defined]
else:
for route_table_id in vpc_endpoint.route_table_ids:
self.delete_route(
self.delete_route( # type: ignore[attr-defined]
route_table_id, vpc_endpoint.destination_prefix_list_id
)
vpc_endpoint.state = "deleted"
return True
def describe_vpc_endpoints(self, vpc_end_point_ids, filters=None):
vpc_end_points = self.vpc_end_points.values()
def describe_vpc_endpoints(
self, vpc_end_point_ids: Optional[List[str]], filters: Any = None
) -> List[VPCEndPoint]:
vpc_end_points = list(self.vpc_end_points.values())
if vpc_end_point_ids:
vpc_end_points = [
@ -657,7 +684,9 @@ class VPCBackend:
return generic_filter(filters, vpc_end_points)
@staticmethod
def _collect_default_endpoint_services(account_id, region):
def _collect_default_endpoint_services(
account_id: str, region: str
) -> List[Dict[str, str]]:
"""Return list of default services using list of backends."""
if DEFAULT_VPC_ENDPOINT_SERVICES:
return DEFAULT_VPC_ENDPOINT_SERVICES
@ -672,14 +701,16 @@ class VPCBackend:
from moto import backends # pylint: disable=import-outside-toplevel
for _backends in backends.service_backends():
_backends = _backends[account_id]
if region in _backends:
service = _backends[region].default_vpc_endpoint_service(region, zones)
account_backend = _backends[account_id]
if region in account_backend:
service = account_backend[region].default_vpc_endpoint_service(
region, zones
)
if service:
DEFAULT_VPC_ENDPOINT_SERVICES.extend(service)
if "global" in _backends:
service = _backends["global"].default_vpc_endpoint_service(
if "global" in account_backend:
service = account_backend["global"].default_vpc_endpoint_service(
region, zones
)
if service:
@ -687,7 +718,7 @@ class VPCBackend:
return DEFAULT_VPC_ENDPOINT_SERVICES
@staticmethod
def _matches_service_by_tags(service, filter_item):
def _matches_service_by_tags(service: Dict[str, Any], filter_item: Dict[str, Any]) -> bool: # type: ignore[misc]
"""Return True if service tags are not filtered by their tags.
Note that the API specifies a key of "Values" for a filter, but
@ -719,7 +750,7 @@ class VPCBackend:
return matched
@staticmethod
def _filter_endpoint_services(service_names_filters, filters, services):
def _filter_endpoint_services(service_names_filters: List[str], filters: List[Dict[str, Any]], services: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Return filtered list of VPC endpoint services."""
if not service_names_filters and not filters:
return services
@ -774,11 +805,16 @@ class VPCBackend:
return filtered_services
def describe_vpc_endpoint_services(
self, dry_run, service_names, filters, max_results, next_token, region
): # pylint: disable=unused-argument,too-many-arguments
self,
service_names: List[str],
filters: Any,
max_results: int,
next_token: Optional[str],
region: str,
) -> Dict[str, Any]: # pylint: disable=too-many-arguments
"""Return info on services to which you can create a VPC endpoint.
Currently only the default endpoing services are returned. When
Currently only the default endpoint services are returned. When
create_vpc_endpoint_service_configuration() is implemented, a
list of those private endpoints would be kept and when this API
is invoked, those private endpoints would be added to the list of
@ -787,7 +823,7 @@ class VPCBackend:
The DryRun parameter is ignored.
"""
default_services = self._collect_default_endpoint_services(
self.account_id, region
self.account_id, region # type: ignore[attr-defined]
)
for service_name in service_names:
if service_name not in [x["ServiceName"] for x in default_services]:
@ -827,7 +863,7 @@ class VPCBackend:
"nextToken": next_token,
}
def get_vpc_end_point(self, vpc_end_point_id):
def get_vpc_end_point(self, vpc_end_point_id: str) -> VPCEndPoint:
vpc_end_point = self.vpc_end_points.get(vpc_end_point_id)
if not vpc_end_point:
raise InvalidVpcEndPointIdError(vpc_end_point_id)

View File

@ -1,3 +1,4 @@
from typing import Any, Dict, List, Optional
from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnConnectionIdError
from ..utils import generic_filter, random_vpn_connection_id
@ -6,18 +7,18 @@ from ..utils import generic_filter, random_vpn_connection_id
class VPNConnection(TaggedEC2Resource):
def __init__(
self,
ec2_backend,
vpn_connection_id,
vpn_conn_type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
tags=None,
ec2_backend: Any,
vpn_connection_id: str,
vpn_conn_type: str,
customer_gateway_id: str,
vpn_gateway_id: Optional[str] = None,
transit_gateway_id: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
):
self.ec2_backend = ec2_backend
self.id = vpn_connection_id
self.state = "available"
self.customer_gateway_configuration = {}
self.customer_gateway_configuration: Dict[str, str] = {}
self.type = vpn_conn_type
self.customer_gateway_id = customer_gateway_id
self.vpn_gateway_id = vpn_gateway_id
@ -27,23 +28,25 @@ class VPNConnection(TaggedEC2Resource):
self.static_routes = None
self.add_tags(tags or {})
def get_filter_value(self, filter_name):
def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
return super().get_filter_value(filter_name, "DescribeVpnConnections")
class VPNConnectionBackend:
def __init__(self):
self.vpn_connections = {}
def __init__(self) -> None:
self.vpn_connections: Dict[str, VPNConnection] = {}
def create_vpn_connection(
self,
vpn_conn_type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
static_routes_only=None,
tags=None,
):
vpn_conn_type: str,
customer_gateway_id: str,
vpn_gateway_id: Optional[str] = None,
transit_gateway_id: Optional[str] = None,
static_routes_only: Optional[bool] = None,
tags: Optional[Dict[str, str]] = None,
) -> VPNConnection:
vpn_connection_id = random_vpn_connection_id()
if static_routes_only:
pass
@ -59,7 +62,7 @@ class VPNConnectionBackend:
self.vpn_connections[vpn_connection.id] = vpn_connection
return vpn_connection
def delete_vpn_connection(self, vpn_connection_id):
def delete_vpn_connection(self, vpn_connection_id: str) -> VPNConnection:
if vpn_connection_id in self.vpn_connections:
self.vpn_connections[vpn_connection_id].state = "deleted"
@ -67,17 +70,10 @@ class VPNConnectionBackend:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return self.vpn_connections[vpn_connection_id]
def describe_vpn_connections(self, vpn_connection_ids=None):
vpn_connections = []
for vpn_connection_id in vpn_connection_ids or []:
if vpn_connection_id in self.vpn_connections:
vpn_connections.append(self.vpn_connections[vpn_connection_id])
else:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return vpn_connections or self.vpn_connections.values()
def get_all_vpn_connections(self, vpn_connection_ids=None, filters=None):
vpn_connections = self.vpn_connections.values()
def describe_vpn_connections(
self, vpn_connection_ids: Optional[List[str]] = None, filters: Any = None
) -> List[VPNConnection]:
vpn_connections = list(self.vpn_connections.values())
if vpn_connection_ids:
vpn_connections = [

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, List, Optional
from moto.core import CloudFormationModel
from .core import TaggedEC2Resource
from ..exceptions import InvalidVpnGatewayIdError, InvalidVpnGatewayAttachmentError
@ -15,18 +15,23 @@ class VPCGatewayAttachment(CloudFormationModel):
self.state = state
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html
return "AWS::EC2::VPCGatewayAttachment"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VPCGatewayAttachment":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
@ -45,20 +50,20 @@ class VPCGatewayAttachment(CloudFormationModel):
return attachment
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.vpc_id
class VpnGateway(CloudFormationModel, TaggedEC2Resource):
def __init__(
self,
ec2_backend,
gateway_id,
gateway_type,
amazon_side_asn,
availability_zone,
tags=None,
state="available",
ec2_backend: Any,
gateway_id: str,
gateway_type: str,
amazon_side_asn: Optional[str],
availability_zone: Optional[str],
tags: Optional[Dict[str, str]] = None,
state: str = "available",
):
self.ec2_backend = ec2_backend
self.id = gateway_id
@ -67,22 +72,27 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
self.availability_zone = availability_zone
self.state = state
self.add_tags(tags or {})
self.attachments = {}
self.attachments: Dict[str, VPCGatewayAttachment] = {}
super().__init__()
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-vpcgatewayattachment.html
return "AWS::EC2::VPNGateway"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any
) -> "VpnGateway":
from ..models import ec2_backends
properties = cloudformation_json["Properties"]
@ -93,10 +103,12 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
return ec2_backend.create_vpn_gateway(gateway_type=_type, amazon_side_asn=asn)
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.id
def get_filter_value(self, filter_name):
def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> Any:
if filter_name == "attachment.vpc-id":
return self.attachments.keys()
elif filter_name == "attachment.state":
@ -109,16 +121,16 @@ class VpnGateway(CloudFormationModel, TaggedEC2Resource):
class VpnGatewayBackend:
def __init__(self):
self.vpn_gateways = {}
def __init__(self) -> None:
self.vpn_gateways: Dict[str, VpnGateway] = {}
def create_vpn_gateway(
self,
gateway_type="ipsec.1",
amazon_side_asn=None,
availability_zone=None,
tags=None,
):
gateway_type: str = "ipsec.1",
amazon_side_asn: Optional[str] = None,
availability_zone: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> VpnGateway:
vpn_gateway_id = random_vpn_gateway_id()
vpn_gateway = VpnGateway(
self, vpn_gateway_id, gateway_type, amazon_side_asn, availability_zone, tags
@ -126,21 +138,25 @@ class VpnGatewayBackend:
self.vpn_gateways[vpn_gateway_id] = vpn_gateway
return vpn_gateway
def describe_vpn_gateways(self, filters=None, vpn_gw_ids=None):
def describe_vpn_gateways(
self, filters: Any = None, vpn_gw_ids: Optional[List[str]] = None
) -> List[VpnGateway]:
vpn_gateways = list(self.vpn_gateways.values() or [])
if vpn_gw_ids:
vpn_gateways = [item for item in vpn_gateways if item.id in vpn_gw_ids]
return generic_filter(filters, vpn_gateways)
def get_vpn_gateway(self, vpn_gateway_id):
def get_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway:
vpn_gateway = self.vpn_gateways.get(vpn_gateway_id, None)
if not vpn_gateway:
raise InvalidVpnGatewayIdError(vpn_gateway_id)
return vpn_gateway
def attach_vpn_gateway(self, vpn_gateway_id, vpc_id):
def attach_vpn_gateway(
self, vpn_gateway_id: str, vpc_id: str
) -> VPCGatewayAttachment:
vpn_gateway = self.get_vpn_gateway(vpn_gateway_id)
self.get_vpc(vpc_id)
self.get_vpc(vpc_id) # type: ignore[attr-defined]
attachment = VPCGatewayAttachment(vpc_id, state="attached")
for key in vpn_gateway.attachments.copy():
if key.startswith("vpc-"):
@ -148,14 +164,16 @@ class VpnGatewayBackend:
vpn_gateway.attachments[vpc_id] = attachment
return attachment
def delete_vpn_gateway(self, vpn_gateway_id):
def delete_vpn_gateway(self, vpn_gateway_id: str) -> VpnGateway:
deleted = self.vpn_gateways.get(vpn_gateway_id, None)
if not deleted:
raise InvalidVpnGatewayIdError(vpn_gateway_id)
deleted.state = "deleted"
return deleted
def detach_vpn_gateway(self, vpn_gateway_id, vpc_id):
def detach_vpn_gateway(
self, vpn_gateway_id: str, vpc_id: str
) -> VPCGatewayAttachment:
vpn_gateway = self.get_vpn_gateway(vpn_gateway_id)
detached = vpn_gateway.attachments.get(vpc_id, None)
if not detached:

View File

@ -5,7 +5,7 @@ from moto.core import BaseModel
class WindowsBackend(BaseModel):
def get_password_data(self, instance_id: str) -> str:
instance = self.get_instance(instance_id)
instance = self.get_instance(instance_id) # type: ignore[attr-defined]
if instance.platform == "windows":
return random.get_random_string(length=128)
return ""

View File

@ -65,9 +65,8 @@ class VPCs(EC2BaseResponse):
def modify_vpc_tenancy(self):
vpc_id = self._get_param("VpcId")
tenancy = self._get_param("InstanceTenancy")
value = self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy)
template = self.response_template(MODIFY_VPC_TENANCY_RESPONSE)
return template.render(value=value)
self.ec2_backend.modify_vpc_tenancy(vpc_id, tenancy)
return self.response_template(MODIFY_VPC_TENANCY_RESPONSE).render()
def describe_vpc_attribute(self):
vpc_id = self._get_param("VpcId")
@ -237,7 +236,6 @@ class VPCs(EC2BaseResponse):
def describe_vpc_endpoint_services(self):
vpc_end_point_services = self.ec2_backend.describe_vpc_endpoint_services(
dry_run=self._get_bool_param("DryRun"),
service_names=self._get_multi_param("ServiceName"),
filters=self._get_multi_param("Filter"),
max_results=self._get_int_param("MaxResults"),
@ -260,9 +258,8 @@ class VPCs(EC2BaseResponse):
def delete_vpc_endpoints(self):
vpc_end_points_ids = self._get_multi_param("VpcEndpointId")
response = self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids)
template = self.response_template(DELETE_VPC_ENDPOINT_RESPONSE)
return template.render(response=response)
self.ec2_backend.delete_vpc_endpoints(vpce_ids=vpc_end_points_ids)
return self.response_template(DELETE_VPC_ENDPOINT_RESPONSE).render()
def create_managed_prefix_list(self):
address_family = self._get_param("AddressFamily")
@ -767,7 +764,7 @@ DESCRIBE_VPC_ENDPOINT_RESPONSE = """<DescribeVpcEndpointsResponse xmlns="http://
DELETE_VPC_ENDPOINT_RESPONSE = """<DeleteVpcEndpointsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>19a9ff46-7df6-49b8-9726-3df27527089d</requestId>
<unsuccessful>{{ 'Error' if not response else '' }}</unsuccessful>
<unsuccessful></unsuccessful>
</DeleteVpcEndpointsResponse>"""

View File

@ -42,7 +42,7 @@ class VPNConnections(EC2BaseResponse):
def describe_vpn_connections(self):
vpn_connection_ids = self._get_multi_param("VpnConnectionId")
filters = self._filters_from_querystring()
vpn_connections = self.ec2_backend.get_all_vpn_connections(
vpn_connections = self.ec2_backend.describe_vpn_connections(
vpn_connection_ids=vpn_connection_ids, filters=filters
)
template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE)

View File

@ -134,11 +134,11 @@ def random_network_acl_subnet_association_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"])
def random_vpn_gateway_id():
def random_vpn_gateway_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"])
def random_vpn_connection_id():
def random_vpn_connection_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"])
@ -154,19 +154,19 @@ def random_key_pair_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["key-pair"])
def random_vpc_id():
def random_vpc_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"])
def random_vpc_ep_id():
def random_vpc_ep_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-endpoint"], size=8)
def random_vpc_cidr_association_id():
def random_vpc_cidr_association_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"])
def random_vpc_peering_connection_id():
def random_vpc_peering_connection_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"])
@ -272,11 +272,11 @@ def random_mac_address() -> str:
return f"02:00:00:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x:{random.randint(0, 255)}02x"
def randor_ipv4_cidr():
def randor_ipv4_cidr() -> str:
return f"10.0.{random.randint(0, 255)}.{random.randint(0, 255)}/16"
def random_ipv6_cidr():
def random_ipv6_cidr() -> str:
return f"2400:6500:{random_resource_id(4)}:{random_resource_id(2)}00::/56"
@ -297,13 +297,11 @@ def random_managed_prefix_list_id() -> str:
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["managed-prefix-list"], size=8)
def create_dns_entries(service_name, vpc_endpoint_id):
dns_entries = {}
dns_entries[
"dns_name"
] = f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}"
dns_entries["hosted_zone_id"] = random_resource_id(13).upper()
return dns_entries
def create_dns_entries(service_name: str, vpc_endpoint_id: str) -> Dict[str, str]:
return {
"dns_name": f"{vpc_endpoint_id}-{random_resource_id(8)}.{service_name}",
"hosted_zone_id": random_resource_id(13).upper(),
}
def utc_date_and_time() -> str:
@ -589,7 +587,7 @@ def get_prefix(resource_id: str) -> str:
return resource_id_prefix
def is_valid_resource_id(resource_id):
def is_valid_resource_id(resource_id: str) -> bool:
valid_prefixes = EC2_RESOURCE_TO_PREFIX.values()
resource_id_prefix = get_prefix(resource_id)
if resource_id_prefix not in valid_prefixes:

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/a*,moto/ec2/models/c*,moto/ec2/models/d*,moto/ec2/models/e*,moto/ec2/models/f*,moto/ec2/models/h*,moto/ec2/models/i*,moto/ec2/models/k*,moto/ec2/models/l*,moto/ec2/models/m*,moto/ec2/models/n*,moto/ec2/models/r*,moto/ec2/models/s*,moto/ec2/models/t*,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2/models/**/*.py,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract