Fix security group filters 2 (#4481)

This commit is contained in:
James Light 2021-11-16 07:24:14 -05:00 committed by GitHub
parent 52aeac1cee
commit f4abd5528f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 123 additions and 28 deletions

View File

@ -399,13 +399,22 @@ def merge_dicts(dict1, dict2, remove_nulls=False):
dict1.pop(key) dict1.pop(key)
def glob_matches(pattern, string): def aws_api_matches(pattern, string):
"""AWS API-style globbing regexes""" """
pattern, n = re.subn(r"[^\\]\*", r".*", pattern) AWS API can match a value based on a glob, or an exact match
pattern, m = re.subn(r"[^\\]\?", r".?", pattern) """
# use a negative lookback regex to match stars that are not prefixed with a backslash
# and replace all stars not prefixed w/ a backslash with '.*' to take this from "glob" to PCRE syntax
pattern, n = re.subn(r"(?<!\\)\*", r".*", pattern)
pattern = ".*" + pattern + ".*" # ? in the AWS glob form becomes .? in regex
# also, don't substitute it if it is prefixed w/ a backslash
pattern, m = re.subn(r"(?<!\\)\?", r".?", pattern)
if re.match(pattern, str(string)): # aws api seems to anchor
anchored_pattern = f"^{pattern}$"
if re.match(anchored_pattern, str(string)):
return True return True
return False else:
return False

View File

@ -31,7 +31,7 @@ from moto.core.models import Model, BaseModel, CloudFormationModel
from moto.core.utils import ( from moto.core.utils import (
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
camelcase_to_underscores, camelcase_to_underscores,
glob_matches, aws_api_matches,
) )
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.kms import kms_backends from moto.kms import kms_backends
@ -2439,7 +2439,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def filter_description(self, values): def filter_description(self, values):
for value in values: for value in values:
if glob_matches(value, self.description): if aws_api_matches(value, self.description):
return True return True
return False return False
@ -2447,14 +2447,16 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
for cidr in rule.ip_ranges: for cidr in rule.ip_ranges:
if glob_matches(value, cidr): if aws_api_matches(value, cidr.get("CidrIp", "NONE")):
return True return True
return False return False
def filter_egress__ip_permission__from_port(self, values): def filter_egress__ip_permission__from_port(self, values):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
if rule.ip_protocol != -1 and glob_matches(value, str(rule.from_port)): if rule.ip_protocol != -1 and aws_api_matches(
value, str(rule.from_port)
):
return True return True
return False return False
@ -2462,7 +2464,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
for sg in rule.source_groups: for sg in rule.source_groups:
if glob_matches(value, sg.get("GroupId", None)): if aws_api_matches(value, sg.get("GroupId", None)):
return True return True
return False return False
@ -2470,7 +2472,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
for group in rule.source_groups: for group in rule.source_groups:
if glob_matches(value, group.get("GroupName", None)): if aws_api_matches(value, group.get("GroupName", None)):
return True return True
return False return False
@ -2483,33 +2485,33 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def filter_egress__ip_permission__protocol(self, values): def filter_egress__ip_permission__protocol(self, values):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
if glob_matches(value, rule.ip_protocol): if aws_api_matches(value, rule.ip_protocol):
return True return True
return False return False
def filter_egress__ip_permission__to_port(self, values): def filter_egress__ip_permission__to_port(self, values):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
if glob_matches(value, rule.to_port): if aws_api_matches(value, rule.to_port):
return True return True
return False return False
def filter_egress__ip_permission__user_id(self, values): def filter_egress__ip_permission__user_id(self, values):
for value in values: for value in values:
for rule in self.egress_rules: for rule in self.egress_rules:
if glob_matches(value, rule.owner_id): if aws_api_matches(value, rule.owner_id):
return True return True
return False return False
def filter_group_id(self, values): def filter_group_id(self, values):
for value in values: for value in values:
if glob_matches(value, self.id): if aws_api_matches(value, self.id):
return True return True
return False return False
def filter_group_name(self, values): def filter_group_name(self, values):
for value in values: for value in values:
if glob_matches(value, self.group_name): if aws_api_matches(value, self.group_name):
return True return True
return False return False
@ -2517,14 +2519,14 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
for cidr in rule.ip_ranges: for cidr in rule.ip_ranges:
if glob_matches(value, cidr): if aws_api_matches(value, cidr.get("CidrIp", "NONE")):
return True return True
return False return False
def filter_ip_permission__from_port(self, values): def filter_ip_permission__from_port(self, values):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
if glob_matches(value, rule.from_port): if aws_api_matches(value, rule.from_port):
return True return True
return False return False
@ -2532,7 +2534,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
for group in rule.source_groups: for group in rule.source_groups:
if glob_matches(value, group.get("GroupId", None)): if aws_api_matches(value, group.get("GroupId", None)):
return True return True
return False return False
@ -2540,7 +2542,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
for group in rule.source_groups: for group in rule.source_groups:
if glob_matches(value, group.get("GroupName", None)): if aws_api_matches(value, group.get("GroupName", None)):
return True return True
return False return False
@ -2553,33 +2555,33 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def filter_ip_permission__protocol(self, values): def filter_ip_permission__protocol(self, values):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
if glob_matches(value, rule.protocol): if aws_api_matches(value, rule.protocol):
return True return True
return False return False
def filter_ip_permission__to_port(self, values): def filter_ip_permission__to_port(self, values):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
if glob_matches(rule.to_port): if aws_api_matches(value, rule.to_port):
return True return True
return False return False
def filter_ip_permission__user_id(self, values): def filter_ip_permission__user_id(self, values):
for value in values: for value in values:
for rule in self.ingress_rules: for rule in self.ingress_rules:
if glob_matches(value, rule.owner_id): if aws_api_matches(value, rule.owner_id):
return True return True
return False return False
def filter_owner_id(self, values): def filter_owner_id(self, values):
for value in values: for value in values:
if glob_matches(value, self.owner_id): if aws_api_matches(value, self.owner_id):
return True return True
return False return False
def filter_vpc_id(self, values): def filter_vpc_id(self, values):
for value in values: for value in values:
if glob_matches(value, self.vpc_id): if aws_api_matches(value, self.vpc_id):
return True return True
return False return False

View File

@ -1919,7 +1919,7 @@ def test_filter_description():
filter_to_match_group_1_description = { filter_to_match_group_1_description = {
"Name": "description", "Name": "description",
"Values": [unique], "Values": [f"*{unique}*"],
} }
security_groups = ec2r.security_groups.filter( security_groups = ec2r.security_groups.filter(
@ -1931,6 +1931,53 @@ def test_filter_description():
assert security_groups[0].group_id == sg1.group_id assert security_groups[0].group_id == sg1.group_id
@mock_ec2
def test_filter_ip_permission__cidr():
if settings.TEST_SERVER_MODE:
raise SkipTest(
"CIDR's might already exist due to other tests creating IP ranges"
)
ec2r = boto3.resource("ec2", region_name="us-west-1")
vpc = ec2r.create_vpc(CidrBlock="10.250.1.0/16")
sg1 = vpc.create_security_group(
Description="A Described Description Descriptor", GroupName="test-1"
)
sg2 = vpc.create_security_group(
Description="Another Description That Awes The Human Mind", GroupName="test-2"
)
sg1.authorize_ingress(
IpPermissions=[
{
"FromPort": 7357,
"ToPort": 7357,
"IpProtocol": "tcp",
"IpRanges": [{"CidrIp": "10.250.0.0/16"}, {"CidrIp": "10.251.0.0/16"}],
}
]
)
sg2.authorize_ingress(
IpPermissions=[
{
"FromPort": 7357,
"ToPort": 7357,
"IpProtocol": "tcp",
"IpRanges": [{"CidrIp": "172.16.0.0/16"}, {"CidrIp": "172.17.0.0/16"}],
}
]
)
filter_to_match_group_1 = {
"Name": "ip-permission.cidr",
"Values": ["10.250.0.0/16"],
}
security_groups = ec2r.security_groups.filter(Filters=[filter_to_match_group_1])
security_groups = list(security_groups)
assert len(security_groups) == 1
assert security_groups[0].group_id == sg1.group_id
@mock_ec2 @mock_ec2
def test_filter_egress__ip_permission__cidr(): def test_filter_egress__ip_permission__cidr():
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
@ -2332,3 +2379,40 @@ def test_get_groups_by_ippermissions_group_id_filter_across_vpcs():
security_groups = list(security_groups) security_groups = list(security_groups)
assert len(security_groups) == 1 assert len(security_groups) == 1
assert security_groups[0].group_id == sg1.group_id assert security_groups[0].group_id == sg1.group_id
@mock_ec2
def test_filter_group_name():
"""
this filter is an exact match, not a glob
"""
ec2r = boto3.resource("ec2", region_name="us-west-1")
vpc = ec2r.create_vpc(CidrBlock="10.250.1.0/16")
uniq_sg_name_prefix = str(uuid4())[0:6]
sg1 = vpc.create_security_group(
Description="A Described Description Descriptor",
GroupName=f"{uniq_sg_name_prefix}-test-1",
)
vpc.create_security_group(
Description="Another Description That Awes The Human Mind",
GroupName=f"{uniq_sg_name_prefix}-test-12",
)
vpc.create_security_group(
Description="Yet Another Descriptive Description",
GroupName=f"{uniq_sg_name_prefix}-test-13",
)
vpc.create_security_group(
Description="Such Description Much Described",
GroupName=f"{uniq_sg_name_prefix}-test-14",
)
filter_to_match_group_1 = {
"Name": "group-name",
"Values": [sg1.group_name],
}
security_groups = ec2r.security_groups.filter(Filters=[filter_to_match_group_1])
security_groups = list(security_groups)
assert len(security_groups) == 1
assert security_groups[0].group_name == sg1.group_name