From 760e28bb7f6934a4c1ee664eee2449c216c06d9b Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Thu, 12 Oct 2023 23:43:41 +0530 Subject: [PATCH] EC2: Add support for ED25519 SSH key pairs (#6904) --- moto/ec2/models/key_pairs.py | 19 +++++++----- moto/ec2/responses/key_pairs.py | 4 +-- moto/ec2/utils.py | 51 +++++++++++++++++++++++++------- tests/test_ec2/helpers.py | 25 +++++++++++----- tests/test_ec2/test_key_pairs.py | 48 +++++++++++++++++------------- tests/test_ec2/test_utils.py | 9 ++++-- 6 files changed, 104 insertions(+), 52 deletions(-) diff --git a/moto/ec2/models/key_pairs.py b/moto/ec2/models/key_pairs.py index 7d29018ff..5ab46f529 100644 --- a/moto/ec2/models/key_pairs.py +++ b/moto/ec2/models/key_pairs.py @@ -8,11 +8,12 @@ from ..exceptions import ( InvalidKeyPairFormatError, ) from ..utils import ( - random_key_pair, - rsa_public_key_fingerprint, - rsa_public_key_parse, generic_filter, + public_key_fingerprint, + public_key_parse, random_key_pair_id, + random_ed25519_key_pair, + random_rsa_key_pair, ) from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow @@ -42,10 +43,14 @@ class KeyPairBackend: def __init__(self) -> None: self.keypairs: Dict[str, KeyPair] = {} - def create_key_pair(self, name: str) -> KeyPair: + def create_key_pair(self, name: str, key_type: str = "rsa") -> KeyPair: if name in self.keypairs: raise InvalidKeyPairDuplicateError(name) - keypair = KeyPair(name, **random_key_pair()) + if key_type == "ed25519": + keypair = KeyPair(name, **random_ed25519_key_pair()) + else: + keypair = KeyPair(name, **random_rsa_key_pair()) + self.keypairs[name] = keypair return keypair @@ -77,11 +82,11 @@ class KeyPairBackend: raise InvalidKeyPairDuplicateError(key_name) try: - rsa_public_key = rsa_public_key_parse(public_key_material) + public_key = public_key_parse(public_key_material) except ValueError: raise InvalidKeyPairFormatError() - fingerprint = rsa_public_key_fingerprint(rsa_public_key) + fingerprint = public_key_fingerprint(public_key) keypair = KeyPair( key_name, material=public_key_material, fingerprint=fingerprint ) diff --git a/moto/ec2/responses/key_pairs.py b/moto/ec2/responses/key_pairs.py index 65f6d2842..b0b80f01c 100644 --- a/moto/ec2/responses/key_pairs.py +++ b/moto/ec2/responses/key_pairs.py @@ -4,9 +4,9 @@ from ._base_response import EC2BaseResponse class KeyPairs(EC2BaseResponse): def create_key_pair(self) -> str: name = self._get_param("KeyName") + key_type = self._get_param("KeyType") self.error_on_dryrun() - - keypair = self.ec2_backend.create_key_pair(name) + keypair = self.ec2_backend.create_key_pair(name, key_type) return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair) def delete_key_pair(self) -> str: diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index c768f026c..2ce0b03df 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -6,6 +6,11 @@ import ipaddress from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PublicKey, + Ed25519PrivateKey, +) +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from typing import Any, Dict, List, Set, TypeVar, Tuple, Optional, Union from moto.core.utils import utcnow @@ -552,7 +557,22 @@ def simple_aws_filter_to_re(filter_string: str) -> str: return tmp_filter -def random_key_pair() -> Dict[str, str]: +def random_ed25519_key_pair() -> Dict[str, str]: + private_key = Ed25519PrivateKey.generate() + private_key_material = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption(), + ) + fingerprint = public_key_fingerprint(private_key.public_key()) + + return { + "fingerprint": fingerprint, + "material": private_key_material.decode("ascii"), + } + + +def random_rsa_key_pair() -> Dict[str, str]: private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) @@ -561,10 +581,10 @@ def random_key_pair() -> Dict[str, str]: format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) - public_key_fingerprint = rsa_public_key_fingerprint(private_key.public_key()) + fingerprint = public_key_fingerprint(private_key.public_key()) return { - "fingerprint": public_key_fingerprint, + "fingerprint": fingerprint, "material": private_key_material.decode("ascii"), } @@ -645,7 +665,9 @@ def generate_instance_identity_document(instance: Any) -> Dict[str, Any]: return document -def rsa_public_key_parse(key_material: Any) -> Any: +def public_key_parse( + key_material: Union[str, bytes] +) -> Union[RSAPublicKey, Ed25519PublicKey]: # These imports take ~.5s; let's keep them local import sshpubkeys.exceptions from sshpubkeys.keys import SSHKey @@ -654,19 +676,26 @@ def rsa_public_key_parse(key_material: Any) -> Any: if not isinstance(key_material, bytes): key_material = key_material.encode("ascii") - decoded_key = base64.b64decode(key_material).decode("ascii") - public_key = SSHKey(decoded_key) + decoded_key = base64.b64decode(key_material) + public_key = SSHKey(decoded_key.decode("ascii")) except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError): raise ValueError("bad key") - if not public_key.rsa: - raise ValueError("bad key") + if public_key.rsa: + return public_key.rsa - return public_key.rsa + # `cryptography` currently does not support RSA RFC4716/SSH2 format, otherwise we could get rid of `sshpubkeys` and + # simply use `load_ssh_public_key()` + if public_key.key_type == b"ssh-ed25519": + return serialization.load_ssh_public_key(decoded_key) # type: ignore[return-value] + + raise ValueError("bad key") -def rsa_public_key_fingerprint(rsa_public_key: Any) -> str: - key_data = rsa_public_key.public_bytes( +def public_key_fingerprint(public_key: Union[RSAPublicKey, Ed25519PublicKey]) -> str: + # TODO: Use different fingerprint calculation methods based on key type and source + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-keys.html#how-ec2-key-fingerprints-are-calculated + key_data = public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) diff --git a/tests/test_ec2/helpers.py b/tests/test_ec2/helpers.py index 9d231d19a..00a3b9efb 100644 --- a/tests/test_ec2/helpers.py +++ b/tests/test_ec2/helpers.py @@ -1,14 +1,23 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import rsa, ed25519 -def rsa_check_private_key(private_key_material): +def check_private_key(private_key_material, key_type): assert isinstance(private_key_material, str) - private_key = serialization.load_pem_private_key( - data=private_key_material.encode("ascii"), - backend=default_backend(), - password=None, - ) - assert isinstance(private_key, rsa.RSAPrivateKey) + if key_type == "rsa": + private_key = serialization.load_pem_private_key( + data=private_key_material.encode("ascii"), + backend=default_backend(), + password=None, + ) + assert isinstance(private_key, rsa.RSAPrivateKey) + elif key_type == "ed25519": + private_key = serialization.load_ssh_private_key( + data=private_key_material.encode("ascii"), + password=None, + ) + assert isinstance(private_key, ed25519.Ed25519PrivateKey) + else: + raise AssertionError("Bad private key") diff --git a/tests/test_ec2/test_key_pairs.py b/tests/test_ec2/test_key_pairs.py index b0e6573df..45c659cc9 100644 --- a/tests/test_ec2/test_key_pairs.py +++ b/tests/test_ec2/test_key_pairs.py @@ -7,9 +7,16 @@ from moto import mock_ec2, settings from uuid import uuid4 from unittest import SkipTest -from .helpers import rsa_check_private_key +from .helpers import check_private_key +ED25519_PUBLIC_KEY_OPENSSH = b"""\ +ssh-ed25519 \ +AAAAC3NzaC1lZDI1NTE5AAAAIEwsSB9HbTeKCdkSlMZeTq9jZggaPJUwAsUi/7wakB+B \ +moto@getmoto""" + +ED25519_PUBLIC_KEY_FINGERPRINT = "6c:d9:cf:90:d7:f7:bc:46:83:9e:f5:56:aa:e1:13:38" + RSA_PUBLIC_KEY_OPENSSH = b"""\ ssh-rsa \ AAAAB3NzaC1yc2EAAAADAQABAAABAQDusXfgTE4eBP50NglSzCSEGnIL6+cr6m3H\ @@ -77,17 +84,18 @@ def test_key_pairs_create_dryrun_boto3(): @mock_ec2 -def test_key_pairs_create_boto3(): +@pytest.mark.parametrize("key_type", ["rsa", "ed25519"]) +def test_key_pairs_create_boto3(key_type): ec2 = boto3.resource("ec2", "us-west-1") client = boto3.client("ec2", "us-west-1") key_name = str(uuid4())[0:6] - kp = ec2.create_key_pair(KeyName=key_name) - rsa_check_private_key(kp.key_material) + kp = ec2.create_key_pair(KeyName=key_name, KeyType=key_type) + check_private_key(kp.key_material, key_type) # Verify the client can create a key_pair as well - should behave the same key_name2 = str(uuid4()) - kp2 = client.create_key_pair(KeyName=key_name2) - rsa_check_private_key(kp2["KeyMaterial"]) + kp2 = client.create_key_pair(KeyName=key_name2, KeyType=key_type) + check_private_key(kp2["KeyMaterial"], key_type) assert kp.key_material != kp2["KeyMaterial"] @@ -145,13 +153,22 @@ def test_key_pairs_delete_exist_boto3(): @mock_ec2 -def test_key_pairs_import_boto3(): +@pytest.mark.parametrize( + "public_key,fingerprint", + [ + (RSA_PUBLIC_KEY_OPENSSH, RSA_PUBLIC_KEY_FINGERPRINT), + (RSA_PUBLIC_KEY_RFC4716, RSA_PUBLIC_KEY_FINGERPRINT), + (ED25519_PUBLIC_KEY_OPENSSH, ED25519_PUBLIC_KEY_FINGERPRINT), + ], + ids=["rsa-openssh", "rsa-rfc4716", "ed25519"], +) +def test_key_pairs_import_boto3(public_key, fingerprint): client = boto3.client("ec2", "us-west-1") key_name = str(uuid4())[0:6] with pytest.raises(ClientError) as ex: client.import_key_pair( - KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH, DryRun=True + KeyName=key_name, PublicKeyMaterial=public_key, DryRun=True ) assert ex.value.response["Error"]["Code"] == "DryRunOperation" assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 412 @@ -160,26 +177,15 @@ def test_key_pairs_import_boto3(): == "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set" ) - kp1 = client.import_key_pair( - KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH - ) + kp1 = client.import_key_pair(KeyName=key_name, PublicKeyMaterial=public_key) assert "KeyPairId" in kp1 assert kp1["KeyName"] == key_name - assert kp1["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT - - key_name2 = str(uuid4()) - kp2 = client.import_key_pair( - KeyName=key_name2, PublicKeyMaterial=RSA_PUBLIC_KEY_RFC4716 - ) - assert "KeyPairId" in kp2 - assert kp2["KeyName"] == key_name2 - assert kp2["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT + assert kp1["KeyFingerprint"] == fingerprint all_kps = client.describe_key_pairs()["KeyPairs"] all_names = [kp["KeyName"] for kp in all_kps] assert kp1["KeyName"] in all_names - assert kp2["KeyName"] in all_names @mock_ec2 diff --git a/tests/test_ec2/test_utils.py b/tests/test_ec2/test_utils.py index 3061467e3..9211b59fa 100644 --- a/tests/test_ec2/test_utils.py +++ b/tests/test_ec2/test_utils.py @@ -5,17 +5,20 @@ from pytest import raises from moto.ec2 import utils -from .helpers import rsa_check_private_key +from .helpers import check_private_key def test_random_key_pair(): - key_pair = utils.random_key_pair() - rsa_check_private_key(key_pair["material"]) + key_pair = utils.random_rsa_key_pair() + check_private_key(key_pair["material"], "rsa") # AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1 # fingerprints with 59 characters. assert len(key_pair["fingerprint"]) == 47 + key_pair = utils.random_ed25519_key_pair() + check_private_key(key_pair["material"], "ed25519") + def test_random_ipv6_cidr(): def mocked_random_resource_id(chars: int):