diff --git a/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java b/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java index 446eec9898..fa4c5ab3f3 100644 --- a/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java +++ b/tls/src/main/java/org/bouncycastle/tls/AbstractTlsClient.java @@ -419,15 +419,73 @@ public Vector getEarlyKeyShareGroups() { return null; } + + Integer firstKemNamedGroup = null; + int firstKemNamedGroupIndex = -1; + + for (int i = 0; i < supportedGroups.size(); i++) + { + Integer group = (Integer) supportedGroups.elementAt(i); + if (NamedGroup.refersToASpecificKem(group)) + { + firstKemNamedGroup = group; + firstKemNamedGroupIndex = i; + break; + } + } + + Integer firstNonKemNamedGroup = null; + int firstNonKemNamedGroupIndex = -1; + if (supportedGroups.contains(Integers.valueOf(NamedGroup.x25519))) { - return TlsUtils.vectorOfOne(Integers.valueOf(NamedGroup.x25519)); + firstNonKemNamedGroup = Integers.valueOf(NamedGroup.x25519); + firstNonKemNamedGroupIndex = supportedGroups.indexOf(firstNonKemNamedGroup); } - if (supportedGroups.contains(Integers.valueOf(NamedGroup.secp256r1))) + else if (supportedGroups.contains(Integers.valueOf(NamedGroup.secp256r1))) { - return TlsUtils.vectorOfOne(Integers.valueOf(NamedGroup.secp256r1)); + firstNonKemNamedGroup = Integers.valueOf(NamedGroup.secp256r1); + firstNonKemNamedGroupIndex = supportedGroups.indexOf(firstNonKemNamedGroup); + } + else + { + for (int i = 0; i < supportedGroups.size(); i++) + { + Integer group = (Integer) supportedGroups.elementAt(i); + if (!NamedGroup.refersToASpecificKem(group)) + { + firstNonKemNamedGroup = group; + firstNonKemNamedGroupIndex = i; + break; + } + } } - return TlsUtils.vectorOfOne(supportedGroups.elementAt(0)); + + Vector earlyKeyShareGroups = new Vector<>(); + + if (firstKemNamedGroupIndex != -1 && firstNonKemNamedGroupIndex != -1) + { + if (firstKemNamedGroupIndex < firstNonKemNamedGroupIndex) + { + earlyKeyShareGroups.add(firstKemNamedGroup); + earlyKeyShareGroups.add(firstNonKemNamedGroup); + } + else + { + earlyKeyShareGroups.add(firstNonKemNamedGroup); + earlyKeyShareGroups.add(firstKemNamedGroup); + } + } + else if (firstKemNamedGroup != null) + { + earlyKeyShareGroups.add(firstKemNamedGroup); + } + else if (firstNonKemNamedGroup != null) + { + earlyKeyShareGroups.add(firstNonKemNamedGroup); + } + + return TlsUtils.vectorFromArray(earlyKeyShareGroups.toArray()); } public boolean shouldUseCompatibilityMode() diff --git a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java index 005bf51c8a..c073c1ddd0 100644 --- a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java +++ b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java @@ -2707,6 +2707,16 @@ public static Vector vectorOfOne(Object obj) return v; } + public static Vector vectorFromArray(Object[] array) + { + Vector v = new Vector(array.length); + for (Object obj: array) + { + v.addElement(obj); + } + return v; + } + public static int getCipherType(int cipherSuite) { int encryptionAlgorithm = getEncryptionAlgorithm(cipherSuite);