From c4544fb78619c31338085e7f7d11c8dc643235dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simen=20R=C3=B8kaas?= Date: Mon, 5 Jun 2023 13:35:58 +0200 Subject: [PATCH 1/9] Revert "Ensure thatPreferredCursorExecution is the only configuration flag to control cursor preference." This reverts commit 07cb863d257909a7ea961c00c5a281f174200daf. --- .../io/r2dbc/mssql/ConnectionOptions.java | 5 + .../mssql/MssqlConnectionConfiguration.java | 54 ++--- .../io/r2dbc/mssql/SimpleMssqlStatement.java | 25 +- .../io/r2dbc/mssql/MssqlBatchUnitTests.java | 6 +- ...MssqlConnectionConfigurationUnitTests.java | 224 ++++++++---------- .../r2dbc/mssql/MssqlConnectionUnitTests.java | 2 +- .../mssql/SimpleMssqlStatementUnitTests.java | 30 ++- .../io/r2dbc/mssql/TestConnectionOptions.java | 29 --- 8 files changed, 180 insertions(+), 195 deletions(-) delete mode 100644 src/test/java/io/r2dbc/mssql/TestConnectionOptions.java diff --git a/src/main/java/io/r2dbc/mssql/ConnectionOptions.java b/src/main/java/io/r2dbc/mssql/ConnectionOptions.java index a051a874..468436d5 100644 --- a/src/main/java/io/r2dbc/mssql/ConnectionOptions.java +++ b/src/main/java/io/r2dbc/mssql/ConnectionOptions.java @@ -17,6 +17,7 @@ package io.r2dbc.mssql; import io.r2dbc.mssql.codec.Codecs; +import io.r2dbc.mssql.codec.DefaultCodecs; import reactor.util.annotation.Nullable; import java.time.Duration; @@ -37,6 +38,10 @@ class ConnectionOptions { private volatile Duration statementTimeout = Duration.ZERO; + ConnectionOptions() { + this(sql -> false, new DefaultCodecs(), new IndefinitePreparedStatementCache(), true); + } + ConnectionOptions(Predicate preferCursoredExecution, Codecs codecs, PreparedStatementCache preparedStatementCache, boolean sendStringParametersAsUnicode) { this.preferCursoredExecution = preferCursoredExecution; this.codecs = codecs; diff --git a/src/main/java/io/r2dbc/mssql/MssqlConnectionConfiguration.java b/src/main/java/io/r2dbc/mssql/MssqlConnectionConfiguration.java index 95e39233..59852c61 100644 --- a/src/main/java/io/r2dbc/mssql/MssqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/mssql/MssqlConnectionConfiguration.java @@ -45,7 +45,6 @@ import java.security.KeyStore; import java.time.Duration; import java.util.Arrays; -import java.util.Locale; import java.util.Optional; import java.util.UUID; import java.util.function.Function; @@ -183,14 +182,14 @@ MssqlConnectionConfiguration withRedirect(Redirect redirect) { } return new MssqlConnectionConfiguration(this.applicationName, this.connectionId, this.connectTimeout, this.database, redirectServerName, hostNameInCertificate, this.lockWaitTimeout, - this.password, - this.preferCursoredExecution, redirect.getPort(), this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, - this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, this.tcpNoDelay, this.trustServerCertificate, this.trustStore, this.trustStoreType, this.trustStorePassword, this.username); + this.password, + this.preferCursoredExecution, redirect.getPort(), this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, + this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, this.tcpNoDelay, this.trustServerCertificate, this.trustStore, this.trustStoreType, this.trustStorePassword, this.username); } ClientConfiguration toClientConfiguration() { return new DefaultClientConfiguration(this.connectTimeout, this.host, this.hostNameInCertificate, this.port, this.ssl, this.sslContextBuilderCustomizer, - this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, this.tcpNoDelay, this.trustServerCertificate, this.trustStore, this.trustStoreType, this.trustStorePassword); + this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, this.tcpNoDelay, this.trustServerCertificate, this.trustStore, this.trustStoreType, this.trustStorePassword); } ConnectionOptions toConnectionOptions() { @@ -355,7 +354,7 @@ public static final class Builder { @Nullable private Duration lockWaitTimeout; - private Predicate preferCursoredExecution = DefaultCursorPreference.INSTANCE; + private Predicate preferCursoredExecution = sql -> false; private CharSequence password; @@ -715,11 +714,11 @@ public MssqlConnectionConfiguration build() { } return new MssqlConnectionConfiguration(this.applicationName, this.connectionId, this.connectTimeout, this.database, this.host, this.hostNameInCertificate, this.lockWaitTimeout, - this.password, - this.preferCursoredExecution, this.port, this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, - this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, - this.tcpNoDelay, this.trustServerCertificate, this.trustStore, - this.trustStoreType, this.trustStorePassword, this.username); + this.password, + this.preferCursoredExecution, this.port, this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, + this.sslTunnelSslContextBuilderCustomizer, this.tcpKeepAlive, + this.tcpNoDelay, this.trustServerCertificate, this.trustStore, + this.trustStoreType, this.trustStorePassword, this.username); } } @@ -888,35 +887,12 @@ public SslContext getSslContext() throws GeneralSecurityException { private static SslContextBuilder createSslContextBuilder() { SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); sslContextBuilder.sslProvider( - OpenSsl.isAvailable() ? - io.netty.handler.ssl.SslProvider.OPENSSL : - io.netty.handler.ssl.SslProvider.JDK) - .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) - .applicationProtocolConfig(null); + OpenSsl.isAvailable() ? + io.netty.handler.ssl.SslProvider.OPENSSL : + io.netty.handler.ssl.SslProvider.JDK) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .applicationProtocolConfig(null); return sslContextBuilder; } - - static class DefaultCursorPreference implements Predicate { - - static final DefaultCursorPreference INSTANCE = new DefaultCursorPreference(); - - @Override - public boolean test(String sql) { - - if (sql.isEmpty()) { - return false; - } - - String lc = sql.trim().toLowerCase(Locale.ENGLISH); - if (lc.contains("for xml") || lc.contains("for json")) { - return false; - } - - char c = sql.charAt(0); - - return (c == 's' || c == 'S') && lc.startsWith("select"); - } - } - } diff --git a/src/main/java/io/r2dbc/mssql/SimpleMssqlStatement.java b/src/main/java/io/r2dbc/mssql/SimpleMssqlStatement.java index c210012e..902a4741 100644 --- a/src/main/java/io/r2dbc/mssql/SimpleMssqlStatement.java +++ b/src/main/java/io/r2dbc/mssql/SimpleMssqlStatement.java @@ -29,6 +29,7 @@ import reactor.util.Logger; import reactor.util.Loggers; +import java.util.Locale; import java.util.function.Predicate; /** @@ -60,7 +61,7 @@ final class SimpleMssqlStatement extends MssqlStatementSupport implements MssqlS */ SimpleMssqlStatement(Client client, ConnectionOptions connectionOptions, String sql) { - super(connectionOptions.prefersCursors(sql)); + super(connectionOptions.prefersCursors(sql) || prefersCursors(sql)); this.connectionOptions = connectionOptions; Assert.requireNonNull(client, "Client must not be null"); @@ -161,4 +162,26 @@ public SimpleMssqlStatement fetchSize(int fetchSize) { return this; } + /** + * Returns {@code true} if the query is supported by this {@link MssqlStatement}. Cursored execution is supported for {@literal SELECT} queries. + * + * @param sql the query to inspect. + * @return {@code true} if the {@code sql} query is supported. + */ + static boolean prefersCursors(String sql) { + + if (sql.isEmpty()) { + return false; + } + + String lc = sql.trim().toLowerCase(Locale.ENGLISH); + if (lc.contains("for xml") || lc.contains("for json")) { + return false; + } + + char c = sql.charAt(0); + + return (c == 's' || c == 'S') && lc.startsWith("select"); + } + } diff --git a/src/test/java/io/r2dbc/mssql/MssqlBatchUnitTests.java b/src/test/java/io/r2dbc/mssql/MssqlBatchUnitTests.java index 87198754..d99b53fb 100644 --- a/src/test/java/io/r2dbc/mssql/MssqlBatchUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/MssqlBatchUnitTests.java @@ -39,7 +39,7 @@ void shouldExecuteSingleBatch() { .thenRespond(DoneToken.create(1)) .build(); - new MssqlBatch(client, new TestConnectionOptions()) + new MssqlBatch(client, new ConnectionOptions()) .add("foo") .execute() .as(StepVerifier::create) @@ -55,7 +55,7 @@ void shouldExecuteMultiBatch() { .thenRespond(DoneToken.create(1), DoneToken.create(1)) .build(); - new MssqlBatch(client, new TestConnectionOptions()) + new MssqlBatch(client, new ConnectionOptions()) .add("foo") .add("bar") .execute() @@ -73,7 +73,7 @@ void shouldFailOnExecution() { "proc", 0)) .build(); - new MssqlBatch(client, new TestConnectionOptions()) + new MssqlBatch(client, new ConnectionOptions()) .add("foo") .execute() .as(StepVerifier::create) diff --git a/src/test/java/io/r2dbc/mssql/MssqlConnectionConfigurationUnitTests.java b/src/test/java/io/r2dbc/mssql/MssqlConnectionConfigurationUnitTests.java index 6b7bd2d3..952dce68 100644 --- a/src/test/java/io/r2dbc/mssql/MssqlConnectionConfigurationUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/MssqlConnectionConfigurationUnitTests.java @@ -19,8 +19,6 @@ import io.r2dbc.mssql.message.tds.Redirect; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.testcontainers.shaded.org.bouncycastle.asn1.x500.X500Name; import org.testcontainers.shaded.org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; import org.testcontainers.shaded.org.bouncycastle.cert.X509CertificateHolder; @@ -55,31 +53,31 @@ final class MssqlConnectionConfigurationUnitTests { @Test void builderNoApplicationName() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder().applicationName(null)) - .withMessage("applicationName must not be null"); + .withMessage("applicationName must not be null"); } @Test void builderNoConnectionId() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder().connectionId(null)) - .withMessage("connectionId must not be null"); + .withMessage("connectionId must not be null"); } @Test void builderNoHost() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder().host(null)) - .withMessage("host must not be null"); + .withMessage("host must not be null"); } @Test void builderNoPassword() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder().password(null)) - .withMessage("password must not be null"); + .withMessage("password must not be null"); } @Test void builderNoUsername() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder().username(null)) - .withMessage("username must not be null"); + .withMessage("username must not be null"); } @Test @@ -87,150 +85,150 @@ void configuration() { UUID connectionId = UUID.randomUUID(); Predicate TRUE = s -> true; MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .connectionId(connectionId) - .database("test-database") - .host("test-host") - .password("test-password") - .preferCursoredExecution(TRUE) - .port(100) - .username("test-username") - .sendStringParametersAsUnicode(false) - .build(); + .connectionId(connectionId) + .database("test-database") + .host("test-host") + .password("test-password") + .preferCursoredExecution(TRUE) + .port(100) + .username("test-username") + .sendStringParametersAsUnicode(false) + .build(); assertThat(configuration) - .hasFieldOrPropertyWithValue("connectionId", connectionId) - .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") - .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("preferCursoredExecution", TRUE) - .hasFieldOrPropertyWithValue("port", 100) - .hasFieldOrPropertyWithValue("username", "test-username") - .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", false); + .hasFieldOrPropertyWithValue("connectionId", connectionId) + .hasFieldOrPropertyWithValue("database", "test-database") + .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("password", "test-password") + .hasFieldOrPropertyWithValue("preferCursoredExecution", TRUE) + .hasFieldOrPropertyWithValue("port", 100) + .hasFieldOrPropertyWithValue("username", "test-username") + .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", false); } @Test void configurationDefaults() { MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .applicationName("r2dbc") - .database("test-database") - .host("test-host") - .password("test-password") - .username("test-username") - .build(); + .applicationName("r2dbc") + .database("test-database") + .host("test-host") + .password("test-password") + .username("test-username") + .build(); assertThat(configuration) - .hasFieldOrPropertyWithValue("applicationName", "r2dbc") - .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") - .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 1433) - .hasFieldOrPropertyWithValue("username", "test-username") - .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true); + .hasFieldOrPropertyWithValue("applicationName", "r2dbc") + .hasFieldOrPropertyWithValue("database", "test-database") + .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("password", "test-password") + .hasFieldOrPropertyWithValue("port", 1433) + .hasFieldOrPropertyWithValue("username", "test-username") + .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true); } @Test void constructorNoNoHost() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder() - .password("test-password") - .username("test-username") - .build()) - .withMessage("host must not be null"); + .password("test-password") + .username("test-username") + .build()) + .withMessage("host must not be null"); } @Test void constructorNoPassword() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder() - .host("test-host") - .username("test-username") - .build()) - .withMessage("password must not be null"); + .host("test-host") + .username("test-username") + .build()) + .withMessage("password must not be null"); } @Test void constructorNoUsername() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder() - .host("test-host") - .password("test-password") - .build()) - .withMessage("username must not be null"); + .host("test-host") + .password("test-password") + .build()) + .withMessage("username must not be null"); } @Test void constructorNoSslCustomizer() { assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder() - .sslContextBuilderCustomizer(null) - .build()) - .withMessage("sslContextBuilderCustomizer must not be null"); + .sslContextBuilderCustomizer(null) + .build()) + .withMessage("sslContextBuilderCustomizer must not be null"); } @Test void redirect() { MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .applicationName("r2dbc") - .database("test-database") - .host("test-host") - .password("test-password") - .username("test-username") - .build(); + .applicationName("r2dbc") + .database("test-database") + .host("test-host") + .password("test-password") + .username("test-username") + .build(); MssqlConnectionConfiguration target = configuration.withRedirect(Redirect.create("target", 1234)); assertThat(target) - .hasFieldOrPropertyWithValue("applicationName", "r2dbc") - .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "target") - .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 1234) - .hasFieldOrPropertyWithValue("username", "test-username") - .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) - .hasFieldOrPropertyWithValue("hostNameInCertificate", "test-host"); + .hasFieldOrPropertyWithValue("applicationName", "r2dbc") + .hasFieldOrPropertyWithValue("database", "test-database") + .hasFieldOrPropertyWithValue("host", "target") + .hasFieldOrPropertyWithValue("password", "test-password") + .hasFieldOrPropertyWithValue("port", 1234) + .hasFieldOrPropertyWithValue("username", "test-username") + .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) + .hasFieldOrPropertyWithValue("hostNameInCertificate", "test-host"); } @Test void redirectOtherDomain() { MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .applicationName("r2dbc") - .database("test-database") - .host("test-host.windows.net") - .password("test-password") - .username("test-username") - .build(); + .applicationName("r2dbc") + .database("test-database") + .host("test-host.windows.net") + .password("test-password") + .username("test-username") + .build(); MssqlConnectionConfiguration target = configuration.withRedirect(Redirect.create("target.other.domain", 1234)); assertThat(target) - .hasFieldOrPropertyWithValue("applicationName", "r2dbc") - .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "target.other.domain") - .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 1234) - .hasFieldOrPropertyWithValue("username", "test-username") - .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) - .hasFieldOrPropertyWithValue("hostNameInCertificate", "test-host.windows.net"); + .hasFieldOrPropertyWithValue("applicationName", "r2dbc") + .hasFieldOrPropertyWithValue("database", "test-database") + .hasFieldOrPropertyWithValue("host", "target.other.domain") + .hasFieldOrPropertyWithValue("password", "test-password") + .hasFieldOrPropertyWithValue("port", 1234) + .hasFieldOrPropertyWithValue("username", "test-username") + .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) + .hasFieldOrPropertyWithValue("hostNameInCertificate", "test-host.windows.net"); } @Test void redirectInDomain() { MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .applicationName("r2dbc") - .database("test-database") - .host("test-host.windows.net") - .password("test-password") - .username("test-username") - .hostNameInCertificate("*.windows.net") - .build(); + .applicationName("r2dbc") + .database("test-database") + .host("test-host.windows.net") + .password("test-password") + .username("test-username") + .hostNameInCertificate("*.windows.net") + .build(); MssqlConnectionConfiguration target = configuration.withRedirect(Redirect.create("worker.target.windows.net", 1234)); assertThat(target) - .hasFieldOrPropertyWithValue("applicationName", "r2dbc") - .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "worker.target.windows.net") - .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 1234) - .hasFieldOrPropertyWithValue("username", "test-username") - .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) - .hasFieldOrPropertyWithValue("hostNameInCertificate", "*.target.windows.net"); + .hasFieldOrPropertyWithValue("applicationName", "r2dbc") + .hasFieldOrPropertyWithValue("database", "test-database") + .hasFieldOrPropertyWithValue("host", "worker.target.windows.net") + .hasFieldOrPropertyWithValue("password", "test-password") + .hasFieldOrPropertyWithValue("port", 1234) + .hasFieldOrPropertyWithValue("username", "test-username") + .hasFieldOrPropertyWithValue("sendStringParametersAsUnicode", true) + .hasFieldOrPropertyWithValue("hostNameInCertificate", "*.target.windows.net"); } @Test @@ -246,7 +244,7 @@ void configureKeyStore(@TempDir File tempDir) throws Exception { Certificate selfSignedCertificate = selfSign(keypair, "CN=dummy"); KeyStore.Entry entry = new KeyStore.PrivateKeyEntry(keypair.getPrivate(), - new Certificate[]{selfSignedCertificate}); + new Certificate[]{selfSignedCertificate}); keyStore.setEntry("dummy", entry, new KeyStore.PasswordProtection("key-password".toCharArray())); @@ -256,13 +254,13 @@ void configureKeyStore(@TempDir File tempDir) throws Exception { } MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder() - .database("test-database") - .host("test-host.windows.net") - .password("test-password") - .username("test-username") - .trustStore(file) - .trustStorePassword("my-password".toCharArray()) - .build(); + .database("test-database") + .host("test-host.windows.net") + .password("test-password") + .username("test-username") + .trustStore(file) + .trustStorePassword("my-password".toCharArray()) + .build(); MssqlConnectionConfiguration.DefaultClientConfiguration clientConfiguration = (MssqlConnectionConfiguration.DefaultClientConfiguration) configuration.toClientConfiguration(); @@ -273,7 +271,7 @@ void configureKeyStore(@TempDir File tempDir) throws Exception { } private static Certificate selfSign(KeyPair keyPair, String subjectDN) - throws Exception { + throws Exception { Date startDate = new Date(); X500Name dnName = new X500Name(subjectDN); @@ -285,28 +283,16 @@ private static Certificate selfSign(KeyPair keyPair, String subjectDN) SubjectPublicKeyInfo subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair - .getPublic().getEncoded()); + .getPublic().getEncoded()); X509v3CertificateBuilder certificateBuilder = new X509v3CertificateBuilder(dnName, - BigInteger.valueOf(1), startDate, endDate, dnName, subjectPublicKeyInfo); + BigInteger.valueOf(1), startDate, endDate, dnName, subjectPublicKeyInfo); ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate()); X509CertificateHolder certificateHolder = certificateBuilder.build(contentSigner); return new JcaX509CertificateConverter() - .getCertificate(certificateHolder); - } - - @ParameterizedTest - @ValueSource(strings = {"select", "SELECT", "sElEcT"}) - void shouldAcceptQueries(String query) { - assertThat(MssqlConnectionConfiguration.DefaultCursorPreference.INSTANCE).accepts(query); - } - - @ParameterizedTest - @ValueSource(strings = {" select", "sp_cursor", "INSERT"}) - void shouldRejectQueries(String query) { - assertThat(MssqlConnectionConfiguration.DefaultCursorPreference.INSTANCE).rejects(query); + .getCertificate(certificateHolder); } } diff --git a/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java b/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java index 1991231c..b6f8b2c9 100644 --- a/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java @@ -50,7 +50,7 @@ class MssqlConnectionUnitTests { static MssqlConnectionMetadata metadata = new MssqlConnectionMetadata("SQL Server", "1.0"); - static ConnectionOptions conectionOptions = new TestConnectionOptions(); + static ConnectionOptions conectionOptions = new ConnectionOptions(); @Test void shouldBeginTransactionFromInitialState() { diff --git a/src/test/java/io/r2dbc/mssql/SimpleMssqlStatementUnitTests.java b/src/test/java/io/r2dbc/mssql/SimpleMssqlStatementUnitTests.java index 68b620df..3d6fb04d 100644 --- a/src/test/java/io/r2dbc/mssql/SimpleMssqlStatementUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/SimpleMssqlStatementUnitTests.java @@ -24,7 +24,16 @@ import io.r2dbc.mssql.message.TransactionDescriptor; import io.r2dbc.mssql.message.tds.Encode; import io.r2dbc.mssql.message.tds.ServerCharset; -import io.r2dbc.mssql.message.token.*; +import io.r2dbc.mssql.message.token.Column; +import io.r2dbc.mssql.message.token.ColumnMetadataToken; +import io.r2dbc.mssql.message.token.DataToken; +import io.r2dbc.mssql.message.token.DoneToken; +import io.r2dbc.mssql.message.token.ErrorToken; +import io.r2dbc.mssql.message.token.RowToken; +import io.r2dbc.mssql.message.token.RowTokenFactory; +import io.r2dbc.mssql.message.token.RpcRequest; +import io.r2dbc.mssql.message.token.SqlBatch; +import io.r2dbc.mssql.message.token.Tabular; import io.r2dbc.mssql.message.type.Collation; import io.r2dbc.mssql.message.type.LengthStrategy; import io.r2dbc.mssql.message.type.SqlServerType; @@ -33,6 +42,8 @@ import io.r2dbc.spi.R2dbcNonTransientResourceException; import io.r2dbc.spi.Result; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -50,7 +61,9 @@ import static io.r2dbc.mssql.message.type.TypeInformation.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Unit tests for {@link SimpleMssqlStatement}. @@ -66,7 +79,7 @@ class SimpleMssqlStatementUnitTests { createColumn(3, "salary", SqlServerType.MONEY, 8, LengthStrategy.BYTELENTYPE, null)).toArray(new Column[0]); - static final ConnectionOptions OPTIONS = new TestConnectionOptions(); + static final ConnectionOptions OPTIONS = new ConnectionOptions(); @Test void shouldReportNumberOfAffectedRows() { @@ -339,4 +352,15 @@ private static Column createColumn(int index, String name, SqlServerType serverT return new Column(index, name, type, null); } + @ParameterizedTest + @ValueSource(strings = {"select", "SELECT", "sElEcT"}) + void shouldAcceptQueries(String query) { + assertThat(SimpleMssqlStatement.prefersCursors(query)).isTrue(); + } + + @ParameterizedTest + @ValueSource(strings = {" select", "sp_cursor", "INSERT"}) + void shouldRejectQueries(String query) { + assertThat(SimpleMssqlStatement.prefersCursors(query)).isFalse(); + } } diff --git a/src/test/java/io/r2dbc/mssql/TestConnectionOptions.java b/src/test/java/io/r2dbc/mssql/TestConnectionOptions.java deleted file mode 100644 index 914aa614..00000000 --- a/src/test/java/io/r2dbc/mssql/TestConnectionOptions.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.r2dbc.mssql; - -import io.r2dbc.mssql.codec.DefaultCodecs; - -/** - * @author Mark Paluch - */ -class TestConnectionOptions extends ConnectionOptions { - - TestConnectionOptions() { - super(MssqlConnectionConfiguration.DefaultCursorPreference.INSTANCE, new DefaultCodecs(), new IndefinitePreparedStatementCache(), true); - } -} From c127cf921b85d528d05437a070e0a90cd0e76f27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simen=20R=C3=B8kaas?= Date: Mon, 5 Jun 2023 15:07:56 +0200 Subject: [PATCH 2/9] Renamed version, cannot refer to SNAPSHOT version in artifact repository. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 386c64d6..0104584e 100644 --- a/pom.xml +++ b/pom.xml @@ -24,7 +24,7 @@ io.r2dbc r2dbc-mssql - 1.1.0.BUILD-SNAPSHOT + 1.0.2.DSB jar Reactive Relational Database Connectivity - Microsoft SQL Server From 617ea4808bda0a591cb5e8be31c925ee0dbf948b Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 23 Jun 2023 14:36:36 +0200 Subject: [PATCH 3/9] Capture encoder in Encoded instead of the plain value. We now accept Suppliers to encode a value multiple times. [resolves #272] Signed-off-by: Mark Paluch --- .../mssql/codec/BinaryCodecBenchmarks.java | 4 +- .../mssql/codec/BooleanCodecBenchmarks.java | 4 +- .../mssql/codec/ByteCodecBenchmarks.java | 4 +- .../mssql/codec/DecimalCodecBenchmarks.java | 4 +- .../mssql/codec/DoubleCodecBenchmarks.java | 4 +- .../mssql/codec/IntegerCodecBenchmarks.java | 4 +- .../mssql/codec/LocalDateCodecBenchmarks.java | 4 +- .../codec/LocalDateTimeCodecBenchmarks.java | 4 +- .../mssql/codec/LocalTimeCodecBenchmarks.java | 4 +- .../mssql/codec/LongCodecBenchmarks.java | 4 +- .../mssql/codec/ShortCodecBenchmarks.java | 4 +- .../mssql/codec/StringCodecBenchmarks.java | 4 +- .../mssql/codec/UuidCodecBenchmarks.java | 4 +- src/main/java/io/r2dbc/mssql/Binding.java | 4 +- .../mssql/ParametrizedMssqlStatement.java | 10 +- .../io/r2dbc/mssql/codec/BinaryCodec.java | 50 +++++--- .../java/io/r2dbc/mssql/codec/ByteArray.java | 2 +- .../r2dbc/mssql/codec/CharacterEncoder.java | 34 +++--- .../io/r2dbc/mssql/codec/DecimalCodec.java | 22 ++-- .../java/io/r2dbc/mssql/codec/Encoded.java | 110 +++++++++++++++--- .../io/r2dbc/mssql/codec/LocalDateCodec.java | 18 +-- .../mssql/codec/OffsetDateTimeCodec.java | 20 ++-- .../java/io/r2dbc/mssql/codec/PlpEncoded.java | 8 +- .../io/r2dbc/mssql/codec/RpcEncoding.java | 53 ++++++--- .../r2dbc/mssql/message/token/RpcRequest.java | 7 +- .../ParametrizedMssqlStatementUnitTests.java | 20 ++-- .../mssql/codec/PlpEncodedUnitTests.java | 8 +- .../mssql/util/BlockhoundExceptions.java | 32 ----- 28 files changed, 262 insertions(+), 188 deletions(-) delete mode 100644 src/test/java/io/r2dbc/mssql/util/BlockhoundExceptions.java diff --git a/src/jmh/java/io/r2dbc/mssql/codec/BinaryCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/BinaryCodecBenchmarks.java index 5b4ffcc0..07db370d 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/BinaryCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/BinaryCodecBenchmarks.java @@ -75,13 +75,13 @@ public Encoded encodeByteBuffer() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, byte[].class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/BooleanCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/BooleanCodecBenchmarks.java index 33f46602..21caac93 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/BooleanCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/BooleanCodecBenchmarks.java @@ -57,13 +57,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Boolean.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/ByteCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/ByteCodecBenchmarks.java index 6b4c31b3..ea9d4a67 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/ByteCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/ByteCodecBenchmarks.java @@ -57,13 +57,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Byte.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/DecimalCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/DecimalCodecBenchmarks.java index e86205a2..2fbef2fb 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/DecimalCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/DecimalCodecBenchmarks.java @@ -61,13 +61,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, BigDecimal.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/DoubleCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/DoubleCodecBenchmarks.java index dc2e77cc..c0a54b37 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/DoubleCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/DoubleCodecBenchmarks.java @@ -65,13 +65,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Double.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/IntegerCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/IntegerCodecBenchmarks.java index ee008173..c72606cf 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/IntegerCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/IntegerCodecBenchmarks.java @@ -57,13 +57,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Integer.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/LocalDateCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/LocalDateCodecBenchmarks.java index e1b888de..053d6069 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/LocalDateCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/LocalDateCodecBenchmarks.java @@ -61,13 +61,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, LocalDate.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/LocalDateTimeCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/LocalDateTimeCodecBenchmarks.java index 83a12f1e..db3f0ed3 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/LocalDateTimeCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/LocalDateTimeCodecBenchmarks.java @@ -84,13 +84,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, LocalDateTime.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/LocalTimeCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/LocalTimeCodecBenchmarks.java index e07222a6..cd020f96 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/LocalTimeCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/LocalTimeCodecBenchmarks.java @@ -61,13 +61,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, LocalTime.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/LongCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/LongCodecBenchmarks.java index 1328d085..56916a14 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/LongCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/LongCodecBenchmarks.java @@ -57,13 +57,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Long.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/ShortCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/ShortCodecBenchmarks.java index c6c74b19..b36d469e 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/ShortCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/ShortCodecBenchmarks.java @@ -57,13 +57,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, Short.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/StringCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/StringCodecBenchmarks.java index 69a7a9fb..ccdf0695 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/StringCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/StringCodecBenchmarks.java @@ -122,13 +122,13 @@ public String decodeUuid() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(TestByteBufAllocator.TEST, String.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Collation collation, Object value) { Encoded encoded = codecs.encode(TestByteBufAllocator.TEST, RpcParameterContext.in(ValueContext.character(collation, true)), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/jmh/java/io/r2dbc/mssql/codec/UuidCodecBenchmarks.java b/src/jmh/java/io/r2dbc/mssql/codec/UuidCodecBenchmarks.java index ff5a5f62..b1cdd5ba 100644 --- a/src/jmh/java/io/r2dbc/mssql/codec/UuidCodecBenchmarks.java +++ b/src/jmh/java/io/r2dbc/mssql/codec/UuidCodecBenchmarks.java @@ -61,13 +61,13 @@ public Encoded encode() { @Benchmark public Encoded encodeNull() { Encoded encoded = codecs.encodeNull(alloc, UUID.class); - encoded.release(); + encoded.dispose(); return encoded; } private Encoded doEncode(Object value) { Encoded encoded = codecs.encode(alloc, RpcParameterContext.in(), value); - encoded.release(); + encoded.dispose(); return encoded; } } diff --git a/src/main/java/io/r2dbc/mssql/Binding.java b/src/main/java/io/r2dbc/mssql/Binding.java index db6738ad..9cf8abe4 100644 --- a/src/main/java/io/r2dbc/mssql/Binding.java +++ b/src/main/java/io/r2dbc/mssql/Binding.java @@ -79,8 +79,8 @@ boolean hasOutParameters() { void clear() { this.parameters.forEach((s, parameter) -> { - while (parameter.encoded.refCnt() > 0) { - parameter.encoded.release(); + if (!parameter.encoded.isDisposed()) { + parameter.encoded.dispose(); } }); diff --git a/src/main/java/io/r2dbc/mssql/ParametrizedMssqlStatement.java b/src/main/java/io/r2dbc/mssql/ParametrizedMssqlStatement.java index 3c695906..2d15c26c 100644 --- a/src/main/java/io/r2dbc/mssql/ParametrizedMssqlStatement.java +++ b/src/main/java/io/r2dbc/mssql/ParametrizedMssqlStatement.java @@ -35,13 +35,7 @@ import reactor.util.Loggers; import reactor.util.annotation.Nullable; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Objects; +import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -245,7 +239,6 @@ public ParametrizedMssqlStatement bind(String identifier, Object value) { } Encoded encoded = this.codecs.encode(this.client.getByteBufAllocator(), parameterContext, value); - encoded.touch("ParametrizedMssqlStatement.bind(…)"); addBinding(getParameterName(identifier), isIn ? RpcDirection.IN : RpcDirection.OUT, encoded); @@ -272,7 +265,6 @@ public ParametrizedMssqlStatement bindNull(String identifier, Class type) { } Encoded encoded = this.codecs.encodeNull(this.client.getByteBufAllocator(), type); - encoded.touch("ParametrizedMssqlStatement.bindNull(…)"); addBinding(getParameterName(identifier), RpcDirection.IN, encoded); return this; } diff --git a/src/main/java/io/r2dbc/mssql/codec/BinaryCodec.java b/src/main/java/io/r2dbc/mssql/codec/BinaryCodec.java index 175850b7..84c69e4d 100644 --- a/src/main/java/io/r2dbc/mssql/codec/BinaryCodec.java +++ b/src/main/java/io/r2dbc/mssql/codec/BinaryCodec.java @@ -20,13 +20,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.r2dbc.mssql.message.tds.Encode; -import io.r2dbc.mssql.message.type.Length; -import io.r2dbc.mssql.message.type.LengthStrategy; -import io.r2dbc.mssql.message.type.PlpLength; -import io.r2dbc.mssql.message.type.SqlServerType; -import io.r2dbc.mssql.message.type.TdsDataType; -import io.r2dbc.mssql.message.type.TypeInformation; -import io.r2dbc.mssql.message.type.TypeUtils; +import io.r2dbc.mssql.message.type.*; import io.r2dbc.mssql.util.Assert; import io.r2dbc.spi.Blob; import reactor.core.publisher.Mono; @@ -35,6 +29,8 @@ import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Set; +import java.util.function.IntFunction; +import java.util.function.Supplier; /** * Codec for binary values that are represented as {@code byte[]} or {@link ByteBuffer}. @@ -64,7 +60,7 @@ class BinaryCodec implements Codec { }); private static final Set SUPPORTED_TYPES = EnumSet.of(SqlServerType.BINARY, SqlServerType.VARBINARY, - SqlServerType.VARBINARYMAX, SqlServerType.IMAGE); + SqlServerType.VARBINARYMAX, SqlServerType.IMAGE); private BinaryCodec() { } @@ -84,8 +80,7 @@ public Encoded encode(ByteBufAllocator allocator, RpcParameterContext context, O Assert.requireNonNull(context, "RpcParameterContext must not be null"); Assert.requireNonNull(value, "Value must not be null"); - ByteBuf buffer; - + int length; if (value instanceof byte[]) { byte[] bytes = (byte[]) value; @@ -93,9 +88,7 @@ public Encoded encode(ByteBufAllocator allocator, RpcParameterContext context, O if (exceedsBigVarbinary(bytes.length)) { return BlobCodec.INSTANCE.encode(allocator, context, Blob.from(Mono.just(ByteBuffer.wrap(bytes)))); } - - buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.BIGVARBINARY.getLengthStrategy(), SqlServerType.VARBINARY.getMaxLength(), bytes.length); - buffer.writeBytes(bytes); + length = bytes.length; } else { ByteBuffer bytes = (ByteBuffer) value; @@ -104,11 +97,30 @@ public Encoded encode(ByteBufAllocator allocator, RpcParameterContext context, O return BlobCodec.INSTANCE.encode(allocator, context, Blob.from(Mono.just(bytes))); } - buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.BIGVARBINARY.getLengthStrategy(), SqlServerType.VARBINARY.getMaxLength(), bytes.remaining()); - buffer.writeBytes(bytes); + length = bytes.remaining(); } - return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, buffer); + + IntFunction encoder = actualLength -> { + ByteBuf buffer; + + if (value instanceof byte[]) { + + byte[] bytes = (byte[]) value; + + buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.BIGVARBINARY.getLengthStrategy(), SqlServerType.VARBINARY.getMaxLength(), actualLength); + buffer.writeBytes(bytes); + } else { + + ByteBuffer bytes = (ByteBuffer) value; + + buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.BIGVARBINARY.getLengthStrategy(), SqlServerType.VARBINARY.getMaxLength(), actualLength); + buffer.writeBytes(bytes.asReadOnlyBuffer()); + } + return buffer; + }; + + return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, Encoded.ofLengthAware(length, encoder)); } @Override @@ -133,12 +145,12 @@ public Class getType() { @Override public Encoded encodeNull(ByteBufAllocator allocator) { - return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, Unpooled.wrappedBuffer(NULL)); + return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, () -> Unpooled.wrappedBuffer(NULL)); } @Override public Encoded encodeNull(ByteBufAllocator allocator, SqlServerType serverType) { - return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, Unpooled.wrappedBuffer(NULL)); + return new VarbinaryEncoded(TdsDataType.BIGVARBINARY, () -> Unpooled.wrappedBuffer(NULL)); } @Override @@ -206,7 +218,7 @@ static class VarbinaryEncoded extends RpcEncoding.HintedEncoded { private static final String FORMAL_TYPE = SqlServerType.VARBINARY + "(" + TypeUtils.SHORT_VARTYPE_MAX_BYTES + ")"; - VarbinaryEncoded(TdsDataType dataType, ByteBuf value) { + VarbinaryEncoded(TdsDataType dataType, Supplier value) { super(dataType, SqlServerType.VARBINARY, value); } diff --git a/src/main/java/io/r2dbc/mssql/codec/ByteArray.java b/src/main/java/io/r2dbc/mssql/codec/ByteArray.java index 4a436bd0..28a596de 100644 --- a/src/main/java/io/r2dbc/mssql/codec/ByteArray.java +++ b/src/main/java/io/r2dbc/mssql/codec/ByteArray.java @@ -45,7 +45,7 @@ static byte[] fromEncoded(Function encodeFunction) { try { return ByteBufUtil.getBytes(encoded.getValue()); } finally { - encoded.release(); + encoded.dispose(); } } diff --git a/src/main/java/io/r2dbc/mssql/codec/CharacterEncoder.java b/src/main/java/io/r2dbc/mssql/codec/CharacterEncoder.java index f45e30d3..952653d0 100644 --- a/src/main/java/io/r2dbc/mssql/codec/CharacterEncoder.java +++ b/src/main/java/io/r2dbc/mssql/codec/CharacterEncoder.java @@ -23,17 +23,15 @@ import io.r2dbc.mssql.codec.RpcParameterContext.CharacterValueContext; import io.r2dbc.mssql.message.tds.Encode; import io.r2dbc.mssql.message.tds.ServerCharset; -import io.r2dbc.mssql.message.type.Collation; -import io.r2dbc.mssql.message.type.Length; -import io.r2dbc.mssql.message.type.SqlServerType; -import io.r2dbc.mssql.message.type.TdsDataType; -import io.r2dbc.mssql.message.type.TypeUtils; +import io.r2dbc.mssql.message.type.*; import io.r2dbc.spi.Clob; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.annotation.Nullable; import java.nio.CharBuffer; +import java.util.function.Function; +import java.util.function.Supplier; import static io.r2dbc.mssql.message.type.SqlServerType.Category.NCHARACTER; @@ -63,10 +61,10 @@ class CharacterEncoder { static Encoded encodeNull(SqlServerType serverType) { if (isNational(serverType)) { - return new VarcharEncoded(TdsDataType.NVARCHAR, Unpooled.wrappedBuffer(NULL)); + return new VarcharEncoded(TdsDataType.NVARCHAR, () -> Unpooled.wrappedBuffer(NULL)); } - return new NvarcharEncoded(TdsDataType.NVARCHAR, Unpooled.wrappedBuffer(NULL)); + return new NvarcharEncoded(TdsDataType.NVARCHAR, () -> Unpooled.wrappedBuffer(NULL)); } /** @@ -77,16 +75,20 @@ static Encoded encodeNull(SqlServerType serverType) { static Encoded encodeBigVarchar(ByteBufAllocator allocator, RpcDirection direction, @Nullable SqlServerType serverType, Collation collation, boolean sendStringParametersAsUnicode, @Nullable CharSequence value) { - ByteBuf buffer = allocator.buffer((value != null ? value.length() * 2 : 0) + 7); + int initialCapacity = (value != null ? value.length() * 2 : 0) + 7; + Function encoder = unicode -> { - if (isNational(serverType) || sendStringParametersAsUnicode) { - encodeBigVarchar(buffer, direction, collation, true, value); - return new NvarcharEncoded(TdsDataType.NVARCHAR, buffer); + ByteBuf buffer = allocator.buffer(initialCapacity); + encodeBigVarchar(buffer, direction, collation, unicode, value); + return buffer; + }; + + if (isNational(serverType) || sendStringParametersAsUnicode) { + return new NvarcharEncoded(TdsDataType.NVARCHAR, () -> encoder.apply(true)); } - encodeBigVarchar(buffer, direction, collation, false, value); - return new VarcharEncoded(TdsDataType.BIGVARCHAR, buffer); + return new VarcharEncoded(TdsDataType.BIGVARCHAR, Encoded.ofLengthAware(initialCapacity, i -> encoder.apply(false))); } /** @@ -205,14 +207,14 @@ private static SqlServerType getPlpType(@Nullable SqlServerType serverType, Char private static ByteBuf encodeCharSequence(ByteBufAllocator allocator, boolean isNational, CharacterValueContext valueContext, CharSequence it) { return ByteBufUtil.encodeString(allocator, CharBuffer.wrap(it), isNational || valueContext.isSendStringParametersAsUnicode() ? ServerCharset.UNICODE.charset() : - valueContext.getCollation().getCharset()); + valueContext.getCollation().getCharset()); } private static class NvarcharEncoded extends RpcEncoding.HintedEncoded { private static final String FORMAL_TYPE = SqlServerType.NVARCHAR + "(" + (TypeUtils.SHORT_VARTYPE_MAX_BYTES / 2) + ")"; - NvarcharEncoded(TdsDataType dataType, ByteBuf value) { + NvarcharEncoded(TdsDataType dataType, Supplier value) { super(dataType, SqlServerType.NVARCHAR, value); } @@ -227,7 +229,7 @@ private static class VarcharEncoded extends RpcEncoding.HintedEncoded { private static final String FORMAL_TYPE = SqlServerType.VARCHAR + "(" + TypeUtils.SHORT_VARTYPE_MAX_BYTES + ")"; - VarcharEncoded(TdsDataType dataType, ByteBuf value) { + VarcharEncoded(TdsDataType dataType, Supplier value) { super(dataType, SqlServerType.NVARCHAR, value); } diff --git a/src/main/java/io/r2dbc/mssql/codec/DecimalCodec.java b/src/main/java/io/r2dbc/mssql/codec/DecimalCodec.java index b65fc2be..059773d3 100644 --- a/src/main/java/io/r2dbc/mssql/codec/DecimalCodec.java +++ b/src/main/java/io/r2dbc/mssql/codec/DecimalCodec.java @@ -27,6 +27,7 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.util.function.Supplier; /** * Codec for fixed floating-point values that are represented as {@link BigDecimal}. @@ -66,26 +67,31 @@ private DecimalCodec() { @Override Encoded doEncode(ByteBufAllocator allocator, RpcParameterContext context, BigDecimal value) { - BigDecimal valueToUse = value; + BigDecimal valueToUse; // Handle negative scale as a special case for Java 1.5 and later - if (valueToUse.scale() < 0) { - valueToUse = valueToUse.setScale(0); + if (value.scale() < 0) { + valueToUse = value.setScale(0); + } else { + valueToUse = value; } if (exceedsMaxPrecisionOrScale(valueToUse)) { throw new IllegalArgumentException("One or more values is out of range of values for the DECIMAL SQL type"); } - ByteBuf buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.DECIMALN.getLengthStrategy(), 0x11, SqlServerType.DECIMAL.getMaxLength()); + return new DecimalEncoded(TdsDataType.DECIMALN, () -> { - encodeBigDecimal(buffer, valueToUse); - return new DecimalEncoded(TdsDataType.DECIMALN, buffer, MAX_PRECISION, valueToUse.scale()); + ByteBuf buffer = RpcEncoding.prepareBuffer(allocator, TdsDataType.DECIMALN.getLengthStrategy(), 0x11, SqlServerType.DECIMAL.getMaxLength()); + + encodeBigDecimal(buffer, valueToUse); + return buffer; + }, MAX_PRECISION, valueToUse.scale()); } @Override Encoded doEncodeNull(ByteBufAllocator allocator) { - return new DecimalEncoded(TdsDataType.DECIMALN, Unpooled.wrappedBuffer(NULL), MAX_PRECISION, 0); + return new DecimalEncoded(TdsDataType.DECIMALN, () -> Unpooled.wrappedBuffer(NULL), MAX_PRECISION, 0); } @Override @@ -145,7 +151,7 @@ static class DecimalEncoded extends RpcEncoding.HintedEncoded { private final int scale; - DecimalEncoded(TdsDataType dataType, ByteBuf value, int length, int scale) { + DecimalEncoded(TdsDataType dataType, Supplier value, int length, int scale) { super(dataType, SqlServerType.DECIMAL, value); this.length = length; this.scale = scale; diff --git a/src/main/java/io/r2dbc/mssql/codec/Encoded.java b/src/main/java/io/r2dbc/mssql/codec/Encoded.java index 05a294c0..c831c723 100644 --- a/src/main/java/io/r2dbc/mssql/codec/Encoded.java +++ b/src/main/java/io/r2dbc/mssql/codec/Encoded.java @@ -17,25 +17,34 @@ package io.r2dbc.mssql.codec; import io.netty.buffer.ByteBuf; -import io.netty.util.AbstractReferenceCounted; import io.r2dbc.mssql.message.type.SqlServerType; import io.r2dbc.mssql.message.type.TdsDataType; +import reactor.core.Disposable; + +import java.util.function.IntFunction; +import java.util.function.Supplier; /** + * Encoded value, either providing a singleton {@link ByteBuf} or a {@link Supplier} of buffers. + * * @author Mark Paluch */ -public class Encoded extends AbstractReferenceCounted { +public class Encoded implements Disposable { private final TdsDataType dataType; - private final ByteBuf value; + private final Supplier encoder; - protected Encoded(TdsDataType dataType, ByteBuf value) { + Encoded(TdsDataType dataType, Supplier encoder) { this.dataType = dataType; - this.value = value; + this.encoder = encoder; } public static Encoded of(TdsDataType dataType, ByteBuf value) { + return new Encoded(dataType, new DisposableSupplier(value)); + } + + public static Encoded of(TdsDataType dataType, Supplier value) { return new Encoded(dataType, value); } @@ -44,18 +53,7 @@ public TdsDataType getDataType() { } public ByteBuf getValue() { - return this.value; - } - - @Override - public Encoded touch(Object hint) { - this.value.touch(hint); - return this; - } - - @Override - protected void deallocate() { - this.value.release(); + return this.encoder.get(); } /** @@ -77,4 +75,82 @@ public String getFormalType() { throw new IllegalStateException(String.format("Cannot determine a formal type for %s", this.dataType)); } + /** + * Attempt to estimate the length of the buffer to apply allocation optimizations. + * + * @return the estimated length. Can be an approximation or zero, if the buffer size cannot be estimated. + */ + public int estimateLength() { + + if (this.encoder instanceof DisposableSupplier) { + return ((DisposableSupplier) this.encoder).get().readableBytes(); + } + + if (this.encoder instanceof LengthAwareSupplier) { + return ((LengthAwareSupplier) this.encoder).getLength(); + } + + return 0; + } + + public static Supplier ofLengthAware(int length, IntFunction supplier) { + return new LengthAwareSupplier(length, supplier); + } + + @Override + public void dispose() { + + if (this.encoder instanceof DisposableSupplier) { + ((DisposableSupplier) this.encoder).dispose(); + } + } + + static class DisposableSupplier implements Supplier, Disposable { + + private final ByteBuf buf; + + DisposableSupplier(ByteBuf buf) { + this.buf = buf; + } + + @Override + public ByteBuf get() { + return buf.asReadOnly(); + } + + @Override + public void dispose() { + + if (!isDisposed()) { + buf.release(); + } + } + + @Override + public boolean isDisposed() { + return buf.refCnt() == 0; + } + } + + + static class LengthAwareSupplier implements Supplier { + + private final int length; + + private final IntFunction delegate; + + public LengthAwareSupplier(int length, IntFunction delegate) { + this.length = length; + this.delegate = delegate; + } + + @Override + public ByteBuf get() { + return delegate.apply(length); + } + + public int getLength() { + return length; + } + } } diff --git a/src/main/java/io/r2dbc/mssql/codec/LocalDateCodec.java b/src/main/java/io/r2dbc/mssql/codec/LocalDateCodec.java index cc120388..22889715 100644 --- a/src/main/java/io/r2dbc/mssql/codec/LocalDateCodec.java +++ b/src/main/java/io/r2dbc/mssql/codec/LocalDateCodec.java @@ -18,11 +18,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.r2dbc.mssql.message.type.Length; -import io.r2dbc.mssql.message.type.SqlServerType; -import io.r2dbc.mssql.message.type.TdsDataType; -import io.r2dbc.mssql.message.type.TypeInformation; -import io.r2dbc.mssql.message.type.TypeUtils; +import io.r2dbc.mssql.message.type.*; import java.time.LocalDate; import java.time.temporal.ChronoUnit; @@ -59,11 +55,15 @@ private LocalDateCodec() { @Override Encoded doEncode(ByteBufAllocator allocator, RpcParameterContext context, LocalDate value) { - ByteBuf buffer = allocator.buffer(4); - buffer.writeByte(TypeUtils.DAYS_INTO_CE_LENGTH); - encode(buffer, value); - return new RpcEncoding.HintedEncoded(TdsDataType.DATEN, SqlServerType.DATE, buffer); + return new RpcEncoding.HintedEncoded(TdsDataType.DATEN, SqlServerType.DATE, () -> { + + ByteBuf buffer = allocator.buffer(4); + buffer.writeByte(TypeUtils.DAYS_INTO_CE_LENGTH); + encode(buffer, value); + + return buffer; + }); } @Override diff --git a/src/main/java/io/r2dbc/mssql/codec/OffsetDateTimeCodec.java b/src/main/java/io/r2dbc/mssql/codec/OffsetDateTimeCodec.java index 61830c0f..6d5d5990 100644 --- a/src/main/java/io/r2dbc/mssql/codec/OffsetDateTimeCodec.java +++ b/src/main/java/io/r2dbc/mssql/codec/OffsetDateTimeCodec.java @@ -20,11 +20,7 @@ import io.netty.buffer.ByteBufAllocator; import io.r2dbc.mssql.message.tds.Decode; import io.r2dbc.mssql.message.tds.Encode; -import io.r2dbc.mssql.message.type.Length; -import io.r2dbc.mssql.message.type.SqlServerType; -import io.r2dbc.mssql.message.type.TdsDataType; -import io.r2dbc.mssql.message.type.TypeInformation; -import io.r2dbc.mssql.message.type.TypeUtils; +import io.r2dbc.mssql.message.type.*; import java.time.LocalDate; import java.time.LocalTime; @@ -58,14 +54,18 @@ private OffsetDateTimeCodec() { @Override Encoded doEncode(ByteBufAllocator allocator, RpcParameterContext context, OffsetDateTime value) { - ByteBuf buffer = allocator.buffer(12); - Encode.asByte(buffer, 7); // scale - Encode.asByte(buffer, 0x0a); // length + return new RpcEncoding.HintedEncoded(TdsDataType.DATETIMEOFFSETN, SqlServerType.DATETIMEOFFSET, () -> { - doEncode(buffer, value.minusSeconds(value.getOffset().getTotalSeconds())); + ByteBuf buffer = allocator.buffer(12); - return new RpcEncoding.HintedEncoded(TdsDataType.DATETIMEOFFSETN, SqlServerType.DATETIMEOFFSET, buffer); + Encode.asByte(buffer, 7); // scale + Encode.asByte(buffer, 0x0a); // length + + doEncode(buffer, value.minusSeconds(value.getOffset().getTotalSeconds())); + + return buffer; + }); } @Override diff --git a/src/main/java/io/r2dbc/mssql/codec/PlpEncoded.java b/src/main/java/io/r2dbc/mssql/codec/PlpEncoded.java index 087afb02..6291c51f 100644 --- a/src/main/java/io/r2dbc/mssql/codec/PlpEncoded.java +++ b/src/main/java/io/r2dbc/mssql/codec/PlpEncoded.java @@ -57,7 +57,7 @@ public class PlpEncoded extends Encoded { public PlpEncoded(SqlServerType dataType, ByteBufAllocator allocator, Publisher dataStream, Disposable disposable) { - super(dataType.getNullableType(), Unpooled.EMPTY_BUFFER); + super(dataType.getNullableType(), () -> Unpooled.EMPTY_BUFFER); this.serverType = dataType; this.allocator = allocator; @@ -72,12 +72,12 @@ public void encodeHeader(ByteBuf byteBuf) { } @Override - public PlpEncoded touch(Object hint) { - return this; + public boolean isDisposed() { + return this.disposable.isDisposed(); } @Override - protected void deallocate() { + public void dispose() { this.disposable.dispose(); } diff --git a/src/main/java/io/r2dbc/mssql/codec/RpcEncoding.java b/src/main/java/io/r2dbc/mssql/codec/RpcEncoding.java index 4f565b71..f4b8a08d 100644 --- a/src/main/java/io/r2dbc/mssql/codec/RpcEncoding.java +++ b/src/main/java/io/r2dbc/mssql/codec/RpcEncoding.java @@ -30,6 +30,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; +import java.util.function.Supplier; /** * Utility methods to encode RPC parameters. @@ -127,11 +128,15 @@ public static Encoded encodeFixed(ByteBufAllocator allocator, SqlServerType Assert.notNull(serverType.getNullableType(), "Server type provides no nullable type"); LengthStrategy lengthStrategy = serverType.getNullableType().getLengthStrategy(); - ByteBuf buffer = prepareBuffer(allocator, lengthStrategy, serverType.getMaxLength(), serverType.getMaxLength()); - valueEncoder.accept(buffer, value); + return new HintedEncoded(serverType.getNullableType(), serverType, () -> { - return new HintedEncoded(serverType.getNullableType(), serverType, buffer); + ByteBuf buffer = prepareBuffer(allocator, lengthStrategy, serverType.getMaxLength(), serverType.getMaxLength()); + + valueEncoder.accept(buffer, value); + + return buffer; + }); } /** @@ -150,11 +155,16 @@ public static Encoded encode(ByteBufAllocator allocator, SqlServerType serve Assert.notNull(serverType.getNullableType(), "Server type provides no nullable type"); TdsDataType dataType = serverType.getNullableType(); - ByteBuf buffer = prepareBuffer(allocator, dataType.getLengthStrategy(), serverType.getMaxLength(), length); - valueEncoder.accept(buffer, value); - return new HintedEncoded(dataType, serverType, buffer); + return new HintedEncoded(dataType, serverType, () -> { + + ByteBuf buffer = prepareBuffer(allocator, dataType.getLengthStrategy(), serverType.getMaxLength(), length); + + valueEncoder.accept(buffer, value); + + return buffer; + }); } /** @@ -167,9 +177,8 @@ public static Encoded encode(ByteBufAllocator allocator, SqlServerType serve public static Encoded encodeNull(ByteBufAllocator allocator, SqlServerType serverType) { Assert.notNull(serverType.getNullableType(), "Server type does not declare a nullable type"); - ByteBuf buffer = prepareBuffer(allocator, serverType.getNullableType().getLengthStrategy(), serverType.getMaxLength(), 0); - return new HintedEncoded(serverType.getNullableType(), serverType, buffer); + return new HintedEncoded(serverType.getNullableType(), serverType, () -> prepareBuffer(allocator, serverType.getNullableType().getLengthStrategy(), serverType.getMaxLength(), 0)); } /** @@ -184,7 +193,7 @@ public static Encoded wrap(byte[] buffer, SqlServerType serverType) { Assert.isTrue(serverType.getMaxLength() > 0, "Server type does not declare a max length"); Assert.notNull(serverType.getNullableType(), "Server type does not declare a nullable type"); - return new HintedEncoded(serverType.getNullableType(), serverType, Unpooled.wrappedBuffer(buffer)); + return new HintedEncoded(serverType.getNullableType(), serverType, () -> Unpooled.wrappedBuffer(buffer)); } /** @@ -197,11 +206,16 @@ public static Encoded wrap(byte[] buffer, SqlServerType serverType) { public static Encoded encodeTemporalNull(ByteBufAllocator allocator, SqlServerType serverType) { Assert.notNull(serverType.getNullableType(), "Server type does not declare a nullable type"); - ByteBuf buffer = allocator.buffer(1); - Encode.asByte(buffer, 0); - return new HintedEncoded(serverType.getNullableType(), serverType, buffer); + return new HintedEncoded(serverType.getNullableType(), serverType, () -> { + + ByteBuf buffer = allocator.buffer(1); + + Encode.asByte(buffer, 0); + + return buffer; + }); } /** @@ -215,12 +229,17 @@ public static Encoded encodeTemporalNull(ByteBufAllocator allocator, SqlServerTy public static Encoded encodeTemporalNull(ByteBufAllocator allocator, SqlServerType serverType, int scale) { Assert.notNull(serverType.getNullableType(), "Server type does not declare a nullable type"); - ByteBuf buffer = allocator.buffer(1); - Encode.asByte(buffer, scale); - Encode.asByte(buffer, 0); // value length - return new HintedEncoded(serverType.getNullableType(), serverType, buffer); + return new HintedEncoded(serverType.getNullableType(), serverType, () -> { + + ByteBuf buffer = allocator.buffer(1); + + Encode.asByte(buffer, scale); + Encode.asByte(buffer, 0); // value length + + return buffer; + }); } static ByteBuf prepareBuffer(ByteBufAllocator allocator, LengthStrategy lengthStrategy, int maxLength, int length) { @@ -265,7 +284,7 @@ static class HintedEncoded extends Encoded { private final SqlServerType sqlServerType; - public HintedEncoded(TdsDataType dataType, SqlServerType sqlServerType, ByteBuf value) { + public HintedEncoded(TdsDataType dataType, SqlServerType sqlServerType, Supplier value) { super(dataType, value); this.sqlServerType = sqlServerType; } diff --git a/src/main/java/io/r2dbc/mssql/message/token/RpcRequest.java b/src/main/java/io/r2dbc/mssql/message/token/RpcRequest.java index ff3192b1..43ec2496 100644 --- a/src/main/java/io/r2dbc/mssql/message/token/RpcRequest.java +++ b/src/main/java/io/r2dbc/mssql/message/token/RpcRequest.java @@ -777,8 +777,9 @@ public Encoded getValue() { @Override void encode(ByteBuf buffer) { encodeHeader(buffer); - buffer.writeBytes(this.value.getValue()); - this.value.release(); + ByteBuf value = this.value.getValue(); + buffer.writeBytes(value); + value.release(); } void encodeHeader(ByteBuf buffer) { @@ -789,7 +790,7 @@ void encodeHeader(ByteBuf buffer) { int estimateLength() { int estimate = 2 + (getName() != null ? (getName().length() + 1) * 2 : 0); - estimate += this.value.getValue().readableBytes(); + estimate += this.value.estimateLength(); return estimate; } diff --git a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java index 662408ae..d9be4387 100644 --- a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java @@ -16,6 +16,7 @@ package io.r2dbc.mssql; +import io.netty.buffer.ByteBuf; import io.r2dbc.mssql.ParametrizedMssqlStatement.ParsedParameter; import io.r2dbc.mssql.client.TestClient; import io.r2dbc.mssql.codec.DefaultCodecs; @@ -151,17 +152,18 @@ void shouldRejectBindForUnknownParameters() { void shouldCachePreparedStatementHandle() { Encoded encodedPreparedStatementHandle = new DefaultCodecs().encode(TestByteBufAllocator.TEST, RpcParameterContext.in(), 1); - encodedPreparedStatementHandle.getValue().skipBytes(1); // skip maxlen byte + ByteBuf value = encodedPreparedStatementHandle.getValue(); + value.skipBytes(1); // skip maxlen byte TestClient testClient = TestClient.builder() - .assertNextRequestWith(it -> { - assertThat(it).isInstanceOf(RpcRequest.class); - RpcRequest request = (RpcRequest) it; - assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec); - }) - .thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(), - encodedPreparedStatementHandle.getValue())) - .build(); + .assertNextRequestWith(it -> { + assertThat(it).isInstanceOf(RpcRequest.class); + RpcRequest request = (RpcRequest) it; + assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec); + }) + .thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(), + value)) + .build(); String sql = "SELECT * from FOO where firstname = @firstname"; ParametrizedMssqlStatement statement = new ParametrizedMssqlStatement(testClient, this.connectionOptions, sql); diff --git a/src/test/java/io/r2dbc/mssql/codec/PlpEncodedUnitTests.java b/src/test/java/io/r2dbc/mssql/codec/PlpEncodedUnitTests.java index 817afec5..71e8541b 100644 --- a/src/test/java/io/r2dbc/mssql/codec/PlpEncodedUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/codec/PlpEncodedUnitTests.java @@ -30,11 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.*; /** * Unit tests for {@link PlpEncoded}. @@ -125,7 +121,7 @@ void shouldDisposeUnusedBuffers() { PlpEncoded encoded = new PlpEncoded(SqlServerType.VARBINARYMAX, TestByteBufAllocator.TEST, Flux.empty(), () -> dispose.set(true)); - encoded.release(); + encoded.dispose(); assertThat(dispose).isTrue(); } diff --git a/src/test/java/io/r2dbc/mssql/util/BlockhoundExceptions.java b/src/test/java/io/r2dbc/mssql/util/BlockhoundExceptions.java deleted file mode 100644 index dd08c670..00000000 --- a/src/test/java/io/r2dbc/mssql/util/BlockhoundExceptions.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2020-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.r2dbc.mssql.util; - -import io.r2dbc.mssql.client.ssl.TdsSslHandler; -import reactor.blockhound.BlockHound; -import reactor.blockhound.integration.BlockHoundIntegration; - -import java.security.SecureRandom; - -public class BlockhoundExceptions implements BlockHoundIntegration { - - @Override - public void applyTo(BlockHound.Builder builder) { - builder.allowBlockingCallsInside(SecureRandom.class.getName(), "next"); - builder.allowBlockingCallsInside(TdsSslHandler.class.getName(), "createSslHandler"); - } -} From 4bf95d63312ad27e5f87f6f742cd8bb1fe7f56e7 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 23 Jun 2023 15:18:08 +0200 Subject: [PATCH 4/9] Reprepare invalid and missing prepared statements. We now reprepare (retry) a statement that is contextually invalid or cannot be found on the server. [resolves #271] Signed-off-by: Mark Paluch --- .../io/r2dbc/mssql/RpcQueryMessageFlow.java | 89 ++++++++++++++----- ...etrizedMssqlStatementIntegrationTests.java | 42 +++++++++ .../ParametrizedMssqlStatementUnitTests.java | 16 ++-- 3 files changed, 116 insertions(+), 31 deletions(-) diff --git a/src/main/java/io/r2dbc/mssql/RpcQueryMessageFlow.java b/src/main/java/io/r2dbc/mssql/RpcQueryMessageFlow.java index f1e8a3f8..7e7b1f70 100644 --- a/src/main/java/io/r2dbc/mssql/RpcQueryMessageFlow.java +++ b/src/main/java/io/r2dbc/mssql/RpcQueryMessageFlow.java @@ -1,5 +1,5 @@ /* - * Copyright 2018-2022 the original author or authors. + * Copyright 2018-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,15 +25,7 @@ import io.r2dbc.mssql.message.ClientMessage; import io.r2dbc.mssql.message.Message; import io.r2dbc.mssql.message.TransactionDescriptor; -import io.r2dbc.mssql.message.token.AbstractDoneToken; -import io.r2dbc.mssql.message.token.AbstractInfoToken; -import io.r2dbc.mssql.message.token.ColumnMetadataToken; -import io.r2dbc.mssql.message.token.DoneInProcToken; -import io.r2dbc.mssql.message.token.DoneProcToken; -import io.r2dbc.mssql.message.token.ErrorToken; -import io.r2dbc.mssql.message.token.ReturnValue; -import io.r2dbc.mssql.message.token.RowToken; -import io.r2dbc.mssql.message.token.RpcRequest; +import io.r2dbc.mssql.message.token.*; import io.r2dbc.mssql.message.type.Collation; import io.r2dbc.mssql.util.Assert; import io.r2dbc.mssql.util.Operators; @@ -46,6 +38,7 @@ import reactor.util.Loggers; import javax.annotation.processing.Completion; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.function.Consumer; import java.util.function.Predicate; @@ -224,29 +217,30 @@ static Flux exchange(PreparedStatementCache statementCache, Client clie Assert.requireNonNull(query, "Query must not be null"); Sinks.Many outbound = Sinks.many().unicast().onBackpressureBuffer(); - CursorState state = new CursorState(); - int handle = statementCache.getHandle(query, binding); - boolean needsPrepare; + AtomicBoolean retryReprepare = new AtomicBoolean(true); + AtomicBoolean needsPrepare = new AtomicBoolean(false); + Flux messageProducer; if (handle == PreparedStatementCache.UNPREPARED) { messageProducer = Flux.defer(() -> { outbound.emitNext(spCursorPrepExec(PreparedStatementCache.UNPREPARED, query, binding, client.getRequiredCollation(), - client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST); + client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST); return outbound.asFlux(); }); - needsPrepare = true; + needsPrepare.set(true); } else { messageProducer = Flux.defer(() -> { outbound.emitNext(spCursorExec(handle, binding, client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST); return outbound.asFlux(); }); - needsPrepare = false; + needsPrepare.set(false); } + CursorState state = new CursorState(); Flux exchange = client.exchange(messageProducer, isFinalToken(state)); OnCursorComplete cursorComplete = new OnCursorComplete(); @@ -258,7 +252,7 @@ static Flux exchange(PreparedStatementCache statementCache, Client clie ReturnValue returnValue = (ReturnValue) message; - emit = handleSpCursorReturnValue(statementCache, codecs, query, binding, state, needsPrepare, returnValue); + emit = handleSpCursorReturnValue(statementCache, codecs, query, binding, state, needsPrepare.get(), returnValue); if (!emit) { returnValue.release(); @@ -267,6 +261,27 @@ static Flux exchange(PreparedStatementCache statementCache, Client clie state.update(message); + if (message instanceof ErrorToken) { + if (isPreparedStatementNotFound(((ErrorToken) message).getNumber()) && retryReprepare.compareAndSet(true, false)) { + logger.debug("Prepared statement no longer valid: {}", handle); + state.update(Phase.PREPARE_RETRY); + } + } + + if (state.phase == Phase.PREPARE_RETRY) { + emit = false; + } + + if (DoneProcToken.isDone(message) && state.phase == Phase.PREPARE_RETRY) { + + logger.debug("Attempting to re-prepare statement: {}", query); + needsPrepare.set(true); + state.update(Phase.NONE); + outbound.emitNext(spCursorPrepExec(PreparedStatementCache.UNPREPARED, query, binding, client.getRequiredCollation(), + client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST); + return; + } + handleMessage(client, fetchSize, outbound, state, message, sink, cursorComplete, emit); }) .filter(FILTER_PREDICATE); @@ -277,6 +292,21 @@ static Flux exchange(PreparedStatementCache statementCache, Client clie .transform(it -> Operators.discardOnCancel(it, state::cancel).doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)).takeUntilOther(cursorComplete.takeUntil()); } + /** + * Check whether the error indicates a prepared statement requiring reprepare. + *

+ *

  • 586: The prepared statement handle %d is not valid in this context. Please verify that current database, user + * default schema ANSI_NULLS and QUOTED_IDENTIFIER set options are not changed since the handle is prepared.
  • + *
  • 8179: Could not find prepared statement with handle %d.
  • + *
+ * + * @param errorNumber + * @return + */ + private static boolean isPreparedStatementNotFound(long errorNumber) { + return errorNumber == 8179 || errorNumber == 586; + } + private static boolean handleSpCursorReturnValue(PreparedStatementCache statementCache, Codecs codecs, String query, Binding binding, CursorState state, boolean needsPrepare, ReturnValue returnValue) { @@ -356,7 +386,7 @@ private static void handleMessage(Client client, int fetchSize, Consumer request completion.run(); - state.phase = Phase.CLOSED; + state.update(Phase.CLOSED); return; } @@ -394,11 +424,11 @@ static void onDone(Client client, int fetchSize, Consumer request if (((state.hasMore && phase == Phase.NONE) || state.hasSeenRows) && state.wantsMore()) { if (phase == Phase.NONE) { - state.phase = Phase.FETCHING; + state.update(Phase.FETCHING); } requests.accept(spCursorFetch(state.cursorId, FETCH_NEXT, fetchSize, client.getTransactionDescriptor())); } else { - state.phase = Phase.CLOSING; + state.update(Phase.CLOSING); // TODO: spCursorClose should happen also if a subscriber cancels its subscription. requests.accept(spCursorClose(state.cursorId, client.getTransactionDescriptor())); } @@ -628,6 +658,8 @@ static class CursorState { volatile boolean cancelRequested; + volatile ErrorToken errorToken; + Phase phase = Phase.NONE; boolean wantsMore() { @@ -644,12 +676,23 @@ void update(Message it) { } if (it instanceof ErrorToken) { + this.errorToken = (ErrorToken) it; this.hasSeenError = true; } } + public void update(Phase newPhase) { + + this.phase = newPhase; + + if (newPhase == Phase.PREPARE_RETRY) { + errorToken = null; + hasSeenError = false; + } + } + enum Phase { - NONE, FETCHING, CLOSING, CLOSED, ERROR + NONE, FETCHING, PREPARE_RETRY, CLOSING, CLOSED, ERROR } } diff --git a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementIntegrationTests.java b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementIntegrationTests.java index c1685b92..0116330f 100644 --- a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementIntegrationTests.java +++ b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementIntegrationTests.java @@ -144,6 +144,48 @@ void shouldEmitSingleResultForCursoredExecution() { assertThat(rowCounter).hasValue(3); } + @Test + void shouldRepreparePreparedStatement() { + + shouldExecuteBatch(); + + connection.createStatement("SET ANSI_NULLS ON") + .execute() + .flatMap(MssqlResult::getRowsUpdated) + .as(StepVerifier::create) + .verifyComplete(); + + Flux.from(connection.createStatement("SELECT first_name FROM r2dbc_example where id != @P0") + .fetchSize(2) + .bind("P0", 99) + .execute()) + .flatMap(result -> { + + return result.map((row, rowMetadata) -> new Object()); + }) + .as(StepVerifier::create) + .expectNextCount(3) + .verifyComplete(); + + connection.createStatement("SET ANSI_NULLS OFF") + .execute() + .flatMap(MssqlResult::getRowsUpdated) + .as(StepVerifier::create) + .verifyComplete(); + + Flux.from(connection.createStatement("SELECT first_name FROM r2dbc_example where id != @P0") + .fetchSize(2) + .bind("P0", 99) + .execute()) + .flatMap(result -> { + + return result.map((row, rowMetadata) -> new Object()); + }) + .as(StepVerifier::create) + .expectNextCount(3) + .verifyComplete(); + } + @Test void shouldRunStatementWithMultipleResults() { diff --git a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java index d9be4387..a8957ba9 100644 --- a/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/ParametrizedMssqlStatementUnitTests.java @@ -156,14 +156,14 @@ void shouldCachePreparedStatementHandle() { value.skipBytes(1); // skip maxlen byte TestClient testClient = TestClient.builder() - .assertNextRequestWith(it -> { - assertThat(it).isInstanceOf(RpcRequest.class); - RpcRequest request = (RpcRequest) it; - assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec); - }) - .thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(), - value)) - .build(); + .assertNextRequestWith(it -> { + assertThat(it).isInstanceOf(RpcRequest.class); + RpcRequest request = (RpcRequest) it; + assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec); + }) + .thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(), + value)) + .build(); String sql = "SELECT * from FOO where firstname = @firstname"; ParametrizedMssqlStatement statement = new ParametrizedMssqlStatement(testClient, this.connectionOptions, sql); From 7d5b08b765d3a15754ed08f8f1374e018b98b558 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 23 Jun 2023 15:18:23 +0200 Subject: [PATCH 5/9] Polishing. [#272] Signed-off-by: Mark Paluch --- pom.xml | 23 ------------------- ...ockhound.integration.BlockHoundIntegration | 1 - 2 files changed, 24 deletions(-) delete mode 100644 src/test/resources/META-INF/services/reactor.blockhound.integration.BlockHoundIntegration diff --git a/pom.xml b/pom.xml index 386c64d6..f5370178 100644 --- a/pom.xml +++ b/pom.xml @@ -33,7 +33,6 @@ 3.23.1 - 1.0.6.RELEASE 4.0.3 1.8 3.0.2 @@ -135,12 +134,6 @@ reactor-test test - - io.projectreactor.tools - blockhound - ${blockhound.version} - test - io.r2dbc r2dbc-spi-test @@ -388,22 +381,6 @@ - - java18-blockhound - - 1.8 - - - - - io.projectreactor.tools - blockhound-junit-platform - ${blockhound.version} - test - - - - jmh diff --git a/src/test/resources/META-INF/services/reactor.blockhound.integration.BlockHoundIntegration b/src/test/resources/META-INF/services/reactor.blockhound.integration.BlockHoundIntegration deleted file mode 100644 index 2aa25f09..00000000 --- a/src/test/resources/META-INF/services/reactor.blockhound.integration.BlockHoundIntegration +++ /dev/null @@ -1 +0,0 @@ -io.r2dbc.mssql.util.BlockhoundExceptions From fb52750967d9c251f7de0cc72264a4f52026aa88 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 14 Jul 2023 09:29:24 +0200 Subject: [PATCH 6/9] Auto-close staging repository after staging. [#274] Signed-off-by: Mark Paluch --- pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/pom.xml b/pom.xml index f5370178..7bdfc3c9 100644 --- a/pom.xml +++ b/pom.xml @@ -524,7 +524,6 @@ https://oss.sonatype.org/ false true - true From e17bd93317721fef3e9d9a54f464eb5686383a0a Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 14 Jul 2023 09:30:13 +0200 Subject: [PATCH 7/9] Upgrade to Reactor 2022.0.9. [resolves #275] Signed-off-by: Mark Paluch --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 7bdfc3c9..d69c8ef0 100644 --- a/pom.xml +++ b/pom.xml @@ -46,7 +46,7 @@ UTF-8 1.0.0.RELEASE 1.0.0.RELEASE - 2022.0.0 + 2022.0.9 5.3.23 1.17.5 From 68b1984be15498dc077e1a1a4449bc7db174ce83 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 14 Jul 2023 09:35:58 +0200 Subject: [PATCH 8/9] Upgrade build plugins. [#274] Signed-off-by: Mark Paluch --- pom.xml | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pom.xml b/pom.xml index d69c8ef0..fd3acc15 100644 --- a/pom.xml +++ b/pom.xml @@ -40,15 +40,15 @@ 1.33 0.3.0.RELEASE 1.2.11 - 4.8.1 - 11.1.2.jre8-preview + 4.11.0 + 12.2.0.jre8 UTF-8 UTF-8 1.0.0.RELEASE - 1.0.0.RELEASE + 1.0.1.RELEASE 2022.0.9 - 5.3.23 - 1.17.5 + 5.3.29 + 1.18.3 @@ -209,7 +209,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.10.1 + 3.11.0 true ${java.version} @@ -236,12 +236,12 @@ org.apache.maven.plugins maven-deploy-plugin - 3.0.0 + 3.1.1 org.apache.maven.plugins maven-enforcer-plugin - 3.1.0 + 3.3.0 enforce-no-snapshots @@ -262,7 +262,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.4.1 + 3.5.0 io.r2dbc.mssql.authentication,io.r2dbc.mssql.client,io.r2dbc.mssql.codec,io.r2dbc.mssql.message,io.r2dbc.mssql.util @@ -285,7 +285,7 @@ org.apache.maven.plugins maven-source-plugin - 3.2.1 + 3.3.0 attach-javadocs @@ -298,7 +298,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.22.2 + 3.1.2 random @@ -313,7 +313,7 @@ org.apache.maven.plugins maven-failsafe-plugin - 2.22.2 + 3.1.2 @@ -334,7 +334,7 @@ org.codehaus.mojo flatten-maven-plugin - 1.3.0 + 1.5.0 flatten @@ -494,7 +494,7 @@ org.apache.maven.plugins maven-gpg-plugin - 3.0.1 + 3.1.0 sign-artifacts From 9e802b70d717c749cdc2726500d4a87c561ecdbd Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 14 Jul 2023 10:17:25 +0200 Subject: [PATCH 9/9] Adapt to Reactor changes. [#275] Signed-off-by: Mark Paluch --- .../java/io/r2dbc/mssql/util/FluxDiscardOnCancelUnitTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/io/r2dbc/mssql/util/FluxDiscardOnCancelUnitTests.java b/src/test/java/io/r2dbc/mssql/util/FluxDiscardOnCancelUnitTests.java index 60d9504a..4d68d66e 100644 --- a/src/test/java/io/r2dbc/mssql/util/FluxDiscardOnCancelUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/util/FluxDiscardOnCancelUnitTests.java @@ -119,7 +119,7 @@ void shouldNotConsumeItemsOnCancel() { .thenCancel() .verify(); - assertThat(items).toIterable().containsSequence(2, 3); + assertThat(items).toIterable().containsOnly(3); } @Test