Allow passing security groups by ID when creating instances

This commit is contained in:
Chris St. Pierre 2014-05-07 09:16:28 -04:00
parent b244457c4f
commit 745368242e
3 changed files with 24 additions and 5 deletions

View File

@ -107,7 +107,10 @@ class InstanceBackend(object):
new_reservation = Reservation() new_reservation = Reservation()
new_reservation.id = random_reservation_id() new_reservation.id = random_reservation_id()
security_groups = [self.get_security_group_from_name(name) for name in security_group_names] security_groups = [self.get_security_group_from_name(name)
for name in security_group_names]
security_groups.extend(self.get_security_group_from_id(sg_id)
for sg_id in kwargs.pop("security_group_ids", []))
for index in range(count): for index in range(count):
new_instance = Instance( new_instance = Instance(
image_id, image_id,

View File

@ -9,7 +9,8 @@ from moto.ec2.exceptions import InvalidIdError
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
def _get_multi_param(self, param_prefix): def _get_multi_param(self, param_prefix):
return [value[0] for key, value in self.querystring.items() if key.startswith(param_prefix)] return [value[0] for key, value in self.querystring.items()
if key.startswith(param_prefix + ".")]
def describe_instances(self): def describe_instances(self):
instance_ids = instance_ids_from_querystring(self.querystring) instance_ids = instance_ids_from_querystring(self.querystring)
@ -33,13 +34,14 @@ class InstanceResponse(BaseResponse):
image_id = self.querystring.get('ImageId')[0] image_id = self.querystring.get('ImageId')[0]
user_data = self.querystring.get('UserData') user_data = self.querystring.get('UserData')
security_group_names = self._get_multi_param('SecurityGroup') security_group_names = self._get_multi_param('SecurityGroup')
security_group_ids = self._get_multi_param('SecurityGroupId')
instance_type = self.querystring.get("InstanceType", ["m1.small"])[0] instance_type = self.querystring.get("InstanceType", ["m1.small"])[0]
subnet_id = self.querystring.get("SubnetId", [None])[0] subnet_id = self.querystring.get("SubnetId", [None])[0]
key_name = self.querystring.get("KeyName", [None])[0] key_name = self.querystring.get("KeyName", [None])[0]
new_reservation = ec2_backend.add_instances( new_reservation = ec2_backend.add_instances(
image_id, min_count, user_data, security_group_names, image_id, min_count, user_data, security_group_names,
instance_type=instance_type, subnet_id=subnet_id, instance_type=instance_type, subnet_id=subnet_id,
key_name=key_name) key_name=key_name, security_group_ids=security_group_ids)
template = Template(EC2_RUN_INSTANCES) template = Template(EC2_RUN_INSTANCES)
return template.render(reservation=new_reservation) return template.render(reservation=new_reservation)

View File

@ -173,11 +173,25 @@ def test_user_data_with_run_instance():
@mock_ec2 @mock_ec2
def test_run_instance_with_security_group(): def test_run_instance_with_security_group_name():
conn = boto.connect_ec2('the_key', 'the_secret') conn = boto.connect_ec2('the_key', 'the_secret')
group = conn.create_security_group('group1', "some description") group = conn.create_security_group('group1', "some description")
reservation = conn.run_instances('ami-1234abcd', security_groups=['group1']) reservation = conn.run_instances('ami-1234abcd',
security_groups=['group1'])
instance = reservation.instances[0]
instance.groups[0].id.should.equal(group.id)
instance.groups[0].name.should.equal("group1")
@mock_ec2
def test_run_instance_with_security_group_id():
conn = boto.connect_ec2('the_key', 'the_secret')
group = conn.create_security_group('group1', "some description")
reservation = conn.run_instances('ami-1234abcd',
security_group_ids=[group.id])
instance = reservation.instances[0] instance = reservation.instances[0]
instance.groups[0].id.should.equal(group.id) instance.groups[0].id.should.equal(group.id)