From 624de34d82a1b2c521727b14a2173380e196f1d8 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Sat, 13 Jan 2024 18:24:51 +0530 Subject: [PATCH] EC2: Add tag support for security group rules (#7204) --- moto/ec2/models/security_groups.py | 48 +++++++++++++++++++------- moto/ec2/responses/security_groups.py | 8 +++++ tests/test_ec2/test_security_groups.py | 29 ++++++++++++++++ 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/moto/ec2/models/security_groups.py b/moto/ec2/models/security_groups.py index 36dc25564..13222eb03 100644 --- a/moto/ec2/models/security_groups.py +++ b/moto/ec2/models/security_groups.py @@ -4,7 +4,7 @@ import json from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -from moto.core import CloudFormationModel +from moto.core import BaseModel, CloudFormationModel from moto.core.utils import aws_api_matches from ..exceptions import ( @@ -28,10 +28,10 @@ from ..utils import ( from .core import TaggedEC2Resource -class SecurityRule: +class SecurityRule(TaggedEC2Resource): def __init__( self, - account_id: str, + ec2_backend: Any, ip_protocol: str, from_port: Optional[str], to_port: Optional[str], @@ -40,7 +40,7 @@ class SecurityRule: prefix_list_ids: Optional[List[Dict[str, str]]] = None, is_egress: bool = True, ): - self.account_id = account_id + self.ec2_backend = ec2_backend self.id = random_security_group_rule_id() self.ip_protocol = str(ip_protocol) if ip_protocol else None self.ip_ranges = ip_ranges or [] @@ -76,7 +76,7 @@ class SecurityRule: @property def owner_id(self) -> str: - return self.account_id + return self.ec2_backend.account_id def __eq__(self, other: "SecurityRule") -> bool: # type: ignore[override] if self.ip_protocol != other.ip_protocol: @@ -111,6 +111,18 @@ class SecurityRule: return True + def __deepcopy__(self, memodict: Dict[Any, Any]) -> BaseModel: + memodict = memodict or {} + cls = self.__class__ + new = cls.__new__(cls) + memodict[id(self)] = new + for k, v in self.__dict__.items(): + if k == "ec2_backend": + setattr(new, k, self.ec2_backend) + else: + setattr(new, k, copy.deepcopy(v, memodict)) + return new + class SecurityGroup(TaggedEC2Resource, CloudFormationModel): def __init__( @@ -142,13 +154,23 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel): if vpc: self.egress_rules.append( SecurityRule( - self.owner_id, "-1", None, None, [{"CidrIp": "0.0.0.0/0"}], [] + self.ec2_backend, + "-1", + None, + None, + [{"CidrIp": "0.0.0.0/0"}], + [], ) ) if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0: self.egress_rules.append( SecurityRule( - self.owner_id, "-1", None, None, [{"CidrIpv6": "::/0"}], [] + self.ec2_backend, + "-1", + None, + None, + [{"CidrIpv6": "::/0"}], + [], ) ) @@ -671,7 +693,7 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, @@ -741,7 +763,7 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, @@ -836,7 +858,7 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, @@ -920,7 +942,7 @@ class SecurityGroupBackend: ip_ranges.remove(item) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, @@ -1006,7 +1028,7 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, @@ -1061,7 +1083,7 @@ class SecurityGroupBackend: _source_groups = self._add_source_group(source_groups, vpc_id) security_rule = SecurityRule( - self.account_id, # type: ignore[attr-defined] + self, ip_protocol, from_port, to_port, diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index f49c74ab3..b65949291 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -272,6 +272,14 @@ DESCRIBE_SECURITY_GROUP_RULES_RESPONSE = """ {{ rule.owner_id }} {{ 'true' if rule.is_egress else 'false' }} {{ rule.id }} + + {% for tag in rule.get_tags() %} + + {{ tag.key }} + {{ tag.value }} + + {% endfor %} + {% endfor %} {% endfor %} diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 5ead9689b..3587912c6 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -569,6 +569,35 @@ def test_authorize_all_protocols_with_no_port_specification(): assert "ToPort" not in permission +@mock_ec2 +def test_security_group_rule_tagging(): + ec2 = boto3.resource("ec2", "us-east-1") + client = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + sg_name = str(uuid4()) + sg = client.create_security_group( + Description="Test SG", GroupName=sg_name, VpcId=vpc.id + ) + + response = client.describe_security_group_rules( + Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}] + ) + rule_id = response["SecurityGroupRules"][0]["SecurityGroupRuleId"] + + tag_name = str(uuid4())[0:6] + tag_val = str(uuid4()) + + client.create_tags(Resources=[rule_id], Tags=[{"Key": tag_name, "Value": tag_val}]) + + response = client.describe_security_group_rules( + Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}] + ) + assert "Tags" in response["SecurityGroupRules"][0] + assert response["SecurityGroupRules"][0]["Tags"][0]["Key"] == tag_name + assert response["SecurityGroupRules"][0]["Tags"][0]["Value"] == tag_val + + @mock_ec2 def test_create_and_describe_security_grp_rule(): ec2 = boto3.resource("ec2", "us-east-1")