EC2: Add tag support for security group rules (#7204)

This commit is contained in:
Viren Nadkarni 2024-01-13 18:24:51 +05:30 committed by GitHub
parent 99e01b5adc
commit 624de34d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 13 deletions

View File

@ -4,7 +4,7 @@ import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from moto.core import CloudFormationModel from moto.core import BaseModel, CloudFormationModel
from moto.core.utils import aws_api_matches from moto.core.utils import aws_api_matches
from ..exceptions import ( from ..exceptions import (
@ -28,10 +28,10 @@ from ..utils import (
from .core import TaggedEC2Resource from .core import TaggedEC2Resource
class SecurityRule: class SecurityRule(TaggedEC2Resource):
def __init__( def __init__(
self, self,
account_id: str, ec2_backend: Any,
ip_protocol: str, ip_protocol: str,
from_port: Optional[str], from_port: Optional[str],
to_port: Optional[str], to_port: Optional[str],
@ -40,7 +40,7 @@ class SecurityRule:
prefix_list_ids: Optional[List[Dict[str, str]]] = None, prefix_list_ids: Optional[List[Dict[str, str]]] = None,
is_egress: bool = True, is_egress: bool = True,
): ):
self.account_id = account_id self.ec2_backend = ec2_backend
self.id = random_security_group_rule_id() self.id = random_security_group_rule_id()
self.ip_protocol = str(ip_protocol) if ip_protocol else None self.ip_protocol = str(ip_protocol) if ip_protocol else None
self.ip_ranges = ip_ranges or [] self.ip_ranges = ip_ranges or []
@ -76,7 +76,7 @@ class SecurityRule:
@property @property
def owner_id(self) -> str: def owner_id(self) -> str:
return self.account_id return self.ec2_backend.account_id
def __eq__(self, other: "SecurityRule") -> bool: # type: ignore[override] def __eq__(self, other: "SecurityRule") -> bool: # type: ignore[override]
if self.ip_protocol != other.ip_protocol: if self.ip_protocol != other.ip_protocol:
@ -111,6 +111,18 @@ class SecurityRule:
return True return True
def __deepcopy__(self, memodict: Dict[Any, Any]) -> BaseModel:
memodict = memodict or {}
cls = self.__class__
new = cls.__new__(cls)
memodict[id(self)] = new
for k, v in self.__dict__.items():
if k == "ec2_backend":
setattr(new, k, self.ec2_backend)
else:
setattr(new, k, copy.deepcopy(v, memodict))
return new
class SecurityGroup(TaggedEC2Resource, CloudFormationModel): class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__( def __init__(
@ -142,13 +154,23 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
if vpc: if vpc:
self.egress_rules.append( self.egress_rules.append(
SecurityRule( SecurityRule(
self.owner_id, "-1", None, None, [{"CidrIp": "0.0.0.0/0"}], [] self.ec2_backend,
"-1",
None,
None,
[{"CidrIp": "0.0.0.0/0"}],
[],
) )
) )
if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0: if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0:
self.egress_rules.append( self.egress_rules.append(
SecurityRule( SecurityRule(
self.owner_id, "-1", None, None, [{"CidrIpv6": "::/0"}], [] self.ec2_backend,
"-1",
None,
None,
[{"CidrIpv6": "::/0"}],
[],
) )
) )
@ -671,7 +693,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id) _source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,
@ -741,7 +763,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id) _source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,
@ -836,7 +858,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id) _source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,
@ -920,7 +942,7 @@ class SecurityGroupBackend:
ip_ranges.remove(item) ip_ranges.remove(item)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,
@ -1006,7 +1028,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id) _source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,
@ -1061,7 +1083,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id) _source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule( security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined] self,
ip_protocol, ip_protocol,
from_port, from_port,
to_port, to_port,

View File

@ -272,6 +272,14 @@ DESCRIBE_SECURITY_GROUP_RULES_RESPONSE = """
<groupOwnerId>{{ rule.owner_id }}</groupOwnerId> <groupOwnerId>{{ rule.owner_id }}</groupOwnerId>
<isEgress>{{ 'true' if rule.is_egress else 'false' }}</isEgress> <isEgress>{{ 'true' if rule.is_egress else 'false' }}</isEgress>
<securityGroupRuleId>{{ rule.id }}</securityGroupRuleId> <securityGroupRuleId>{{ rule.id }}</securityGroupRuleId>
<tagSet>
{% for tag in rule.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</item> </item>
{% endfor %} {% endfor %}
{% endfor %} {% endfor %}

View File

@ -569,6 +569,35 @@ def test_authorize_all_protocols_with_no_port_specification():
assert "ToPort" not in permission assert "ToPort" not in permission
@mock_ec2
def test_security_group_rule_tagging():
ec2 = boto3.resource("ec2", "us-east-1")
client = boto3.client("ec2", "us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
sg_name = str(uuid4())
sg = client.create_security_group(
Description="Test SG", GroupName=sg_name, VpcId=vpc.id
)
response = client.describe_security_group_rules(
Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}]
)
rule_id = response["SecurityGroupRules"][0]["SecurityGroupRuleId"]
tag_name = str(uuid4())[0:6]
tag_val = str(uuid4())
client.create_tags(Resources=[rule_id], Tags=[{"Key": tag_name, "Value": tag_val}])
response = client.describe_security_group_rules(
Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}]
)
assert "Tags" in response["SecurityGroupRules"][0]
assert response["SecurityGroupRules"][0]["Tags"][0]["Key"] == tag_name
assert response["SecurityGroupRules"][0]["Tags"][0]["Value"] == tag_val
@mock_ec2 @mock_ec2
def test_create_and_describe_security_grp_rule(): def test_create_and_describe_security_grp_rule():
ec2 = boto3.resource("ec2", "us-east-1") ec2 = boto3.resource("ec2", "us-east-1")