Merge pull request #221 from thedrow/topic/filters

More tagging filters and refactorings
This commit is contained in:
Steve Pulec 2014-10-01 09:43:08 -04:00
commit 0a99aae99f
3 changed files with 147 additions and 53 deletions

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six
import copy import copy
import itertools import itertools
from collections import defaultdict from collections import defaultdict
import six
import boto import boto
from boto.ec2.instance import Instance as BotoInstance, Reservation from boto.ec2.instance import Instance as BotoInstance, Reservation
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
@ -70,7 +70,7 @@ from .utils import (
random_volume_id, random_volume_id,
random_vpc_id, random_vpc_id,
random_vpc_peering_connection_id, random_vpc_peering_connection_id,
) generic_filter)
class InstanceState(object): class InstanceState(object):
@ -88,11 +88,19 @@ class TaggedEC2Instance(object):
tags = self.get_tags() tags = self.get_tags()
if filter_name.startswith('tag:'): if filter_name.startswith('tag:'):
tagname = filter_name.split('tag:')[1] tagname = filter_name.replace('tag:', '', 1)
for tag in tags: for tag in tags:
if tag['key'] == tagname: if tag['key'] == tagname:
return tag['value'] return tag['value']
return ''
if filter_name == 'tag-key':
return [tag['key'] for tag in tags]
if filter_name == 'tag-value':
return [tag['value'] for tag in tags]
class NetworkInterface(object): class NetworkInterface(object):
def __init__(self, subnet, private_ip_address, device_index=0, public_ip_auto_assign=True, group_ids=None): def __init__(self, subnet, private_ip_address, device_index=0, public_ip_auto_assign=True, group_ids=None):
@ -615,13 +623,14 @@ class Ami(TaggedEC2Instance):
return self.id return self.id
elif filter_name == 'state': elif filter_name == 'state':
return self.state return self.state
elif filter_name.startswith('tag:'):
tag_name = filter_name.replace('tag:', '', 1) filter_value = super(Ami, self).get_filter_value(filter_name)
tags = dict((tag['key'], tag['value']) for tag in self.get_tags())
return tags.get(tag_name) if filter_value is None:
else:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeImages".format(filter_name)) ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeImages".format(filter_name))
return filter_value
class AmiBackend(object): class AmiBackend(object):
def __init__(self): def __init__(self):
@ -639,9 +648,8 @@ class AmiBackend(object):
def describe_images(self, ami_ids=(), filters=None): def describe_images(self, ami_ids=(), filters=None):
if filters: if filters:
images = self.amis.values() images = self.amis.values()
for (_filter, _filter_value) in filters.items():
images = [ ami for ami in images if ami.get_filter_value(_filter) in _filter_value ] return generic_filter(filters, images)
return images
else: else:
images = [] images = []
for ami_id in ami_ids: for ami_id in ami_ids:
@ -1159,11 +1167,8 @@ class VPC(TaggedEC2Instance):
filter_value = super(VPC, self).get_filter_value(filter_name) filter_value = super(VPC, self).get_filter_value(filter_name)
if not filter_value: if filter_value is None:
msg = "The filter '{0}' for DescribeVPCs has not been" \ ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeVPCs".format(filter_name))
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name)
raise NotImplementedError(msg)
return filter_value return filter_value
@ -1198,11 +1203,7 @@ class VPCBackend(object):
else: else:
vpcs = self.vpcs.values() vpcs = self.vpcs.values()
if filters: return generic_filter(filters, vpcs)
for (_filter, _filter_value) in filters.items():
vpcs = [ vpc for vpc in vpcs if vpc.get_filter_value(_filter) in _filter_value ]
return vpcs
def delete_vpc(self, vpc_id): def delete_vpc(self, vpc_id):
# Delete route table if only main route table remains. # Delete route table if only main route table remains.
@ -1346,11 +1347,13 @@ class Subnet(TaggedEC2Instance):
return self.vpc_id return self.vpc_id
elif filter_name == 'subnet-id': elif filter_name == 'subnet-id':
return self.id return self.id
else:
msg = "The filter '{0}' for DescribeSubnets has not been" \ filter_value = super(Subnet, self).get_filter_value(filter_name)
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name) if filter_value is None:
raise NotImplementedError(msg) ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSubnets".format(filter_name))
return filter_value
class SubnetBackend(object): class SubnetBackend(object):
@ -1374,11 +1377,7 @@ class SubnetBackend(object):
def get_all_subnets(self, filters=None): def get_all_subnets(self, filters=None):
subnets = self.subnets.values() subnets = self.subnets.values()
if filters: return generic_filter(filters, subnets)
for (_filter, _filter_value) in filters.items():
subnets = [ subnet for subnet in subnets if subnet.get_filter_value(_filter) in _filter_value ]
return subnets
def delete_subnet(self, subnet_id): def delete_subnet(self, subnet_id):
deleted = self.subnets.pop(subnet_id, None) deleted = self.subnets.pop(subnet_id, None)
@ -1417,7 +1416,7 @@ class SubnetRouteTableAssociationBackend(object):
return subnet_association return subnet_association
class RouteTable(object): class RouteTable(TaggedEC2Instance):
def __init__(self, route_table_id, vpc_id, main=False): def __init__(self, route_table_id, vpc_id, main=False):
self.id = route_table_id self.id = route_table_id
self.vpc_id = vpc_id self.vpc_id = vpc_id
@ -1450,11 +1449,13 @@ class RouteTable(object):
return 'false' return 'false'
elif filter_name == "vpc-id": elif filter_name == "vpc-id":
return self.vpc_id return self.vpc_id
else:
msg = "The filter '{0}' for DescribeRouteTables has not been" \ filter_value = super(RouteTable, self).get_filter_value(filter_name)
" implemented in Moto yet. Feel free to open an issue at" \
" https://github.com/spulec/moto/issues".format(filter_name) if filter_value is None:
raise NotImplementedError(msg) ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeRouteTables".format(filter_name))
return filter_value
class RouteTableBackend(object): class RouteTableBackend(object):
@ -1488,11 +1489,7 @@ class RouteTableBackend(object):
invalid_id = list(set(route_table_ids).difference(set([route_table.id for route_table in route_tables])))[0] invalid_id = list(set(route_table_ids).difference(set([route_table.id for route_table in route_tables])))[0]
raise InvalidRouteTableIdError(invalid_id) raise InvalidRouteTableIdError(invalid_id)
if filters: return generic_filter(filters, route_tables)
for (_filter, _filter_value) in filters.items():
route_tables = [ route_table for route_table in route_tables if route_table.get_filter_value(_filter) in _filter_value ]
return route_tables
def delete_route_table(self, route_table_id): def delete_route_table(self, route_table_id):
deleted = self.route_tables.pop(route_table_id, None) deleted = self.route_tables.pop(route_table_id, None)
@ -1718,13 +1715,14 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Instance):
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
if filter_name == 'state': if filter_name == 'state':
return self.state return self.state
elif filter_name.startswith('tag:'):
tag_name = filter_name.replace('tag:', '', 1) filter_value = super(SpotInstanceRequest, self).get_filter_value(filter_name)
tags = dict((tag['key'], tag['value']) for tag in self.get_tags())
return tags.get(tag_name) if filter_value is None:
else:
ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name)) ec2_backend.raise_not_implemented_error("The filter '{0}' for DescribeSpotInstanceRequests".format(filter_name))
return filter_value
@six.add_metaclass(Model) @six.add_metaclass(Model)
class SpotRequestBackend(object): class SpotRequestBackend(object):
@ -1754,11 +1752,7 @@ class SpotRequestBackend(object):
def describe_spot_instance_requests(self, filters=None): def describe_spot_instance_requests(self, filters=None):
requests = self.spot_instance_requests.values() requests = self.spot_instance_requests.values()
if filters: return generic_filter(filters, requests)
for (_filter, _filter_value) in filters.items():
requests = [ request for request in requests if request.get_filter_value(_filter) in _filter_value ]
return requests
def cancel_spot_instance_requests(self, request_ids): def cancel_spot_instance_requests(self, request_ids):
requests = [] requests = []

