Merge pull request #1157 from gvlproject/fix_security_group_filters

Raise InvalidGroup.NotFound in DescribeSecurityGroups
This commit is contained in:
Jack Danger 2017-09-18 12:58:51 -07:00 committed by GitHub
commit 0e33f44bbe
6 changed files with 120 additions and 45 deletions

View File

@ -1360,22 +1360,25 @@ class SecurityGroupBackend(object):
return group
def describe_security_groups(self, group_ids=None, groupnames=None, filters=None):
all_groups = itertools.chain(*[x.values()
for x in self.groups.values()])
groups = []
matches = itertools.chain(*[x.values()
for x in self.groups.values()])
if group_ids:
matches = [grp for grp in matches
if grp.id in group_ids]
if len(group_ids) > len(matches):
unknown_ids = set(group_ids) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_ids)
if groupnames:
matches = [grp for grp in matches
if grp.name in groupnames]
if len(groupnames) > len(matches):
unknown_names = set(groupnames) - set(matches)
raise InvalidSecurityGroupNotFoundError(unknown_names)
if filters:
matches = [grp for grp in matches
if grp.matches_filters(filters)]
if group_ids or groupnames or filters:
for group in all_groups:
if ((group_ids and group.id not in group_ids) or
(groupnames and group.name not in groupnames)):
continue
if filters and not group.matches_filters(filters):
continue
groups.append(group)
else:
groups = all_groups
return groups
return matches
def _delete_security_group(self, vpc_id, group_id):
if self.groups[vpc_id][group_id].enis:
@ -1772,11 +1775,17 @@ class EBSBackend(object):
self.volumes[volume_id] = volume
return volume
def describe_volumes(self, filters=None):
def describe_volumes(self, volume_ids=None, filters=None):
matches = self.volumes.values()
if volume_ids:
matches = [vol for vol in matches
if vol.id in volume_ids]
if len(volume_ids) > len(matches):
unknown_ids = set(volume_ids) - set(matches)
raise InvalidVolumeIdError(unknown_ids)
if filters:
volumes = self.volumes.values()
return generic_filter(filters, volumes)
return self.volumes.values()
matches = generic_filter(filters, matches)
return matches
def get_volume(self, volume_id):
volume = self.volumes.get(volume_id, None)
@ -1824,11 +1833,17 @@ class EBSBackend(object):
self.snapshots[snapshot_id] = snapshot
return snapshot
def describe_snapshots(self, filters=None):
def describe_snapshots(self, snapshot_ids=None, filters=None):
matches = self.snapshots.values()
if snapshot_ids:
matches = [snap for snap in matches
if snap.id in snapshot_ids]
if len(snapshot_ids) > len(matches):
unknown_ids = set(snapshot_ids) - set(matches)
raise InvalidSnapshotIdError(unknown_ids)
if filters:
snapshots = self.snapshots.values()
return generic_filter(filters, snapshots)
return self.snapshots.values()
matches = generic_filter(filters, matches)
return matches
def get_snapshot(self, snapshot_id):
snapshot = self.snapshots.get(snapshot_id, None)
@ -1947,12 +1962,16 @@ class VPCBackend(object):
return self.vpcs.get(vpc_id)
def get_all_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.values()
if vpc_ids:
vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids]
else:
vpcs = self.vpcs.values()
return generic_filter(filters, vpcs)
matches = [vpc for vpc in matches
if vpc.id in vpc_ids]
if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches)
raise InvalidVPCIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains.
@ -2189,16 +2208,19 @@ class SubnetBackend(object):
return subnet
def get_all_subnets(self, subnet_ids=None, filters=None):
subnets = []
# Extract a list of all subnets
matches = itertools.chain(*[x.values()
for x in self.subnets.values()])
if subnet_ids:
for subnet_id in subnet_ids:
for items in self.subnets.values():
if subnet_id in items:
subnets.append(items[subnet_id])
else:
for items in self.subnets.values():
subnets.extend(items.values())
return generic_filter(filters, subnets)
matches = [sn for sn in matches
if sn.id in subnet_ids]
if len(subnet_ids) > len(matches):
unknown_ids = set(subnet_ids) - set(matches)
raise InvalidSubnetIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_subnet(self, subnet_id):
for subnets in self.subnets.values():

