diff --git a/pom.xml b/pom.xml
index 7e2db780..9b4f26ac 100644
--- a/pom.xml
+++ b/pom.xml
@@ -130,6 +130,11 @@
scram-client
${scram-client.version}
+
+ com.ongres.scram
+ scram-common
+ ${scram-client.version}
+
io.projectreactor
reactor-core
diff --git a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java
index e3383c18..31431753 100644
--- a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java
+++ b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java
@@ -20,6 +20,7 @@
import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler;
import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler;
import io.r2dbc.postgresql.client.Client;
+import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.client.ConnectionSettings;
import io.r2dbc.postgresql.client.PostgresStartupParameterProvider;
import io.r2dbc.postgresql.client.StartupMessageFlow;
@@ -46,7 +47,7 @@ public Mono connect(SocketAddress endpoint, ConnectionSettings settings)
return this.upstreamFunction.connect(endpoint, settings)
.delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow
- .exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(),
+ .exchange(auth -> getAuthenticationHandler(auth, credentials, client.getContext()), client, this.configuration.getDatabase(), credentials.getUsername(),
getParameterProvider(this.configuration, settings)))
.handle(ExceptionFactory.INSTANCE::handleErrorResponse));
}
@@ -55,13 +56,13 @@ private static PostgresStartupParameterProvider getParameterProvider(PostgresqlC
return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings);
}
- protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) {
+ protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword, ConnectionContext context) {
if (PasswordAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername());
} else if (SASLAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
- return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername());
+ return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername(), context);
} else {
throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message));
}
diff --git a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java
index c38c6bac..0ca21ac9 100644
--- a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java
+++ b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java
@@ -3,7 +3,8 @@
import com.ongres.scram.client.ScramClient;
import com.ongres.scram.common.StringPreparation;
import com.ongres.scram.common.exception.ScramException;
-
+import com.ongres.scram.common.util.TlsServerEndpoint;
+import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
@@ -14,14 +15,26 @@
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.ByteBufferUtils;
import reactor.core.Exceptions;
+import reactor.util.Logger;
+import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLSession;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+
public class SASLAuthenticationHandler implements AuthenticationHandler {
+ private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);
+
private final CharSequence password;
private final String username;
+ private final ConnectionContext context;
+
private ScramClient scramClient;
/**
@@ -29,11 +42,13 @@ public class SASLAuthenticationHandler implements AuthenticationHandler {
*
* @param password the password to use for authentication
* @param username the username to use for authentication
+ * @param context the connection context
* @throws IllegalArgumentException if {@code password} or {@code user} is {@code null}
*/
- public SASLAuthenticationHandler(CharSequence password, String username) {
+ public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) {
this.password = Assert.requireNonNull(password, "password must not be null");
this.username = Assert.requireNonNull(username, "username must not be null");
+ this.context = Assert.requireNonNull(context, "context must not be null");
}
/**
@@ -67,14 +82,44 @@ public FrontendMessage handle(AuthenticationMessage message) {
}
private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
- this.scramClient = ScramClient.builder()
+
+ char[] password = new char[this.password.length()];
+ for (int i = 0; i < password.length; i++) {
+ password[i] = this.password.charAt(i);
+ }
+
+ ScramClient.FinalBuildStage builder = ScramClient.builder()
.advertisedMechanisms(message.getAuthenticationMechanisms())
- .username(username) // ignored by the server, use startup message
- .password(password.toString().toCharArray())
- .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION)
- .build();
+ .username(this.username) // ignored by the server, use startup message
+ .password(password)
+ .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);
+
+ SSLSession sslSession = this.context.getSslSession();
- return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName());
+ if (sslSession != null && sslSession.isValid()) {
+ builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
+ }
+
+ this.scramClient = builder.build();
+
+ return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName());
+ }
+
+ private static byte[] extractSslEndpoint(SSLSession sslSession) {
+ try {
+ Certificate[] certificates = sslSession.getPeerCertificates();
+ if (certificates != null && certificates.length > 0) {
+ Certificate peerCert = certificates[0]; // First certificate is the peer's certificate
+ if (peerCert instanceof X509Certificate) {
+ X509Certificate cert = (X509Certificate) peerCert;
+ return TlsServerEndpoint.getChannelBindingData(cert);
+
+ }
+ }
+ } catch (CertificateException | SSLException e) {
+ LOG.debug("Cannot extract X509Certificate from SSL session", e);
+ }
+ return new byte[0];
}
private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
diff --git a/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java b/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java
index 90796c33..c348d606 100644
--- a/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java
+++ b/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java
@@ -20,7 +20,9 @@
import reactor.util.Loggers;
import javax.annotation.Nullable;
+import javax.net.ssl.SSLSession;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Supplier;
/**
* Value object capturing diagnostic connection context. Allows for log-message post-processing with {@link #getMessage(String) if the logger category for
@@ -50,6 +52,8 @@ public final class ConnectionContext {
private final String connectionIdPrefix;
+ private final Supplier sslSession;
+
/**
* Create a new {@link ConnectionContext} with a unique connection Id.
*/
@@ -58,13 +62,15 @@ public ConnectionContext() {
this.connectionCounter = incrementConnectionCounter();
this.connectionIdPrefix = getConnectionIdPrefix();
this.channelId = null;
+ this.sslSession = () -> null;
}
- private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter) {
+ private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter, Supplier sslSession) {
this.processId = processId;
this.channelId = channelId;
this.connectionCounter = connectionCounter;
this.connectionIdPrefix = getConnectionIdPrefix();
+ this.sslSession = sslSession;
}
private String incrementConnectionCounter() {
@@ -101,6 +107,11 @@ public String getMessage(String original) {
return original;
}
+ @Nullable
+ public SSLSession getSslSession() {
+ return this.sslSession.get();
+ }
+
/**
* Create a new {@link ConnectionContext} by associating the {@code channelId}.
*
@@ -108,7 +119,17 @@ public String getMessage(String original) {
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code channelId}.
*/
public ConnectionContext withChannelId(String channelId) {
- return new ConnectionContext(this.processId, channelId, this.connectionCounter);
+ return new ConnectionContext(this.processId, channelId, this.connectionCounter, this.sslSession);
+ }
+
+ /**
+ * Create a new {@link ConnectionContext} by associating the {@code sslSession}.
+ *
+ * @param sslSession the SSL session supplier.
+ * @return a new {@link ConnectionContext} with all previously set values and the associated {@code sslSession}.
+ */
+ public ConnectionContext withSslSession(Supplier sslSession) {
+ return new ConnectionContext(this.processId, this.channelId, this.connectionCounter, sslSession);
}
/**
@@ -118,7 +139,7 @@ public ConnectionContext withChannelId(String channelId) {
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code processId}.
*/
public ConnectionContext withProcessId(int processId) {
- return new ConnectionContext(processId, this.channelId, this.connectionCounter);
+ return new ConnectionContext(processId, this.channelId, this.connectionCounter, this.sslSession);
}
}
diff --git a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java
index 1effe894..8ff9cf30 100644
--- a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java
+++ b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java
@@ -25,6 +25,7 @@
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.ssl.SslHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@@ -148,7 +149,23 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
this.connection = connection;
this.byteBufAllocator = connection.outbound().alloc();
- this.context = new ConnectionContext().withChannelId(connection.channel().toString());
+
+ ConnectionContext connectionContext = new ConnectionContext().withChannelId(connection.channel().toString());
+ SslHandler sslHandler = this.connection.channel().pipeline().get(SslHandler.class);
+
+ if (sslHandler == null) {
+ SSLSessionHandlerAdapter handlerAdapter = this.connection.channel().pipeline().get(SSLSessionHandlerAdapter.class);
+ if (handlerAdapter != null) {
+ sslHandler = handlerAdapter.getSslHandler();
+ }
+ }
+
+ if (sslHandler != null) {
+ SslHandler toUse = sslHandler;
+ connectionContext = connectionContext.withSslSession(() -> toUse.engine().getSession());
+ }
+
+ this.context = connectionContext;
AtomicReference receiveError = new AtomicReference<>();
diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java
index 9247cd62..616f2da2 100644
--- a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java
+++ b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java
@@ -45,7 +45,7 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
- if (negotiating) {
+ if (this.negotiating) {
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
}
super.channelActive(ctx);
@@ -53,7 +53,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
- if (negotiating) {
+ if (this.negotiating) {
// If we receive channel inactive before negotiated, then the inbound has closed early.
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
completeHandshakeExceptionally(e);
@@ -63,7 +63,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
- if (negotiating) {
+ if (this.negotiating) {
ByteBuf buf = (ByteBuf) msg;
char response = (char) buf.readByte();
try {
@@ -79,7 +79,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
} finally {
buf.release();
- negotiating = false;
+ this.negotiating = false;
}
} else {
super.channelRead(ctx, msg);
diff --git a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java
index ecb94ef0..c03aba09 100644
--- a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java
+++ b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java
@@ -451,6 +451,21 @@ void exchangeSslWithClientCertNoCert() {
.expectError(R2dbcPermissionDeniedException.class));
}
+ @Test
+ void exchangeSslWitScram() {
+ client(
+ c -> c
+ .sslRootCert(SERVER.getServerCrt())
+ .username("test-ssl-scram")
+ .password("test-ssl-scram"),
+ c -> c.map(client -> client.createStatement("SELECT 10")
+ .execute()
+ .flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
+ .as(StepVerifier::create)
+ .expectNext(10)
+ .verifyComplete()));
+ }
+
@Test
void exchangeSslWithPassword() {
client(
diff --git a/src/test/resources/pg_hba.conf b/src/test/resources/pg_hba.conf
index 6acee8b0..6330483e 100644
--- a/src/test/resources/pg_hba.conf
+++ b/src/test/resources/pg_hba.conf
@@ -1,5 +1,6 @@
hostnossl all test all md5
hostnossl all test-scram all scram-sha-256
hostssl all test-ssl all password
+hostssl all test-ssl-scram all scram-sha-256
hostssl all test-ssl-with-cert all cert
local all all md5