EC2: Add support for ED25519 SSH key pairs (#6904)

This commit is contained in:
Viren Nadkarni 2023-10-12 23:43:41 +05:30 committed by GitHub
parent e5944307fc
commit 760e28bb7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 52 deletions

View File

@ -8,11 +8,12 @@ from ..exceptions import (
InvalidKeyPairFormatError, InvalidKeyPairFormatError,
) )
from ..utils import ( from ..utils import (
random_key_pair,
rsa_public_key_fingerprint,
rsa_public_key_parse,
generic_filter, generic_filter,
public_key_fingerprint,
public_key_parse,
random_key_pair_id, random_key_pair_id,
random_ed25519_key_pair,
random_rsa_key_pair,
) )
from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow
@ -42,10 +43,14 @@ class KeyPairBackend:
def __init__(self) -> None: def __init__(self) -> None:
self.keypairs: Dict[str, KeyPair] = {} self.keypairs: Dict[str, KeyPair] = {}
def create_key_pair(self, name: str) -> KeyPair: def create_key_pair(self, name: str, key_type: str = "rsa") -> KeyPair:
if name in self.keypairs: if name in self.keypairs:
raise InvalidKeyPairDuplicateError(name) raise InvalidKeyPairDuplicateError(name)
keypair = KeyPair(name, **random_key_pair()) if key_type == "ed25519":
keypair = KeyPair(name, **random_ed25519_key_pair())
else:
keypair = KeyPair(name, **random_rsa_key_pair())
self.keypairs[name] = keypair self.keypairs[name] = keypair
return keypair return keypair
@ -77,11 +82,11 @@ class KeyPairBackend:
raise InvalidKeyPairDuplicateError(key_name) raise InvalidKeyPairDuplicateError(key_name)
try: try:
rsa_public_key = rsa_public_key_parse(public_key_material) public_key = public_key_parse(public_key_material)
except ValueError: except ValueError:
raise InvalidKeyPairFormatError() raise InvalidKeyPairFormatError()
fingerprint = rsa_public_key_fingerprint(rsa_public_key) fingerprint = public_key_fingerprint(public_key)
keypair = KeyPair( keypair = KeyPair(
key_name, material=public_key_material, fingerprint=fingerprint key_name, material=public_key_material, fingerprint=fingerprint
) )

View File

@ -4,9 +4,9 @@ from ._base_response import EC2BaseResponse
class KeyPairs(EC2BaseResponse): class KeyPairs(EC2BaseResponse):
def create_key_pair(self) -> str: def create_key_pair(self) -> str:
name = self._get_param("KeyName") name = self._get_param("KeyName")
key_type = self._get_param("KeyType")
self.error_on_dryrun() self.error_on_dryrun()
keypair = self.ec2_backend.create_key_pair(name, key_type)
keypair = self.ec2_backend.create_key_pair(name)
return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair) return self.response_template(CREATE_KEY_PAIR_RESPONSE).render(keypair=keypair)
def delete_key_pair(self) -> str: def delete_key_pair(self) -> str:

View File

