Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: the problem that the number of mirrors before and after may be inconsistent #6348

Open
wants to merge 8 commits into
base: 2.x
Choose a base branch
from
Next Next commit
bugfix: the problem that the number of mirrors before and after may b…
…e inconsistent
funky-eyes committed Feb 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 7ef6fab1e34258266c55319afecf9b5e8cba53d3
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.apache.seata.common.exception.ShouldNeverHappenException;
import org.apache.seata.common.util.CollectionUtils;
import org.apache.seata.rm.datasource.AbstractConnectionProxy;
import org.apache.seata.rm.datasource.ConnectionContext;
@@ -34,6 +35,7 @@
import org.apache.seata.rm.datasource.exception.TableMetaException;
import org.apache.seata.rm.datasource.sql.struct.TableRecords;
import org.apache.seata.sqlparser.SQLRecognizer;
import org.apache.seata.sqlparser.SQLType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@@ -97,6 +99,16 @@ protected T executeAutoCommitFalse(Object[] args) throws Exception {
try {
TableRecords beforeImage = beforeImage();
T result = statementCallback.execute(statementProxy.getTargetStatement(), args);
int updateCount = statementProxy.getUpdateCount();
if (updateCount > 0) {
if (SQLType.UPDATE == sqlRecognizer.getSQLType()) {
if (updateCount > beforeImage.size()) {
String errorMsg =
"Before image size is not equaled to after image size, probably because you use read committed, please retry transaction.";
throw new ShouldNeverHappenException(errorMsg);
}
}
}
TableRecords afterImage = afterImage(beforeImage);
prepareUndoLog(beforeImage, afterImage);
return result;
Original file line number Diff line number Diff line change
@@ -95,20 +95,27 @@ protected String buildBeforeImageSQL(TableMeta tableMeta, ArrayList<List<Object>

@Override
protected TableRecords afterImage(TableRecords beforeImage) throws SQLException {
TableMeta tmeta = getTableMeta();
if (beforeImage == null || beforeImage.size() == 0) {
return TableRecords.empty(getTableMeta());
}
String selectSQL = buildAfterImageSQL(tmeta, beforeImage);
TableMeta tmeta = getTableMeta();
PreparedStatement pst = null;
ResultSet rs = null;
try {
pst = statementProxy.getConnection().prepareStatement(selectSQL);
SqlGenerateUtils.setParamForPk(beforeImage.pkRows(), getTableMeta().getPrimaryKeyOnlyName(), pst);
rs = pst.executeQuery();
return TableRecords.buildRecords(tmeta, rs);
} finally {
IOUtil.close(rs, pst);
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer)sqlRecognizer;
List<String> whereColumns = recognizer.getWhereColumns();
boolean contain = tmeta.containsPK(whereColumns);
if (contain) {
String selectSQL = buildAfterImageSQL(tmeta, beforeImage);
try {
pst = statementProxy.getConnection().prepareStatement(selectSQL);
SqlGenerateUtils.setParamForPk(beforeImage.pkRows(), getTableMeta().getPrimaryKeyOnlyName(), pst);
rs = pst.executeQuery();
return TableRecords.buildRecords(tmeta, rs);
} finally {
IOUtil.close(rs, pst);
}
} else {
return beforeImage();
}
}

Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
*/
package org.apache.seata.sqlparser.antlr.mysql;

import org.apache.seata.common.util.CollectionUtils;
import org.apache.seata.sqlparser.util.ColumnUtils;
import org.apache.seata.sqlparser.ParametersHolder;
import org.apache.seata.sqlparser.SQLType;
@@ -29,6 +30,7 @@
import org.antlr.v4.runtime.tree.ParseTreeWalker;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@@ -129,4 +131,24 @@ public String getTableName() {
public String getOriginalSQL() {
return sqlContext.getOriginalSQL();
}

@Override
public List<String> getWhereColumns() {
List<MySqlContext.SQL> sqls = sqlContext.getUpdateForWhereColumnNames();
if (CollectionUtils.isNotEmpty(sqls)) {
List<String> list = new ArrayList<>(sqls.size());
for (MySqlContext.SQL sql : sqls) {
String column = sql.getUpdateWhereColumnName();
int index = column.indexOf(".");
if (index > 0) {
// table.column -> column name
column = column.substring(index + 1);
}
list.add(column);
}
return list;
}
return Collections.emptyList();
}

}
Original file line number Diff line number Diff line change
@@ -55,4 +55,12 @@ default String getTableAlias(String tableName) {
* @return (`a`, `b`, `c`) -> (a, b, c)
*/
List<String> getUpdateColumnsUnEscape();

/**
* Gets update where columns.
*
* @return the update where columns
*/
List<String> getWhereColumns();

}
Original file line number Diff line number Diff line change
@@ -16,23 +16,30 @@
*/
package org.apache.seata.sqlparser.druid;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLLimit;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLOrderBy;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLExistsExpr;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLPropertyExpr;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.ast.statement.SQLMergeStatement;
import com.alibaba.druid.sql.ast.statement.SQLReplaceStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.visitor.SQLASTVisitor;
import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter;
import org.apache.seata.common.exception.NotSupportYetException;
import org.apache.seata.common.util.CollectionUtils;
import org.apache.seata.sqlparser.SQLParsingException;
import org.apache.seata.sqlparser.SQLRecognizer;

@@ -160,4 +167,60 @@ public boolean visit(SQLInsertStatement x) {
getAst().accept(visitor);
return true;
}

public List<String> getWhereColumns(SQLExpr sqlExpr) {
// single condition
if (sqlExpr instanceof SQLBinaryOpExpr) {
return getWhereColumns(Collections.singletonList(sqlExpr));
} else {
// multiple conditions
return getWhereColumns(sqlExpr.getChildren());
}
}

public List<String> getWhereColumns(List<SQLObject> list) {
if (CollectionUtils.isNotEmpty(list)) {
List<String> columns = new ArrayList<>(list.size());
for (SQLObject sqlObject : list) {
if (sqlObject instanceof SQLIdentifierExpr) {
columns.add(((SQLIdentifierExpr)sqlObject).getName());
} else {
getWhereColumns(sqlObject, columns);
}
}
return columns;
}
return Collections.emptyList();
}

public void getWhereColumns(SQLObject sqlExpr, List<String> list) {
if (sqlExpr instanceof SQLBinaryOpExpr) {
SQLExpr left = ((SQLBinaryOpExpr)sqlExpr).getLeft();
getWhereColumn(left, list);
SQLExpr right = ((SQLBinaryOpExpr)sqlExpr).getRight();
getWhereColumn(right, list);
}
}

public void getWhereColumn(SQLExpr left, List<String> list) {
if (left instanceof SQLBetweenExpr) {
SQLExpr expr = ((SQLBetweenExpr)left).getTestExpr();
if (expr instanceof SQLIdentifierExpr) {
list.add(((SQLIdentifierExpr)expr).getName());
}
if (expr instanceof SQLPropertyExpr) {
list.add(((SQLPropertyExpr)expr).getName());
}
} else if (left instanceof SQLIdentifierExpr) {
list.add(((SQLIdentifierExpr)left).getName());
} else if (left instanceof SQLInListExpr) {
SQLExpr expr = ((SQLInListExpr)left).getExpr();
if (expr instanceof SQLIdentifierExpr) {
list.add(((SQLIdentifierExpr)expr).getName());
}
} else if (left instanceof SQLBinaryOpExpr) {
getWhereColumns(left, list);
}
}

}
Original file line number Diff line number Diff line change
@@ -173,6 +173,13 @@ public List<String> getUpdateColumnsUnEscape() {
return ColumnUtils.delEscape(updateColumns, getDbType());
}


@Override
public List<String> getWhereColumns() {
SQLExpr where = ast.getWhere();
return super.getWhereColumns(where);
}

@Override
protected SQLStatement getAst() {
return this.ast;
Original file line number Diff line number Diff line change
@@ -249,4 +249,12 @@ public boolean visit(SQLExprTableSource x) {
visitor.visit(tableSource);
return tableName.toString();
}


@Override
public List<String> getWhereColumns() {
SQLExpr where = ast.getWhere();
return super.getWhereColumns(where);
}

}
Original file line number Diff line number Diff line change
@@ -176,6 +176,13 @@ public boolean visit(SQLJoinTableSource x) {
return sb.toString();
}


@Override
public List<String> getWhereColumns() {
SQLExpr where = ast.getWhere();
return super.getWhereColumns(where);
}

@Override
protected SQLStatement getAst() {
return ast;
Original file line number Diff line number Diff line change
@@ -173,6 +173,12 @@ public String getOrderByCondition(ParametersHolder parametersHolder, ArrayList<L
return null;
}

@Override
public List<String> getWhereColumns() {
SQLExpr where = ast.getWhere();
return super.getWhereColumns(where);
}

@Override
protected SQLStatement getAst() {
return ast;
Original file line number Diff line number Diff line change
@@ -201,6 +201,13 @@ public String getOrderByCondition(ParametersHolder parametersHolder, ArrayList<L
return null;
}


@Override
public List<String> getWhereColumns() {
SQLExpr where = ast.getWhere();
return super.getWhereColumns(where);
}

@Override
protected SQLStatement getAst() {
return ast;
Original file line number Diff line number Diff line change
@@ -375,4 +375,34 @@ public void testGetUpdateColumns_2() {
Assertions.assertTrue(updateColumn.contains("`"));
}
}

@Test
public void testGetWhereColumns() {
String sql = "UPDATE t1 SET name1 = 'name1', name2 = 'name2' WHERE t1.id between ? and ? or name1= ? and name2= ?";

SQLStatement statement = getSQLStatement(sql);

MySQLUpdateRecognizer mySQLUpdateRecognizer = new MySQLUpdateRecognizer(sql, statement);
List<String> whereColumns = mySQLUpdateRecognizer.getWhereColumns();
Assertions.assertEquals("id", whereColumns.get(0));
Assertions.assertEquals("name1", whereColumns.get(1));
Assertions.assertEquals("name2", whereColumns.get(2));
sql = "UPDATE t1 SET name1 = 'name1', name2 = 'name2' WHERE id between ? and ?";

statement = getSQLStatement(sql);

mySQLUpdateRecognizer = new MySQLUpdateRecognizer(sql, statement);
whereColumns = mySQLUpdateRecognizer.getWhereColumns();
Assertions.assertEquals("id", whereColumns.get(0));

sql = "UPDATE t1 SET name1 = 'name1', name2 = 'name2' WHERE id in(?,? ) and createTime between ? and ?";

statement = getSQLStatement(sql);

mySQLUpdateRecognizer = new MySQLUpdateRecognizer(sql, statement);
whereColumns = mySQLUpdateRecognizer.getWhereColumns();
Assertions.assertEquals("id", whereColumns.get(0));
Assertions.assertEquals("createTime", whereColumns.get(1));
}

}