From af8697c9a72280030f5e0c5c80223fbf1857ee99 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Fri, 13 Apr 2018 15:03:07 -0400 Subject: [PATCH] Fix security group rules for single rule case. Closes #1522. --- moto/ec2/responses/security_groups.py | 48 ++++++++++++++++---------- tests/test_ec2/test_security_groups.py | 25 ++++++++++++++ 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 9118c01b3..34d22be8b 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -11,6 +11,29 @@ def try_parse_int(value, default=None): return default +def parse_sg_attributes_from_dict(sg_attributes): + ip_protocol = sg_attributes.get('IpProtocol', [None])[0] + from_port = sg_attributes.get('FromPort', [None])[0] + to_port = sg_attributes.get('ToPort', [None])[0] + + ip_ranges = [] + ip_ranges_tree = sg_attributes.get('IpRanges') or {} + for ip_range_idx in sorted(ip_ranges_tree.keys()): + ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0]) + + source_groups = [] + source_group_ids = [] + groups_tree = sg_attributes.get('Groups') or {} + for group_idx in sorted(groups_tree.keys()): + group_dict = groups_tree[group_idx] + if 'GroupId' in group_dict: + source_group_ids.append(group_dict['GroupId'][0]) + elif 'GroupName' in group_dict: + source_groups.append(group_dict['GroupName'][0]) + + return ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids + + class SecurityGroups(BaseResponse): def _process_rules_from_querystring(self): @@ -29,28 +52,17 @@ class SecurityGroups(BaseResponse): d = d[subkey] d[key_splitted[-1]] = value + if 'IpPermissions' not in querytree: + # Handle single rule syntax + ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(querytree) + yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, + source_groups, source_group_ids) + ip_permissions = querytree.get('IpPermissions') or {} for ip_permission_idx in sorted(ip_permissions.keys()): ip_permission = ip_permissions[ip_permission_idx] - ip_protocol = ip_permission.get('IpProtocol', [None])[0] - from_port = ip_permission.get('FromPort', [None])[0] - to_port = ip_permission.get('ToPort', [None])[0] - - ip_ranges = [] - ip_ranges_tree = ip_permission.get('IpRanges') or {} - for ip_range_idx in sorted(ip_ranges_tree.keys()): - ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0]) - - source_groups = [] - source_group_ids = [] - groups_tree = ip_permission.get('Groups') or {} - for group_idx in sorted(groups_tree.keys()): - group_dict = groups_tree[group_idx] - if 'GroupId' in group_dict: - source_group_ids.append(group_dict['GroupId'][0]) - elif 'GroupName' in group_dict: - source_groups.append(group_dict['GroupName'][0]) + ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(ip_permission) yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids) diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 0d7565a31..d843087a6 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -689,6 +689,31 @@ def test_authorize_and_revoke_in_bulk(): sg01.ip_permissions_egress.shouldnt.contain(ip_permission) +@mock_ec2 +def test_security_group_ingress_without_multirule(): + ec2 = boto3.resource('ec2', 'ca-central-1') + sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + + assert len(sg.ip_permissions) == 0 + sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + + # Fails + assert len(sg.ip_permissions) == 1 + + +@mock_ec2 +def test_security_group_ingress_without_multirule_after_reload(): + ec2 = boto3.resource('ec2', 'ca-central-1') + sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + + assert len(sg.ip_permissions) == 0 + sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + + # Also Fails + sg_after = ec2.SecurityGroup(sg.id) + assert len(sg_after.ip_permissions) == 1 + + @mock_ec2_deprecated def test_get_all_security_groups_filter_with_same_vpc_id(): conn = boto.connect_ec2('the_key', 'the_secret')