@ -6,6 +6,11 @@ import ipaddress
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PublicKey,
Ed25519PrivateKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from typing import Any, Dict, List, Set, TypeVar, Tuple, Optional, Union from typing import Any, Dict, List, Set, TypeVar, Tuple, Optional, Union
from moto.core.utils import utcnow from moto.core.utils import utcnow
@ -552,7 +557,22 @@ def simple_aws_filter_to_re(filter_string: str) -> str:
return tmp_filter return tmp_filter
def random_key_pair() -> Dict[str, str]: def random_ed25519_key_pair() -> Dict[str, str]:
private_key = Ed25519PrivateKey.generate()
private_key_material = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption(),
)
fingerprint = public_key_fingerprint(private_key.public_key())
return {
"fingerprint": fingerprint,
"material": private_key_material.decode("ascii"),
}
def random_rsa_key_pair() -> Dict[str, str]:
private_key = rsa.generate_private_key( private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend() public_exponent=65537, key_size=2048, backend=default_backend()
) )
@ -561,10 +581,10 @@ def random_key_pair() -> Dict[str, str]:
format=serialization.PrivateFormat.TraditionalOpenSSL, format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(), encryption_algorithm=serialization.NoEncryption(),
) )
public_key_fingerprint = rsa_public_key_fingerprint(private_key.public_key()) fingerprint = public_key_fingerprint(private_key.public_key())
return { return {
"fingerprint": public_key_fingerprint, "fingerprint": fingerprint,
"material": private_key_material.decode("ascii"), "material": private_key_material.decode("ascii"),
} }
@ -645,7 +665,9 @@ def generate_instance_identity_document(instance: Any) -> Dict[str, Any]:
return document return document
def rsa_public_key_parse(key_material: Any) -> Any: def public_key_parse(
key_material: Union[str, bytes]
) -> Union[RSAPublicKey, Ed25519PublicKey]:
# These imports take ~.5s; let's keep them local # These imports take ~.5s; let's keep them local
import sshpubkeys.exceptions import sshpubkeys.exceptions
from sshpubkeys.keys import SSHKey from sshpubkeys.keys import SSHKey
@ -654,19 +676,26 @@ def rsa_public_key_parse(key_material: Any) -> Any:
if not isinstance(key_material, bytes): if not isinstance(key_material, bytes):
key_material = key_material.encode("ascii") key_material = key_material.encode("ascii")
decoded_key = base64.b64decode(key_material).decode("ascii") decoded_key = base64.b64decode(key_material)
public_key = SSHKey(decoded_key) public_key = SSHKey(decoded_key.decode("ascii"))
except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError): except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError):
raise ValueError("bad key") raise ValueError("bad key")
if not public_key.rsa: if public_key.rsa:
raise ValueError("bad key") return public_key.rsa
return public_key.rsa # `cryptography` currently does not support RSA RFC4716/SSH2 format, otherwise we could get rid of `sshpubkeys` and
# simply use `load_ssh_public_key()`
if public_key.key_type == b"ssh-ed25519":
return serialization.load_ssh_public_key(decoded_key) # type: ignore[return-value]
raise ValueError("bad key")
def rsa_public_key_fingerprint(rsa_public_key: Any) -> str: def public_key_fingerprint(public_key: Union[RSAPublicKey, Ed25519PublicKey]) -> str:
key_data = rsa_public_key.public_bytes( # TODO: Use different fingerprint calculation methods based on key type and source
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-keys.html#how-ec2-key-fingerprints-are-calculated
key_data = public_key.public_bytes(
encoding=serialization.Encoding.DER, encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo, format=serialization.PublicFormat.SubjectPublicKeyInfo,
) )

View File

@ -1,14 +1,23 @@
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa, ed25519
def rsa_check_private_key(private_key_material): def check_private_key(private_key_material, key_type):
assert isinstance(private_key_material, str) assert isinstance(private_key_material, str)
private_key = serialization.load_pem_private_key( if key_type == "rsa":
data=private_key_material.encode("ascii"), private_key = serialization.load_pem_private_key(
backend=default_backend(), data=private_key_material.encode("ascii"),
password=None, backend=default_backend(),
) password=None,
assert isinstance(private_key, rsa.RSAPrivateKey) )
assert isinstance(private_key, rsa.RSAPrivateKey)
elif key_type == "ed25519":
private_key = serialization.load_ssh_private_key(
data=private_key_material.encode("ascii"),
password=None,
)
assert isinstance(private_key, ed25519.Ed25519PrivateKey)
else:
raise AssertionError("Bad private key")

View File

