diff --git a/moto/ec2/models/security_groups.py b/moto/ec2/models/security_groups.py index 5cd32ee58..1f19ce8a7 100644 --- a/moto/ec2/models/security_groups.py +++ b/moto/ec2/models/security_groups.py @@ -36,16 +36,18 @@ class SecurityRule: ip_ranges: Optional[List[Any]], source_groups: List[Dict[str, Any]], prefix_list_ids: Optional[List[Dict[str, str]]] = None, + is_egress: bool = True, ): self.account_id = account_id self.id = random_security_group_rule_id() - self.ip_protocol = str(ip_protocol) + self.ip_protocol = str(ip_protocol) if ip_protocol else None self.ip_ranges = ip_ranges or [] self.source_groups = source_groups or [] self.prefix_list_ids = prefix_list_ids or [] self.from_port = self.to_port = None + self.is_egress = is_egress - if self.ip_protocol != "-1": + if self.ip_protocol and self.ip_protocol != "-1": self.from_port = int(from_port) # type: ignore[arg-type] self.to_port = int(to_port) # type: ignore[arg-type] @@ -65,7 +67,11 @@ class SecurityRule: "1": "icmp", "icmp": "icmp", } - proto = ip_protocol_keywords.get(self.ip_protocol.lower()) + proto = ( + ip_protocol_keywords.get(self.ip_protocol.lower()) + if self.ip_protocol + else None + ) self.ip_protocol = proto if proto else self.ip_protocol @property @@ -629,6 +635,7 @@ class SecurityGroupBackend: ip_ranges: List[Any], source_groups: Optional[List[Dict[str, str]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument vpc_id: Optional[str] = None, ) -> Tuple[SecurityRule, SecurityGroup]: group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) @@ -669,6 +676,7 @@ class SecurityGroupBackend: ip_ranges, _source_groups, prefix_list_ids, + is_egress=False, ) if security_rule in group.ingress_rules: @@ -717,11 +725,18 @@ class SecurityGroupBackend: ip_ranges: List[Any], source_groups: Optional[List[Dict[str, Any]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, vpc_id: Optional[str] = None, - ) -> SecurityRule: + ) -> None: group: SecurityGroup = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) # type: ignore[assignment] + if security_rule_ids: + group.ingress_rules = [ + rule for rule in group.egress_rules if rule.id not in security_rule_ids + ] + return + _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( @@ -732,6 +747,7 @@ class SecurityGroupBackend: ip_ranges, _source_groups, prefix_list_ids, + is_egress=False, ) # To match drift property of the security rules. @@ -770,7 +786,7 @@ class SecurityGroupBackend: ): group.ingress_rules.remove(rule) self.sg_old_ingress_ruls[group.id] = group.ingress_rules.copy() - return security_rule + return raise InvalidPermissionNotFoundError() def authorize_security_group_egress( @@ -782,6 +798,7 @@ class SecurityGroupBackend: ip_ranges: List[Any], source_groups: Optional[List[Dict[str, Any]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument vpc_id: Optional[str] = None, ) -> Tuple[SecurityRule, SecurityGroup]: group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) @@ -875,11 +892,18 @@ class SecurityGroupBackend: ip_ranges: List[Any], source_groups: Optional[List[Dict[str, Any]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, vpc_id: Optional[str] = None, - ) -> SecurityRule: + ) -> None: group: SecurityGroup = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) # type: ignore[assignment] + if security_rule_ids: + group.egress_rules = [ + rule for rule in group.egress_rules if rule.id not in security_rule_ids + ] + return + _source_groups = self._add_source_group(source_groups, vpc_id) # I don't believe this is required after changing the default egress rule @@ -942,7 +966,7 @@ class SecurityGroupBackend: ): group.egress_rules.remove(rule) self.sg_old_egress_ruls[group.id] = group.egress_rules.copy() - return security_rule + return raise InvalidPermissionNotFoundError() def update_security_group_rule_descriptions_ingress( @@ -954,6 +978,7 @@ class SecurityGroupBackend: ip_ranges: List[str], source_groups: Optional[List[Dict[str, Any]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument vpc_id: Optional[str] = None, ) -> SecurityGroup: @@ -1010,6 +1035,7 @@ class SecurityGroupBackend: ip_ranges: List[str], source_groups: Optional[List[Dict[str, Any]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None, + security_rule_ids: Optional[List[str]] = None, # pylint:disable=unused-argument vpc_id: Optional[str] = None, ) -> SecurityGroup: diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 5d1df1475..2742f6518 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -71,6 +71,7 @@ def parse_sg_attributes_from_dict(sg_attributes: Dict[str, Any]) -> Tuple[Any, . class SecurityGroups(EC2BaseResponse): def _process_rules_from_querystring(self) -> Any: group_name_or_id = self._get_param("GroupName") or self._get_param("GroupId") + security_rule_ids = self._get_multi_param("SecurityGroupRuleId") querytree: Dict[str, Any] = {} for key, value in self.querystring.items(): @@ -103,6 +104,7 @@ class SecurityGroups(EC2BaseResponse): ip_ranges, source_groups, prefix_list_ids, + security_rule_ids, ) ip_permissions = querytree.get("IpPermissions") or {} @@ -126,6 +128,7 @@ class SecurityGroups(EC2BaseResponse): ip_ranges, source_groups, prefix_list_ids, + security_rule_ids, ) def authorize_security_group_egress(self) -> str: @@ -264,7 +267,7 @@ DESCRIBE_SECURITY_GROUP_RULES_RESPONSE = """ {% endif %} {{ rule.ip_protocol }} {{ rule.owner_id }} - true + {{ 'true' if rule.is_egress else 'false' }} {{ rule.id }} {% endfor %} diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 8a00d637d..c09a680de 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -1034,16 +1034,23 @@ def test_authorize_and_revoke_in_bulk(): @mock_ec2 def test_security_group_ingress_without_multirule(): ec2 = boto3.resource("ec2", "ca-central-1") + client = boto3.client("ec2", "ca-central-1") sg = ec2.create_security_group(Description="Test SG", GroupName=str(uuid4())) assert len(sg.ip_permissions) == 0 - sg.authorize_ingress( + rule_id = sg.authorize_ingress( CidrIp="192.168.0.1/32", FromPort=22, ToPort=22, IpProtocol="tcp" - ) + )["SecurityGroupRules"][0]["SecurityGroupRuleId"] - # Fails assert len(sg.ip_permissions) == 1 + rules = client.describe_security_group_rules(SecurityGroupRuleIds=[rule_id])[ + "SecurityGroupRules" + ] + ingress = [rule for rule in rules if rule["SecurityGroupRuleId"] == rule_id] + assert len(ingress) == 1 + assert ingress[0]["IsEgress"] is False + @mock_ec2 def test_security_group_ingress_without_multirule_after_reload(): @@ -1123,6 +1130,32 @@ def test_revoke_security_group_egress(): sg.ip_permissions_egress.should.have.length_of(0) +@mock_ec2 +def test_revoke_security_group_egress__without_ipprotocol(): + ec2 = boto3.resource("ec2", "eu-west-2") + client = boto3.client("ec2", region_name="eu-west-2") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + sec_group = client.describe_security_groups( + Filters=[ + {"Name": "group-name", "Values": ["default"]}, + {"Name": "vpc-id", "Values": [vpc.id]}, + ] + )["SecurityGroups"][0]["GroupId"] + + rule_id = client.describe_security_group_rules( + Filters=[{"Name": "group-id", "Values": [sec_group]}] + )["SecurityGroupRules"][0]["SecurityGroupRuleId"] + + client.revoke_security_group_egress( + GroupId=sec_group, SecurityGroupRuleIds=[rule_id] + ) + + remaining_rules = client.describe_security_group_rules( + Filters=[{"Name": "group-id", "Values": [sec_group]}] + )["SecurityGroupRules"] + assert len(remaining_rules) == 0 + + @mock_ec2 def test_update_security_group_rule_descriptions_egress(): ec2 = boto3.resource("ec2", "us-east-1") @@ -1237,6 +1270,7 @@ def test_security_group_rules_added_via_the_backend_can_be_revoked_via_the_api() # Add an ingress/egress rule using the EC2 backend directly. rule_ingress = { "group_name_or_id": sg.id, + "security_rule_ids": None, "from_port": 0, "ip_protocol": "udp", "ip_ranges": [], @@ -1246,6 +1280,7 @@ def test_security_group_rules_added_via_the_backend_can_be_revoked_via_the_api() ec2_backend.authorize_security_group_ingress(**rule_ingress) rule_egress = { "group_name_or_id": sg.id, + "security_rule_ids": None, "from_port": 8443, "ip_protocol": "tcp", "ip_ranges": [],