diff --git a/src/main/java/io/vertx/core/net/impl/SslChannelProvider.java b/src/main/java/io/vertx/core/net/impl/SslChannelProvider.java index 7543705fe97..1dae47fa9fa 100644 --- a/src/main/java/io/vertx/core/net/impl/SslChannelProvider.java +++ b/src/main/java/io/vertx/core/net/impl/SslChannelProvider.java @@ -27,8 +27,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import static io.vertx.core.net.impl.SslContextProvider.createTrustAllTrustManager; - /** * Provider for {@link SslHandler} and {@link SniHandler}. *
@@ -60,40 +58,34 @@ public SslContextProvider sslContextProvider() { } public SslContext sslClientContext(String serverName, boolean useAlpn, boolean trustAll) { + try { + return sslContext(serverName, useAlpn, false, trustAll); + } catch (Exception e) { + throw new VertxException(e); + } + } + + public SslContext sslContext(String serverName, boolean useAlpn, boolean server, boolean trustAll) throws Exception { int idx = idx(useAlpn); if (serverName == null) { if (sslContexts[idx] == null) { - SslContext context = sslContextProvider.createClientContext(useAlpn, trustAll); + SslContext context = sslContextProvider.createContext(server, null, null, null, useAlpn, trustAll); sslContexts[idx] = context; } return sslContexts[idx]; } else { - KeyManagerFactory kmf; - try { - kmf = sslContextProvider.resolveKeyManagerFactory(serverName); - } catch (Exception e) { - throw new VertxException(e); - } - TrustManager[] trustManagers; - if (trustAll) { - trustManagers = new TrustManager[] { createTrustAllTrustManager() }; - } else { - try { - trustManagers = sslContextProvider.resolveTrustManagers(serverName); - } catch (Exception e) { - throw new VertxException(e); - } - } - return sslContextMaps[idx].computeIfAbsent(serverName, s -> sslContextProvider.createClientContext(kmf, trustManagers, s, useAlpn)); + KeyManagerFactory kmf = sslContextProvider.resolveKeyManagerFactory(serverName); + TrustManager[] trustManagers = trustAll ? null : sslContextProvider.resolveTrustManagers(serverName); + return sslContextMaps[idx].computeIfAbsent(serverName, s -> sslContextProvider.createContext(server, kmf, trustManagers, s, useAlpn, trustAll)); } } public SslContext sslServerContext(boolean useAlpn) { - int idx = idx(useAlpn); - if (sslContexts[idx] == null) { - sslContexts[idx] = sslContextProvider.createServerContext(useAlpn); + try { + return sslContext(null, useAlpn, true, false); + } catch (Exception e) { + throw new VertxException(e); } - return sslContexts[idx]; } /** @@ -104,27 +96,14 @@ public SslContext sslServerContext(boolean useAlpn) { public AsyncMapping serverNameMapping(boolean useAlpn) { return (AsyncMapping) (serverName, promise) -> { workerPool.execute(() -> { - if (serverName == null) { - promise.setSuccess(sslServerContext(useAlpn)); - } else { - KeyManagerFactory kmf; - try { - kmf = sslContextProvider.resolveKeyManagerFactory(serverName); - } catch (Exception e) { - promise.setFailure(e); - return; - } - TrustManager[] trustManagers; - try { - trustManagers = sslContextProvider.resolveTrustManagers(serverName); - } catch (Exception e) { - promise.setFailure(e); - return; - } - int idx = idx(useAlpn); - SslContext sslContext = sslContextMaps[idx].computeIfAbsent(serverName, s -> sslContextProvider.createServerContext(kmf, trustManagers, s, useAlpn)); - promise.setSuccess(sslContext); + SslContext sslContext; + try { + sslContext = sslContext(serverName, useAlpn, true, false); + } catch (Exception e) { + promise.setFailure(e); + return; } + promise.setSuccess(sslContext); }); return promise; }; diff --git a/src/main/java/io/vertx/core/net/impl/SslContextProvider.java b/src/main/java/io/vertx/core/net/impl/SslContextProvider.java index 65fd59deea7..d8c42b2e0ed 100644 --- a/src/main/java/io/vertx/core/net/impl/SslContextProvider.java +++ b/src/main/java/io/vertx/core/net/impl/SslContextProvider.java @@ -66,16 +66,29 @@ public SslContextProvider(ClientAuth clientAuth, this.crls = crls; } - public VertxSslContext createClientContext( - boolean useAlpn, - boolean trustAll) { - TrustManager[] trustManagers = null; + public VertxSslContext createContext(boolean server, + KeyManagerFactory keyManagerFactory, + TrustManager[] trustManagers, + String serverName, + boolean useAlpn, + boolean trustAll) { + if (keyManagerFactory == null) { + keyManagerFactory = defaultKeyManagerFactory(); + } if (trustAll) { - trustManagers = new TrustManager[] { createTrustAllTrustManager() }; - } else if (trustManagerFactory != null) { - trustManagers = trustManagerFactory.getTrustManagers(); + trustManagers = SslContextProvider.createTrustAllManager(); + } else if (trustManagers == null) { + trustManagers = defaultTrustManagers(); } - return createClientContext(keyManagerFactory, trustManagers, null, useAlpn); + if (server) { + return createServerContext(keyManagerFactory, trustManagers, serverName, useAlpn); + } else { + return createClientContext(keyManagerFactory, trustManagers, serverName, useAlpn); + } + } + + public VertxSslContext createContext(boolean server, boolean useAlpn) { + return createContext(server, defaultKeyManagerFactory(), defaultTrustManagers(), null, useAlpn, false); } public VertxSslContext createClientContext( @@ -108,10 +121,6 @@ protected void initEngine(SSLEngine engine) { } } - public VertxSslContext createServerContext(boolean useAlpn) { - return createServerContext(keyManagerFactory, trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null, null, useAlpn); - } - public VertxSslContext createServerContext(KeyManagerFactory keyManagerFactory, TrustManager[] trustManagers, String serverName, @@ -152,6 +161,18 @@ public KeyManagerFactory loadKeyManagerFactory(String serverName) throws Excepti return null; } + public TrustManager[] defaultTrustManagers() { + return trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null; + } + + public TrustManagerFactory defaultTrustManagerFactory() { + return trustManagerFactory; + } + + public KeyManagerFactory defaultKeyManagerFactory() { + return keyManagerFactory; + } + /** * Resolve the {@link KeyManagerFactory} for the {@code serverName}, when a factory cannot be resolved, the default * factory is returned. @@ -242,22 +263,24 @@ public X509Certificate[] getAcceptedIssuers() { return trustMgrs; } - // Create a TrustManager which trusts everything - static TrustManager createTrustAllTrustManager() { - return new X509TrustManager() { - @Override - public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { - } + private static final TrustManager TRUST_ALL_MANAGER = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + } - @Override - public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { - } + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + } - @Override - public X509Certificate[] getAcceptedIssuers() { - return new X509Certificate[0]; - } - }; + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + + // Create a TrustManager which trusts everything + private static TrustManager[] createTrustAllManager() { + return new TrustManager[] { TRUST_ALL_MANAGER }; } public void configureEngine(SSLEngine engine, Set enabledProtocols, String serverName, boolean client) { diff --git a/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java b/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java index ae50bd0e958..1ea796c7268 100755 --- a/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java +++ b/src/test/java/io/vertx/core/net/impl/SSLHelperTest.java @@ -45,7 +45,7 @@ public void testUseJdkCiphersWhenNotSpecified() throws Exception { helper .buildContextProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_JKS.get()).setTrustOptions(Trust.SERVER_JKS.get()), null, ClientAuth.NONE, null, false, (ContextInternal) vertx.getOrCreateContext()) .onComplete(onSuccess(provider -> { - SslContext ctx = provider.createClientContext(false, false); + SslContext ctx = provider.createContext(false, false); assertEquals(new HashSet<>(Arrays.asList(expected)), new HashSet<>(ctx.cipherSuites())); testComplete(); })); @@ -57,7 +57,7 @@ public void testUseOpenSSLCiphersWhenNotSpecified() throws Exception { Set expected = OpenSsl.availableOpenSslCipherSuites(); SSLHelper helper = new SSLHelper(new OpenSSLEngineOptions()); helper.buildContextProvider(new SSLOptions().setKeyCertOptions(Cert.CLIENT_PEM.get()).setTrustOptions(Trust.SERVER_PEM.get()), null, ClientAuth.NONE, null, false, (ContextInternal) vertx.getOrCreateContext()).onComplete(onSuccess(provider -> { - SslContext ctx = provider.createClientContext(false, false); + SslContext ctx = provider.createContext(false, false); assertEquals(expected, new HashSet<>(ctx.cipherSuites())); testComplete(); })); @@ -91,7 +91,7 @@ private void testOpenSslServerSessionContext(boolean testDefault){ defaultHelper .buildContextProvider(sslOptions, null, ClientAuth.NONE, null, false, (ContextInternal) vertx.getOrCreateContext()) .onComplete(onSuccess(provider -> { - SslContext ctx = provider.createServerContext(false); + SslContext ctx = provider.createContext(true, false); SSLSessionContext sslSessionContext = ctx.sessionContext(); assertTrue(sslSessionContext instanceof OpenSslServerSessionContext); @@ -201,6 +201,6 @@ private void testTLSVersions(SSLOptions options, Consumer check) { } public SSLEngine createEngine(SslContextProvider provider) { - return provider.createClientContext(false, false).newEngine(ByteBufAllocator.DEFAULT); + return provider.createContext(false, false).newEngine(ByteBufAllocator.DEFAULT); } } diff --git a/src/test/java/io/vertx/it/SSLEngineTest.java b/src/test/java/io/vertx/it/SSLEngineTest.java index d447e14777e..c4520c528fc 100644 --- a/src/test/java/io/vertx/it/SSLEngineTest.java +++ b/src/test/java/io/vertx/it/SSLEngineTest.java @@ -99,7 +99,7 @@ private void doTest(SSLEngineOptions engine, } } SslContextProvider provider = ((HttpServerImpl)server).sslContextProvider(); - SslContext ctx = provider.createClientContext(false, false); + SslContext ctx = provider.createContext(false, false); switch (expectedSslContext != null ? expectedSslContext : "jdk") { case "jdk": assertTrue(ctx.sessionContext().getClass().getName().equals("sun.security.ssl.SSLSessionContextImpl"));