Fix SecuirtyGroupRules and added default Ec2-VPC support for SG. (#4267)

This commit is contained in:
Mohit Alonja 2021-09-09 18:09:48 +05:30 committed by GitHub
parent b3b326f578
commit eef21767f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 211 additions and 42 deletions

View File

@ -963,7 +963,7 @@ class InstanceBackend(object):
new_reservation.id = random_reservation_id()
security_groups = [
self.get_security_group_from_name(name) for name in security_group_names
self.get_security_group_by_name_or_id(name) for name in security_group_names
]
security_groups.extend(
self.get_security_group_from_id(sg_id)
@ -2102,7 +2102,14 @@ class SecurityRule(object):
class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__(
self, ec2_backend, group_id, name, description, vpc_id=None, tags=None
self,
ec2_backend,
group_id,
name,
description,
vpc_id=None,
tags=None,
is_default=None,
):
self.ec2_backend = ec2_backend
self.id = group_id
@ -2114,6 +2121,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
self.vpc_id = vpc_id
self.owner_id = ACCOUNT_ID
self.add_tags(tags or {})
self.is_default = is_default or False
# Append default IPv6 egress rule for VPCs with IPv6 support
if vpc_id:
@ -2198,7 +2206,9 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
@classmethod
def _delete_security_group_given_vpc_id(cls, resource_name, vpc_id, region_name):
ec2_backend = ec2_backends[region_name]
security_group = ec2_backend.get_security_group_from_name(resource_name, vpc_id)
security_group = ec2_backend.get_security_group_by_name_or_id(
resource_name, vpc_id
)
if security_group:
security_group.delete(region_name)
@ -2285,24 +2295,28 @@ class SecurityGroupBackend(object):
self.sg_old_ingress_ruls = {}
self.sg_old_egress_ruls = {}
# Create the default security group
self.create_security_group("default", "default group")
super(SecurityGroupBackend, self).__init__()
def create_security_group(
self, name, description, vpc_id=None, tags=None, force=False
self, name, description, vpc_id=None, tags=None, force=False, is_default=None
):
vpc_id = vpc_id or self.default_vpc.id
if not description:
raise MissingParameterError("GroupDescription")
group_id = random_security_group_id()
if not force:
existing_group = self.get_security_group_from_name(name, vpc_id)
existing_group = self.get_security_group_by_name_or_id(name, vpc_id)
if existing_group:
raise InvalidSecurityGroupDuplicateError(name)
group = SecurityGroup(
self, group_id, name, description, vpc_id=vpc_id, tags=tags
self,
group_id,
name,
description,
vpc_id=vpc_id,
tags=tags,
is_default=is_default,
)
self.groups[vpc_id][group_id] = group
@ -2326,6 +2340,7 @@ class SecurityGroupBackend(object):
return matches
def _delete_security_group(self, vpc_id, group_id):
vpc_id = vpc_id or self.default_vpc.id
if self.groups[vpc_id][group_id].enis:
raise DependencyViolationError(
"{0} is being utilized by {1}".format(group_id, "ENIs")
@ -2342,7 +2357,7 @@ class SecurityGroupBackend(object):
elif name:
# Group Name. Has to be in standard EC2, VPC needs to be
# identified by group_id
group = self.get_security_group_from_name(name)
group = self.get_security_group_by_name_or_id(name)
if group:
return self._delete_security_group(None, group.id)
raise InvalidSecurityGroupNotFoundError(name)
@ -2367,7 +2382,8 @@ class SecurityGroupBackend(object):
if group.name == name:
return group
def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id):
def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id=None):
# try searching by id, fallbacks to name search
group = self.get_security_group_from_id(group_name_or_id)
if group is None:
@ -2675,6 +2691,128 @@ class SecurityGroupBackend(object):
return security_rule
raise InvalidPermissionNotFoundError()
def update_security_group_rule_descriptions_ingress(
self,
group_name_or_id,
ip_protocol,
from_port,
to_port,
ip_ranges,
source_groups=[],
prefix_list_ids=[],
vpc_id=None,
):
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges and not isinstance(ip_ranges, list):
if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges:
ip_ranges = [{"CidrIp": ip_ranges}]
else:
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
type(cidr) is dict
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
type(cidr) is str
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids
)
for rule in group.ingress_rules:
if (
security_rule.from_port == rule.from_port
and security_rule.to_port == rule.to_port
and security_rule.ip_protocol == rule.ip_protocol
):
self._sg_update_description(security_rule, rule)
return group
def update_security_group_rule_descriptions_egress(
self,
group_name_or_id,
ip_protocol,
from_port,
to_port,
ip_ranges,
source_groups=[],
prefix_list_ids=[],
vpc_id=None,
):
group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id)
if group is None:
raise InvalidSecurityGroupNotFoundError(group_name_or_id)
if ip_ranges and not isinstance(ip_ranges, list):
if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges:
ip_ranges = [{"CidrIp": ip_ranges}]
else:
ip_ranges = [json.loads(ip_ranges)]
if ip_ranges:
for cidr in ip_ranges:
if (
type(cidr) is dict
and not any(
[
is_valid_cidr(cidr.get("CidrIp", "")),
is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")),
]
)
) or (
type(cidr) is str
and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)])
):
raise InvalidCIDRSubnetError(cidr=cidr)
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids
)
for rule in group.egress_rules:
if (
security_rule.from_port == rule.from_port
and security_rule.to_port == rule.to_port
and security_rule.ip_protocol == rule.ip_protocol
):
self._sg_update_description(security_rule, rule)
return group
def _sg_update_description(self, security_rule, rule):
for item in security_rule.ip_ranges:
for cidr_item in rule.ip_ranges:
if cidr_item.get("CidrIp") == item.get("CidrIp"):
cidr_item["Description"] = item.get("Description")
if cidr_item.get("CidrIp6") == item.get("CidrIp6"):
cidr_item["Description"] = item.get("Description")
for item in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
for item in security_rule.source_groups:
for source_group in rule.source_groups:
if source_group.get("GroupId") == item.get(
"GroupId"
) or source_group.get("GroupName") == item.get("GroupName"):
source_group["Description"] = item.get("Description")
def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule):
for item in ip_ranges:
if item not in rule.ip_ranges:
@ -2702,26 +2840,27 @@ class SecurityGroupBackend(object):
item["OwnerId"] = ACCOUNT_ID
# for VPCs
if "GroupId" in item:
if not self.get_security_group_from_id(item.get("GroupId")):
if not self.get_security_group_by_name_or_id(
item.get("GroupId"), vpc_id
):
raise InvalidSecurityGroupNotFoundError(item.get("GroupId"))
if "GroupName" in item:
source_group = self.get_security_group_from_name(
source_group = self.get_security_group_by_name_or_id(
item.get("GroupName"), vpc_id
)
if not source_group:
raise InvalidSecurityGroupNotFoundError(item.get("GroupName"))
else:
item["GroupId"] = source_group.id
item.pop("GroupName")
if vpc_id:
item["VpcId"] = vpc_id
_source_groups.append(item)
return _source_groups
def _verify_group_will_respect_rule_count_limit(
self, group, current_rule_nb, ip_ranges, source_groups=None, egress=False,
):
max_nb_rules = 50 if group.vpc_id else 100
max_nb_rules = 60 if group.vpc_id else 100
future_group_nb_rules = current_rule_nb
if ip_ranges:
future_group_nb_rules += len(ip_ranges)
@ -3408,7 +3547,7 @@ class VPCBackend(object):
default = self.get_security_group_from_name("default", vpc_id=vpc_id)
if not default:
self.create_security_group(
"default", "default VPC security group", vpc_id=vpc_id
"default", "default VPC security group", vpc_id=vpc_id, is_default=True
)
return vpc
@ -3447,7 +3586,7 @@ class VPCBackend(object):
self.delete_route_table(route_table.id)
# Delete default security group if exists.
default = self.get_security_group_from_name("default", vpc_id=vpc_id)
default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id)
if default:
self.delete_security_group(group_id=default.id)
@ -5285,12 +5424,12 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource):
if security_groups:
for group_name in security_groups:
group = self.ec2_backend.get_security_group_from_name(group_name)
group = self.ec2_backend.get_security_group_by_name_or_id(group_name)
if group:
ls.groups.append(group)
else:
# If not security groups, add the default
default_group = self.ec2_backend.get_security_group_from_name("default")
default_group = self.ec2_backend.get_security_group_by_name_or_id("default")
ls.groups.append(default_group)
self.instance = self.launch_instance()
@ -7717,8 +7856,8 @@ class EC2Backend(
TagBackend,
EBSBackend,
RegionsAndZonesBackend,
SecurityGroupBackend,
AmiBackend,
SecurityGroupBackend,
VPCBackend,
ManagedPrefixListBackend,
SubnetBackend,
@ -7765,6 +7904,8 @@ class EC2Backend(
# backward-compatibility issues
vpc = self.vpcs.values()[0]
self.default_vpc = vpc
# Create default subnet for each availability zone
ip, _ = vpc.cidr_block.split("/")
ip = ip.split(".")

View File

@ -218,6 +218,22 @@ class SecurityGroups(BaseResponse):
self.ec2_backend.revoke_security_group_ingress(*args)
return REVOKE_SECURITY_GROUP_INGRESS_RESPONSE
def update_security_group_rule_descriptions_ingress(self):
for args in self._process_rules_from_querystring():
group = self.ec2_backend.update_security_group_rule_descriptions_ingress(
*args
)
self.ec2_backend.sg_old_ingress_ruls[group.id] = group.ingress_rules.copy()
return UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_INGRESS
def update_security_group_rule_descriptions_egress(self):
for args in self._process_rules_from_querystring():
group = self.ec2_backend.update_security_group_rule_descriptions_egress(
*args
)
self.ec2_backend.sg_old_egress_ruls[group.id] = group.egress_rules.copy()
return UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_EGRESS
CREATE_SECURITY_GROUP_RESPONSE = """<CreateSecurityGroupResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
@ -552,3 +568,13 @@ REVOKE_SECURITY_GROUP_EGRESS_RESPONSE = """<RevokeSecurityGroupEgressResponse xm
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return>
</RevokeSecurityGroupEgressResponse>"""
UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_INGRESS = """<UpdateSecurityGroupRuleDescriptionsIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return>
</UpdateSecurityGroupRuleDescriptionsIngressResponse>"""
UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_EGRESS = """<UpdateSecurityGroupRuleDescriptionsEgressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return>
</UpdateSecurityGroupRuleDescriptionsEgressResponse>"""

View File

@ -81,7 +81,7 @@ def random_reservation_id():
def random_security_group_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"])
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"], size=17)
def random_security_group_rule_id():

View File

@ -6,3 +6,4 @@ TestAccAWSFms
TestAccAWSIAMRolePolicy
TestAccAWSSecurityGroup_forceRevokeRules_
TestAccAWSSSMDocument_package
TestAccAWSDefaultSecurityGroup_Classic_

View File

@ -23,7 +23,7 @@ TestAccAWSDataSourceIAMPolicyDocument
TestAccAWSDataSourceIAMRole
TestAccAWSDataSourceIAMSessionContext
TestAccAWSDataSourceIAMUser
TestAccAWSDefaultSecurityGroup
TestAccAWSDefaultSecurityGroup_
TestAccAWSDefaultSubnet
TestAccAWSDefaultTagsDataSource
TestAccAWSDynamoDbTableItem
@ -116,3 +116,4 @@ TestAccAwsEc2ManagedPrefixList
TestAccAWSEgressOnlyInternetGateway
TestAccAWSSecurityGroup_
TestAccAWSInternetGateway
TestAccAWSSecurityGroupRule_

View File

@ -48,7 +48,7 @@ def test_create_and_describe_security_group():
all_groups = conn.get_all_security_groups()
# The default group gets created automatically
all_groups.should.have.length_of(3)
all_groups.should.have.length_of(2)
group_names = [group.name for group in all_groups]
set(group_names).should.equal(set(["default", "test security group"]))
@ -68,7 +68,7 @@ def test_create_security_group_without_description_raises_error():
def test_default_security_group():
conn = boto.ec2.connect_to_region("us-east-1")
groups = conn.get_all_security_groups()
groups.should.have.length_of(2)
groups.should.have.length_of(1)
groups[0].name.should.equal("default")
@ -118,7 +118,7 @@ def test_create_two_security_groups_with_same_name_in_different_vpc():
all_groups = conn.get_all_security_groups()
all_groups.should.have.length_of(4)
all_groups.should.have.length_of(3)
group_names = [group.name for group in all_groups]
# The default group is created automatically
set(group_names).should.equal(set(["default", "test security group"]))
@ -143,7 +143,7 @@ def test_deleting_security_groups():
security_group1 = conn.create_security_group("test1", "test1")
conn.create_security_group("test2", "test2")
conn.get_all_security_groups().should.have.length_of(4)
conn.get_all_security_groups().should.have.length_of(3)
# Deleting a group that doesn't exist should throw an error
with pytest.raises(EC2ResponseError) as cm:
@ -162,11 +162,11 @@ def test_deleting_security_groups():
)
conn.delete_security_group("test2")
conn.get_all_security_groups().should.have.length_of(3)
conn.get_all_security_groups().should.have.length_of(2)
# Delete by group id
conn.delete_security_group(group_id=security_group1.id)
conn.get_all_security_groups().should.have.length_of(2)
conn.get_all_security_groups().should.have.length_of(1)
@mock_ec2_deprecated
@ -384,12 +384,14 @@ def test_authorize_other_group_egress_and_revoke():
"Ipv6Ranges": [],
"PrefixListIds": [],
}
org_ip_permission = ip_permission.copy()
ip_permission["UserIdGroupPairs"][0].pop("GroupName")
sg01.authorize_egress(IpPermissions=[ip_permission])
sg01.authorize_egress(IpPermissions=[org_ip_permission])
sg01.ip_permissions_egress.should.have.length_of(2)
sg01.ip_permissions_egress.should.contain(ip_permission)
sg01.revoke_egress(IpPermissions=[ip_permission])
sg01.revoke_egress(IpPermissions=[org_ip_permission])
sg01.ip_permissions_egress.should.have.length_of(1)
@ -467,7 +469,7 @@ def test_get_all_security_groups():
resp[0].id.should.equal(sg1.id)
resp = conn.get_all_security_groups()
resp.should.have.length_of(4)
resp.should.have.length_of(3)
@mock_ec2_deprecated
@ -557,7 +559,7 @@ def test_sec_group_rule_limit():
success = ec2_conn.authorize_security_group(
group_id=sg.id,
ip_protocol="-1",
cidr_ip=["{0}.0.0.0/0".format(i) for i in range(99)],
cidr_ip=["{0}.0.0.0/0".format(i) for i in range(1, 60)],
)
success.should.be.true
# verify that we cannot authorize past the limit for a CIDR IP
@ -578,8 +580,8 @@ def test_sec_group_rule_limit():
ec2_conn.authorize_security_group_egress(
group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id
)
# fill the rules up the limit
for i in range(1, 100):
# fill the rules up the limit, 59 + 1 default rule
for i in range(1, 59):
ec2_conn.authorize_security_group_egress(
group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i)
)
@ -626,7 +628,7 @@ def test_sec_group_rule_limit_vpc():
success = ec2_conn.authorize_security_group(
group_id=sg.id,
ip_protocol="-1",
cidr_ip=["{0}.0.0.0/0".format(i) for i in range(49)],
cidr_ip=["{0}.0.0.0/0".format(i) for i in range(59)],
)
# verify that we cannot authorize past the limit for a CIDR IP
success.should.be.true
@ -650,7 +652,7 @@ def test_sec_group_rule_limit_vpc():
# fill the rules up the limit
# remember that by default, when created a sec group contains 1 egress rule
# so our other_sg rule + 48 CIDR IP rules + 1 by default == 50 the limit
for i in range(1, 49):
for i in range(1, 59):
ec2_conn.authorize_security_group_egress(
group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i)
)
@ -857,9 +859,7 @@ def test_authorize_and_revoke_in_bulk():
"IpProtocol": "tcp",
"FromPort": 27017,
"ToPort": 27017,
"UserIdGroupPairs": [
{"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id}
],
"UserIdGroupPairs": [{"GroupId": sg02.id, "UserId": sg02.owner_id}],
"IpRanges": [],
"Ipv6Ranges": [],
"PrefixListIds": [],
@ -877,7 +877,7 @@ def test_authorize_and_revoke_in_bulk():
"IpProtocol": "tcp",
"FromPort": 27017,
"ToPort": 27017,
"UserIdGroupPairs": [{"GroupName": "sg03", "UserId": sg03.owner_id}],
"UserIdGroupPairs": [{"GroupId": sg03.id, "UserId": sg03.owner_id}],
"IpRanges": [],
"Ipv6Ranges": [],
"PrefixListIds": [],
@ -886,7 +886,7 @@ def test_authorize_and_revoke_in_bulk():
"IpProtocol": "tcp",
"FromPort": 27015,
"ToPort": 27015,
"UserIdGroupPairs": [{"GroupName": "sg04", "UserId": sg04.owner_id}],
"UserIdGroupPairs": [{"GroupId": sg04.id, "UserId": sg04.owner_id}],
"IpRanges": [
{"CidrIp": "10.10.10.0/24", "Description": "Some Description"}
],