From 002f9979ef7d4a6539084468f0bc1558817ed4f9 Mon Sep 17 00:00:00 2001 From: Mohit Alonja Date: Wed, 15 Sep 2021 02:10:17 +0530 Subject: [PATCH] Fix Nat Gateway (#4281) --- moto/ec2/models.py | 50 ++++++++++++++------- moto/ec2/responses/nat_gateways.py | 72 +++++++++++++++++++++--------- tests/terraform-tests.success.txt | 3 +- tests/test_ec2/test_nat_gateway.py | 4 +- 4 files changed, 90 insertions(+), 39 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index d4ab7f931..d754fd4bf 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -7656,39 +7656,44 @@ class TransitGatewayRelationsBackend(object): return tgw_association -class NatGateway(CloudFormationModel): +class NatGateway(CloudFormationModel, TaggedEC2Resource): def __init__( - self, backend, subnet_id, allocation_id, tags=[], connectivity_type="public" + self, + backend, + subnet_id, + allocation_id, + tags=[], + connectivity_type="public", + address_set=None, ): # public properties self.id = random_nat_gateway_id() self.subnet_id = subnet_id - self.allocation_id = allocation_id + self.address_set = address_set or [] self.state = "available" self.private_ip = random_private_ip() self.connectivity_type = connectivity_type # protected properties self._created_at = datetime.utcnow() - self._backend = backend + self.ec2_backend = backend # NOTE: this is the core of NAT Gateways creation - self._eni = self._backend.create_network_interface( + self._eni = self.ec2_backend.create_network_interface( backend.get_subnet(self.subnet_id), self.private_ip ) # associate allocation with ENI - self._backend.associate_address(eni=self._eni, allocation_id=self.allocation_id) - self.tags = tags + if allocation_id and connectivity_type != "private": + self.ec2_backend.associate_address( + eni=self._eni, allocation_id=allocation_id + ) + self.add_tags(tags or {}) + self.vpc_id = self.ec2_backend.get_subnet(subnet_id).vpc_id @property def physical_resource_id(self): return self.id - @property - def vpc_id(self): - subnet = self._backend.get_subnet(self.subnet_id) - return subnet.vpc_id - @property def create_time(self): return iso_8601_datetime_with_milliseconds(self._created_at) @@ -7728,8 +7733,11 @@ class NatGatewayBackend(object): self.nat_gateways = {} super(NatGatewayBackend, self).__init__() - def describe_nat_gateways(self, filters): - nat_gateways = self.nat_gateways.values() + def describe_nat_gateways(self, filters, nat_gateway_ids): + nat_gateways = list(self.nat_gateways.values()) + + if nat_gateway_ids: + nat_gateways = [item for item in nat_gateways if item.id in nat_gateway_ids] if filters is not None: if filters.get("nat-gateway-id") is not None: @@ -7765,11 +7773,23 @@ class NatGatewayBackend(object): nat_gateway = NatGateway( self, subnet_id, allocation_id, tags, connectivity_type ) + address_set = {} + if allocation_id: + eips = self.address_by_allocation([allocation_id]) + eip = eips[0] if len(eips) > 0 else None + if eip: + address_set["allocationId"] = allocation_id + address_set["publicIp"] = eip.public_ip or None + address_set["networkInterfaceId"] = nat_gateway._eni.id + address_set["privateIp"] = nat_gateway._eni.private_ip_address + nat_gateway.address_set.append(address_set) self.nat_gateways[nat_gateway.id] = nat_gateway return nat_gateway def delete_nat_gateway(self, nat_gateway_id): - return self.nat_gateways.pop(nat_gateway_id) + nat_gw = self.nat_gateways.get(nat_gateway_id) + nat_gw.state = "deleted" + return nat_gw class LaunchTemplateVersion(object): diff --git a/moto/ec2/responses/nat_gateways.py b/moto/ec2/responses/nat_gateways.py index e61a726a5..8654253dc 100644 --- a/moto/ec2/responses/nat_gateways.py +++ b/moto/ec2/responses/nat_gateways.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.ec2.utils import filters_from_querystring +from moto.ec2.utils import filters_from_querystring, add_tag_specification class NatGateways(BaseResponse): @@ -9,8 +9,8 @@ class NatGateways(BaseResponse): allocation_id = self._get_param("AllocationId") connectivity_type = self._get_param("ConnectivityType") tags = self._get_multi_param("TagSpecification") - if tags: - tags = tags[0].get("Tag") + tags = add_tag_specification(tags) + nat_gateway = self.ec2_backend.create_nat_gateway( subnet_id=subnet_id, allocation_id=allocation_id, @@ -28,7 +28,8 @@ class NatGateways(BaseResponse): def describe_nat_gateways(self): filters = filters_from_querystring(self.querystring) - nat_gateways = self.ec2_backend.describe_nat_gateways(filters) + nat_gateway_ids = self._get_multi_param("NatGatewayId") + nat_gateways = self.ec2_backend.describe_nat_gateways(filters, nat_gateway_ids) template = self.response_template(DESCRIBE_NAT_GATEWAYS_RESPONSE) return template.render(nat_gateways=nat_gateways) @@ -40,28 +41,36 @@ DESCRIBE_NAT_GATEWAYS_RESPONSE = """