EC2: Add tag support for security group rules (#7204)

This commit is contained in:
Viren Nadkarni 2024-01-13 18:24:51 +05:30 committed by GitHub
parent 99e01b5adc
commit 624de34d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 13 deletions

View File

@ -4,7 +4,7 @@ import json
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from moto.core import CloudFormationModel
from moto.core import BaseModel, CloudFormationModel
from moto.core.utils import aws_api_matches
from ..exceptions import (
@ -28,10 +28,10 @@ from ..utils import (
from .core import TaggedEC2Resource
class SecurityRule:
class SecurityRule(TaggedEC2Resource):
def __init__(
self,
account_id: str,
ec2_backend: Any,
ip_protocol: str,
from_port: Optional[str],
to_port: Optional[str],
@ -40,7 +40,7 @@ class SecurityRule:
prefix_list_ids: Optional[List[Dict[str, str]]] = None,
is_egress: bool = True,
):
self.account_id = account_id
self.ec2_backend = ec2_backend
self.id = random_security_group_rule_id()
self.ip_protocol = str(ip_protocol) if ip_protocol else None
self.ip_ranges = ip_ranges or []
@ -76,7 +76,7 @@ class SecurityRule:
@property
def owner_id(self) -> str:
return self.account_id
return self.ec2_backend.account_id
def __eq__(self, other: "SecurityRule") -> bool: # type: ignore[override]
if self.ip_protocol != other.ip_protocol:
@ -111,6 +111,18 @@ class SecurityRule:
return True
def __deepcopy__(self, memodict: Dict[Any, Any]) -> BaseModel:
memodict = memodict or {}
cls = self.__class__
new = cls.__new__(cls)
memodict[id(self)] = new
for k, v in self.__dict__.items():
if k == "ec2_backend":
setattr(new, k, self.ec2_backend)
else:
setattr(new, k, copy.deepcopy(v, memodict))
return new
class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__(
@ -142,13 +154,23 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
if vpc:
self.egress_rules.append(
SecurityRule(
self.owner_id, "-1", None, None, [{"CidrIp": "0.0.0.0/0"}], []
self.ec2_backend,
"-1",
None,
None,
[{"CidrIp": "0.0.0.0/0"}],
[],
)
)
if vpc and len(vpc.get_cidr_block_association_set(ipv6=True)) > 0:
self.egress_rules.append(
SecurityRule(
self.owner_id, "-1", None, None, [{"CidrIpv6": "::/0"}], []
self.ec2_backend,
"-1",
None,
None,
[{"CidrIpv6": "::/0"}],
[],
)
)
@ -671,7 +693,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,
@ -741,7 +763,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,
@ -836,7 +858,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,
@ -920,7 +942,7 @@ class SecurityGroupBackend:
ip_ranges.remove(item)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,
@ -1006,7 +1028,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,
@ -1061,7 +1083,7 @@ class SecurityGroupBackend:
_source_groups = self._add_source_group(source_groups, vpc_id)
security_rule = SecurityRule(
self.account_id, # type: ignore[attr-defined]
self,
ip_protocol,
from_port,
to_port,

View File

@ -272,6 +272,14 @@ DESCRIBE_SECURITY_GROUP_RULES_RESPONSE = """
<groupOwnerId>{{ rule.owner_id }}</groupOwnerId>
<isEgress>{{ 'true' if rule.is_egress else 'false' }}</isEgress>
<securityGroupRuleId>{{ rule.id }}</securityGroupRuleId>
<tagSet>
{% for tag in rule.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</item>
{% endfor %}
{% endfor %}

View File

@ -569,6 +569,35 @@ def test_authorize_all_protocols_with_no_port_specification():
assert "ToPort" not in permission
@mock_ec2
def test_security_group_rule_tagging():
ec2 = boto3.resource("ec2", "us-east-1")
client = boto3.client("ec2", "us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
sg_name = str(uuid4())
sg = client.create_security_group(
Description="Test SG", GroupName=sg_name, VpcId=vpc.id
)
response = client.describe_security_group_rules(
Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}]
)
rule_id = response["SecurityGroupRules"][0]["SecurityGroupRuleId"]
tag_name = str(uuid4())[0:6]
tag_val = str(uuid4())
client.create_tags(Resources=[rule_id], Tags=[{"Key": tag_name, "Value": tag_val}])
response = client.describe_security_group_rules(
Filters=[{"Name": "group-id", "Values": [sg["GroupId"]]}]
)
assert "Tags" in response["SecurityGroupRules"][0]
assert response["SecurityGroupRules"][0]["Tags"][0]["Key"] == tag_name
assert response["SecurityGroupRules"][0]["Tags"][0]["Value"] == tag_val
@mock_ec2
def test_create_and_describe_security_grp_rule():
ec2 = boto3.resource("ec2", "us-east-1")