From ca56955a97c98057bffb7b3b9bba9f49fd8212bf Mon Sep 17 00:00:00 2001 From: Nuwan Goonasekera Date: Mon, 18 Sep 2017 23:38:39 +0530 Subject: [PATCH] Added invalid id exceptions when filtering vpcs and subnets --- moto/ec2/models.py | 39 ++++++++++++++++++++-------------- tests/test_ec2/test_subnets.py | 26 +++++++++++++++++++++++ tests/test_ec2/test_vpcs.py | 6 ++++++ 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 4b143eeab..e7e8a1dd8 100755 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1836,8 +1836,8 @@ class EBSBackend(object): def describe_snapshots(self, snapshot_ids=None, filters=None): matches = self.snapshots.values() if snapshot_ids: - matches = [vol for vol in matches - if vol.id in snapshot_ids] + matches = [snap for snap in matches + if snap.id in snapshot_ids] if len(snapshot_ids) > len(matches): unknown_ids = set(snapshot_ids) - set(matches) raise InvalidSnapshotIdError(unknown_ids) @@ -1962,12 +1962,16 @@ class VPCBackend(object): return self.vpcs.get(vpc_id) def get_all_vpcs(self, vpc_ids=None, filters=None): + matches = self.vpcs.values() if vpc_ids: - vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids] - else: - vpcs = self.vpcs.values() - - return generic_filter(filters, vpcs) + matches = [vpc for vpc in matches + if vpc.id in vpc_ids] + if len(vpc_ids) > len(matches): + unknown_ids = set(vpc_ids) - set(matches) + raise InvalidVPCIdError(unknown_ids) + if filters: + matches = generic_filter(filters, matches) + return matches def delete_vpc(self, vpc_id): # Delete route table if only main route table remains. @@ -2204,16 +2208,19 @@ class SubnetBackend(object): return subnet def get_all_subnets(self, subnet_ids=None, filters=None): - subnets = [] + # Extract a list of all subnets + matches = itertools.chain(*[x.values() + for x in self.subnets.values()]) if subnet_ids: - for subnet_id in subnet_ids: - for items in self.subnets.values(): - if subnet_id in items: - subnets.append(items[subnet_id]) - else: - for items in self.subnets.values(): - subnets.extend(items.values()) - return generic_filter(filters, subnets) + matches = [sn for sn in matches + if sn.id in subnet_ids] + if len(subnet_ids) > len(matches): + unknown_ids = set(subnet_ids) - set(matches) + raise InvalidSubnetIdError(unknown_ids) + if filters: + matches = generic_filter(filters, matches) + + return matches def delete_subnet(self, subnet_id): for subnets in self.subnets.values(): diff --git a/tests/test_ec2/test_subnets.py b/tests/test_ec2/test_subnets.py index 38565a28f..99e6d45d8 100644 --- a/tests/test_ec2/test_subnets.py +++ b/tests/test_ec2/test_subnets.py @@ -158,6 +158,32 @@ def test_modify_subnet_attribute_validation(): SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'}) +@mock_ec2_deprecated +def test_subnet_get_by_id(): + ec2 = boto.ec2.connect_to_region('us-west-1') + conn = boto.vpc.connect_to_region('us-west-1') + vpcA = conn.create_vpc("10.0.0.0/16") + subnetA = conn.create_subnet( + vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB = conn.create_vpc("10.0.0.0/16") + subnetB1 = conn.create_subnet( + vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetB2 = conn.create_subnet( + vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + + subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id]) + subnets_by_id.should.have.length_of(2) + subnets_by_id = tuple(map(lambda s: s.id, subnets_by_id)) + subnetA.id.should.be.within(subnets_by_id) + subnetB1.id.should.be.within(subnets_by_id) + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_subnets(subnet_ids=['subnet-does_not_exist']) + cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + @mock_ec2_deprecated def test_get_subnets_filtering(): ec2 = boto.ec2.connect_to_region('us-west-1') diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index 904603f6d..fc0a93cbb 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -113,6 +113,12 @@ def test_vpc_get_by_id(): vpc1.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist']) + cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_vpc_get_by_cidr_block():