Skip to content

Commit

Permalink
Refactor proxy and jdbc adapter and add ImplicitTransactionCallback (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Jan 17, 2024
1 parent a8e57a0 commit 3da2f14
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public abstract class AbstractStatementAdapter extends AbstractUnsupportedOperat

private boolean closed;

protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final Collection<ExecutionContext> executionContexts) {
protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final ExecutionContext executionContext) {
if (connection.getAutoCommit()) {
return false;
}
Expand All @@ -66,16 +66,8 @@ protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConn
if (!TransactionType.isDistributedTransaction(connectionTransaction.getTransactionType()) || isInTransaction) {
return false;
}
if (1 == executionContexts.size()) {
SQLStatement sqlStatement = executionContexts.iterator().next().getSqlStatementContext().getSqlStatement();
return isWriteDMLStatement(sqlStatement) && executionContexts.iterator().next().getExecutionUnits().size() > 1;
}
for (ExecutionContext each : executionContexts) {
if (isWriteDMLStatement(each.getSqlStatementContext().getSqlStatement())) {
return true;
}
}
return false;
SQLStatement sqlStatement = executionContext.getSqlStatementContext().getSqlStatement();
return isWriteDMLStatement(sqlStatement) && executionContext.getExecutionUnits().size() > 1;
}

private boolean isWriteDMLStatement(final SQLStatement sqlStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.driver.jdbc.core.statement;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import lombok.AccessLevel;
import lombok.Getter;
Expand Down Expand Up @@ -88,6 +87,7 @@
import org.apache.shardingsphere.traffic.engine.TrafficEngine;
import org.apache.shardingsphere.traffic.exception.metadata.EmptyTrafficExecutionUnitException;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback;
import org.apache.shardingsphere.transaction.util.AutoCommitUtils;

import java.sql.Connection;
Expand Down Expand Up @@ -151,7 +151,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
@Getter
private final boolean selectContainsEnhancedTable;

private Collection<ExecutionContext> executionContexts;
private ExecutionContext executionContext;

private Map<String, Integer> columnLabelAndIndexMap;

Expand Down Expand Up @@ -240,8 +240,8 @@ public ResultSet executeQuery() throws SQLException {
if (useFederation) {
return executeFederationQuery(queryContext);
}
executionContexts = createExecutionContext(queryContext);
result = doExecuteQuery(executionContexts);
executionContext = createExecutionContext(queryContext);
result = doExecuteQuery(executionContext);
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -254,19 +254,14 @@ public ResultSet executeQuery() throws SQLException {
return result;
}

private ShardingSphereResultSet doExecuteQuery(final Collection<ExecutionContext> executionContexts) throws SQLException {
ShardingSphereResultSet result = null;
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
List<QueryResult> queryResults = executeQuery0(each);
MergedResult mergedResult = mergeQuery(queryResults, each.getSqlStatementContext());
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
result = new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, each, columnLabelAndIndexMap);
private ShardingSphereResultSet doExecuteQuery(final ExecutionContext executionContext) throws SQLException {
List<QueryResult> queryResults = executeQuery0(executionContext);
MergedResult mergedResult = mergeQuery(queryResults, executionContext.getSqlStatementContext());
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
return result;
return new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext, columnLabelAndIndexMap);
}

private boolean decide(final QueryContext queryContext, final ShardingSphereDatabase database, final RuleMetaData globalRuleMetaData) {
Expand Down Expand Up @@ -356,15 +351,13 @@ public int executeUpdate() throws SQLException {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
}
executionContexts = createExecutionContext(queryContext);
executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(results);
}
return executeUpdateWithExecutionContexts(executionContexts);
return executeUpdateWithExecutionContext(executionContext);
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -375,17 +368,11 @@ public int executeUpdate() throws SQLException {
}
}

private int useDriverToExecuteUpdate(final Collection<ExecutionContext> executionContexts) throws SQLException {
Integer result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
result = executor.getRegularExecutor().executeUpdate(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}
return result;
private int useDriverToExecuteUpdate(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}

private int accumulate(final Collection<ExecuteResult> results) {
Expand Down Expand Up @@ -434,16 +421,13 @@ public boolean execute() throws SQLException {
ResultSet resultSet = executeFederationQuery(queryContext);
return null != resultSet;
}
executionContexts = createExecutionContext(queryContext);
executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results = new LinkedList<>();
for (ExecutionContext each : executionContexts) {
// TODO process getStatement
results.addAll(executor.getRawExecutor().execute(createRawExecutionGroupContext(each), each.getQueryContext(), new RawSQLExecutorCallback()));
}
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
return executeWithExecutionContexts(executionContexts);
return executeWithExecutionContext(executionContext);
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -469,15 +453,15 @@ private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContex
.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(databaseName));
}

private boolean executeWithExecutionContexts(final Collection<ExecutionContext> executionContexts) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeWithImplicitCommitTransaction(executionContexts) : useDriverToExecute(executionContexts);
private boolean executeWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContext) ? executeWithImplicitCommitTransaction(() -> useDriverToExecute(executionContext)) : useDriverToExecute(executionContext);
}

