Skip to content

Enable SHA-384 based signature algorithms and SECP384R1 key exchange #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,16 +1300,19 @@ def __init__(
self._legacy_compression_methods: List[int] = [CompressionMethod.NULL]
self._psk_key_exchange_modes: List[int] = [PskKeyExchangeMode.PSK_DHE_KE]
self._signature_algorithms: List[int] = [
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
SignatureAlgorithm.ECDSA_SECP256R1_SHA256,
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
SignatureAlgorithm.RSA_PKCS1_SHA256,
SignatureAlgorithm.ECDSA_SECP384R1_SHA384,
SignatureAlgorithm.RSA_PSS_RSAE_SHA384,
SignatureAlgorithm.RSA_PKCS1_SHA384,
SignatureAlgorithm.RSA_PKCS1_SHA1,
]
if default_backend().ed25519_supported():
self._signature_algorithms.append(SignatureAlgorithm.ED25519)
if default_backend().ed448_supported():
self._signature_algorithms.append(SignatureAlgorithm.ED448)
self._supported_groups = [Group.SECP256R1]
self._supported_groups = [Group.SECP256R1, Group.SECP384R1]
if default_backend().x25519_supported():
self._supported_groups.append(Group.X25519)
if default_backend().x448_supported():
Expand All @@ -1334,7 +1337,7 @@ def __init__(
self._dec_key: Optional[bytes] = None
self.__logger = logger

self._ec_private_key: Optional[ec.EllipticCurvePrivateKey] = None
self._ec_private_keys: List[ec.EllipticCurvePrivateKey] = []
self._x25519_private_key: Optional[x25519.X25519PrivateKey] = None
self._x448_private_key: Optional[x448.X448PrivateKey] = None

Expand Down Expand Up @@ -1522,13 +1525,7 @@ def _client_send_hello(self, output_buf: Buffer) -> None:
supported_groups: List[int] = []

for group in self._supported_groups:
if group == Group.SECP256R1:
self._ec_private_key = ec.generate_private_key(
GROUP_TO_CURVE[Group.SECP256R1]()
)
key_share.append(encode_public_key(self._ec_private_key.public_key()))
supported_groups.append(Group.SECP256R1)
elif group == Group.X25519:
if group == Group.X25519:
self._x25519_private_key = x25519.X25519PrivateKey.generate()
key_share.append(
encode_public_key(self._x25519_private_key.public_key())
Expand All @@ -1541,6 +1538,11 @@ def _client_send_hello(self, output_buf: Buffer) -> None:
elif group == Group.GREASE:
key_share.append((Group.GREASE, b"\x00"))
supported_groups.append(Group.GREASE)
elif group in GROUP_TO_CURVE:
ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[group]())
self._ec_private_keys.append(ec_private_key)
key_share.append(encode_public_key(ec_private_key.public_key()))
supported_groups.append(group)

assert len(key_share), "no key share entries"

Expand Down Expand Up @@ -1665,13 +1667,13 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None:
and self._x448_private_key is not None
):
shared_key = self._x448_private_key.exchange(peer_public_key)
elif (
isinstance(peer_public_key, ec.EllipticCurvePublicKey)
and self._ec_private_key is not None
and self._ec_private_key.public_key().curve.__class__
== peer_public_key.curve.__class__
):
shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key)
elif isinstance(peer_public_key, ec.EllipticCurvePublicKey):
for ec_private_key in self._ec_private_keys:
if (
ec_private_key.public_key().curve.__class__
== peer_public_key.curve.__class__
):
shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key)
assert shared_key is not None

self.key_schedule.update_hash(input_buf.data)
Expand Down Expand Up @@ -1986,11 +1988,10 @@ def _server_handle_hello(
shared_key = self._x448_private_key.exchange(peer_public_key)
break
elif isinstance(peer_public_key, ec.EllipticCurvePublicKey):
self._ec_private_key = ec.generate_private_key(
GROUP_TO_CURVE[key_share[0]]()
)
public_key = self._ec_private_key.public_key()
shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key)
ec_private_key = ec.generate_private_key(GROUP_TO_CURVE[key_share[0]]())
self._ec_private_keys.append(ec_private_key)
public_key = ec_private_key.public_key()
shared_key = ec_private_key.exchange(ec.ECDH(), peer_public_key)
break
assert shared_key is not None

