Fix Nat Gateway (#4281)

This commit is contained in:
Mohit Alonja 2021-09-15 02:10:17 +05:30 committed by GitHub
parent b89b0039e4
commit 002f9979ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 39 deletions

View File

@ -7656,39 +7656,44 @@ class TransitGatewayRelationsBackend(object):
return tgw_association
class NatGateway(CloudFormationModel):
class NatGateway(CloudFormationModel, TaggedEC2Resource):
def __init__(
self, backend, subnet_id, allocation_id, tags=[], connectivity_type="public"
self,
backend,
subnet_id,
allocation_id,
tags=[],
connectivity_type="public",
address_set=None,
):
# public properties
self.id = random_nat_gateway_id()
self.subnet_id = subnet_id
self.allocation_id = allocation_id
self.address_set = address_set or []
self.state = "available"
self.private_ip = random_private_ip()
self.connectivity_type = connectivity_type
# protected properties
self._created_at = datetime.utcnow()
self._backend = backend
self.ec2_backend = backend
# NOTE: this is the core of NAT Gateways creation
self._eni = self._backend.create_network_interface(
self._eni = self.ec2_backend.create_network_interface(
backend.get_subnet(self.subnet_id), self.private_ip
)
# associate allocation with ENI
self._backend.associate_address(eni=self._eni, allocation_id=self.allocation_id)
self.tags = tags
if allocation_id and connectivity_type != "private":
self.ec2_backend.associate_address(
eni=self._eni, allocation_id=allocation_id
)
self.add_tags(tags or {})
self.vpc_id = self.ec2_backend.get_subnet(subnet_id).vpc_id
@property
def physical_resource_id(self):
return self.id
@property
def vpc_id(self):
subnet = self._backend.get_subnet(self.subnet_id)
return subnet.vpc_id
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
@ -7728,8 +7733,11 @@ class NatGatewayBackend(object):
self.nat_gateways = {}
super(NatGatewayBackend, self).__init__()
def describe_nat_gateways(self, filters):
nat_gateways = self.nat_gateways.values()
def describe_nat_gateways(self, filters, nat_gateway_ids):
nat_gateways = list(self.nat_gateways.values())
if nat_gateway_ids:
nat_gateways = [item for item in nat_gateways if item.id in nat_gateway_ids]
if filters is not None:
if filters.get("nat-gateway-id") is not None:
@ -7765,11 +7773,23 @@ class NatGatewayBackend(object):
nat_gateway = NatGateway(
self, subnet_id, allocation_id, tags, connectivity_type
)
address_set = {}
if allocation_id:
eips = self.address_by_allocation([allocation_id])
eip = eips[0] if len(eips) > 0 else None
if eip:
address_set["allocationId"] = allocation_id
address_set["publicIp"] = eip.public_ip or None
address_set["networkInterfaceId"] = nat_gateway._eni.id
address_set["privateIp"] = nat_gateway._eni.private_ip_address
nat_gateway.address_set.append(address_set)
self.nat_gateways[nat_gateway.id] = nat_gateway
return nat_gateway
def delete_nat_gateway(self, nat_gateway_id):
return self.nat_gateways.pop(nat_gateway_id)
nat_gw = self.nat_gateways.get(nat_gateway_id)
nat_gw.state = "deleted"
return nat_gw
class LaunchTemplateVersion(object):

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
from moto.ec2.utils import filters_from_querystring, add_tag_specification
class NatGateways(BaseResponse):
@ -9,8 +9,8 @@ class NatGateways(BaseResponse):
allocation_id = self._get_param("AllocationId")
connectivity_type = self._get_param("ConnectivityType")
tags = self._get_multi_param("TagSpecification")
if tags:
tags = tags[0].get("Tag")
tags = add_tag_specification(tags)
nat_gateway = self.ec2_backend.create_nat_gateway(
subnet_id=subnet_id,
allocation_id=allocation_id,
@ -28,7 +28,8 @@ class NatGateways(BaseResponse):
def describe_nat_gateways(self):
filters = filters_from_querystring(self.querystring)
nat_gateways = self.ec2_backend.describe_nat_gateways(filters)
nat_gateway_ids = self._get_multi_param("NatGatewayId")
nat_gateways = self.ec2_backend.describe_nat_gateways(filters, nat_gateway_ids)
template = self.response_template(DESCRIBE_NAT_GATEWAYS_RESPONSE)
return template.render(nat_gateways=nat_gateways)
@ -40,28 +41,36 @@ DESCRIBE_NAT_GATEWAYS_RESPONSE = """<DescribeNatGatewaysResponse xmlns="http://e
<item>
<subnetId>{{ nat_gateway.subnet_id }}</subnetId>
<natGatewayAddressSet>
{% for address_set in nat_gateway.address_set %}
<item>
<networkInterfaceId>{{ nat_gateway.network_interface_id }}</networkInterfaceId>
<publicIp>{{ nat_gateway.public_ip }}</publicIp>
<allocationId>{{ nat_gateway.allocation_id }}</allocationId>
<privateIp>{{ nat_gateway.private_ip }}</privateIp>
{% if address_set.allocationId %}
<allocationId>{{ address_set.allocationId }}</allocationId>
{% endif %}
{% if address_set.privateIp %}
<privateIp>{{ address_set.privateIp }}</privateIp>
{% endif %}
{% if address_set.publicIp %}
<publicIp>{{ address_set.publicIp }}</publicIp>
{% endif %}
{% if address_set.networkInterfaceId %}
<networkInterfaceId>{{ address_set.networkInterfaceId }}</networkInterfaceId>
{% endif %}
</item>
{% endfor %}
</natGatewayAddressSet>
<createTime>{{ nat_gateway.create_time }}</createTime>
<vpcId>{{ nat_gateway.vpc_id }}</vpcId>
<natGatewayId>{{ nat_gateway.id }}</natGatewayId>
<connectivityType>{{ nat_gateway.connectivity_type }}</connectivityType>
<state>{{ nat_gateway.state }}</state>
{% if nat_gateway.tags %}
<tagSet>
{% for tag in nat_gateway.tags %}
<item>
<key>{{ tag['Key'] }}</key>
<value>{{ tag['Value'] }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
<tagSet>
{% for tag in nat_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</item>
{% endfor %}
</natGatewaySet>
@ -73,15 +82,36 @@ CREATE_NAT_GATEWAY = """<CreateNatGatewayResponse xmlns="http://ec2.amazonaws.co
<natGateway>
<subnetId>{{ nat_gateway.subnet_id }}</subnetId>
<natGatewayAddressSet>
<item>
<allocationId>{{ nat_gateway.allocation_id }}</allocationId>
</item>
{% for address_set in nat_gateway.address_set %}
<item>
{% if address_set.allocationId %}
<allocationId>{{ address_set.allocationId }}</allocationId>
{% endif %}
{% if address_set.privateIp %}
<privateIp>{{ address_set.privateIp }}</privateIp>
{% endif %}
{% if address_set.publicIp %}
<publicIp>{{ address_set.publicIp }}</publicIp>
{% endif %}
{% if address_set.networkInterfaceId %}
<networkInterfaceId>{{ address_set.networkInterfaceId }}</networkInterfaceId>
{% endif %}
</item>
{% endfor %}
</natGatewayAddressSet>
<createTime>{{ nat_gateway.create_time }}</createTime>
<vpcId>{{ nat_gateway.vpc_id }}</vpcId>
<natGatewayId>{{ nat_gateway.id }}</natGatewayId>
<connectivityType>{{ nat_gateway.connectivity_type }}</connectivityType>
<state>{{ nat_gateway.state }}</state>
<tagSet>
{% for tag in nat_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</natGateway>
</CreateNatGatewayResponse>
"""

View File

@ -123,4 +123,5 @@ TestAccAWSInternetGateway
TestAccAWSSecurityGroupRule_
TestAccAWSVpnGateway
TestAccAWSVpnGatewayAttachment
TestAccAWSEc2CarrierGateway
TestAccAWSEc2CarrierGateway
TestAccAWSNatGateway

View File

@ -51,7 +51,7 @@ def test_describe_nat_gateway_tags():
"ResourceType": "nat-gateway",
"Tags": [
{"Key": "name", "Value": "some-nat-gateway"},
{"Key": "name", "Value": "some-nat-gateway-1"},
{"Key": "name1", "Value": "some-nat-gateway-1"},
],
}
],
@ -62,7 +62,7 @@ def test_describe_nat_gateway_tags():
assert describe_response["NatGateways"][0]["VpcId"] == vpc_id
assert describe_response["NatGateways"][0]["Tags"] == [
{"Key": "name", "Value": "some-nat-gateway"},
{"Key": "name", "Value": "some-nat-gateway-1"},
{"Key": "name1", "Value": "some-nat-gateway-1"},
]