From 4beda260076a2b4027a22be7b819ac66298a67e6 Mon Sep 17 00:00:00 2001 From: Hugo Lopes Tavares Date: Wed, 25 Feb 2015 18:11:00 -0500 Subject: [PATCH] Change SecurityGroupBackend.{authorize,revoke}_security_group_ingress() methods to receive group name or id, never both --- moto/ec2/models.py | 31 ++++----- moto/ec2/responses/security_groups.py | 10 +-- .../test_cloudformation_stack_integration.py | 66 +++++++++++++++++++ 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index fc81174a3..c5d3c256c 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1085,8 +1085,7 @@ class SecurityGroup(TaggedEC2Resource): source_group_id = ingress_rule.get('SourceSecurityGroupId') ec2_backend.authorize_security_group_ingress( - group_name=security_group.name, - group_id=security_group.id, + group_name_or_id=security_group.id, ip_protocol=ingress_rule['IpProtocol'], from_port=ingress_rule['FromPort'], to_port=ingress_rule['ToPort'], @@ -1218,9 +1217,15 @@ class SecurityGroupBackend(object): default_group = self.create_security_group("default", "The default security group", vpc_id=vpc_id, force=True) return default_group + def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id): + # try searching by id, fallbacks to name search + group = self.get_security_group_from_id(group_name_or_id) + if group is None: + group = self.get_security_group_from_name(group_name_or_id, vpc_id) + return group + def authorize_security_group_ingress(self, - group_name, - group_id, + group_name_or_id, ip_protocol, from_port, to_port, @@ -1228,12 +1233,7 @@ class SecurityGroupBackend(object): source_group_names=None, source_group_ids=None, vpc_id=None): - # to auth a group in a VPC you need the group_id the name isn't enough - - if group_name: - group = self.get_security_group_from_name(group_name, vpc_id) - elif group_id: - group = self.get_security_group_from_id(group_id) + group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) if ip_ranges and not isinstance(ip_ranges, list): ip_ranges = [ip_ranges] @@ -1261,8 +1261,7 @@ class SecurityGroupBackend(object): group.ingress_rules.append(security_rule) def revoke_security_group_ingress(self, - group_name, - group_id, + group_name_or_id, ip_protocol, from_port, to_port, @@ -1271,10 +1270,7 @@ class SecurityGroupBackend(object): source_group_ids=None, vpc_id=None): - if group_name: - group = self.get_security_group_from_name(group_name, vpc_id) - elif group_id: - group = self.get_security_group_from_id(group_id) + group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) source_groups = [] for source_group_name in source_group_names: @@ -1340,8 +1336,7 @@ class SecurityGroupIngress(object): security_group = ec2_backend.describe_security_groups(groupnames=[group_name])[0] ec2_backend.authorize_security_group_ingress( - group_name=security_group.name, - group_id=security_group.id, + group_name_or_id=security_group.id, ip_protocol=ip_protocol, from_port=from_port, to_port=to_port, diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 38fadb883..eec27c3aa 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -4,14 +4,10 @@ from moto.ec2.utils import filters_from_querystring def process_rules_from_querystring(querystring): - - name = None - group_id = None - try: - name = querystring.get('GroupName')[0] + group_name_or_id = querystring.get('GroupName')[0] except: - group_id = querystring.get('GroupId')[0] + group_name_or_id = querystring.get('GroupId')[0] ip_protocol = querystring.get('IpPermissions.1.IpProtocol')[0] from_port = querystring.get('IpPermissions.1.FromPort')[0] @@ -30,7 +26,7 @@ def process_rules_from_querystring(querystring): elif 'IpPermissions.1.Groups' in key: source_groups.append(value[0]) - return (name, group_id, ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids) + return (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids) class SecurityGroups(BaseResponse): diff --git a/tests/test_cloudformation/test_cloudformation_stack_integration.py b/tests/test_cloudformation/test_cloudformation_stack_integration.py index ca60bf016..0ca96db20 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_integration.py +++ b/tests/test_cloudformation/test_cloudformation_stack_integration.py @@ -1128,3 +1128,69 @@ def test_security_group_ingress_separate_from_security_group_by_id(): security_group1.rules[0].ip_protocol.should.equal('tcp') security_group1.rules[0].from_port.should.equal('80') security_group1.rules[0].to_port.should.equal('8080') + + +@mock_cloudformation +@mock_ec2 +def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): + vpc_conn = boto.vpc.connect_to_region("us-west-1") + vpc = vpc_conn.create_vpc("10.0.0.0/16") + + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + "test-security-group1": { + "Type": "AWS::EC2::SecurityGroup", + "Properties": { + "GroupDescription": "test security group", + "VpcId": vpc.id, + "Tags": [ + { + "Key": "sg-name", + "Value": "sg1" + } + ] + }, + }, + "test-security-group2": { + "Type": "AWS::EC2::SecurityGroup", + "Properties": { + "GroupDescription": "test security group", + "VpcId": vpc.id, + "Tags": [ + { + "Key": "sg-name", + "Value": "sg2" + } + ] + }, + }, + "test-sg-ingress": { + "Type": "AWS::EC2::SecurityGroupIngress", + "Properties": { + "GroupId": {"Ref": "test-security-group1"}, + "VpcId": vpc.id, + "IpProtocol": "tcp", + "FromPort": "80", + "ToPort": "8080", + "SourceSecurityGroupId": {"Ref": "test-security-group2"}, + } + } + } + } + + template_json = json.dumps(template) + cf_conn = boto.cloudformation.connect_to_region("us-west-1") + cf_conn.create_stack( + "test_stack", + template_body=template_json, + ) + security_group1 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg1"})[0] + security_group2 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[0] + + security_group1.rules.should.have.length_of(1) + security_group1.rules[0].grants.should.have.length_of(1) + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal('tcp') + security_group1.rules[0].from_port.should.equal('80') + security_group1.rules[0].to_port.should.equal('8080')