@ -7,9 +7,16 @@ from moto import mock_ec2, settings
from uuid import uuid4 from uuid import uuid4
from unittest import SkipTest from unittest import SkipTest
from .helpers import rsa_check_private_key from .helpers import check_private_key
ED25519_PUBLIC_KEY_OPENSSH = b"""\
ssh-ed25519 \
AAAAC3NzaC1lZDI1NTE5AAAAIEwsSB9HbTeKCdkSlMZeTq9jZggaPJUwAsUi/7wakB+B \
moto@getmoto"""
ED25519_PUBLIC_KEY_FINGERPRINT = "6c:d9:cf:90:d7:f7:bc:46:83:9e:f5:56:aa:e1:13:38"
RSA_PUBLIC_KEY_OPENSSH = b"""\ RSA_PUBLIC_KEY_OPENSSH = b"""\
ssh-rsa \ ssh-rsa \
AAAAB3NzaC1yc2EAAAADAQABAAABAQDusXfgTE4eBP50NglSzCSEGnIL6+cr6m3H\ AAAAB3NzaC1yc2EAAAADAQABAAABAQDusXfgTE4eBP50NglSzCSEGnIL6+cr6m3H\
@ -77,17 +84,18 @@ def test_key_pairs_create_dryrun_boto3():
@mock_ec2 @mock_ec2
def test_key_pairs_create_boto3(): @pytest.mark.parametrize("key_type", ["rsa", "ed25519"])
def test_key_pairs_create_boto3(key_type):
ec2 = boto3.resource("ec2", "us-west-1") ec2 = boto3.resource("ec2", "us-west-1")
client = boto3.client("ec2", "us-west-1") client = boto3.client("ec2", "us-west-1")
key_name = str(uuid4())[0:6] key_name = str(uuid4())[0:6]
kp = ec2.create_key_pair(KeyName=key_name) kp = ec2.create_key_pair(KeyName=key_name, KeyType=key_type)
rsa_check_private_key(kp.key_material) check_private_key(kp.key_material, key_type)
# Verify the client can create a key_pair as well - should behave the same # Verify the client can create a key_pair as well - should behave the same
key_name2 = str(uuid4()) key_name2 = str(uuid4())
kp2 = client.create_key_pair(KeyName=key_name2) kp2 = client.create_key_pair(KeyName=key_name2, KeyType=key_type)
rsa_check_private_key(kp2["KeyMaterial"]) check_private_key(kp2["KeyMaterial"], key_type)
assert kp.key_material != kp2["KeyMaterial"] assert kp.key_material != kp2["KeyMaterial"]
@ -145,13 +153,22 @@ def test_key_pairs_delete_exist_boto3():
@mock_ec2 @mock_ec2
def test_key_pairs_import_boto3(): @pytest.mark.parametrize(
"public_key,fingerprint",
[
(RSA_PUBLIC_KEY_OPENSSH, RSA_PUBLIC_KEY_FINGERPRINT),
(RSA_PUBLIC_KEY_RFC4716, RSA_PUBLIC_KEY_FINGERPRINT),
(ED25519_PUBLIC_KEY_OPENSSH, ED25519_PUBLIC_KEY_FINGERPRINT),
],
ids=["rsa-openssh", "rsa-rfc4716", "ed25519"],
)
def test_key_pairs_import_boto3(public_key, fingerprint):
client = boto3.client("ec2", "us-west-1") client = boto3.client("ec2", "us-west-1")
key_name = str(uuid4())[0:6] key_name = str(uuid4())[0:6]
with pytest.raises(ClientError) as ex: with pytest.raises(ClientError) as ex:
client.import_key_pair( client.import_key_pair(
KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH, DryRun=True KeyName=key_name, PublicKeyMaterial=public_key, DryRun=True
) )
assert ex.value.response["Error"]["Code"] == "DryRunOperation" assert ex.value.response["Error"]["Code"] == "DryRunOperation"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 412 assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 412
@ -160,26 +177,15 @@ def test_key_pairs_import_boto3():
== "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set" == "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set"
) )
kp1 = client.import_key_pair( kp1 = client.import_key_pair(KeyName=key_name, PublicKeyMaterial=public_key)
KeyName=key_name, PublicKeyMaterial=RSA_PUBLIC_KEY_OPENSSH
)
assert "KeyPairId" in kp1 assert "KeyPairId" in kp1
assert kp1["KeyName"] == key_name assert kp1["KeyName"] == key_name
assert kp1["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT assert kp1["KeyFingerprint"] == fingerprint
key_name2 = str(uuid4())
kp2 = client.import_key_pair(
KeyName=key_name2, PublicKeyMaterial=RSA_PUBLIC_KEY_RFC4716
)
assert "KeyPairId" in kp2
assert kp2["KeyName"] == key_name2
assert kp2["KeyFingerprint"] == RSA_PUBLIC_KEY_FINGERPRINT
all_kps = client.describe_key_pairs()["KeyPairs"] all_kps = client.describe_key_pairs()["KeyPairs"]
all_names = [kp["KeyName"] for kp in all_kps] all_names = [kp["KeyName"] for kp in all_kps]
assert kp1["KeyName"] in all_names assert kp1["KeyName"] in all_names
assert kp2["KeyName"] in all_names
@mock_ec2 @mock_ec2

View File

@ -5,17 +5,20 @@ from pytest import raises
from moto.ec2 import utils from moto.ec2 import utils
from .helpers import rsa_check_private_key from .helpers import check_private_key
def test_random_key_pair(): def test_random_key_pair():
key_pair = utils.random_key_pair() key_pair = utils.random_rsa_key_pair()
rsa_check_private_key(key_pair["material"]) check_private_key(key_pair["material"], "rsa")
# AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1 # AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1
# fingerprints with 59 characters. # fingerprints with 59 characters.
assert len(key_pair["fingerprint"]) == 47 assert len(key_pair["fingerprint"]) == 47
key_pair = utils.random_ed25519_key_pair()
check_private_key(key_pair["material"], "ed25519")
def test_random_ipv6_cidr(): def test_random_ipv6_cidr():
def mocked_random_resource_id(chars: int): def mocked_random_resource_id(chars: int):