From aafbfa752f5c971536fefb281f6eea77a235888e Mon Sep 17 00:00:00 2001 From: Linuka Ratnayake <79963204+linukaratnayake@users.noreply.github.com> Date: Tue, 8 Apr 2025 15:12:58 +0530 Subject: [PATCH] Change logic of choosing early key shares --- .../bouncycastle/tls/AbstractTlsClient.java | 66 +++++++++++++++++-- .../java/org/bouncycastle/tls/TlsUtils.java | 10 +++ 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java b/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java index 446eec9898..fa4c5ab3f3 100644 --- a/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java +++ b/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java @@ -419,15 +419,73 @@ public Vector getEarlyKeyShareGroups() { return null; } + + Integer firstKemNamedGroup = null; + int firstKemNamedGroupIndex = -1; + + for (int i = 0; i < supportedGroups.size(); i++) + { + Integer group = (Integer) supportedGroups.elementAt(i); + if (NamedGroup.refersToASpecificKem(group)) + { + firstKemNamedGroup = group; + firstKemNamedGroupIndex = i; + break; + } + } + + Integer firstNonKemNamedGroup = null; + int firstNonKemNamedGroupIndex = -1; + if (supportedGroups.contains(Integers.valueOf(NamedGroup.x25519))) { - return TlsUtils.vectorOfOne(Integers.valueOf(NamedGroup.x25519)); + firstNonKemNamedGroup = Integers.valueOf(NamedGroup.x25519); + firstNonKemNamedGroupIndex = supportedGroups.indexOf(firstNonKemNamedGroup); } - if (supportedGroups.contains(Integers.valueOf(NamedGroup.secp256r1))) + else if (supportedGroups.contains(Integers.valueOf(NamedGroup.secp256r1))) { - return TlsUtils.vectorOfOne(Integers.valueOf(NamedGroup.secp256r1)); + firstNonKemNamedGroup = Integers.valueOf(NamedGroup.secp256r1); + firstNonKemNamedGroupIndex = supportedGroups.indexOf(firstNonKemNamedGroup); + } + else + { + for (int i = 0; i < supportedGroups.size(); i++) + { + Integer group = (Integer) supportedGroups.elementAt(i); + if (!NamedGroup.refersToASpecificKem(group)) + { + firstNonKemNamedGroup = group; + firstNonKemNamedGroupIndex = i; + break; + } + } } - return TlsUtils.vectorOfOne(supportedGroups.elementAt(0)); + + Vector earlyKeyShareGroups = new Vector<>(); + + if (firstKemNamedGroupIndex != -1 && firstNonKemNamedGroupIndex != -1) + { + if (firstKemNamedGroupIndex < firstNonKemNamedGroupIndex) + { + earlyKeyShareGroups.add(firstKemNamedGroup); + earlyKeyShareGroups.add(firstNonKemNamedGroup); + } + else + { + earlyKeyShareGroups.add(firstNonKemNamedGroup); + earlyKeyShareGroups.add(firstKemNamedGroup); + } + } + else if (firstKemNamedGroup != null) + { + earlyKeyShareGroups.add(firstKemNamedGroup); + } + else if (firstNonKemNamedGroup != null) + { + earlyKeyShareGroups.add(firstNonKemNamedGroup); + } + + return TlsUtils.vectorFromArray(earlyKeyShareGroups.toArray()); } public boolean shouldUseCompatibilityMode() diff --git a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java index 005bf51c8a..c073c1ddd0 100644 --- a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java +++ b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java @@ -2707,6 +2707,16 @@ public static Vector vectorOfOne(Object obj) return v; } + public static Vector vectorFromArray(Object[] array) + { + Vector v = new Vector(array.length); + for (Object obj: array) + { + v.addElement(obj); + } + return v; + } + public static int getCipherType(int cipherSuite) { int encryptionAlgorithm = getEncryptionAlgorithm(cipherSuite);