View File

@ -54,20 +54,14 @@ class ElasticBlockStore(BaseResponse):
def describe_snapshots(self):
filters = filters_from_querystring(self.querystring)
snapshot_ids = self._get_multi_param('SnapshotId')
snapshots = self.ec2_backend.describe_snapshots(filters=filters)
# Describe snapshots to handle filter on snapshot_ids
snapshots = [
s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots
snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters)
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots)
def describe_volumes(self):
filters = filters_from_querystring(self.querystring)
volume_ids = self._get_multi_param('VolumeId')
volumes = self.ec2_backend.describe_volumes(filters=filters)
# Describe volumes to handle filter on volume_ids
volumes = [
v for v in volumes if v.id in volume_ids] if volume_ids else volumes
volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters)
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes)

View File

@ -83,6 +83,12 @@ def test_filter_volume_by_id():
vol2 = conn.get_all_volumes(volume_ids=[volume1.id, volume2.id])
vol2.should.have.length_of(2)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_volumes(volume_ids=['vol-does_not_exist'])
cm.exception.code.should.equal('InvalidVolume.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_volume_filters():
@ -302,6 +308,12 @@ def test_filter_snapshot_by_id():
s.volume_id.should.be.within([volume2.id, volume3.id])
s.region.name.should.equal(conn.region.name)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_snapshots(snapshot_ids=['snap-does_not_exist'])
cm.exception.code.should.equal('InvalidSnapshot.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_snapshot_filters():

View File

@ -348,6 +348,15 @@ def test_get_all_security_groups():
resp.should.have.length_of(1)
resp[0].id.should.equal(sg1.id)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_security_groups(groupnames=['does_not_exist'])
cm.exception.code.should.equal('InvalidGroup.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
resp.should.have.length_of(1)
resp[0].id.should.equal(sg1.id)
resp = conn.get_all_security_groups(filters={'vpc-id': ['vpc-mjm05d27']})
resp.should.have.length_of(1)
resp[0].id.should.equal(sg1.id)
@ -681,3 +690,9 @@ def test_get_all_security_groups_filter_with_same_vpc_id():
security_groups = conn.get_all_security_groups(
group_ids=[security_group.id], filters={'vpc-id': [vpc_id]})
security_groups.should.have.length_of(1)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_security_groups(group_ids=['does_not_exist'])
cm.exception.code.should.equal('InvalidGroup.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none

View File

@ -158,6 +158,32 @@ def test_modify_subnet_attribute_validation():
SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'})
@mock_ec2_deprecated
def test_subnet_get_by_id():
ec2 = boto.ec2.connect_to_region('us-west-1')
conn = boto.vpc.connect_to_region('us-west-1')
vpcA = conn.create_vpc("10.0.0.0/16")
subnetA = conn.create_subnet(
vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a')
vpcB = conn.create_vpc("10.0.0.0/16")
subnetB1 = conn.create_subnet(
vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a')
subnetB2 = conn.create_subnet(
vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b')
subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id])
subnets_by_id.should.have.length_of(2)
subnets_by_id = tuple(map(lambda s: s.id, subnets_by_id))
subnetA.id.should.be.within(subnets_by_id)
subnetB1.id.should.be.within(subnets_by_id)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_subnets(subnet_ids=['subnet-does_not_exist'])
cm.exception.code.should.equal('InvalidSubnetID.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_get_subnets_filtering():
ec2 = boto.ec2.connect_to_region('us-west-1')

View File

@ -113,6 +113,12 @@ def test_vpc_get_by_id():
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist'])
cm.exception.code.should.equal('InvalidVpcID.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_vpc_get_by_cidr_block():