private boolean executeWithImplicitCommitTransaction(final Collection<ExecutionContext> executionContexts) throws SQLException {
private boolean executeWithImplicitCommitTransaction(final ImplicitTransactionCallback<Boolean> callback) throws SQLException {
boolean result;
try {
connection.setAutoCommit(false);
result = useDriverToExecute(executionContexts);
result = callback.execute();
connection.commit();
// CHECKSTYLE:OFF
} catch (final Exception ex) {
Expand All @@ -490,15 +474,16 @@ private boolean executeWithImplicitCommitTransaction(final Collection<ExecutionC
return result;
}

private int executeUpdateWithExecutionContexts(final Collection<ExecutionContext> executionContexts) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContexts) ? executeUpdateWithImplicitCommitTransaction(executionContexts) : useDriverToExecuteUpdate(executionContexts);
private int executeUpdateWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContext) ? executeUpdateWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(executionContext))
: useDriverToExecuteUpdate(executionContext);
}

private int executeUpdateWithImplicitCommitTransaction(final Collection<ExecutionContext> executionContexts) throws SQLException {
private int executeUpdateWithImplicitCommitTransaction(final ImplicitTransactionCallback<Integer> callback) throws SQLException {
int result;
try {
connection.setAutoCommit(false);
result = useDriverToExecuteUpdate(executionContexts);
result = callback.execute();
connection.commit();
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
Expand All @@ -511,17 +496,11 @@ private int executeUpdateWithImplicitCommitTransaction(final Collection<Executio
return result;
}

private boolean useDriverToExecute(final Collection<ExecutionContext> executionContexts) throws SQLException {
Boolean result = null;
Preconditions.checkArgument(!executionContexts.isEmpty());
// TODO support multi execution context, currently executionContexts.size() always equals 1
for (ExecutionContext each : executionContexts) {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(each);
cacheStatements(executionGroupContext.getInputGroups());
result = executor.getRegularExecutor().execute(executionGroupContext,
each.getQueryContext(), each.getRouteContext().getRouteUnits(), createExecuteCallback());
}
return result;
private boolean useDriverToExecute(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().execute(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback());
}

private JDBCExecutorCallback<Boolean> createExecuteCallback() {
Expand Down Expand Up @@ -557,7 +536,6 @@ public ResultSet getResultSet() throws SQLException {
if (useFederation) {
return executor.getSqlFederationEngine().getResultSet();
}
ExecutionContext executionContext = executionContexts.iterator().next();
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext
|| executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
Expand Down Expand Up @@ -594,19 +572,19 @@ private List<QueryResult> getQueryResults(final List<ResultSet> resultSets) thro
return result;
}

private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext) {
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return Collections.singleton(result);
return result;
}

private Collection<ExecutionContext> createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters()));
return Collections.singleton(new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext()));
return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext());
}

private QueryContext createQueryContext() {
Expand Down Expand Up @@ -663,7 +641,7 @@ public ResultSet getGeneratedKeys() throws SQLException {
if (null != currentBatchGeneratedKeysResultSet) {
return currentBatchGeneratedKeysResultSet;
}
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContexts.iterator().next());
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContext);
if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) {
return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this);
}
Expand All @@ -686,8 +664,8 @@ public void addBatch() {
try {
QueryContext queryContext = createQueryContext();
trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
executionContexts = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContexts.iterator().next().getExecutionUnits());
executionContext = null == trafficInstanceId ? createExecutionContext(queryContext) : createExecutionContext(queryContext, trafficInstanceId);
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits());
} finally {
currentResultSet = null;
clearParameters();
Expand All @@ -696,13 +674,13 @@ public void addBatch() {

@Override
public int[] executeBatch() throws SQLException {
if (null == executionContexts || executionContexts.isEmpty()) {
if (null == executionContext) {
return new int[0];
}
try {
// TODO add raw SQL executor
initBatchPreparedStatementExecutor();
int[] results = batchPreparedStatementExecutor.executeBatch(executionContexts.iterator().next().getSqlStatementContext());
int[] results = batchPreparedStatementExecutor.executeBatch(executionContext.getSqlStatementContext());
if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) {
List<Statement> batchPreparedStatementExecutorStatements = batchPreparedStatementExecutor.getStatements();
for (Statement statement : batchPreparedStatementExecutorStatements) {
Expand Down Expand Up @@ -732,7 +710,7 @@ private void initBatchPreparedStatementExecutor() throws SQLException {
ExecutionUnit executionUnit = each.getExecutionUnit();
executionUnits.add(executionUnit);
}
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
batchPreparedStatementExecutor.init(prepareEngine.prepare(executionContext.getRouteContext(), executionUnits, new ExecutionGroupReportContext(databaseName)));
setBatchParametersForStatements();
}

Expand Down Expand Up @@ -773,7 +751,7 @@ public int getResultSetHoldability() {
@Override
public boolean isAccumulate() {
return metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().findRules(DataNodeContainedRule.class).stream()
.anyMatch(each -> each.isNeedAccumulate(executionContexts.iterator().next().getSqlStatementContext().getTablesContext().getTableNames()));
.anyMatch(each -> each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames()));
}

@Override
Expand Down
Loading

0 comments on commit 3da2f14

Please sign in to comment.