From eef21767f89b5ea550fd6af741967e331c3ac258 Mon Sep 17 00:00:00 2001 From: Mohit Alonja Date: Thu, 9 Sep 2021 18:09:48 +0530 Subject: [PATCH] Fix SecuirtyGroupRules and added default Ec2-VPC support for SG. (#4267) --- moto/ec2/models.py | 183 ++++++++++++++++++++++--- moto/ec2/responses/security_groups.py | 26 ++++ moto/ec2/utils.py | 2 +- tests/terraform-tests.failures.txt | 1 + tests/terraform-tests.success.txt | 3 +- tests/test_ec2/test_security_groups.py | 38 ++--- 6 files changed, 211 insertions(+), 42 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 9b382a1aa..33eaef774 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -963,7 +963,7 @@ class InstanceBackend(object): new_reservation.id = random_reservation_id() security_groups = [ - self.get_security_group_from_name(name) for name in security_group_names + self.get_security_group_by_name_or_id(name) for name in security_group_names ] security_groups.extend( self.get_security_group_from_id(sg_id) @@ -2102,7 +2102,14 @@ class SecurityRule(object): class SecurityGroup(TaggedEC2Resource, CloudFormationModel): def __init__( - self, ec2_backend, group_id, name, description, vpc_id=None, tags=None + self, + ec2_backend, + group_id, + name, + description, + vpc_id=None, + tags=None, + is_default=None, ): self.ec2_backend = ec2_backend self.id = group_id @@ -2114,6 +2121,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): self.vpc_id = vpc_id self.owner_id = ACCOUNT_ID self.add_tags(tags or {}) + self.is_default = is_default or False # Append default IPv6 egress rule for VPCs with IPv6 support if vpc_id: @@ -2198,7 +2206,9 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): @classmethod def _delete_security_group_given_vpc_id(cls, resource_name, vpc_id, region_name): ec2_backend = ec2_backends[region_name] - security_group = ec2_backend.get_security_group_from_name(resource_name, vpc_id) + security_group = ec2_backend.get_security_group_by_name_or_id( + resource_name, vpc_id + ) if security_group: security_group.delete(region_name) @@ -2285,24 +2295,28 @@ class SecurityGroupBackend(object): self.sg_old_ingress_ruls = {} self.sg_old_egress_ruls = {} - # Create the default security group - self.create_security_group("default", "default group") - super(SecurityGroupBackend, self).__init__() def create_security_group( - self, name, description, vpc_id=None, tags=None, force=False + self, name, description, vpc_id=None, tags=None, force=False, is_default=None ): + vpc_id = vpc_id or self.default_vpc.id if not description: raise MissingParameterError("GroupDescription") group_id = random_security_group_id() if not force: - existing_group = self.get_security_group_from_name(name, vpc_id) + existing_group = self.get_security_group_by_name_or_id(name, vpc_id) if existing_group: raise InvalidSecurityGroupDuplicateError(name) group = SecurityGroup( - self, group_id, name, description, vpc_id=vpc_id, tags=tags + self, + group_id, + name, + description, + vpc_id=vpc_id, + tags=tags, + is_default=is_default, ) self.groups[vpc_id][group_id] = group @@ -2326,6 +2340,7 @@ class SecurityGroupBackend(object): return matches def _delete_security_group(self, vpc_id, group_id): + vpc_id = vpc_id or self.default_vpc.id if self.groups[vpc_id][group_id].enis: raise DependencyViolationError( "{0} is being utilized by {1}".format(group_id, "ENIs") @@ -2342,7 +2357,7 @@ class SecurityGroupBackend(object): elif name: # Group Name. Has to be in standard EC2, VPC needs to be # identified by group_id - group = self.get_security_group_from_name(name) + group = self.get_security_group_by_name_or_id(name) if group: return self._delete_security_group(None, group.id) raise InvalidSecurityGroupNotFoundError(name) @@ -2367,7 +2382,8 @@ class SecurityGroupBackend(object): if group.name == name: return group - def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id): + def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id=None): + # try searching by id, fallbacks to name search group = self.get_security_group_from_id(group_name_or_id) if group is None: @@ -2675,6 +2691,128 @@ class SecurityGroupBackend(object): return security_rule raise InvalidPermissionNotFoundError() + def update_security_group_rule_descriptions_ingress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups=[], + prefix_list_ids=[], + vpc_id=None, + ): + + group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) + if group is None: + raise InvalidSecurityGroupNotFoundError(group_name_or_id) + if ip_ranges and not isinstance(ip_ranges, list): + + if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges: + ip_ranges = [{"CidrIp": ip_ranges}] + else: + ip_ranges = [json.loads(ip_ranges)] + if ip_ranges: + for cidr in ip_ranges: + if ( + type(cidr) is dict + and not any( + [ + is_valid_cidr(cidr.get("CidrIp", "")), + is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")), + ] + ) + ) or ( + type(cidr) is str + and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)]) + ): + raise InvalidCIDRSubnetError(cidr=cidr) + _source_groups = self._add_source_group(source_groups, vpc_id) + + security_rule = SecurityRule( + ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + ) + for rule in group.ingress_rules: + if ( + security_rule.from_port == rule.from_port + and security_rule.to_port == rule.to_port + and security_rule.ip_protocol == rule.ip_protocol + ): + self._sg_update_description(security_rule, rule) + return group + + def update_security_group_rule_descriptions_egress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups=[], + prefix_list_ids=[], + vpc_id=None, + ): + + group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) + if group is None: + raise InvalidSecurityGroupNotFoundError(group_name_or_id) + if ip_ranges and not isinstance(ip_ranges, list): + + if isinstance(ip_ranges, str) and "CidrIp" not in ip_ranges: + ip_ranges = [{"CidrIp": ip_ranges}] + else: + ip_ranges = [json.loads(ip_ranges)] + if ip_ranges: + for cidr in ip_ranges: + if ( + type(cidr) is dict + and not any( + [ + is_valid_cidr(cidr.get("CidrIp", "")), + is_valid_ipv6_cidr(cidr.get("CidrIpv6", "")), + ] + ) + ) or ( + type(cidr) is str + and not any([is_valid_cidr(cidr), is_valid_ipv6_cidr(cidr)]) + ): + raise InvalidCIDRSubnetError(cidr=cidr) + _source_groups = self._add_source_group(source_groups, vpc_id) + + security_rule = SecurityRule( + ip_protocol, from_port, to_port, ip_ranges, _source_groups, prefix_list_ids + ) + for rule in group.egress_rules: + if ( + security_rule.from_port == rule.from_port + and security_rule.to_port == rule.to_port + and security_rule.ip_protocol == rule.ip_protocol + ): + self._sg_update_description(security_rule, rule) + return group + + def _sg_update_description(self, security_rule, rule): + for item in security_rule.ip_ranges: + for cidr_item in rule.ip_ranges: + if cidr_item.get("CidrIp") == item.get("CidrIp"): + cidr_item["Description"] = item.get("Description") + if cidr_item.get("CidrIp6") == item.get("CidrIp6"): + cidr_item["Description"] = item.get("Description") + + for item in security_rule.source_groups: + for source_group in rule.source_groups: + if source_group.get("GroupId") == item.get( + "GroupId" + ) or source_group.get("GroupName") == item.get("GroupName"): + source_group["Description"] = item.get("Description") + + for item in security_rule.source_groups: + for source_group in rule.source_groups: + if source_group.get("GroupId") == item.get( + "GroupId" + ) or source_group.get("GroupName") == item.get("GroupName"): + source_group["Description"] = item.get("Description") + def _remove_items_from_rule(self, ip_ranges, _source_groups, prefix_list_ids, rule): for item in ip_ranges: if item not in rule.ip_ranges: @@ -2702,26 +2840,27 @@ class SecurityGroupBackend(object): item["OwnerId"] = ACCOUNT_ID # for VPCs if "GroupId" in item: - if not self.get_security_group_from_id(item.get("GroupId")): + if not self.get_security_group_by_name_or_id( + item.get("GroupId"), vpc_id + ): raise InvalidSecurityGroupNotFoundError(item.get("GroupId")) if "GroupName" in item: - source_group = self.get_security_group_from_name( + source_group = self.get_security_group_by_name_or_id( item.get("GroupName"), vpc_id ) if not source_group: raise InvalidSecurityGroupNotFoundError(item.get("GroupName")) else: item["GroupId"] = source_group.id + item.pop("GroupName") - if vpc_id: - item["VpcId"] = vpc_id _source_groups.append(item) return _source_groups def _verify_group_will_respect_rule_count_limit( self, group, current_rule_nb, ip_ranges, source_groups=None, egress=False, ): - max_nb_rules = 50 if group.vpc_id else 100 + max_nb_rules = 60 if group.vpc_id else 100 future_group_nb_rules = current_rule_nb if ip_ranges: future_group_nb_rules += len(ip_ranges) @@ -3408,7 +3547,7 @@ class VPCBackend(object): default = self.get_security_group_from_name("default", vpc_id=vpc_id) if not default: self.create_security_group( - "default", "default VPC security group", vpc_id=vpc_id + "default", "default VPC security group", vpc_id=vpc_id, is_default=True ) return vpc @@ -3447,7 +3586,7 @@ class VPCBackend(object): self.delete_route_table(route_table.id) # Delete default security group if exists. - default = self.get_security_group_from_name("default", vpc_id=vpc_id) + default = self.get_security_group_by_name_or_id("default", vpc_id=vpc_id) if default: self.delete_security_group(group_id=default.id) @@ -5285,12 +5424,12 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource): if security_groups: for group_name in security_groups: - group = self.ec2_backend.get_security_group_from_name(group_name) + group = self.ec2_backend.get_security_group_by_name_or_id(group_name) if group: ls.groups.append(group) else: # If not security groups, add the default - default_group = self.ec2_backend.get_security_group_from_name("default") + default_group = self.ec2_backend.get_security_group_by_name_or_id("default") ls.groups.append(default_group) self.instance = self.launch_instance() @@ -7717,8 +7856,8 @@ class EC2Backend( TagBackend, EBSBackend, RegionsAndZonesBackend, - SecurityGroupBackend, AmiBackend, + SecurityGroupBackend, VPCBackend, ManagedPrefixListBackend, SubnetBackend, @@ -7765,6 +7904,8 @@ class EC2Backend( # backward-compatibility issues vpc = self.vpcs.values()[0] + self.default_vpc = vpc + # Create default subnet for each availability zone ip, _ = vpc.cidr_block.split("/") ip = ip.split(".") diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 089ea0ff9..aafae8c52 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -218,6 +218,22 @@ class SecurityGroups(BaseResponse): self.ec2_backend.revoke_security_group_ingress(*args) return REVOKE_SECURITY_GROUP_INGRESS_RESPONSE + def update_security_group_rule_descriptions_ingress(self): + for args in self._process_rules_from_querystring(): + group = self.ec2_backend.update_security_group_rule_descriptions_ingress( + *args + ) + self.ec2_backend.sg_old_ingress_ruls[group.id] = group.ingress_rules.copy() + return UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_INGRESS + + def update_security_group_rule_descriptions_egress(self): + for args in self._process_rules_from_querystring(): + group = self.ec2_backend.update_security_group_rule_descriptions_egress( + *args + ) + self.ec2_backend.sg_old_egress_ruls[group.id] = group.egress_rules.copy() + return UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_EGRESS + CREATE_SECURITY_GROUP_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE @@ -552,3 +568,13 @@ REVOKE_SECURITY_GROUP_EGRESS_RESPONSE = """59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ + +UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_INGRESS = """ + 59dbff89-35bd-4eac-99ed-be587EXAMPLE + true +""" + +UPDATE_SECURITY_GROUP_RULE_DESCRIPTIONS_EGRESS = """ + 59dbff89-35bd-4eac-99ed-be587EXAMPLE + true +""" diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index 72b2fce35..7043acd1a 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -81,7 +81,7 @@ def random_reservation_id(): def random_security_group_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"]) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"], size=17) def random_security_group_rule_id(): diff --git a/tests/terraform-tests.failures.txt b/tests/terraform-tests.failures.txt index 1a6cc5f66..b3d78799f 100644 --- a/tests/terraform-tests.failures.txt +++ b/tests/terraform-tests.failures.txt @@ -6,3 +6,4 @@ TestAccAWSFms TestAccAWSIAMRolePolicy TestAccAWSSecurityGroup_forceRevokeRules_ TestAccAWSSSMDocument_package +TestAccAWSDefaultSecurityGroup_Classic_ diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index 0e71a82fc..9d90ca3ed 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -23,7 +23,7 @@ TestAccAWSDataSourceIAMPolicyDocument TestAccAWSDataSourceIAMRole TestAccAWSDataSourceIAMSessionContext TestAccAWSDataSourceIAMUser -TestAccAWSDefaultSecurityGroup +TestAccAWSDefaultSecurityGroup_ TestAccAWSDefaultSubnet TestAccAWSDefaultTagsDataSource TestAccAWSDynamoDbTableItem @@ -116,3 +116,4 @@ TestAccAwsEc2ManagedPrefixList TestAccAWSEgressOnlyInternetGateway TestAccAWSSecurityGroup_ TestAccAWSInternetGateway +TestAccAWSSecurityGroupRule_ diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index f39148e50..2542051c9 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -48,7 +48,7 @@ def test_create_and_describe_security_group(): all_groups = conn.get_all_security_groups() # The default group gets created automatically - all_groups.should.have.length_of(3) + all_groups.should.have.length_of(2) group_names = [group.name for group in all_groups] set(group_names).should.equal(set(["default", "test security group"])) @@ -68,7 +68,7 @@ def test_create_security_group_without_description_raises_error(): def test_default_security_group(): conn = boto.ec2.connect_to_region("us-east-1") groups = conn.get_all_security_groups() - groups.should.have.length_of(2) + groups.should.have.length_of(1) groups[0].name.should.equal("default") @@ -118,7 +118,7 @@ def test_create_two_security_groups_with_same_name_in_different_vpc(): all_groups = conn.get_all_security_groups() - all_groups.should.have.length_of(4) + all_groups.should.have.length_of(3) group_names = [group.name for group in all_groups] # The default group is created automatically set(group_names).should.equal(set(["default", "test security group"])) @@ -143,7 +143,7 @@ def test_deleting_security_groups(): security_group1 = conn.create_security_group("test1", "test1") conn.create_security_group("test2", "test2") - conn.get_all_security_groups().should.have.length_of(4) + conn.get_all_security_groups().should.have.length_of(3) # Deleting a group that doesn't exist should throw an error with pytest.raises(EC2ResponseError) as cm: @@ -162,11 +162,11 @@ def test_deleting_security_groups(): ) conn.delete_security_group("test2") - conn.get_all_security_groups().should.have.length_of(3) + conn.get_all_security_groups().should.have.length_of(2) # Delete by group id conn.delete_security_group(group_id=security_group1.id) - conn.get_all_security_groups().should.have.length_of(2) + conn.get_all_security_groups().should.have.length_of(1) @mock_ec2_deprecated @@ -384,12 +384,14 @@ def test_authorize_other_group_egress_and_revoke(): "Ipv6Ranges": [], "PrefixListIds": [], } + org_ip_permission = ip_permission.copy() + ip_permission["UserIdGroupPairs"][0].pop("GroupName") - sg01.authorize_egress(IpPermissions=[ip_permission]) + sg01.authorize_egress(IpPermissions=[org_ip_permission]) sg01.ip_permissions_egress.should.have.length_of(2) sg01.ip_permissions_egress.should.contain(ip_permission) - sg01.revoke_egress(IpPermissions=[ip_permission]) + sg01.revoke_egress(IpPermissions=[org_ip_permission]) sg01.ip_permissions_egress.should.have.length_of(1) @@ -467,7 +469,7 @@ def test_get_all_security_groups(): resp[0].id.should.equal(sg1.id) resp = conn.get_all_security_groups() - resp.should.have.length_of(4) + resp.should.have.length_of(3) @mock_ec2_deprecated @@ -557,7 +559,7 @@ def test_sec_group_rule_limit(): success = ec2_conn.authorize_security_group( group_id=sg.id, ip_protocol="-1", - cidr_ip=["{0}.0.0.0/0".format(i) for i in range(99)], + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(1, 60)], ) success.should.be.true # verify that we cannot authorize past the limit for a CIDR IP @@ -578,8 +580,8 @@ def test_sec_group_rule_limit(): ec2_conn.authorize_security_group_egress( group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id ) - # fill the rules up the limit - for i in range(1, 100): + # fill the rules up the limit, 59 + 1 default rule + for i in range(1, 59): ec2_conn.authorize_security_group_egress( group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) ) @@ -626,7 +628,7 @@ def test_sec_group_rule_limit_vpc(): success = ec2_conn.authorize_security_group( group_id=sg.id, ip_protocol="-1", - cidr_ip=["{0}.0.0.0/0".format(i) for i in range(49)], + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(59)], ) # verify that we cannot authorize past the limit for a CIDR IP success.should.be.true @@ -650,7 +652,7 @@ def test_sec_group_rule_limit_vpc(): # fill the rules up the limit # remember that by default, when created a sec group contains 1 egress rule # so our other_sg rule + 48 CIDR IP rules + 1 by default == 50 the limit - for i in range(1, 49): + for i in range(1, 59): ec2_conn.authorize_security_group_egress( group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) ) @@ -857,9 +859,7 @@ def test_authorize_and_revoke_in_bulk(): "IpProtocol": "tcp", "FromPort": 27017, "ToPort": 27017, - "UserIdGroupPairs": [ - {"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id} - ], + "UserIdGroupPairs": [{"GroupId": sg02.id, "UserId": sg02.owner_id}], "IpRanges": [], "Ipv6Ranges": [], "PrefixListIds": [], @@ -877,7 +877,7 @@ def test_authorize_and_revoke_in_bulk(): "IpProtocol": "tcp", "FromPort": 27017, "ToPort": 27017, - "UserIdGroupPairs": [{"GroupName": "sg03", "UserId": sg03.owner_id}], + "UserIdGroupPairs": [{"GroupId": sg03.id, "UserId": sg03.owner_id}], "IpRanges": [], "Ipv6Ranges": [], "PrefixListIds": [], @@ -886,7 +886,7 @@ def test_authorize_and_revoke_in_bulk(): "IpProtocol": "tcp", "FromPort": 27015, "ToPort": 27015, - "UserIdGroupPairs": [{"GroupName": "sg04", "UserId": sg04.owner_id}], + "UserIdGroupPairs": [{"GroupId": sg04.id, "UserId": sg04.owner_id}], "IpRanges": [ {"CidrIp": "10.10.10.0/24", "Description": "Some Description"} ],