From 2021e564fafcdaa701b53de49bd580c8691a5fcc Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 15 Oct 2023 21:40:34 +0000 Subject: [PATCH] EC2: KeyPairs now support tags (#6919) --- moto/ec2/models/key_pairs.py | 46 ++++++++++++++++++++++-------- moto/ec2/responses/key_pairs.py | 36 ++++++++++++++++++++++-- tests/test_ec2/test_key_pairs.py | 48 ++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 14 deletions(-) diff --git a/moto/ec2/models/key_pairs.py b/moto/ec2/models/key_pairs.py index 5ab46f529..dcdd1d357 100644 --- a/moto/ec2/models/key_pairs.py +++ b/moto/ec2/models/key_pairs.py @@ -1,8 +1,7 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from moto.core import BaseModel +from .core import TaggedEC2Resource from ..exceptions import ( - FilterNotImplementedError, InvalidKeyPairNameError, InvalidKeyPairDuplicateError, InvalidKeyPairFormatError, @@ -18,38 +17,55 @@ from ..utils import ( from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow -class KeyPair(BaseModel): - def __init__(self, name: str, fingerprint: str, material: str): +class KeyPair(TaggedEC2Resource): + def __init__( + self, + name: str, + fingerprint: str, + material: str, + tags: Dict[str, str], + ec2_backend: Any, + ): self.id = random_key_pair_id() self.name = name self.fingerprint = fingerprint self.material = material self.create_time = utcnow() + self.ec2_backend = ec2_backend + self.add_tags(tags or {}) @property def created_iso_8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.create_time) - def get_filter_value(self, filter_name: str) -> str: + def get_filter_value( + self, filter_name: str, method_name: Optional[str] = None + ) -> str: if filter_name == "key-name": return self.name elif filter_name == "fingerprint": return self.fingerprint else: - raise FilterNotImplementedError(filter_name, "DescribeKeyPairs") + return super().get_filter_value(filter_name, "DescribeKeyPairs") class KeyPairBackend: def __init__(self) -> None: self.keypairs: Dict[str, KeyPair] = {} - def create_key_pair(self, name: str, key_type: str = "rsa") -> KeyPair: + def create_key_pair( + self, name: str, key_type: str, tags: Dict[str, str] + ) -> KeyPair: if name in self.keypairs: raise InvalidKeyPairDuplicateError(name) if key_type == "ed25519": - keypair = KeyPair(name, **random_ed25519_key_pair()) + keypair = KeyPair( + name, **random_ed25519_key_pair(), tags=tags, ec2_backend=self + ) else: - keypair = KeyPair(name, **random_rsa_key_pair()) + keypair = KeyPair( + name, **random_rsa_key_pair(), tags=tags, ec2_backend=self + ) self.keypairs[name] = keypair return keypair @@ -77,7 +93,9 @@ class KeyPairBackend: else: return results - def import_key_pair(self, key_name: str, public_key_material: str) -> KeyPair: + def import_key_pair( + self, key_name: str, public_key_material: str, tags: Dict[str, str] + ) -> KeyPair: if key_name in self.keypairs: raise InvalidKeyPairDuplicateError(key_name) @@ -88,7 +106,11 @@ class KeyPairBackend: fingerprint = public_key_fingerprint(public_key) keypair = KeyPair( - key_name, material=public_key_material, fingerprint=fingerprint + key_name, + material=public_key_material, + fingerprint=fingerprint, + tags=tags, + ec2_backend=self, ) self.keypairs[key_name] = keypair return keypair diff --git a/moto/ec2/responses/key_pairs.py b/moto/ec2/responses/key_pairs.py index b0b80f01c..be24226d7 100644 --- a/moto/ec2/responses/key_pairs.py +++ b/moto/ec2/responses/key_pairs.py @@ -5,8 +5,9 @@ class KeyPairs(EC2BaseResponse): def create_key_pair(self) -> str: name = self._get_param("KeyName") key_type = self._get_param("KeyType") + tags = self._parse_tag_specification("key-pair").get("key-pair", {}) self.error_on_dryrun() - keypair = self.ec2_backend.create_key_pair(name, key_type) + keypair = self.ec2_backend.create_key_pair(name, key_type, tags=tags) return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair) def delete_key_pair(self) -> str: @@ -26,9 +27,10 @@ class KeyPairs(EC2BaseResponse): def import_key_pair(self) -> str: name = self._get_param("KeyName") material = self._get_param("PublicKeyMaterial") + tags = self._parse_tag_specification("key-pair").get("key-pair", {}) self.error_on_dryrun() - keypair = self.ec2_backend.import_key_pair(name, material) + keypair = self.ec2_backend.import_key_pair(name, material, tags=tags) return self.response_template(IMPORT_KEYPAIR_RESPONSE).render(keypair=keypair) @@ -41,6 +43,16 @@ DESCRIBE_KEY_PAIRS_RESPONSE = """ {{ keypair.id }} {{ keypair.name }} {{ keypair.fingerprint }} + {% if keypair.get_tags() %} + + {% for tag in keypair.get_tags() %} + + {{ tag.key }} + {{ tag.value }} + + {% endfor %} + + {% endif %} """ diff --git a/tests/test_ec2/test_key_pairs.py b/tests/test_ec2/test_key_pairs.py index 45c659cc9..cf684fdf1 100644 --- a/tests/test_ec2/test_key_pairs.py +++ b/tests/test_ec2/test_key_pairs.py @@ -254,3 +254,51 @@ def test_key_pair_filters_boto3(): Filters=[{"Name": "fingerprint", "Values": [kp3.key_fingerprint]}] )["KeyPairs"] assert set([kp["KeyName"] for kp in kp_by_name]) == set([kp3.name]) + + +@mock_ec2 +def test_key_pair_with_tags(): + client = boto3.client("ec2", "us-east-1") + + key_name_1 = str(uuid4())[0:6] + key_name_2 = str(uuid4())[0:6] + key_name_3 = str(uuid4())[0:6] + key_name_4 = str(uuid4())[0:6] + resp = client.create_key_pair( + KeyName=key_name_1, + TagSpecifications=[ + {"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val1"}]} + ], + ) + assert resp["Tags"] == [{"Key": "key", "Value": "val1"}] + kp1 = resp["KeyPairId"] + + kp2 = client.create_key_pair( + KeyName=key_name_2, + TagSpecifications=[ + {"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val2"}]} + ], + )["KeyPairId"] + + assert "Tags" not in client.create_key_pair(KeyName=key_name_3) + + key_pairs = client.describe_key_pairs( + Filters=[{"Name": "tag-key", "Values": ["key"]}] + )["KeyPairs"] + assert [kp["KeyPairId"] for kp in key_pairs] == [kp1, kp2] + + key_pairs = client.describe_key_pairs( + Filters=[{"Name": "tag:key", "Values": ["val1"]}] + )["KeyPairs"] + assert len(key_pairs) == 1 + assert key_pairs[0]["KeyPairId"] == kp1 + assert key_pairs[0]["Tags"] == [{"Key": "key", "Value": "val1"}] + + resp = client.import_key_pair( + KeyName=key_name_4, + PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH, + TagSpecifications=[ + {"ResourceType": "key-pair", "Tags": [{"Key": "key", "Value": "val4"}]} + ], + ) + assert resp["Tags"] == [{"Key": "key", "Value": "val4"}]