From 3da2f14c062089da44462ce4c2a122b35dcd3009 Mon Sep 17 00:00:00 2001 From: Zhengqiang Duan Date: Wed, 17 Jan 2024 15:15:13 +0800 Subject: [PATCH] Refactor proxy and jdbc adapter and add ImplicitTransactionCallback (#29748) --- .../adapter/AbstractStatementAdapter.java | 14 +- .../ShardingSpherePreparedStatement.java | 120 +++++++----------- .../statement/ShardingSphereStatement.java | 112 +++++++--------- .../jdbc/adapter/StatementAdapterTest.java | 2 +- .../implicit/ImplicitTransactionCallback.java | 36 ++++++ .../backend/connector/DatabaseConnector.java | 78 +++++------- 6 files changed, 162 insertions(+), 200 deletions(-) create mode 100644 kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/implicit/ImplicitTransactionCallback.java diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java index 3ccbc6d1404b6..b86728c578cfd 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/adapter/AbstractStatementAdapter.java @@ -57,7 +57,7 @@ public abstract class AbstractStatementAdapter extends AbstractUnsupportedOperat private boolean closed; - protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final Collection executionContexts) { + protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final ExecutionContext executionContext) { if (connection.getAutoCommit()) { return false; } @@ -66,16 +66,8 @@ protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConn if (!TransactionType.isDistributedTransaction(connectionTransaction.getTransactionType()) || isInTransaction) { return false; } - if (1 == executionContexts.size()) { - SQLStatement sqlStatement = executionContexts.iterator().next().getSqlStatementContext().getSqlStatement(); - return isWriteDMLStatement(sqlStatement) && executionContexts.iterator().next().getExecutionUnits().size() > 1; - } - for (ExecutionContext each : executionContexts) { - if (isWriteDMLStatement(each.getSqlStatementContext().getSqlStatement())) { - return true; - } - } - return false; + SQLStatement sqlStatement = executionContext.getSqlStatementContext().getSqlStatement(); + return isWriteDMLStatement(sqlStatement) && executionContext.getExecutionUnits().size() > 1; } private boolean isWriteDMLStatement(final SQLStatement sqlStatement) { diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 220ceea2956dd..0d64b5af9c496 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -17,7 +17,6 @@ package org.apache.shardingsphere.driver.jdbc.core.statement; -import com.google.common.base.Preconditions; import com.google.common.base.Strings; import lombok.AccessLevel; import lombok.Getter; @@ -88,6 +87,7 @@ import org.apache.shardingsphere.traffic.engine.TrafficEngine; import org.apache.shardingsphere.traffic.exception.metadata.EmptyTrafficExecutionUnitException; import org.apache.shardingsphere.traffic.rule.TrafficRule; +import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback; import org.apache.shardingsphere.transaction.util.AutoCommitUtils; import java.sql.Connection; @@ -151,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState @Getter private final boolean selectContainsEnhancedTable; - private Collection executionContexts; + private ExecutionContext executionContext; private Map columnLabelAndIndexMap; @@ -240,8 +240,8 @@ public ResultSet executeQuery() throws SQLException { if (useFederation) { return executeFederationQuery(queryContext); } - executionContexts = createExecutionContext(queryContext); - result = doExecuteQuery(executionContexts); + executionContext = createExecutionContext(queryContext); + result = doExecuteQuery(executionContext); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -254,19 +254,14 @@ public ResultSet executeQuery() throws SQLException { return result; } - private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { - ShardingSphereResultSet result = null; - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - List queryResults = executeQuery0(each); - MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); - List resultSets = getResultSets(); - if (null == columnLabelAndIndexMap) { - columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); - } - result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap); + private ShardingSphereResultSet doExecuteQuery(final ExecutionContext executionContext) throws SQLException { + List queryResults = executeQuery0(executionContext); + MergedResult mergedResult = mergeQuery(queryResults, executionContext.getSqlStatementContext()); + List resultSets = getResultSets(); + if (null == columnLabelAndIndexMap) { + columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData()); } - return result; + return new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap); } private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { @@ -356,15 +351,13 @@ public int executeUpdate() throws SQLException { JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); } - executionContexts = createExecutionContext(queryContext); + executionContext = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - Collection results = new LinkedList<>(); - for (ExecutionContext each : executionContexts) { - results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); - } + Collection results = + executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()); return accumulate(results); } - return executeUpdateWithExecutionContexts(executionContexts); + return executeUpdateWithExecutionContext(executionContext); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -375,17 +368,11 @@ public int executeUpdate() throws SQLException { } } - private int useDriverToExecuteUpdate(final Collection executionContexts) throws SQLException { - Integer result = null; - Preconditions.checkArgument(!executionContexts.isEmpty()); - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); - cacheStatements(executionGroupContext.getInputGroups()); - result = executor.getRegularExecutor().executeUpdate(executionGroupContext, - each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); - } - return result; + private int useDriverToExecuteUpdate(final ExecutionContext executionContext) throws SQLException { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); + cacheStatements(executionGroupContext.getInputGroups()); + return executor.getRegularExecutor().executeUpdate(executionGroupContext, + executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback()); } private int accumulate(final Collection results) { @@ -434,16 +421,13 @@ public boolean execute() throws SQLException { ResultSet resultSet = executeFederationQuery(queryContext); return null != resultSet; } - executionContexts = createExecutionContext(queryContext); + executionContext = createExecutionContext(queryContext); if (hasRawExecutionRule()) { - Collection results = new LinkedList<>(); - for (ExecutionContext each : executionContexts) { - // TODO process getStatement - results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); - } + Collection results = + executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()); return results.iterator().next() instanceof QueryResult; } - return executeWithExecutionContexts(executionContexts); + return executeWithExecutionContext(executionContext); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -469,15 +453,15 @@ private ExecutionGroupContext createRawExecutionGroupContex .prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); } - private boolean executeWithExecutionContexts(final Collection executionContexts) throws SQLException { - return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction(executionContexts) : useDriverToExecute(executionContexts); + private boolean executeWithExecutionContext(final ExecutionContext executionContext) throws SQLException { + return isNeedImplicitCommitTransaction(connection, executionContext) ? executeWithImplicitCommitTransaction(() -> useDriverToExecute(executionContext)) : useDriverToExecute(executionContext); } - private boolean executeWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { + private boolean executeWithImplicitCommitTransaction(final ImplicitTransactionCallback callback) throws SQLException { boolean result; try { connection.setAutoCommit(false); - result = useDriverToExecute(executionContexts); + result = callback.execute(); connection.commit(); // CHECKSTYLE:OFF } catch (final Exception ex) { @@ -490,15 +474,16 @@ private boolean executeWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { - return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction(executionContexts) : useDriverToExecuteUpdate(executionContexts); + private int executeUpdateWithExecutionContext(final ExecutionContext executionContext) throws SQLException { + return isNeedImplicitCommitTransaction(connection, executionContext) ? executeUpdateWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(executionContext)) + : useDriverToExecuteUpdate(executionContext); } - private int executeUpdateWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { + private int executeUpdateWithImplicitCommitTransaction(final ImplicitTransactionCallback callback) throws SQLException { int result; try { connection.setAutoCommit(false); - result = useDriverToExecuteUpdate(executionContexts); + result = callback.execute(); connection.commit(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { @@ -511,17 +496,11 @@ private int executeUpdateWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { - Boolean result = null; - Preconditions.checkArgument(!executionContexts.isEmpty()); - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); - cacheStatements(executionGroupContext.getInputGroups()); - result = executor.getRegularExecutor().execute(executionGroupContext, - each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback()); - } - return result; + private boolean useDriverToExecute(final ExecutionContext executionContext) throws SQLException { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); + cacheStatements(executionGroupContext.getInputGroups()); + return executor.getRegularExecutor().execute(executionGroupContext, + executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback()); } private JDBCExecutorCallback createExecuteCallback() { @@ -557,7 +536,6 @@ public ResultSet getResultSet() throws SQLException { if (useFederation) { return executor.getSqlFederationEngine().getResultSet(); } - ExecutionContext executionContext = executionContexts.iterator().next(); if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) { List resultSets = getResultSets(); @@ -594,19 +572,19 @@ private List getQueryResults(final List resultSets) thro return result; } - private Collection createExecutionContext(final QueryContext queryContext) { + private ExecutionContext createExecutionContext(final QueryContext queryContext) { RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData(); ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName); SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext()); ExecutionContext result = kernelProcessor.generateExecutionContext( queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); - return Collections.singleton(result); + return result; } - private Collection createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { + private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); - return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext())); + return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()); } private QueryContext createQueryContext() { @@ -663,7 +641,7 @@ public ResultSet getGeneratedKeys() throws SQLException { if (null != currentBatchGeneratedKeysResultSet) { return currentBatchGeneratedKeysResultSet; } - Optional generatedKey = findGeneratedKey(executionContexts.iterator().next()); + Optional generatedKey = findGeneratedKey(executionContext); if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) { return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this); } @@ -686,8 +664,8 @@ public void addBatch() { try { QueryContext queryContext = createQueryContext(); trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); - executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); - batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits()); + executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId); + batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits()); } finally { currentResultSet = null; clearParameters(); @@ -696,13 +674,13 @@ public void addBatch() { @Override public int[] executeBatch() throws SQLException { - if (null == executionContexts || executionContexts.isEmpty()) { + if (null == executionContext) { return new int[0]; } try { // TODO add raw SQL executor initBatchPreparedStatementExecutor(); - int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext()); + int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext()); if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) { List batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements(); for (Statement statement : batchPreparedStatementExecutorStatements) { @@ -732,7 +710,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException { ExecutionUnit executionUnit = each.getExecutionUnit(); executionUnits.add(executionUnit); } - batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); + batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName))); setBatchParametersForStatements(); } @@ -773,7 +751,7 @@ public int getResultSetHoldability() { @Override public boolean isAccumulate() { return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream() - .anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames())); + .anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames())); } @Override diff --git a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java index 61a3a48996a96..5cacd97385c9c 100644 --- a/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java +++ b/jdbc/core/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java @@ -17,7 +17,6 @@ package org.apache.shardingsphere.driver.jdbc.core.statement; -import com.google.common.base.Preconditions; import com.google.common.base.Strings; import lombok.AccessLevel; import lombok.Getter; @@ -82,6 +81,7 @@ import org.apache.shardingsphere.traffic.exception.metadata.EmptyTrafficExecutionUnitException; import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback; import org.apache.shardingsphere.traffic.rule.TrafficRule; +import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback; import org.apache.shardingsphere.transaction.util.AutoCommitUtils; import java.sql.Connection; @@ -124,7 +124,7 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter { private boolean returnGeneratedKeys; - private Collection executionContexts; + private ExecutionContext executionContext; private ResultSet currentResultSet; @@ -175,8 +175,8 @@ public ResultSet executeQuery(final String sql) throws SQLException { if (useFederation) { return executeFederationQuery(queryContext); } - executionContexts = createExecutionContext(queryContext); - result = doExecuteQuery(executionContexts); + executionContext = createExecutionContext(queryContext); + result = doExecuteQuery(executionContext); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { // CHECKSTYLE:ON @@ -189,18 +189,12 @@ public ResultSet executeQuery(final String sql) throws SQLException { return result; } - private ShardingSphereResultSet doExecuteQuery(final Collection executionContexts) throws SQLException { - ShardingSphereResultSet result = null; - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - List queryResults = executeQuery0(each); - MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext()); - boolean selectContainsEnhancedTable = - each.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext) each.getSqlStatementContext()).isContainsEnhancedTable(); - result = new ShardingSphereResultSet(getResultSets(), mergedResult, this, selectContainsEnhancedTable, each); - - } - return result; + private ShardingSphereResultSet doExecuteQuery(final ExecutionContext executionContext) throws SQLException { + List queryResults = executeQuery0(executionContext); + MergedResult mergedResult = mergeQuery(queryResults, executionContext.getSqlStatementContext()); + boolean selectContainsEnhancedTable = + executionContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext) executionContext.getSqlStatementContext()).isContainsEnhancedTable(); + return new ShardingSphereResultSet(getResultSets(), mergedResult, this, selectContainsEnhancedTable, executionContext); } private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) { @@ -315,9 +309,10 @@ public int executeUpdate(final String sql, final String[] columnNames) throws SQ } } - private int executeUpdate(final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final Collection executionContexts) throws SQLException { - return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction(updateCallback, sqlStatementContext, executionContexts) - : useDriverToExecuteUpdate(updateCallback, sqlStatementContext, executionContexts); + private int executeUpdate(final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final ExecutionContext executionContext) throws SQLException { + return isNeedImplicitCommitTransaction(connection, executionContext) + ? executeUpdateWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(updateCallback, sqlStatementContext, executionContext)) + : useDriverToExecuteUpdate(updateCallback, sqlStatementContext, executionContext); } private int executeUpdate0(final String sql, final ExecuteUpdateCallback updateCallback, final TrafficExecutorCallback trafficCallback) throws SQLException { @@ -330,23 +325,19 @@ private int executeUpdate0(final String sql, final ExecuteUpdateCallback updateC JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); return executor.getTrafficExecutor().execute(executionUnit, trafficCallback); } - executionContexts = createExecutionContext(queryContext); + executionContext = createExecutionContext(queryContext); if (metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules().stream().anyMatch(RawExecutionRule.class::isInstance)) { - Collection results = new LinkedList<>(); - for (ExecutionContext each : executionContexts) { - results.addAll(executor.getRawExecutor().execute(createRawExecutionContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); - } + Collection results = executor.getRawExecutor().execute(createRawExecutionContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()); return accumulate(results); } - return executeUpdate(updateCallback, queryContext.getSqlStatementContext(), executionContexts); + return executeUpdate(updateCallback, queryContext.getSqlStatementContext(), executionContext); } - private int executeUpdateWithImplicitCommitTransaction(final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, - final Collection executionContexts) throws SQLException { + private int executeUpdateWithImplicitCommitTransaction(final ImplicitTransactionCallback callback) throws SQLException { int result; try { connection.setAutoCommit(false); - result = useDriverToExecuteUpdate(updateCallback, sqlStatementContext, executionContexts); + result = callback.execute(); connection.commit(); // CHECKSTYLE:OFF } catch (final RuntimeException ex) { @@ -360,18 +351,12 @@ private int executeUpdateWithImplicitCommitTransaction(final ExecuteUpdateCallba } private int useDriverToExecuteUpdate(final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, - final Collection executionContexts) throws SQLException { - Integer result = null; - Preconditions.checkArgument(!executionContexts.isEmpty()); - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); - cacheStatements(executionGroupContext.getInputGroups()); - JDBCExecutorCallback callback = createExecuteUpdateCallback(updateCallback, sqlStatementContext); - result = executor.getRegularExecutor().executeUpdate(executionGroupContext, - each.getQueryContext(), each.getRouteContext().getRouteUnits(), callback); - } - return result; + final ExecutionContext executionContext) throws SQLException { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); + cacheStatements(executionGroupContext.getInputGroups()); + JDBCExecutorCallback callback = createExecuteUpdateCallback(updateCallback, sqlStatementContext); + return executor.getRegularExecutor().executeUpdate(executionGroupContext, + executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), callback); } private JDBCExecutorCallback createExecuteUpdateCallback(final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext) { @@ -469,15 +454,12 @@ private boolean execute0(final String sql, final ExecuteCallback executeCallback ResultSet resultSet = executeFederationQuery(queryContext); return null != resultSet; } - executionContexts = createExecutionContext(queryContext); + executionContext = createExecutionContext(queryContext); if (metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules().stream().anyMatch(RawExecutionRule.class::isInstance)) { - Collection results = new LinkedList<>(); - for (ExecutionContext each : executionContexts) { - results.addAll(executor.getRawExecutor().execute(createRawExecutionContext(each), each.getQueryContext(), new RawSQLExecutorCallback())); - } + Collection results = executor.getRawExecutor().execute(createRawExecutionContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback()); return results.iterator().next() instanceof QueryResult; } - return executeWithExecutionContexts(executeCallback, executionContexts); + return executeWithExecutionContext(executeCallback, executionContext); } finally { currentResultSet = null; } @@ -528,13 +510,13 @@ private QueryContext createQueryContext(final String originSQL) { return new QueryContext(sqlStatementContext, sql, Collections.emptyList(), hintValueContext); } - private Collection createExecutionContext(final QueryContext queryContext) throws SQLException { + private ExecutionContext createExecutionContext(final QueryContext queryContext) throws SQLException { clearStatements(); RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData(); ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName); SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext()); - return Collections.singleton(kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), - connection.getDatabaseConnectionManager().getConnectionContext())); + return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), + connection.getDatabaseConnectionManager().getConnectionContext()); } private ExecutionGroupContext createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { @@ -548,16 +530,16 @@ private ExecutionGroupContext createRawExecutionContext(fin .prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName)); } - private boolean executeWithExecutionContexts(final ExecuteCallback executeCallback, final Collection executionContexts) throws SQLException { - return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction(executeCallback, executionContexts) - : useDriverToExecute(executeCallback, executionContexts); + private boolean executeWithExecutionContext(final ExecuteCallback executeCallback, final ExecutionContext executionContext) throws SQLException { + return isNeedImplicitCommitTransaction(connection, executionContext) ? executeWithImplicitCommitTransaction(() -> useDriverToExecute(executeCallback, executionContext)) + : useDriverToExecute(executeCallback, executionContext); } - private boolean executeWithImplicitCommitTransaction(final ExecuteCallback callback, final Collection executionContexts) throws SQLException { + private boolean executeWithImplicitCommitTransaction(final ImplicitTransactionCallback callback) throws SQLException { boolean result; try { connection.setAutoCommit(false); - result = useDriverToExecute(callback, executionContexts); + result = callback.execute(); connection.commit(); // CHECKSTYLE:OFF } catch (final Exception ex) { @@ -570,18 +552,12 @@ private boolean executeWithImplicitCommitTransaction(final ExecuteCallback callb return result; } - private boolean useDriverToExecute(final ExecuteCallback callback, final Collection executionContexts) throws SQLException { - Boolean result = null; - Preconditions.checkArgument(!executionContexts.isEmpty()); - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - ExecutionGroupContext executionGroupContext = createExecutionGroupContext(each); - cacheStatements(executionGroupContext.getInputGroups()); - JDBCExecutorCallback jdbcExecutorCallback = createExecuteCallback(callback, each.getSqlStatementContext().getSqlStatement()); - result = executor.getRegularExecutor().execute(executionGroupContext, - each.getQueryContext(), each.getRouteContext().getRouteUnits(), jdbcExecutorCallback); - } - return result; + private boolean useDriverToExecute(final ExecuteCallback callback, final ExecutionContext executionContext) throws SQLException { + ExecutionGroupContext executionGroupContext = createExecutionGroupContext(executionContext); + cacheStatements(executionGroupContext.getInputGroups()); + JDBCExecutorCallback jdbcExecutorCallback = createExecuteCallback(callback, executionContext.getSqlStatementContext().getSqlStatement()); + return executor.getRegularExecutor().execute(executionGroupContext, + executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), jdbcExecutorCallback); } private void cacheStatements(final Collection> executionGroups) throws SQLException { @@ -625,7 +601,6 @@ public ResultSet getResultSet() throws SQLException { if (useFederation) { return executor.getSqlFederationEngine().getResultSet(); } - ExecutionContext executionContext = executionContexts.iterator().next(); if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) { List resultSets = getResultSets(); @@ -686,7 +661,7 @@ public int getResultSetHoldability() { @Override public boolean isAccumulate() { return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream() - .anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames())); + .anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames())); } @Override @@ -712,7 +687,6 @@ public ResultSet getGeneratedKeys() throws SQLException { } private Optional findGeneratedKey() { - ExecutionContext executionContext = executionContexts.iterator().next(); return executionContext.getSqlStatementContext() instanceof InsertStatementContext ? ((InsertStatementContext) executionContext.getSqlStatementContext()).getGeneratedKeyContext() : Optional.empty(); diff --git a/jdbc/core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java b/jdbc/core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java index 816625d8614af..e907a752eac05 100644 --- a/jdbc/core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java +++ b/jdbc/core/src/test/java/org/apache/shardingsphere/driver/jdbc/adapter/StatementAdapterTest.java @@ -272,6 +272,6 @@ private ShardingSphereStatement mockShardingSphereStatementWithNeedAccumulate(fi @SneakyThrows(ReflectiveOperationException.class) private void setExecutionContext(final ShardingSphereStatement statement, final ExecutionContext executionContext) { - Plugins.getMemberAccessor().set(statement.getClass().getDeclaredField("executionContexts"), statement, Collections.singleton(executionContext)); + Plugins.getMemberAccessor().set(statement.getClass().getDeclaredField("executionContext"), statement, executionContext); } } diff --git a/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/implicit/ImplicitTransactionCallback.java b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/implicit/ImplicitTransactionCallback.java new file mode 100644 index 0000000000000..e9616dd80577f --- /dev/null +++ b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/implicit/ImplicitTransactionCallback.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.transaction.implicit; + +import java.sql.SQLException; + +/** + * Implicit transaction callback. + * + * @param type of return value + */ +public interface ImplicitTransactionCallback { + + /** + * Execute. + * + * @return return value + * @throws SQLException SQL exception + */ + T execute() throws SQLException; +} diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java index ee51ee8e6dca6..b073eaa2c61df 100644 --- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java +++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/connector/DatabaseConnector.java @@ -43,6 +43,7 @@ import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption; import org.apache.shardingsphere.infra.merge.MergeEngine; import org.apache.shardingsphere.infra.merge.result.MergedResult; +import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.schema.util.SystemSchemaUtils; import org.apache.shardingsphere.infra.rule.identifier.type.DataNodeContainedRule; @@ -69,10 +70,11 @@ import org.apache.shardingsphere.sharding.merge.common.IteratorStreamMergedResult; import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement; +import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; -import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLInsertStatement; import org.apache.shardingsphere.sqlfederation.executor.SQLFederationExecutorContext; import org.apache.shardingsphere.transaction.api.TransactionType; +import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback; import java.sql.Connection; import java.sql.ResultSet; @@ -165,21 +167,17 @@ public ResponseHeader execute() throws SQLException { ResultSet resultSet = doExecuteFederation(queryContext, metaDataContexts); return processExecuteFederation(resultSet, metaDataContexts); } - Collection executionContexts = generateExecutionContexts(); - return isNeedImplicitCommitTransaction(executionContexts) ? doExecuteWithImplicitCommitTransaction(executionContexts) : doExecute(executionContexts); + ExecutionContext executionContext = generateExecutionContext(); + return isNeedImplicitCommitTransaction(executionContext) ? doExecuteWithImplicitCommitTransaction(() -> doExecute(executionContext)) : doExecute(executionContext); } - private Collection generateExecutionContexts() { - Collection result = new LinkedList<>(); - MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts(); - ExecutionContext executionContext = new KernelProcessor().generateExecutionContext(queryContext, database, metaDataContexts.getMetaData().getGlobalRuleMetaData(), - metaDataContexts.getMetaData().getProps(), databaseConnectionManager.getConnectionSession().getConnectionContext()); - result.add(executionContext); - // TODO support logical SQL optimize to generate multiple logical SQL - return result; + private ExecutionContext generateExecutionContext() { + ShardingSphereMetaData metaData = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(); + return new KernelProcessor().generateExecutionContext(queryContext, database, metaData.getGlobalRuleMetaData(), metaData.getProps(), + databaseConnectionManager.getConnectionSession().getConnectionContext()); } - private boolean isNeedImplicitCommitTransaction(final Collection executionContexts) { + private boolean isNeedImplicitCommitTransaction(final ExecutionContext executionContext) { if (!databaseConnectionManager.getConnectionSession().isAutoCommit()) { return false; } @@ -187,28 +185,20 @@ private boolean isNeedImplicitCommitTransaction(final Collection 1; - } - for (ExecutionContext each : executionContexts) { - if (isWriteDMLStatement(each.getSqlStatementContext().getSqlStatement())) { - return true; - } - } - return false; + SQLStatement sqlStatement = executionContext.getSqlStatementContext().getSqlStatement(); + return isWriteDMLStatement(sqlStatement) && executionContext.getExecutionUnits().size() > 1; } private boolean isWriteDMLStatement(final SQLStatement sqlStatement) { return sqlStatement instanceof DMLStatement && !(sqlStatement instanceof SelectStatement); } - private ResponseHeader doExecuteWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { - ResponseHeader result; + private T doExecuteWithImplicitCommitTransaction(final ImplicitTransactionCallback callback) throws SQLException { + T result; BackendTransactionManager transactionManager = new BackendTransactionManager(databaseConnectionManager); try { transactionManager.begin(); - result = doExecute(executionContexts); + result = callback.execute(); transactionManager.commit(); // CHECKSTYLE:OFF } catch (final Exception ex) { @@ -220,15 +210,6 @@ private ResponseHeader doExecuteWithImplicitCommitTransaction(final Collection executionContexts) throws SQLException { - ResponseHeader result = null; - // TODO support multi execution context, currently executionContexts.size() always equals 1 - for (ExecutionContext each : executionContexts) { - result = doExecute(each); - } - return result; - } - @SuppressWarnings({"unchecked", "rawtypes"}) private ResponseHeader doExecute(final ExecutionContext executionContext) throws SQLException { if (executionContext.getExecutionUnits().isEmpty()) { @@ -238,11 +219,12 @@ private ResponseHeader doExecute(final ExecutionContext executionContext) throws List result = proxySQLExecutor.execute(executionContext); refreshMetaData(executionContext); Object executeResultSample = result.iterator().next(); - return executeResultSample instanceof QueryResult ? processExecuteQuery(executionContext, result, (QueryResult) executeResultSample) : processExecuteUpdate(executionContext, result); + return executeResultSample instanceof QueryResult ? processExecuteQuery(queryContext.getSqlStatementContext(), result, (QueryResult) executeResultSample) + : processExecuteUpdate(executionContext, result); } private ResultSet doExecuteFederation(final QueryContext queryContext, final MetaDataContexts metaDataContexts) { - boolean isReturnGeneratedKeys = queryContext.getSqlStatementContext().getSqlStatement() instanceof MySQLInsertStatement; + boolean isReturnGeneratedKeys = queryContext.getSqlStatementContext().getSqlStatement() instanceof InsertStatement; ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseConnectionManager.getConnectionSession().getDatabaseName()); DatabaseType protocolType = database.getProtocolType(); ProxyJDBCExecutorCallback callback = ProxyJDBCExecutorCallbackFactory.newInstance(driverType, protocolType, database.getResourceMetaData(), @@ -302,25 +284,25 @@ private void refreshMetaData(final ExecutionContext executionContext) throws SQL contextManager.getMetaDataContexts().getMetaData().getProps()).refresh(executionContext.getSqlStatementContext(), executionContext.getRouteContext().getRouteUnits()); } - private QueryResponseHeader processExecuteQuery(final ExecutionContext executionContext, final List queryResults, final QueryResult queryResultSample) throws SQLException { - queryHeaders = createQueryHeaders(executionContext, queryResultSample); - mergedResult = mergeQuery(executionContext.getSqlStatementContext(), queryResults); + private QueryResponseHeader processExecuteQuery(final SQLStatementContext sqlStatementContext, final List queryResults, final QueryResult queryResultSample) throws SQLException { + queryHeaders = createQueryHeaders(sqlStatementContext, queryResultSample); + mergedResult = mergeQuery(sqlStatementContext, queryResults); return new QueryResponseHeader(queryHeaders); } - private List createQueryHeaders(final ExecutionContext executionContext, final QueryResult queryResultSample) throws SQLException { - int columnCount = getColumnCount(executionContext, queryResultSample); + private List createQueryHeaders(final SQLStatementContext sqlStatementContext, final QueryResult queryResultSample) throws SQLException { + int columnCount = getColumnCount(sqlStatementContext, queryResultSample); List result = new ArrayList<>(columnCount); QueryHeaderBuilderEngine queryHeaderBuilderEngine = new QueryHeaderBuilderEngine(database.getProtocolType()); for (int columnIndex = 1; columnIndex <= columnCount; columnIndex++) { - result.add(createQueryHeader(queryHeaderBuilderEngine, executionContext, queryResultSample, database, columnIndex)); + result.add(createQueryHeader(queryHeaderBuilderEngine, sqlStatementContext, queryResultSample, database, columnIndex)); } return result; } - private int getColumnCount(final ExecutionContext executionContext, final QueryResult queryResultSample) throws SQLException { - return selectContainsEnhancedTable && hasSelectExpandProjections(executionContext.getSqlStatementContext()) - ? ((SelectStatementContext) executionContext.getSqlStatementContext()).getProjectionsContext().getExpandProjections().size() + private int getColumnCount(final SQLStatementContext sqlStatementContext, final QueryResult queryResultSample) throws SQLException { + return selectContainsEnhancedTable && hasSelectExpandProjections(sqlStatementContext) + ? ((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections().size() : queryResultSample.getMetaData().getColumnCount(); } @@ -328,10 +310,10 @@ private boolean hasSelectExpandProjections(final SQLStatementContext sqlStatemen return sqlStatementContext instanceof SelectStatementContext && !((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections().isEmpty(); } - private QueryHeader createQueryHeader(final QueryHeaderBuilderEngine queryHeaderBuilderEngine, final ExecutionContext executionContext, + private QueryHeader createQueryHeader(final QueryHeaderBuilderEngine queryHeaderBuilderEngine, final SQLStatementContext sqlStatementContext, final QueryResult queryResultSample, final ShardingSphereDatabase database, final int columnIndex) throws SQLException { - return selectContainsEnhancedTable && hasSelectExpandProjections(executionContext.getSqlStatementContext()) - ? queryHeaderBuilderEngine.build(((SelectStatementContext) executionContext.getSqlStatementContext()).getProjectionsContext(), queryResultSample.getMetaData(), database, columnIndex) + return selectContainsEnhancedTable && hasSelectExpandProjections(sqlStatementContext) + ? queryHeaderBuilderEngine.build(((SelectStatementContext) sqlStatementContext).getProjectionsContext(), queryResultSample.getMetaData(), database, columnIndex) : queryHeaderBuilderEngine.build(queryResultSample.getMetaData(), database, columnIndex); }