Expand Down Expand Up @@ -2161,12 +2162,18 @@ def _signature_algorithms_for_private_key(self) -> List[SignatureAlgorithm]:
signature_algorithms = [
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
SignatureAlgorithm.RSA_PKCS1_SHA256,
SignatureAlgorithm.RSA_PSS_RSAE_SHA384,
SignatureAlgorithm.RSA_PKCS1_SHA384,
SignatureAlgorithm.RSA_PKCS1_SHA1,
]
elif isinstance(
self.certificate_private_key, ec.EllipticCurvePrivateKey
) and isinstance(self.certificate_private_key.curve, ec.SECP256R1):
signature_algorithms = [SignatureAlgorithm.ECDSA_SECP256R1_SHA256]
elif isinstance(
self.certificate_private_key, ec.EllipticCurvePrivateKey
) and isinstance(self.certificate_private_key.curve, ec.SECP384R1):
signature_algorithms = [SignatureAlgorithm.ECDSA_SECP384R1_SHA384]
elif isinstance(self.certificate_private_key, ed25519.Ed25519PrivateKey):
signature_algorithms = [SignatureAlgorithm.ED25519]
elif isinstance(self.certificate_private_key, ed448.Ed448PrivateKey):
Expand Down
75 changes: 61 additions & 14 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec

from .utils import (
SERVER_CACERTFILE,
Expand Down Expand Up @@ -115,6 +116,11 @@ def reset_buffers(buffers):


class ContextTest(TestCase):
def assertClientHello(self, data: bytes):
self.assertEqual(data[0], tls.HandshakeType.CLIENT_HELLO)
self.assertGreaterEqual(len(data), 191)
self.assertLessEqual(len(data), 564)

def create_client(
self, alpn_protocols=None, cadata=None, cafile=SERVER_CACERTFILE, **kwargs
):
Expand Down Expand Up @@ -378,8 +384,7 @@ def _handshake(self, client, server):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -444,8 +449,7 @@ def test_handshake_with_certificate_request_no_certificate(self):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -503,8 +507,7 @@ def test_handshake_with_certificate_request_with_certificate(self):
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 181)
self.assertLessEqual(len(server_input), 358)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -565,9 +568,14 @@ def _test_handshake_with_certificate(self, certificate, private_key):
self.assertEqual(client.alpn_negotiated, None)
self.assertEqual(server.alpn_negotiated, None)

def test_handshake_with_ec_certificate(self):
def test_handshake_with_ec_certificate_secp256r1(self):
self._test_handshake_with_certificate(
*generate_ec_certificate(common_name="example.com")
*generate_ec_certificate(common_name="example.com", curve=ec.SECP256R1)
)

def test_handshake_with_ec_certificate_secp384r1(self):
self._test_handshake_with_certificate(
*generate_ec_certificate(common_name="example.com", curve=ec.SECP384R1)
)

def test_handshake_with_ed25519_certificate(self):
Expand Down Expand Up @@ -598,13 +606,41 @@ def test_handshake_with_alpn_fail(self):
self._handshake(client, server)
self.assertEqual(str(cm.exception), "No common ALPN protocols")

def test_handshake_with_rsa_pkcs1_sha1_signature(self):
client = self.create_client()
client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA1]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_rsa_pkcs1_sha256_signature(self):
client = self.create_client()
client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA256]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_rsa_pkcs1_sha384_signature(self):
client = self.create_client()
client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA384]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_rsa_pss_rsae_sha256_signature(self):
client = self.create_client()
client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_rsa_pss_rsae_sha384_signature(self):
client = self.create_client()
client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_certificate_error(self):
client = self.create_client(cafile=None)
server = self.create_server()
Expand All @@ -626,6 +662,20 @@ def test_handshake_with_grease_group(self):

self._handshake(client, server)

def test_handshake_with_secp256r1_group(self):
client = self.create_client()
client._supported_groups = [tls.Group.SECP256R1]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_secp384r1_group(self):
client = self.create_client()
client._supported_groups = [tls.Group.SECP384R1]
server = self.create_server()

self._handshake(client, server)

def test_handshake_with_x25519(self):
client = self.create_client()
client._supported_groups = [tls.Group.X25519]
Expand Down Expand Up @@ -695,8 +745,7 @@ def second_handshake():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# Handle client hello.
Expand Down Expand Up @@ -748,8 +797,7 @@ def second_handshake_bad_binder():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# tamper with binder
Expand All @@ -774,8 +822,7 @@ def second_handshake_bad_pre_shared_key():
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
self.assertGreaterEqual(len(server_input), 383)
self.assertLessEqual(len(server_input), 483)
self.assertClientHello(server_input)
reset_buffers(client_buf)

# handle client hello
Expand Down
Loading