View File

@ -278,6 +278,27 @@ def filter_reservations(reservations, filter_dict):
return result return result
def is_filter_matching(obj, filter, filter_value):
value = obj.get_filter_value(filter)
if isinstance(value, six.string_types):
return value in filter_value
try:
value = set(value)
return (value and value.issubset(filter_value)) or value.issuperset(filter_value)
except TypeError:
return value in filter_value
def generic_filter(filters, objects):
if filters:
for (_filter, _filter_value) in filters.items():
objects = [obj for obj in objects if is_filter_matching(obj, _filter, _filter_value)]
return objects
# not really random ( http://xkcd.com/221/ ) # not really random ( http://xkcd.com/221/ )
def random_key_pair(): def random_key_pair():
return { return {

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# Ensure 'assert_raises' context manager support for Python 2.6 # Ensure 'assert_raises' context manager support for Python 2.6
#import tests.backport_assert_raises import tests.backport_assert_raises
from nose.tools import assert_raises from nose.tools import assert_raises
import boto import boto
@ -129,3 +129,82 @@ def test_vpc_get_by_tag():
vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids) vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids)
@mock_ec2
def test_vpc_get_by_tag_key_superset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")
vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Key', 'TestVPC2')
vpcs = conn.get_all_vpcs(filters={'tag-key': 'Name'})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)
@mock_ec2
def test_vpc_get_by_tag_key_subset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")
vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Test', 'TestVPC2')
vpcs = conn.get_all_vpcs(filters={'tag-key': ['Name', 'Key']})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)
@mock_ec2
def test_vpc_get_by_tag_value_superset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")
vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpc3.add_tag('Key', 'TestVPC2')
vpcs = conn.get_all_vpcs(filters={'tag-value': 'TestVPC'})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)
@mock_ec2
def test_vpc_get_by_tag_value_subset():
conn = boto.connect_vpc()
vpc1 = conn.create_vpc("10.0.0.0/16")
vpc2 = conn.create_vpc("10.0.0.0/16")
vpc3 = conn.create_vpc("10.0.0.0/24")
vpc1.add_tag('Name', 'TestVPC')
vpc1.add_tag('Key', 'TestVPC2')
vpc2.add_tag('Name', 'TestVPC')
vpc2.add_tag('Key', 'TestVPC2')
vpcs = conn.get_all_vpcs(filters={'tag-value': ['TestVPC', 'TestVPC2']})
vpcs.should.have.length_of(2)
vpc_ids = tuple(map(lambda v: v.id, vpcs))
vpc1.id.should.be.within(vpc_ids)
vpc2.id.should.be.within(vpc_ids)