Added keypair filtering

This commit is contained in:
Nuwan Goonasekera 2017-09-14 14:59:13 +05:30
parent ace54787c0
commit d0c610c5ac
3 changed files with 58 additions and 30 deletions

View File

@ -795,16 +795,29 @@ class InstanceBackend(object):
return reservations return reservations
class KeyPair(object):
def __init__(self, name, fingerprint, material):
self.name = name
self.fingerprint = fingerprint
self.material = material
def get_filter_value(self, filter_name):
if filter_name == 'key-name':
return self.name
elif filter_name == 'fingerprint':
return self.fingerprint
class KeyPairBackend(object): class KeyPairBackend(object):
def __init__(self): def __init__(self):
self.keypairs = defaultdict(dict) self.keypairs = {}
super(KeyPairBackend, self).__init__() super(KeyPairBackend, self).__init__()
def create_key_pair(self, name): def create_key_pair(self, name):
if name in self.keypairs: if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name) raise InvalidKeyPairDuplicateError(name)
self.keypairs[name] = keypair = random_key_pair() keypair = KeyPair(name, **random_key_pair())
keypair['name'] = name self.keypairs[name] = keypair
return keypair return keypair
def delete_key_pair(self, name): def delete_key_pair(self, name):
@ -812,24 +825,27 @@ class KeyPairBackend(object):
self.keypairs.pop(name) self.keypairs.pop(name)
return True return True
def describe_key_pairs(self, filter_names=None): def describe_key_pairs(self, key_names=None, filters=None):
results = [] results = []
for name, keypair in self.keypairs.items(): if key_names:
if not filter_names or name in filter_names: results = [keypair for keypair in self.keypairs.values()
keypair['name'] = name if keypair.name in key_names]
results.append(keypair) if len(key_names) > len(results):
unknown_keys = set(key_names) - set(results)
raise InvalidKeyPairNameError(unknown_keys)
else:
results = self.keypairs.values()
# TODO: Trim error message down to specific invalid name. if filters:
if filter_names and len(filter_names) > len(results): return generic_filter(filters, results)
raise InvalidKeyPairNameError(filter_names) else:
return results
return results
def import_key_pair(self, key_name, public_key_material): def import_key_pair(self, key_name, public_key_material):
if key_name in self.keypairs: if key_name in self.keypairs:
raise InvalidKeyPairDuplicateError(key_name) raise InvalidKeyPairDuplicateError(key_name)
self.keypairs[key_name] = keypair = random_key_pair() keypair = KeyPair(key_name, **random_key_pair())
keypair['name'] = key_name self.keypairs[key_name] = keypair
return keypair return keypair

View File

@ -11,7 +11,7 @@ class KeyPairs(BaseResponse):
if self.is_not_dryrun('CreateKeyPair'): if self.is_not_dryrun('CreateKeyPair'):
keypair = self.ec2_backend.create_key_pair(name) keypair = self.ec2_backend.create_key_pair(name)
template = self.response_template(CREATE_KEY_PAIR_RESPONSE) template = self.response_template(CREATE_KEY_PAIR_RESPONSE)
return template.render(**keypair) return template.render(keypair=keypair)
def delete_key_pair(self): def delete_key_pair(self):
name = self.querystring.get('KeyName')[0] name = self.querystring.get('KeyName')[0]
@ -23,11 +23,7 @@ class KeyPairs(BaseResponse):
def describe_key_pairs(self): def describe_key_pairs(self):
names = keypair_names_from_querystring(self.querystring) names = keypair_names_from_querystring(self.querystring)
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
if len(filters) > 0: keypairs = self.ec2_backend.describe_key_pairs(names, filters)
raise NotImplementedError(
'Using filters in KeyPairs.describe_key_pairs is not yet implemented')
keypairs = self.ec2_backend.describe_key_pairs(names)
template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE) template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE)
return template.render(keypairs=keypairs) return template.render(keypairs=keypairs)
@ -37,7 +33,7 @@ class KeyPairs(BaseResponse):
if self.is_not_dryrun('ImportKeyPair'): if self.is_not_dryrun('ImportKeyPair'):
keypair = self.ec2_backend.import_key_pair(name, material) keypair = self.ec2_backend.import_key_pair(name, material)
template = self.response_template(IMPORT_KEYPAIR_RESPONSE) template = self.response_template(IMPORT_KEYPAIR_RESPONSE)
return template.render(**keypair) return template.render(keypair=keypair)
DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -54,12 +50,9 @@ DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.ama
CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<keyName>{{ name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
{{ fingerprint }} <keyMaterial>{{ keypair.material }}</keyMaterial>
</keyFingerprint>
<keyMaterial>{{ material }}
</keyMaterial>
</CreateKeyPairResponse>""" </CreateKeyPairResponse>"""
@ -71,6 +64,6 @@ DELETE_KEY_PAIR_RESPONSE = """<DeleteKeyPairResponse xmlns="http://ec2.amazonaws
IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ImportKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <ImportKeyPairResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>471f9fdd-8fe2-4a84-86b0-bd3d3e350979</requestId> <requestId>471f9fdd-8fe2-4a84-86b0-bd3d3e350979</requestId>
<keyName>{{ name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ fingerprint }}</keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
</ImportKeyPairResponse>""" </ImportKeyPairResponse>"""

View File

@ -130,3 +130,22 @@ def test_key_pairs_import_exist():
cm.exception.code.should.equal('InvalidKeyPair.Duplicate') cm.exception.code.should.equal('InvalidKeyPair.Duplicate')
cm.exception.status.should.equal(400) cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_key_pair_filters():
conn = boto.connect_ec2('the_key', 'the_secret')
_ = conn.create_key_pair('kpfltr1')
kp2 = conn.create_key_pair('kpfltr2')
kp3 = conn.create_key_pair('kpfltr3')
kp_by_name = conn.get_all_key_pairs(
filters={'key-name': 'kpfltr2'})
set([kp.name for kp in kp_by_name]
).should.equal(set([kp2.name]))
kp_by_name = conn.get_all_key_pairs(
filters={'fingerprint': kp3.fingerprint})
set([kp.name for kp in kp_by_name]
).should.equal(set([kp3.name]))