From 444ab96b4f656ff104c7f12bbf5675f70c21cafc Mon Sep 17 00:00:00 2001 From: Omer Katz Date: Wed, 1 Oct 2014 15:33:12 +0300 Subject: [PATCH 1/3] Added the ability to filter by tag-key. --- moto/ec2/models.py | 9 +++++--- moto/ec2/utils.py | 13 ++++++++++++ tests/test_ec2/test_vpcs.py | 41 ++++++++++++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index a9e9f0ad7..b2e6b050e 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals -import six import copy import itertools from collections import defaultdict +import six import boto from boto.ec2.instance import Instance as BotoInstance, Reservation from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType @@ -70,7 +70,7 @@ from .utils import ( random_volume_id, random_vpc_id, random_vpc_peering_connection_id, -) + is_filter_matching) class InstanceState(object): @@ -93,6 +93,9 @@ class TaggedEC2Instance(object): if tag['key'] == tagname: return tag['value'] + if filter_name == 'tag-key': + return [tag['key'] for tag in tags] + class NetworkInterface(object): def __init__(self, subnet, private_ip_address, device_index=0, public_ip_auto_assign=True, group_ids=None): @@ -1194,7 +1197,7 @@ class VPCBackend(object): if filters: for (_filter, _filter_value) in filters.items(): - vpcs = [ vpc for vpc in vpcs if vpc.get_filter_value(_filter) in _filter_value ] + vpcs = [ vpc for vpc in vpcs if is_filter_matching(vpc, _filter, _filter_value) ] return vpcs diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index 181732856..a6935aa2c 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -278,6 +278,19 @@ def filter_reservations(reservations, filter_dict): return result +def is_filter_matching(obj, filter, filter_value): + value = obj.get_filter_value(filter) + + if isinstance(value, six.string_types): + return value in filter_value + + try: + value = set(value) + return value.issubset(filter_value) or value.issuperset(filter_value) + except TypeError: + return value in filter_value + + # not really random ( http://xkcd.com/221/ ) def random_key_pair(): return { diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index d9d06fd68..47dc855c4 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals # Ensure 'assert_raises' context manager support for Python 2.6 -#import tests.backport_assert_raises +# import tests.backport_assert_raises from nose.tools import assert_raises import boto @@ -128,4 +128,43 @@ def test_vpc_get_by_tag(): vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) + vpc2.id.should.be.within(vpc_ids) + + +@mock_ec2 +def test_vpc_get_by_tag_key_superset(): + conn = boto.connect_vpc() + vpc1 = conn.create_vpc("10.0.0.0/16") + vpc2 = conn.create_vpc("10.0.0.0/16") + vpc3 = conn.create_vpc("10.0.0.0/24") + + vpc1.add_tag('Name', 'TestVPC') + vpc1.add_tag('Key', 'TestVPC2') + vpc2.add_tag('Name', 'TestVPC') + vpc2.add_tag('Key', 'TestVPC2') + vpc3.add_tag('Key', 'TestVPC2') + + vpcs = conn.get_all_vpcs(filters={'tag-key': 'Name'}) + vpcs.should.have.length_of(2) + vpc_ids = tuple(map(lambda v: v.id, vpcs)) + vpc1.id.should.be.within(vpc_ids) + vpc2.id.should.be.within(vpc_ids) + +@mock_ec2 +def test_vpc_get_by_tag_key_subset(): + conn = boto.connect_vpc() + vpc1 = conn.create_vpc("10.0.0.0/16") + vpc2 = conn.create_vpc("10.0.0.0/16") + vpc3 = conn.create_vpc("10.0.0.0/24") + + vpc1.add_tag('Name', 'TestVPC') + vpc1.add_tag('Key', 'TestVPC2') + vpc2.add_tag('Name', 'TestVPC') + vpc2.add_tag('Key', 'TestVPC2') + vpc3.add_tag('Test', 'TestVPC2') + + vpcs = conn.get_all_vpcs(filters={'tag-key': ['Name', 'Key']}) + vpcs.should.have.length_of(2) + vpc_ids = tuple(map(lambda v: v.id, vpcs)) + vpc1.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids) \ No newline at end of file From 298cf65569388d39026229673522a4fa0ac99fe9 Mon Sep 17 00:00:00 2001 From: Omer Katz Date: Wed, 1 Oct 2014 15:44:54 +0300 Subject: [PATCH 2/3] Added the ability to filter by tag-value and refactored the filters to be generic. --- moto/ec2/models.py | 13 ++++++------ moto/ec2/utils.py | 10 +++++++++- tests/test_ec2/test_vpcs.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index b2e6b050e..6f9ada1f3 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -70,7 +70,7 @@ from .utils import ( random_volume_id, random_vpc_id, random_vpc_peering_connection_id, - is_filter_matching) + generic_filter) class InstanceState(object): @@ -96,6 +96,9 @@ class TaggedEC2Instance(object): if filter_name == 'tag-key': return [tag['key'] for tag in tags] + if filter_name == 'tag-value': + return [tag['value'] for tag in tags] + class NetworkInterface(object): def __init__(self, subnet, private_ip_address, device_index=0, public_ip_auto_assign=True, group_ids=None): @@ -1156,7 +1159,7 @@ class VPC(TaggedEC2Instance): filter_value = super(VPC, self).get_filter_value(filter_name) - if not filter_value: + if filter_value is None: msg = "The filter '{0}' for DescribeVPCs has not been" \ " implemented in Moto yet. Feel free to open an issue at" \ " https://github.com/spulec/moto/issues".format(filter_name) @@ -1195,11 +1198,7 @@ class VPCBackend(object): else: vpcs = self.vpcs.values() - if filters: - for (_filter, _filter_value) in filters.items(): - vpcs = [ vpc for vpc in vpcs if is_filter_matching(vpc, _filter, _filter_value) ] - - return vpcs + return generic_filter(filters, vpcs) def delete_vpc(self, vpc_id): # Delete route table if only main route table remains. diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index a6935aa2c..8a0c702d3 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -286,11 +286,19 @@ def is_filter_matching(obj, filter, filter_value): try: value = set(value) - return value.issubset(filter_value) or value.issuperset(filter_value) + return (value and value.issubset(filter_value)) or value.issuperset(filter_value) except TypeError: return value in filter_value +def generic_filter(filters, objects): + if filters: + for (_filter, _filter_value) in filters.items(): + objects = [obj for obj in objects if is_filter_matching(obj, _filter, _filter_value)] + + return objects + + # not really random ( http://xkcd.com/221/ ) def random_key_pair(): return { diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index 47dc855c4..9f3bc8351 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -150,6 +150,7 @@ def test_vpc_get_by_tag_key_superset(): vpc1.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids) + @mock_ec2 def test_vpc_get_by_tag_key_subset(): conn = boto.connect_vpc() @@ -167,4 +168,43 @@ def test_vpc_get_by_tag_key_subset(): vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) + vpc2.id.should.be.within(vpc_ids) + + +@mock_ec2 +def test_vpc_get_by_tag_value_superset(): + conn = boto.connect_vpc() + vpc1 = conn.create_vpc("10.0.0.0/16") + vpc2 = conn.create_vpc("10.0.0.0/16") + vpc3 = conn.create_vpc("10.0.0.0/24") + + vpc1.add_tag('Name', 'TestVPC') + vpc1.add_tag('Key', 'TestVPC2') + vpc2.add_tag('Name', 'TestVPC') + vpc2.add_tag('Key', 'TestVPC2') + vpc3.add_tag('Key', 'TestVPC2') + + vpcs = conn.get_all_vpcs(filters={'tag-value': 'TestVPC'}) + vpcs.should.have.length_of(2) + vpc_ids = tuple(map(lambda v: v.id, vpcs)) + vpc1.id.should.be.within(vpc_ids) + vpc2.id.should.be.within(vpc_ids) + + +@mock_ec2 +def test_vpc_get_by_tag_value_subset(): + conn = boto.connect_vpc() + vpc1 = conn.create_vpc("10.0.0.0/16") + vpc2 = conn.create_vpc("10.0.0.0/16") + vpc3 = conn.create_vpc("10.0.0.0/24") + + vpc1.add_tag('Name', 'TestVPC') + vpc1.add_tag('Key', 'TestVPC2') + vpc2.add_tag('Name', 'TestVPC') + vpc2.add_tag('Key', 'TestVPC2') + + vpcs = conn.get_all_vpcs(filters={'tag-value': ['TestVPC', 'TestVPC2']}) + vpcs.should.have.length_of(2) + vpc_ids = tuple(map(lambda v: v.id, vpcs)) + vpc1.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids) \ No newline at end of file From efa687f41d03bb5d38013d5a8d6b4027e0f94239 Mon Sep 17 00:00:00 2001 From: Omer Katz Date: Wed, 1 Oct 2014 16:17:56 +0300 Subject: [PATCH 3/3] Added tag filters to some of the entities. --- moto/ec2/models.py | 80 +++++++++++++++++-------------------- tests/test_ec2/test_vpcs.py | 2 +- 2 files changed, 37 insertions(+), 45 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index e11a668f9..03dc29c5b 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -88,11 +88,13 @@ class TaggedEC2Instance(object): tags = self.get_tags() if filter_name.startswith('tag:'): - tagname = filter_name.split('tag:')[1] + tagname = filter_name.replace('tag:', '', 1) for tag in tags: if tag['key'] == tagname: return tag['value'] + return '' + if filter_name == 'tag-key': return [tag['key'] for tag in tags] @@ -621,13 +623,14 @@ class Ami(TaggedEC2Instance): return self.id elif filter_name == 'state': return self.state - elif filter_name.startswith('tag:'): - tag_name = filter_name.replace('tag:', '', 1) - tags = dict((tag['key'], tag['value']) for tag in self.get_tags()) - return tags.get(tag_name) - else: + + filter_value = super(Ami, self).get_filter_value(filter_name) + + if filter_value is None: ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeImages".format(filter_name)) + return filter_value + class AmiBackend(object): def __init__(self): @@ -645,9 +648,8 @@ class AmiBackend(object): def describe_images(self, ami_ids=(), filters=None): if filters: images = self.amis.values() - for (_filter, _filter_value) in filters.items(): - images = [ ami for ami in images if ami.get_filter_value(_filter) in _filter_value ] - return images + + return generic_filter(filters, images) else: images = [] for ami_id in ami_ids: @@ -1166,10 +1168,7 @@ class VPC(TaggedEC2Instance): filter_value = super(VPC, self).get_filter_value(filter_name) if filter_value is None: - msg = "The filter '{0}' for DescribeVPCs has not been" \ - " implemented in Moto yet. Feel free to open an issue at" \ - " https://github.com/spulec/moto/issues".format(filter_name) - raise NotImplementedError(msg) + ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeVPCs".format(filter_name)) return filter_value @@ -1348,11 +1347,13 @@ class Subnet(TaggedEC2Instance): return self.vpc_id elif filter_name == 'subnet-id': return self.id - else: - msg = "The filter '{0}' for DescribeSubnets has not been" \ - " implemented in Moto yet. Feel free to open an issue at" \ - " https://github.com/spulec/moto/issues".format(filter_name) - raise NotImplementedError(msg) + + filter_value = super(Subnet, self).get_filter_value(filter_name) + + if filter_value is None: + ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSubnets".format(filter_name)) + + return filter_value class SubnetBackend(object): @@ -1376,11 +1377,7 @@ class SubnetBackend(object): def get_all_subnets(self, filters=None): subnets = self.subnets.values() - if filters: - for (_filter, _filter_value) in filters.items(): - subnets = [ subnet for subnet in subnets if subnet.get_filter_value(_filter) in _filter_value ] - - return subnets + return generic_filter(filters, subnets) def delete_subnet(self, subnet_id): deleted = self.subnets.pop(subnet_id, None) @@ -1419,7 +1416,7 @@ class SubnetRouteTableAssociationBackend(object): return subnet_association -class RouteTable(object): +class RouteTable(TaggedEC2Instance): def __init__(self, route_table_id, vpc_id, main=False): self.id = route_table_id self.vpc_id = vpc_id @@ -1452,11 +1449,13 @@ class RouteTable(object): return 'false' elif filter_name == "vpc-id": return self.vpc_id - else: - msg = "The filter '{0}' for DescribeRouteTables has not been" \ - " implemented in Moto yet. Feel free to open an issue at" \ - " https://github.com/spulec/moto/issues".format(filter_name) - raise NotImplementedError(msg) + + filter_value = super(RouteTable, self).get_filter_value(filter_name) + + if filter_value is None: + ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeRouteTables".format(filter_name)) + + return filter_value class RouteTableBackend(object): @@ -1490,11 +1489,7 @@ class RouteTableBackend(object): invalid_id = list(set(route_table_ids).difference(set([route_table.id for route_table in route_tables])))[0] raise InvalidRouteTableIdError(invalid_id) - if filters: - for (_filter, _filter_value) in filters.items(): - route_tables = [ route_table for route_table in route_tables if route_table.get_filter_value(_filter) in _filter_value ] - - return route_tables + return generic_filter(filters, route_tables) def delete_route_table(self, route_table_id): deleted = self.route_tables.pop(route_table_id, None) @@ -1720,13 +1715,14 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Instance): def get_filter_value(self, filter_name): if filter_name == 'state': return self.state - elif filter_name.startswith('tag:'): - tag_name = filter_name.replace('tag:', '', 1) - tags = dict((tag['key'], tag['value']) for tag in self.get_tags()) - return tags.get(tag_name) - else: + + filter_value = super(SpotInstanceRequest, self).get_filter_value(filter_name) + + if filter_value is None: ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name)) + return filter_value + @six.add_metaclass(Model) class SpotRequestBackend(object): @@ -1756,11 +1752,7 @@ class SpotRequestBackend(object): def describe_spot_instance_requests(self, filters=None): requests = self.spot_instance_requests.values() - if filters: - for (_filter, _filter_value) in filters.items(): - requests = [ request for request in requests if request.get_filter_value(_filter) in _filter_value ] - - return requests + return generic_filter(filters, requests) def cancel_spot_instance_requests(self, request_ids): requests = [] diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index 9f3bc8351..ae18e4ce9 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals # Ensure 'assert_raises' context manager support for Python 2.6 -# import tests.backport_assert_raises +import tests.backport_assert_raises from nose.tools import assert_raises import boto