EC2: KeyPairs now support tags (#6919)

This commit is contained in:
Bert Blommers 2023-10-15 21:40:34 +00:00 committed by GitHub
parent 5794b619e2
commit 2021e564fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 116 additions and 14 deletions

View File

@ -1,8 +1,7 @@
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from moto.core import BaseModel from .core import TaggedEC2Resource
from ..exceptions import ( from ..exceptions import (
FilterNotImplementedError,
InvalidKeyPairNameError, InvalidKeyPairNameError,
InvalidKeyPairDuplicateError, InvalidKeyPairDuplicateError,
InvalidKeyPairFormatError, InvalidKeyPairFormatError,
@ -18,38 +17,55 @@ from ..utils import (
from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow
class KeyPair(BaseModel): class KeyPair(TaggedEC2Resource):
def __init__(self, name: str, fingerprint: str, material: str): def __init__(
self,
name: str,
fingerprint: str,
material: str,
tags: Dict[str, str],
ec2_backend: Any,
):
self.id = random_key_pair_id() self.id = random_key_pair_id()
self.name = name self.name = name
self.fingerprint = fingerprint self.fingerprint = fingerprint
self.material = material self.material = material
self.create_time = utcnow() self.create_time = utcnow()
self.ec2_backend = ec2_backend
self.add_tags(tags or {})
@property @property
def created_iso_8601(self) -> str: def created_iso_8601(self) -> str:
return iso_8601_datetime_with_milliseconds(self.create_time) return iso_8601_datetime_with_milliseconds(self.create_time)
def get_filter_value(self, filter_name: str) -> str: def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
) -> str:
if filter_name == "key-name": if filter_name == "key-name":
return self.name return self.name
elif filter_name == "fingerprint": elif filter_name == "fingerprint":
return self.fingerprint return self.fingerprint
else: else:
raise FilterNotImplementedError(filter_name, "DescribeKeyPairs") return super().get_filter_value(filter_name, "DescribeKeyPairs")
class KeyPairBackend: class KeyPairBackend:
def __init__(self) -> None: def __init__(self) -> None:
self.keypairs: Dict[str, KeyPair] = {} self.keypairs: Dict[str, KeyPair] = {}
def create_key_pair(self, name: str, key_type: str = "rsa") -> KeyPair: def create_key_pair(
self, name: str, key_type: str, tags: Dict[str, str]
) -> KeyPair:
if name in self.keypairs: if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name) raise InvalidKeyPairDuplicateError(name)
if key_type == "ed25519": if key_type == "ed25519":
keypair = KeyPair(name, **random_ed25519_key_pair()) keypair = KeyPair(
name, **random_ed25519_key_pair(), tags=tags, ec2_backend=self
)
else: else:
keypair = KeyPair(name, **random_rsa_key_pair()) keypair = KeyPair(
name, **random_rsa_key_pair(), tags=tags, ec2_backend=self
)
self.keypairs[name] = keypair self.keypairs[name] = keypair
return keypair return keypair
@ -77,7 +93,9 @@ class KeyPairBackend:
else: else:
return results return results
def import_key_pair(self, key_name: str, public_key_material: str) -> KeyPair: def import_key_pair(
self, key_name: str, public_key_material: str, tags: Dict[str, str]
) -> KeyPair:
if key_name in self.keypairs: if key_name in self.keypairs:
raise InvalidKeyPairDuplicateError(key_name) raise InvalidKeyPairDuplicateError(key_name)
@ -88,7 +106,11 @@ class KeyPairBackend:
fingerprint = public_key_fingerprint(public_key) fingerprint = public_key_fingerprint(public_key)
keypair = KeyPair( keypair = KeyPair(
key_name, material=public_key_material, fingerprint=fingerprint key_name,
material=public_key_material,
fingerprint=fingerprint,
tags=tags,
ec2_backend=self,
) )
self.keypairs[key_name] = keypair self.keypairs[key_name] = keypair
return keypair return keypair

View File

@ -5,8 +5,9 @@ class KeyPairs(EC2BaseResponse):
def create_key_pair(self) -> str: def create_key_pair(self) -> str:
name = self._get_param("KeyName") name = self._get_param("KeyName")
key_type = self._get_param("KeyType") key_type = self._get_param("KeyType")
tags = self._parse_tag_specification("key-pair").get("key-pair", {})
self.error_on_dryrun() self.error_on_dryrun()
keypair = self.ec2_backend.create_key_pair(name, key_type) keypair = self.ec2_backend.create_key_pair(name, key_type, tags=tags)
return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair) return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair)
def delete_key_pair(self) -> str: def delete_key_pair(self) -> str:
@ -26,9 +27,10 @@ class KeyPairs(EC2BaseResponse):
def import_key_pair(self) -> str: def import_key_pair(self) -> str:
name = self._get_param("KeyName") name = self._get_param("KeyName")
material = self._get_param("PublicKeyMaterial") material = self._get_param("PublicKeyMaterial")
tags = self._parse_tag_specification("key-pair").get("key-pair", {})
self.error_on_dryrun() self.error_on_dryrun()
keypair = self.ec2_backend.import_key_pair(name, material) keypair = self.ec2_backend.import_key_pair(name, material, tags=tags)
return self.response_template(IMPORT_KEYPAIR_RESPONSE).render(keypair=keypair) return self.response_template(IMPORT_KEYPAIR_RESPONSE).render(keypair=keypair)
@ -41,6 +43,16 @@ DESCRIBE_KEY_PAIRS_RESPONSE = """<DescribeKeyPairsResponse xmlns="http://ec2.ama
<keyPairId>{{ keypair.id }}</keyPairId> <keyPairId>{{ keypair.id }}</keyPairId>
<keyName>{{ keypair.name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
{% if keypair.get_tags() %}
<tagSet>
{% for tag in keypair.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</item> </item>
{% endfor %} {% endfor %}
</keySet> </keySet>
@ -52,6 +64,16 @@ CREATE_KEY_PAIR_RESPONSE = """<CreateKeyPairResponse xmlns="http://ec2.amazonaws
<keyName>{{ keypair.name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
<keyMaterial>{{ keypair.material }}</keyMaterial> <keyMaterial>{{ keypair.material }}</keyMaterial>
{% if keypair.get_tags() %}
<tagSet>
{% for tag in keypair.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</CreateKeyPairResponse>""" </CreateKeyPairResponse>"""
@ -66,4 +88,14 @@ IMPORT_KEYPAIR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<keyPairId>{{ keypair.id }}</keyPairId> <keyPairId>{{ keypair.id }}</keyPairId>
<keyName>{{ keypair.name }}</keyName> <keyName>{{ keypair.name }}</keyName>
<keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint> <keyFingerprint>{{ keypair.fingerprint }}</keyFingerprint>
{% if keypair.get_tags() %}
<tagSet>
{% for tag in keypair.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</ImportKeyPairResponse>""" </ImportKeyPairResponse>"""

View File

@ -254,3 +254,51 @@ def test_key_pair_filters_boto3():
Filters=[{"Name": "fingerprint", "Values": [kp3.key_fingerprint]}] Filters=[{"Name": "fingerprint", "Values": [kp3.key_fingerprint]}]
)["KeyPairs"] )["KeyPairs"]
assert set([kp["KeyName"] for kp in kp_by_name]) == set([kp3.name]) assert set([kp["KeyName"] for kp in kp_by_name]) == set([kp3.name])
@mock_ec2
def test_key_pair_with_tags():
client = boto3.client("ec2", "us-east-1")
key_name_1 = str(uuid4())[0:6]
key_name_2 = str(uuid4())[0:6]
key_name_3 = str(uuid4())[0:6]
key_name_4 = str(uuid4())[0:6]
resp = client.create_key_pair(
KeyName=key_name_1,
TagSpecifications=[
{"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val1"}]}
],
)
assert resp["Tags"] == [{"Key": "key", "Value": "val1"}]
kp1 = resp["KeyPairId"]
kp2 = client.create_key_pair(
KeyName=key_name_2,
TagSpecifications=[
{"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val2"}]}
],
)["KeyPairId"]
assert "Tags" not in client.create_key_pair(KeyName=key_name_3)
key_pairs = client.describe_key_pairs(
Filters=[{"Name": "tag-key", "Values": ["key"]}]
)["KeyPairs"]
assert [kp["KeyPairId"] for kp in key_pairs] == [kp1, kp2]
key_pairs = client.describe_key_pairs(
Filters=[{"Name": "tag:key", "Values": ["val1"]}]
)["KeyPairs"]
assert len(key_pairs) == 1
assert key_pairs[0]["KeyPairId"] == kp1
assert key_pairs[0]["Tags"] == [{"Key": "key", "Value": "val1"}]
resp = client.import_key_pair(
KeyName=key_name_4,
PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH,
TagSpecifications=[
{"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val4"}]}
],
)
assert resp["Tags"] == [{"Key": "key", "Value": "val4"}]