diff --git a/core/src/main/java/io/grpc/internal/AuthorityVerifier.java b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java new file mode 100644 index 00000000000..e6164a7dc4d --- /dev/null +++ b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.Status; + +/** Verifier for the outgoing authority pseudo-header against peer cert. */ +public interface AuthorityVerifier { + Status verifyAuthority(String authority); +} diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java new file mode 100644 index 00000000000..7efd16eaf27 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Collection; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.security.auth.x500.X500Principal; + +/** + * Contains certificate/key PEM file utility method(s) for internal usage. + */ +public class CertificateUtils { + /** + * Creates X509TrustManagers using the provided CA certs. + */ + public static TrustManager[] createTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + + private static X509Certificate[] getX509Certificates(InputStream inputStream) + throws CertificateException { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + Collection certs = factory.generateCertificates(inputStream); + return certs.toArray(new X509Certificate[0]); + } +} diff --git a/core/src/main/java/io/grpc/internal/GrpcAttributes.java b/core/src/main/java/io/grpc/internal/GrpcAttributes.java index da43ae14800..f95f9b9dab8 100644 --- a/core/src/main/java/io/grpc/internal/GrpcAttributes.java +++ b/core/src/main/java/io/grpc/internal/GrpcAttributes.java @@ -42,5 +42,8 @@ public final class GrpcAttributes { public static final Attributes.Key ATTR_CLIENT_EAG_ATTRS = Attributes.Key.create("io.grpc.internal.GrpcAttributes.clientEagAttrs"); + public static final Attributes.Key ATTR_AUTHORITY_VERIFIER = + Attributes.Key.create("io.grpc.internal.GrpcAttributes.authorityVerifier"); + private GrpcAttributes() {} } diff --git a/core/src/main/java/io/grpc/internal/NoopSslSession.java b/core/src/main/java/io/grpc/internal/NoopSslSession.java new file mode 100644 index 00000000000..9a79d281ad5 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/NoopSslSession.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.security.Principal; +import java.security.cert.Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; + +/** A no-op ssl session, to facilitate overriding only the required methods in specific + * implementations. + */ +public class NoopSslSession implements SSLSession { + @Override + public byte[] getId() { + return new byte[0]; + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + @SuppressWarnings("deprecation") + public javax.security.cert.X509Certificate[] getPeerCertificateChain() { + throw new UnsupportedOperationException("This method is deprecated and marked for removal. " + + "Use the getPeerCertificates() method instead."); + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() { + } + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String s, Object o) { + } + + @Override + public Object getValue(String s) { + return null; + } + + @Override + public void removeValue(String s) { + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } +} diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java index 0489e135813..7959f0fdd9d 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java @@ -46,6 +46,15 @@ static GrpcHttp2OutboundHeaders clientRequestHeaders(byte[][] serializedMetadata return new GrpcHttp2OutboundHeaders(preHeaders, serializedMetadata); } + String getAuthority() { + for (int i = 0; i < preHeaders.length / 2; i++) { + if (preHeaders[i].equals(Http2Headers.PseudoHeaderName.AUTHORITY.value())) { + return preHeaders[i + 1].toString(); + } + } + return null; + } + static GrpcHttp2OutboundHeaders serverResponseHeaders(byte[][] serializedMetadata) { AsciiString[] preHeaders = new AsciiString[] { Http2Headers.PseudoHeaderName.STATUS.value(), Utils.STATUS_OK, diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 3d1aa83d9ff..039ea6c4f24 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -44,7 +44,7 @@ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslCo ObjectPool executorPool, Optional handshakeCompleteRunnable) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool, handshakeCompleteRunnable); + executorPool, handshakeCompleteRunnable, null); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -170,7 +170,7 @@ public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, - Optional.absent()); + Optional.absent(), null, null); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 8b80f7b4e46..46566eaca1a 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -652,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent()); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 194decb1120..c4cf38b897b 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -83,6 +83,8 @@ import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.nio.channels.ClosedChannelException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -94,6 +96,8 @@ */ class NettyClientHandler extends AbstractNettyHandler { private static final Logger logger = Logger.getLogger(NettyClientHandler.class.getName()); + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag("GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK", false); /** * A message that simply passes through the channel without any real processing. It is useful to @@ -128,6 +132,13 @@ protected void handleNotInUse() { lifecycleManager.notifyInUse(false); } }; + private final Map peerVerificationResults = + new LinkedHashMap() { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + }; private WriteQueue clientWriteQueue; private Http2Ping ping; @@ -575,6 +586,17 @@ protected boolean isGracefulShutdownComplete() { && ((StreamBufferingEncoder) encoder()).numBufferedStreams() == 0; } + private String getAuthorityPseudoHeader(Http2Headers http2Headers) { + if (http2Headers instanceof GrpcHttp2OutboundHeaders) { + return ((GrpcHttp2OutboundHeaders) http2Headers).getAuthority(); + } + try { + return http2Headers.authority().toString(); + } catch (UnsupportedOperationException e) { + return null; + } + } + /** * Attempts to create a new stream from the given command. If there are too many active streams, * the creation request is queued. @@ -591,6 +613,29 @@ private void createStream(CreateStreamCommand command, ChannelPromise promise) return; } + String authority = getAuthorityPseudoHeader(command.headers()); + if (authority != null) { + Status authorityVerificationStatus = peerVerificationResults.get(authority); + if (authorityVerificationStatus == null && attributes != null + && attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) != null) { + authorityVerificationStatus = attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) + .verifyAuthority(((GrpcHttp2OutboundHeaders) command.headers()).getAuthority()); + peerVerificationResults.put(authority, authorityVerificationStatus); + } + if (authorityVerificationStatus != null && !authorityVerificationStatus.isOk()) { + logger.log(Level.WARNING, String.format("%s.%s", + authorityVerificationStatus.getDescription(), enablePerRpcAuthorityCheck + ? "" : "This will be an error in the future."), + authorityVerificationStatus.getCause()); + if (enablePerRpcAuthorityCheck) { + command.stream().setNonExistent(); + command.stream().transportReportStatus( + authorityVerificationStatus, RpcProgress.DROPPED, true, new Metadata()); + promise.setFailure(authorityVerificationStatus.getCause()); + return; + } + } + } // Get the stream ID for the new stream. int streamId; try { diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index 54d1641a7ed..86d8991ba95 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -106,6 +106,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final boolean useGetForSafeMethods; private final Ticker ticker; + NettyClientTransport( SocketAddress address, ChannelFactory channelFactory, diff --git a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java index ede511b68f6..3d3fdc67e8e 100644 --- a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java @@ -34,6 +34,6 @@ public static ChannelCredentials create(SslContext sslContext) { Preconditions.checkArgument(sslContext.isClient(), "Server SSL context can not be used for client channel"); GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); - return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext)); + return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext, null)); } } diff --git a/netty/src/main/java/io/grpc/netty/NoopSslEngine.java b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java new file mode 100644 index 00000000000..7e14dbf0e79 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java @@ -0,0 +1,151 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import java.nio.ByteBuffer; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; + +/** + * A no-op implementation of SslEngine, to facilitate overriding only the required methods in + * specific implementations. + */ +class NoopSslEngine extends SSLEngine { + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) + throws SSLException { + return null; + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) + throws SSLException { + return null; + } + + @Override + public Runnable getDelegatedTask() { + return null; + } + + @Override + public void closeInbound() throws SSLException { + + } + + @Override + public boolean isInboundDone() { + return false; + } + + @Override + public void closeOutbound() { + + } + + @Override + public boolean isOutboundDone() { + return false; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void beginHandshake() throws SSLException { + + } + + @Override + public SSLEngineResult.HandshakeStatus getHandshakeStatus() { + return null; + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java index 8a2c6f104b2..4332fdf2919 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java @@ -63,4 +63,5 @@ interface ServerFactory { */ ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool); } + } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 3f4d59bb334..2def192c36f 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; @@ -42,8 +41,10 @@ import io.grpc.Status; import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.NoopSslSession; import io.grpc.internal.ObjectPool; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; @@ -71,8 +72,11 @@ import java.net.SocketAddress; import java.net.URI; import java.nio.channels.ClosedChannelException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; import java.util.Arrays; import java.util.EnumSet; +import java.util.List; import java.util.Set; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -82,6 +86,9 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** @@ -95,7 +102,15 @@ final class ProtocolNegotiators { private static final EnumSet understoodServerTlsFeatures = EnumSet.of( TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); + private static Class x509ExtendedTrustManagerClass; + static { + try { + x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + } catch (ClassNotFoundException e) { + // Will disallow per-rpc authority override via call option. + } + } private ProtocolNegotiators() { } @@ -118,14 +133,32 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { new ByteArrayInputStream(tlsCreds.getPrivateKey()), tlsCreds.getPrivateKeyPassword()); } - if (tlsCreds.getTrustManagers() != null) { - builder.trustManager(new FixedTrustManagerFactory(tlsCreds.getTrustManagers())); - } else if (tlsCreds.getRootCertificates() != null) { - builder.trustManager(new ByteArrayInputStream(tlsCreds.getRootCertificates())); - } // else use system default try { - return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build())); - } catch (SSLException ex) { + List trustManagers; + if (tlsCreds.getTrustManagers() != null) { + trustManagers = tlsCreds.getTrustManagers(); + } else if (tlsCreds.getRootCertificates() != null) { + trustManagers = Arrays.asList(CertificateUtils.createTrustManager( + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + trustManagers = Arrays.asList(tmf.getTrustManagers()); + } + builder.trustManager(new FixedTrustManagerFactory(trustManagers)); + TrustManager x509ExtendedTrustManager = null; + if (x509ExtendedTrustManagerClass != null) { + for (TrustManager trustManager : trustManagers) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + x509ExtendedTrustManager = trustManager; + break; + } + } + } + return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), + (X509TrustManager) x509ExtendedTrustManager)); + } catch (SSLException | GeneralSecurityException ex) { log.log(Level.FINE, "Exception building SslContext", ex); return FromChannelCredentialsResult.error( "Unable to create SslContext: " + ex.getMessage()); @@ -411,8 +444,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { ServerTlsHandler(ChannelHandler next, SslContext sslContext, final ObjectPool executorPool) { - this.sslContext = checkNotNull(sslContext, "sslContext"); - this.next = checkNotNull(next, "next"); + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { this.executor = executorPool.getObject(); } @@ -469,8 +502,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, final @Nullable String proxyUsername, final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { - checkNotNull(negotiator, "negotiator"); - checkNotNull(proxyAddress, "proxyAddress"); + Preconditions.checkNotNull(negotiator, "negotiator"); + Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); class ProxyNegotiator implements ProtocolNegotiator { @Override @@ -516,7 +549,7 @@ public ProxyProtocolNegotiationHandler( ChannelHandler next, ChannelLogger negotiationLogger) { super(next, negotiationLogger); - this.address = checkNotNull(address, "address"); + this.address = Preconditions.checkNotNull(address, "address"); this.userName = userName; this.password = password; } @@ -545,18 +578,21 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable) { - this.sslContext = checkNotNull(sslContext, "sslContext"); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager) { + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { this.executor = this.executorPool.getObject(); } this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } private final SslContext sslContext; private final ObjectPool executorPool; private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509ExtendedTrustManager; private Executor executor; @Override @@ -569,7 +605,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable); + this.executor, negotiationLogger, handshakeCompleteRunnable, this, + x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -579,6 +616,11 @@ public void close() { this.executorPool.returnObject(this.executor); } } + + @VisibleForTesting + boolean hasX509ExtendedTrustManager() { + return x509ExtendedTrustManager != null; + } } static final class ClientTlsHandler extends ProtocolNegotiationHandler { @@ -588,23 +630,28 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final int port; private Executor executor; private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509ExtendedTrustManager; + private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, Executor executor, ChannelLogger negotiationLogger, - Optional handshakeCompleteRunnable) { + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); - this.sslContext = checkNotNull(sslContext, "sslContext"); + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); HostPort hostPort = parseAuthority(authority); this.host = hostPort.host; this.port = hostPort.port; this.executor = executor; this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -661,6 +708,8 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) Attributes attrs = existingPne.getAttributes().toBuilder() .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, new X509AuthorityVerifier( + sslEngine, x509ExtendedTrustManager)) .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); if (handshakeCompleteRunnable.isPresent()) { @@ -700,8 +749,10 @@ static HostPort parseAuthority(String authority) { * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable) { - return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager) { + return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, + x509ExtendedTrustManager); } /** @@ -709,25 +760,31 @@ public static ProtocolNegotiator tls(SslContext sslContext, * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null, Optional.absent()); + public static ProtocolNegotiator tls(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return tls(sslContext, null, Optional.absent(), + x509ExtendedTrustManager); } - public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { - return new TlsProtocolNegotiatorClientFactory(sslContext); + public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } @VisibleForTesting static final class TlsProtocolNegotiatorClientFactory implements ProtocolNegotiator.ClientFactory { private final SslContext sslContext; + private final X509TrustManager x509ExtendedTrustManager; - public TlsProtocolNegotiatorClientFactory(SslContext sslContext) { + public TlsProtocolNegotiatorClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } @Override public ProtocolNegotiator newNegotiator() { - return tls(sslContext); + return tls(sslContext, x509ExtendedTrustManager); } @Override public int getDefaultPort() { @@ -801,8 +858,8 @@ static final class Http2UpgradeAndGrpcHandler extends ChannelInboundHandlerAdapt private ProtocolNegotiationEvent pne; Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) { - this.authority = checkNotNull(authority, "authority"); - this.next = checkNotNull(next, "next"); + this.authority = Preconditions.checkNotNull(authority, "authority"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiationLogger = next.getNegotiationLogger(); } @@ -846,9 +903,9 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } /** - * Returns a {@link ChannelHandler} that ensures that the {@code handler} is added to the - * pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, even before it - * is active. + * Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is + * added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, + * even before it is active. */ public static ProtocolNegotiator plaintext() { return new PlaintextProtocolNegotiator(); @@ -926,7 +983,7 @@ static final class GrpcNegotiationHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler next; public GrpcNegotiationHandler(GrpcHttp2ConnectionHandler next) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); } @Override @@ -1048,15 +1105,15 @@ static class ProtocolNegotiationHandler extends ChannelDuplexHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = negotiatorName; - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = getClass().getSimpleName().replace("Handler", ""); - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } @Override @@ -1097,7 +1154,7 @@ protected final ProtocolNegotiationEvent getProtocolNegotiationEvent() { protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) { checkState(this.pne != null, "previous protocol negotiation event hasn't triggered"); - this.pne = checkNotNull(pne); + this.pne = Preconditions.checkNotNull(pne); } protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { @@ -1107,4 +1164,42 @@ protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { ctx.fireUserEventTriggered(pne); } } + + static final class SslEngineWrapper extends NoopSslEngine { + private final SSLEngine sslEngine; + private final String peerHost; + + SslEngineWrapper(SSLEngine sslEngine, String peerHost) { + this.sslEngine = sslEngine; + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + + @Override + public SSLSession getHandshakeSession() { + return new FakeSslSession(peerHost); + } + + @Override + public SSLParameters getSSLParameters() { + return sslEngine.getSSLParameters(); + } + } + + static final class FakeSslSession extends NoopSslSession { + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java new file mode 100644 index 00000000000..d080bb78733 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java @@ -0,0 +1,105 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.Status; +import io.grpc.internal.AuthorityVerifier; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.annotation.Nonnull; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.X509TrustManager; + +public class X509AuthorityVerifier implements AuthorityVerifier { + private final SSLEngine sslEngine; + private final X509TrustManager x509ExtendedTrustManager; + + private static final Method checkServerTrustedMethod; + + static { + Method method = null; + try { + Class x509ExtendedTrustManagerClass = + Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + method = x509ExtendedTrustManagerClass.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, SSLEngine.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority overriding via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + checkServerTrustedMethod = method; + } + + public X509AuthorityVerifier(SSLEngine sslEngine, X509TrustManager x509ExtendedTrustManager) { + this.sslEngine = sslEngine; + this.x509ExtendedTrustManager = x509ExtendedTrustManager; + } + + @Override + public Status verifyAuthority(@Nonnull String authority) { + // sslEngine won't be set when creating ClientTlsHandler from InternalProtocolNegotiators + // for example. + if (sslEngine == null || x509ExtendedTrustManager == null) { + return Status.FAILED_PRECONDITION.withDescription( + "Can't allow authority override in rpc when SslEngine or X509ExtendedTrustManager" + + " is not available"); + } + Status peerVerificationStatus; + try { + // Because the authority pseudo-header can contain a port number: + // https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.3com + verifyAuthorityAllowedForPeerCert(removeAnyPortNumber(authority)); + peerVerificationStatus = Status.OK; + } catch (SSLPeerUnverifiedException | CertificateException | InvocationTargetException + | IllegalAccessException | IllegalStateException e) { + peerVerificationStatus = Status.UNAVAILABLE.withDescription( + String.format("Peer hostname verification during rpc failed for authority '%s'", + authority)).withCause(e); + } + return peerVerificationStatus; + } + + private String removeAnyPortNumber(String authority) { + int closingSquareBracketIndex = authority.lastIndexOf(']'); + int portNumberSeperatorColonIndex = authority.lastIndexOf(':'); + if (portNumberSeperatorColonIndex > closingSquareBracketIndex) { + return authority.substring(0, portNumberSeperatorColonIndex); + } + return authority; + } + + private void verifyAuthorityAllowedForPeerCert(String authority) + throws SSLPeerUnverifiedException, CertificateException, InvocationTargetException, + IllegalAccessException { + SSLEngine sslEngineWrapper = new ProtocolNegotiators.SslEngineWrapper(sslEngine, authority); + // The typecasting of Certificate to X509Certificate should work because this method will only + // be called when using TLS and thus X509. + Certificate[] peerCertificates = sslEngine.getSession().getPeerCertificates(); + X509Certificate[] x509PeerCertificates = new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + checkServerTrustedMethod.invoke( + x509ExtendedTrustManager, x509PeerCertificates, "RSA", sslEngineWrapper); + } +} diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index d0a6456c430..1e52cd1f592 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -59,6 +59,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusException; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; @@ -76,6 +77,7 @@ import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; @@ -101,9 +103,14 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -115,8 +122,15 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -131,6 +145,7 @@ * Tests for {@link NettyClientTransport}. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class NettyClientTransportTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -203,7 +218,7 @@ public void addDefaultUserAgent() throws Exception { } @Test - public void setSoLingerChannelOption() throws IOException { + public void setSoLingerChannelOption() throws IOException, GeneralSecurityException { startServer(); Map, Object> channelOptions = new HashMap<>(); // set SO_LINGER option @@ -354,7 +369,7 @@ public void tlsNegotiationFailurePropagatesToStatus() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, null); final NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); verify(clientTransportListener, timeout(5000)).transportTerminated(); @@ -821,7 +836,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent()); + Optional.absent(), null); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); @@ -836,6 +851,179 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { assertEquals(false, serverExecutorPool.isInUse()); } + /** + * This test tests the case of TlsCredentials passed to ProtocolNegotiators not having an instance + * of X509ExtendedTrustManager (this is not testable in ProtocolNegotiatorsTest without creating + * accessors for the internal state of negotiator whether it has a X509ExtendedTrustManager, + * hence the need to test it in this class instead). To establish a successful handshake we create + * a fake X509TrustManager not implementing X509ExtendedTrustManager but wraps the real + * X509ExtendedTrustManager. + */ + @Test + public void authorityOverrideInCallOptions_noX509ExtendedTrustManager_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + InputStream caCert = TlsTesting.loadCert("ca.pem"); + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)).build()); + NettyClientTransport transport = newTransport(result.negotiator.newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Can't allow authority override in rpc " + + "when SslEngine or X509ExtendedTrustManager is not available"); + assertThat(status.getCode()).isEqualTo(Code.FAILED_PRECONDITION); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_doesntMatchServerPeerHost_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority 'foo.test.google.in'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative DNS name matching foo.test.google.in found."); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_matchesServerPeerHost_newStreamCreationSucceeds() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false;; + } + } + + // Without removing the port number part that {@link X509AuthorityVerifier} does, there will be a + // java.security.cert.CertificateException: Illegal given domain name: foo.test.google.fr:12345 + @Test + public void authorityOverrideInCallOptions_portNumberInAuthority_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr:12345").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false;; + } + } + + @Test + public void authorityOverrideInCallOptions_portNumberAndIpv6_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345") + .waitForResponse(); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority '[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + // Port number is removed by {@link X509AuthorityVerifier}. + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative names matching IP address 2001:db8:3333:4444:5555:6666:1.2.3.4 " + + "found"); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false;; + } + } + + @Test + public void authorityOverrideInCallOptions_notMatches_flagDisabled_createsStream() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.in").waitForResponse(); + } + private Throwable getRootCause(Throwable t) { if (t.getCause() == null) { return t; @@ -843,10 +1031,37 @@ private Throwable getRootCause(Throwable t) { return getRootCause(t.getCause()); } - private ProtocolNegotiator newNegotiator() throws IOException { + private ProtocolNegotiator newNegotiator() throws IOException, GeneralSecurityException { InputStream caCert = TlsTesting.loadCert("ca.pem"); SslContext clientContext = GrpcSslContexts.forClient().trustManager(caCert).build(); - return ProtocolNegotiators.tls(clientContext); + return ProtocolNegotiators.tls(clientContext, + (X509TrustManager) getX509ExtendedTrustManager(TlsTesting.loadCert("ca.pem"))); + } + + private static TrustManager getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) { + if (trustManager instanceof X509ExtendedTrustManager) { + return trustManager; + } + } + return null; } private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { @@ -965,13 +1180,20 @@ private static class Rpc { final TestClientStreamListener listener = new TestClientStreamListener(); Rpc(NettyClientTransport transport) { - this(transport, new Metadata()); + this(transport, new Metadata(), null); } Rpc(NettyClientTransport transport, Metadata headers) { + this(transport, headers, null); + } + + Rpc(NettyClientTransport transport, Metadata headers, String authorityOverride) { stream = transport.newStream( METHOD, headers, CallOptions.DEFAULT, new ClientStreamTracer[]{ new ClientStreamTracer() {} }); + if (authorityOverride != null) { + stream.setAuthority(authorityOverride); + } stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); @@ -1169,4 +1391,62 @@ public void log(ChannelLogLevel level, String message) {} @Override public void log(ChannelLogLevel level, String messageFormat, Object... args) {} } + + static class FakeClientTransportListener implements ManagedClientTransport.Listener { + private final SettableFuture connected; + + @GuardedBy("this") + private boolean isConnected = false; + + public FakeClientTransportListener(SettableFuture connected) { + this.connected = connected; + } + + @Override + public void transportShutdown(Status s) {} + + @Override + public void transportTerminated() {} + + @Override + public void transportReady() { + synchronized (this) { + isConnected = true; + } + connected.set(null); + } + + synchronized boolean isConnected() { + return isConnected; + } + + @Override + public void transportInUse(boolean inUse) {} + } + + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } } diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 6dff3de2b2a..4829bcc7419 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -112,10 +112,14 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayDeque; import java.util.Arrays; @@ -222,13 +226,52 @@ public ChannelCredentials withoutBearerTokens() { } @Test - public void fromClient_tls() { + public void fromClient_tls_trustManager() + throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); + certStore.load(null); + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + try (InputStream ca = TlsTesting.loadCert("ca.pem")) { + for (X509Certificate cert : CertificateUtils.getX509Certificates(ca)) { + certStore.setCertificateEntry(cert.getSubjectX500Principal().getName("RFC2253"), cert); + } + } + trustManagerFactory.init(certStore); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(trustManagerFactory.getTrustManagers()).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_CaCertsInputStream() throws IOException { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(TlsTesting.loadCert("ca.pem")).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_systemDefault() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(TlsChannelCredentials.create()); assertThat(result.error).isNull(); assertThat(result.callCredentials).isNull(); assertThat(result.negotiator) .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); } @Test @@ -877,7 +920,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger, Optional.absent()); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -915,7 +959,8 @@ public String applicationProtocol() { .applicationProtocolConfig(apn).build(); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger, Optional.absent()); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -939,7 +984,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger, Optional.absent()); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -967,7 +1013,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @Test public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", null, noopLogger, Optional.absent()); + "authority", null, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -979,6 +1026,12 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { .isEqualTo(Status.Code.UNAVAILABLE); } + private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { + return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( + TlsTesting.loadCert("ca.pem")).build(), + null, Optional.absent(), null); + } + @Test public void engineLog() { ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); @@ -1007,7 +1060,7 @@ public boolean isLoggable(LogRecord record) { public void tls_failsOnNullSslContext() { thrown.expect(NullPointerException.class); - Object unused = ProtocolNegotiators.tls(null); + Object unused = ProtocolNegotiators.tls(null, null); } @Test @@ -1230,7 +1283,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception { } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, - null, Optional.absent()); + null, Optional.absent(), null); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index f42cb9fb16d..7eaaa6fd763 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -81,8 +81,6 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; /** Convenience class for building channels with the OkHttp transport. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") @@ -705,32 +703,12 @@ static KeyManager[] createKeyManager(InputStream certChain, InputStream privateK static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { InputStream rootCertsStream = new ByteArrayInputStream(rootCerts); try { - return createTrustManager(rootCertsStream); + return io.grpc.internal.CertificateUtils.createTrustManager(rootCertsStream); } finally { GrpcUtil.closeQuietly(rootCertsStream); } } - static TrustManager[] createTrustManager(InputStream rootCerts) throws GeneralSecurityException { - KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); - try { - ks.load(null, null); - } catch (IOException ex) { - // Shouldn't really happen, as we're not loading any data. - throw new GeneralSecurityException(ex); - } - X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); - for (X509Certificate cert : certs) { - X500Principal principal = cert.getSubjectX500Principal(); - ks.setCertificateEntry(principal.getName("RFC2253"), cert); - } - - TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init(ks); - return trustManagerFactory.getTrustManagers(); - } - static Collection> getSupportedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 3670cd057c1..c86e80656e3 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -34,6 +34,7 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.TlsChannelCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.FakeClock; @@ -212,7 +213,7 @@ public void sslSocketFactoryFrom_tls_mtls() throws Exception { TrustManager[] trustManagers; try (InputStream ca = TlsTesting.loadCert("ca.pem")) { - trustManagers = OkHttpChannelBuilder.createTrustManager(ca); + trustManagers = CertificateUtils.createTrustManager(ca); } SSLContext serverContext = SSLContext.getInstance("TLS"); @@ -257,7 +258,7 @@ public void sslSocketFactoryFrom_tls_mtls_keyFile() throws Exception { InputStream ca = TlsTesting.loadCert("ca.pem")) { serverContext.init( OkHttpChannelBuilder.createKeyManager(server1Chain, server1Key), - OkHttpChannelBuilder.createTrustManager(ca), + CertificateUtils.createTrustManager(ca), null); } final SSLServerSocket serverListenSocket =