We're getting back the correct group from get_security_group_from_id,

but hitting another issue with the source_group_name also using an id
rather than a name
This commit is contained in:
Jon Haddad 2014-03-20 17:26:08 -07:00
parent 1480f8b44a
commit cbdc8ba183
3 changed files with 31 additions and 6 deletions

View File

@ -373,6 +373,16 @@ class SecurityGroupBackend(object):
if group: if group:
return self.groups[None].pop(group.id) return self.groups[None].pop(group.id)
def get_security_group_from_id(self, group_id):
# 2 levels of chaining necessary since it's a complex structure
all_groups = itertools.chain.from_iterable([x.values() for x in self.groups.values()])
for group in itertools.chain(all_groups):
if group.id == group_id:
return group
def get_security_group_from_name(self, name, vpc_id): def get_security_group_from_name(self, name, vpc_id):
for group_id, group in self.groups[vpc_id].iteritems(): for group_id, group in self.groups[vpc_id].iteritems():
if group.name == name: if group.name == name:
@ -383,8 +393,14 @@ class SecurityGroupBackend(object):
default_group = ec2_backend.create_security_group("default", "The default security group", force=True) default_group = ec2_backend.create_security_group("default", "The default security group", force=True)
return default_group return default_group
def authorize_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None): def authorize_security_group_ingress(self, group_name, group_id, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None):
group = self.get_security_group_from_name(group_name, vpc_id) # to auth a group in a VPC you need the group_id the name isn't enough
if group_name:
group = self.get_security_group_from_name(group_name, vpc_id)
elif group_id:
group = self.get_security_group_from_id(group_id)
source_groups = [] source_groups = []
for source_group_name in source_group_names: for source_group_name in source_group_names:
source_group = self.get_security_group_from_name(source_group_name, vpc_id) source_group = self.get_security_group_from_name(source_group_name, vpc_id)
@ -394,7 +410,7 @@ class SecurityGroupBackend(object):
security_rule = SecurityRule(ip_protocol, from_port, to_port, ip_ranges, source_groups) security_rule = SecurityRule(ip_protocol, from_port, to_port, ip_ranges, source_groups)
group.ingress_rules.append(security_rule) group.ingress_rules.append(security_rule)
def revoke_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None): def revoke_security_group_ingress(self, group_name, group_id, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None):
group = self.get_security_group_from_name(group_name, vpc_id) group = self.get_security_group_from_name(group_name, vpc_id)
source_groups = [] source_groups = []
for source_group_name in source_group_names: for source_group_name in source_group_names:

View File

@ -5,7 +5,15 @@ from moto.ec2.models import ec2_backend
def process_rules_from_querystring(querystring): def process_rules_from_querystring(querystring):
name = querystring.get('GroupName')[0]
name = None
group_id = None
try:
name = querystring.get('GroupName')[0]
except:
group_id = querystring.get('GroupId')[0]
ip_protocol = querystring.get('IpPermissions.1.IpProtocol')[0] ip_protocol = querystring.get('IpPermissions.1.IpProtocol')[0]
from_port = querystring.get('IpPermissions.1.FromPort')[0] from_port = querystring.get('IpPermissions.1.FromPort')[0]
to_port = querystring.get('IpPermissions.1.ToPort')[0] to_port = querystring.get('IpPermissions.1.ToPort')[0]
@ -18,7 +26,7 @@ def process_rules_from_querystring(querystring):
for key, value in querystring.iteritems(): for key, value in querystring.iteritems():
if 'IpPermissions.1.Groups' in key: if 'IpPermissions.1.Groups' in key:
source_groups.append(value[0]) source_groups.append(value[0])
return (name, ip_protocol, from_port, to_port, ip_ranges, source_groups) return (name, group_id, ip_protocol, from_port, to_port, ip_ranges, source_groups)
class SecurityGroups(BaseResponse): class SecurityGroups(BaseResponse):

View File

@ -133,7 +133,7 @@ def test_authorize_other_group_and_revoke():
security_group.rules.should.have.length_of(0) security_group.rules.should.have.length_of(0)
@mock_ec2 @mock_ec2
def test_authorize_ip_in_vpc(): def test_authorize_group_in_vpc():
conn = boto.connect_ec2('the_key', 'the_secret') conn = boto.connect_ec2('the_key', 'the_secret')
vpc_id = "vpc-12345" vpc_id = "vpc-12345"
@ -142,5 +142,6 @@ def test_authorize_ip_in_vpc():
security_group2 = conn.create_security_group('test2', 'test2', vpc_id) security_group2 = conn.create_security_group('test2', 'test2', vpc_id)
success = security_group1.authorize(ip_protocol="tcp", from_port="22", to_port="2222", src_group=security_group2) success = security_group1.authorize(ip_protocol="tcp", from_port="22", to_port="2222", src_group=security_group2)
success = security_group1.revoke(ip_protocol="tcp", from_port="22", to_port="2222", src_group=security_group2)