Added invalid id exceptions when filtering vpcs and subnets

This commit is contained in:
Nuwan Goonasekera 2017-09-18 23:38:39 +05:30
parent 08c4eff0b2
commit ca56955a97
3 changed files with 55 additions and 16 deletions

View File

@ -1836,8 +1836,8 @@ class EBSBackend(object):
def describe_snapshots(self, snapshot_ids=None, filters=None): def describe_snapshots(self, snapshot_ids=None, filters=None):
matches = self.snapshots.values() matches = self.snapshots.values()
if snapshot_ids: if snapshot_ids:
matches = [vol for vol in matches matches = [snap for snap in matches
if vol.id in snapshot_ids] if snap.id in snapshot_ids]
if len(snapshot_ids) > len(matches): if len(snapshot_ids) > len(matches):
unknown_ids = set(snapshot_ids) - set(matches) unknown_ids = set(snapshot_ids) - set(matches)
raise InvalidSnapshotIdError(unknown_ids) raise InvalidSnapshotIdError(unknown_ids)
@ -1962,12 +1962,16 @@ class VPCBackend(object):
return self.vpcs.get(vpc_id) return self.vpcs.get(vpc_id)
def get_all_vpcs(self, vpc_ids=None, filters=None): def get_all_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.values()
if vpc_ids: if vpc_ids:
vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids] matches = [vpc for vpc in matches
else: if vpc.id in vpc_ids]
vpcs = self.vpcs.values() if len(vpc_ids) > len(matches):
unknown_ids = set(vpc_ids) - set(matches)
return generic_filter(filters, vpcs) raise InvalidVPCIdError(unknown_ids)
if filters:
matches = generic_filter(filters, matches)
return matches
def delete_vpc(self, vpc_id): def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains. # Delete route table if only main route table remains.
@ -2204,16 +2208,19 @@ class SubnetBackend(object):
return subnet return subnet
def get_all_subnets(self, subnet_ids=None, filters=None): def get_all_subnets(self, subnet_ids=None, filters=None):
subnets = [] # Extract a list of all subnets
matches = itertools.chain(*[x.values()
for x in self.subnets.values()])
if subnet_ids: if subnet_ids:
for subnet_id in subnet_ids: matches = [sn for sn in matches
for items in self.subnets.values(): if sn.id in subnet_ids]
if subnet_id in items: if len(subnet_ids) > len(matches):
subnets.append(items[subnet_id]) unknown_ids = set(subnet_ids) - set(matches)
else: raise InvalidSubnetIdError(unknown_ids)
for items in self.subnets.values(): if filters:
subnets.extend(items.values()) matches = generic_filter(filters, matches)
return generic_filter(filters, subnets)
return matches
def delete_subnet(self, subnet_id): def delete_subnet(self, subnet_id):
for subnets in self.subnets.values(): for subnets in self.subnets.values():

View File

@ -158,6 +158,32 @@ def test_modify_subnet_attribute_validation():
SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'}) SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'})
@mock_ec2_deprecated
def test_subnet_get_by_id():
ec2 = boto.ec2.connect_to_region('us-west-1')
conn = boto.vpc.connect_to_region('us-west-1')
vpcA = conn.create_vpc("10.0.0.0/16")
subnetA = conn.create_subnet(
vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a')
vpcB = conn.create_vpc("10.0.0.0/16")
subnetB1 = conn.create_subnet(
vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a')
subnetB2 = conn.create_subnet(
vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b')
subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id])
subnets_by_id.should.have.length_of(2)
subnets_by_id = tuple(map(lambda s: s.id, subnets_by_id))
subnetA.id.should.be.within(subnets_by_id)
subnetB1.id.should.be.within(subnets_by_id)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_subnets(subnet_ids=['subnet-does_not_exist'])
cm.exception.code.should.equal('InvalidSubnetID.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated @mock_ec2_deprecated
def test_get_subnets_filtering(): def test_get_subnets_filtering():
ec2 = boto.ec2.connect_to_region('us-west-1') ec2 = boto.ec2.connect_to_region('us-west-1')

View File

@ -113,6 +113,12 @@ def test_vpc_get_by_id():
vpc1.id.should.be.within(vpc_ids) vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids)
with assert_raises(EC2ResponseError) as cm:
conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist'])
cm.exception.code.should.equal('InvalidVpcID.NotFound')
cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated @mock_ec2_deprecated
def test_vpc_get_by_cidr_block(): def test_vpc_get_by_cidr_block():