From ac7255a4ff6cab062ab443465e84166f5accff83 Mon Sep 17 00:00:00 2001 From: newborn22 <953950914@qq.com> Date: Mon, 8 Jan 2024 23:28:05 +0800 Subject: [PATCH] fix: refactor splitBatchIntoTwo, add related testcases Signed-off-by: newborn22 <953950914@qq.com> --- go/vt/vttablet/jobcontroller/controller.go | 270 ++---------------- go/vt/vttablet/jobcontroller/gen_sql.go | 156 ---------- go/vt/vttablet/jobcontroller/gen_sql_test.go | 155 ---------- go/vt/vttablet/jobcontroller/sql_related.go | 213 ++++++++++++++ .../jobcontroller/sql_related_test.go | 217 ++++++++++++++ 5 files changed, 459 insertions(+), 552 deletions(-) delete mode 100644 go/vt/vttablet/jobcontroller/gen_sql.go delete mode 100644 go/vt/vttablet/jobcontroller/gen_sql_test.go create mode 100644 go/vt/vttablet/jobcontroller/sql_related.go create mode 100644 go/vt/vttablet/jobcontroller/sql_related_test.go diff --git a/go/vt/vttablet/jobcontroller/controller.go b/go/vt/vttablet/jobcontroller/controller.go index edc52aae1f..db606580a8 100644 --- a/go/vt/vttablet/jobcontroller/controller.go +++ b/go/vt/vttablet/jobcontroller/controller.go @@ -108,6 +108,7 @@ type JobController struct { schedulerNotifyChan chan struct{} // jobScheduler每隔一段时间运行一次调度。但当它收到这个chan的消息后,会立刻开始一次调度 } +// todo newborn22 删除pktype? type PKInfo struct { pkName string pkType querypb.Type @@ -792,6 +793,15 @@ func (jc *JobController) execBatchAndRecord(ctx context.Context, tableSchema, ta // 拆分的基本原理是遍历原先batch的batchCountSQL的结果集,将第batchSize条record的pk作为原先batch的PKEnd,第batchSize+1条record的pk作为新batch的PKStart // 原先batch的PKStart和PKEnd分别成为原先batch的PKStart和新batch的PKEnd func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, table, batchTable, batchSQL, batchCountSQL, batchID string, conn *connpool.DBConn, batchSize, expectedRow int64) (newCurrentBatchSQL string, err error) { + batchSQLStmt, err := sqlparser.Parse(batchSQL) + if err != nil { + return "", err + } + batchCountSQLStmt, err := sqlparser.Parse(batchCountSQL) + if err != nil { + return "", err + } + // 1.根据batchCountSQL生成查询pk值的select sql // 1.1.获得PK信息 pkInfos, err := jc.getTablePkInfo(ctx, tableSchema, table) @@ -801,11 +811,8 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab // 1.2.根据当前batch的batchCountSQL生成select sql,用于获得拆分后batch的拆分列start和end // 只需要将batchCountSQL的投影部分(SelectExprs)从count(*)改为拆分列即可 - batchCountSQLStmt, err := sqlparser.Parse(batchCountSQL) - if err != nil { - return "", err - } batchCountSQLStmtSelect, _ := batchCountSQLStmt.(*sqlparser.Select) + // 根据pk信息生成select exprs var pkExprs []sqlparser.SelectExpr for _, pkInfo := range pkInfos { pkExprs = append(pkExprs, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName(pkInfo.pkName)}) @@ -821,85 +828,33 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab // 拆成多个batch需要遍历完select的全部结果,这可能会导致超时 // 2.1.计算两个batch的batchPKStart和batchPKEnd。实际上,只要获得当前batch的新的PKEnd和新的batch的PKStart - // 遍历前threshold+1条,依然使用同一个连接 qr, err := conn.Exec(ctx, batchSplitSelectSQL, math.MaxInt32, true) if err != nil { return "", err } - // todo literal - var curBatchNewEnd []any - var newBatchStart []any + var curBatchNewEnd []sqltypes.Value + var newBatchStart []sqltypes.Value - for rowCount, row := range qr.Named().Rows { + for rowCount, row := range qr.Rows { // 将原本batch的PKEnd设在threshold条数处 if int64(rowCount) == batchSize-1 { - for _, pkInfo := range pkInfos { - pkName := pkInfo.pkName - keyVal, err := ProcessValue(row[pkName]) - if err != nil { - return "", err - } - curBatchNewEnd = append(curBatchNewEnd, keyVal) - } + curBatchNewEnd = row } // 将第threshold+1条的PK作为新PK的起点 if int64(rowCount) == batchSize { - for _, pkInfo := range pkInfos { - pkName := pkInfo.pkName - keyVal, err := ProcessValue(row[pkName]) - if err != nil { - return "", err - } - newBatchStart = append(newBatchStart, keyVal) - } + newBatchStart = row + break } } - - // 2.2) 将curBatchNewEnd和newBatchStart转换成sql中where部分的<=和>=的字符串 - // todo 直接从ast开始构建 - curBatchLessThanPart, err := genPKsLessThanPart(pkInfos, curBatchNewEnd) - if err != nil { - return "", err - } - curBatchLessThanExpr, err := genExprNodeFromStr(curBatchLessThanPart) - if err != nil { - return "", err - } - - newBatchGreatThanPart, err := genPKsGreaterThanPart(pkInfos, newBatchStart) - if err != nil { - return "", err - } - newBatchGreatThanExpr, err := genExprNodeFromStr(newBatchGreatThanPart) + // 2.2.生成新的batchSQL和新的batchCountSQL + curBatchSQL, newBatchSQL, newBatchCountSQL, err := genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) if err != nil { return "", err } - // 2.3) 通过parser,获得原先batchSQL的greatThan和lessThan的expr ast node - batchSQLStmt, err := sqlparser.Parse(batchSQL) - if err != nil { - return "", err - } - curBatchGreatThanExpr, newBatchLessThanExpr := jc.getBatchSQLGreatThanAndLessThanExprNode(batchSQLStmt) - - // 2.4) 生成拆分后,当前batch的sql和新batch的sql - // 2.4.1) 先获得curBatchSQL和newBatchSQL的where expr ast node,需要将用户输入的where expr与上PK Condition Expr - oldBatchSQLUserWhereExpr := getUserWhereExpr(batchSQLStmt) - curBatchPKConditionExpr := sqlparser.AndExpr{Left: curBatchGreatThanExpr, Right: curBatchLessThanExpr} - newBatchPKConditionExpr := sqlparser.AndExpr{Left: newBatchGreatThanExpr, Right: newBatchLessThanExpr} - curBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: oldBatchSQLUserWhereExpr, Right: &curBatchPKConditionExpr}} - newBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: oldBatchSQLUserWhereExpr, Right: &newBatchPKConditionExpr}} - - // 2.4.2) 替换原先batchSQL和batchCountSQL的where expr来生成新的sql - curBatchSQL := GenSQLByReplaceWhereExprNode(batchSQLStmt, curBatchWhereExpr) - newBatchSQL := GenSQLByReplaceWhereExprNode(batchSQLStmt, newBatchWhereExpr) - - // 2.4.3) 生成新batch的batchCountSQL,原理同上 - newBatchCountSQL := GenSQLByReplaceWhereExprNode(batchCountSQLStmt, newBatchWhereExpr) - - // 构建当前batch新的batch begin及end字段以及新batch的begin及end字段 + // 2.3.计算两个batch的batch start和end字段 getBatchBeginAndEndSQL := fmt.Sprintf(sqlTemplateGetBatchBeginAndEnd, batchTable) getBatchBeginAndEndQuery, err := sqlparser.ParseAndBind(getBatchBeginAndEndSQL, sqltypes.StringBindVariable(batchID)) if err != nil { @@ -914,14 +869,13 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab } currentBatchNewBeginStr := qr.Named().Rows[0]["batch_begin"].ToString() newBatchEndStr := qr.Named().Rows[0]["batch_end"].ToString() - // todo newborn22 next - currentBatchNewEndStr, newBatchBegintStr, err := genBatchStartAndEndStr(nil, nil) + currentBatchNewEndStr, newBatchBeginStr, err := genBatchStartAndEndStr(curBatchNewEnd, newBatchStart) if err != nil { return "", err } - // 2.5) 在batch表中更改旧的条目的sql,并插入新batch条目 - // 在表中更改旧的sql + // 3 将结果记录在表中:在batch表中更改旧的条目的sql,并插入新batch条目 + // 3.1.在表中更改旧的sql updateBatchSQL := fmt.Sprintf(sqlTemplateUpdateBatchSQL, batchTable) updateBatchSQLQuery, err := sqlparser.ParseAndBind(updateBatchSQL, sqltypes.StringBindVariable(curBatchSQL), @@ -935,7 +889,7 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab if err != nil { return "", err } - // 插入新batch条目 + // 3.2.插入新batch条目 newCurrentBatchSQL = curBatchSQL // todo 1-1 -> 1-2开始 nextBatchID, err := genNewBatchID(batchID) @@ -949,7 +903,7 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab sqltypes.StringBindVariable(newBatchSQL), sqltypes.StringBindVariable(newBatchCountSQL), sqltypes.Int64BindVariable(newBatchSize), - sqltypes.StringBindVariable(newBatchBegintStr), + sqltypes.StringBindVariable(newBatchBeginStr), sqltypes.StringBindVariable(newBatchEndStr)) if err != nil { return "", err @@ -1538,11 +1492,11 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, currentBatchSize++ if currentBatchSize == batchSize { - batchSQL, finalWhereStr, err := GenBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) + batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) if err != nil { return "", err } - countSQL := GenCountSQL(tableSchema, tableName, finalWhereStr) + countSQL := genCountSQL(tableSchema, tableName, finalWhereStr) if err != nil { return "", err } @@ -1573,11 +1527,11 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, } // 最后一个batch的行数不一定是batchSize,在循环结束时要将剩余的行数划分到最后一个batch中 if currentBatchSize != 0 { - batchSQL, finalWhereStr, err := GenBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) + batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) if err != nil { return "", err } - countSQL := GenCountSQL(tableSchema, tableName, finalWhereStr) + countSQL := genCountSQL(tableSchema, tableName, finalWhereStr) if err != nil { return "", err } @@ -1615,125 +1569,6 @@ func currentBatchIDInc(currentBatchID string) (string, error) { return strconv.FormatInt(currentBatchIDInt64, 10), nil } -func genBatchStartAndEndStr(currentBatchStart, currentBatchEnd []sqltypes.Value) (currentBatchStartStr string, currentBatchStartEnd string, err error) { - prefix := "" - for i := range currentBatchStart { - prefix = "," - currentBatchStartStr += prefix + currentBatchStart[i].ToString() - currentBatchStartEnd += prefix + currentBatchEnd[i].ToString() - } - return currentBatchStartStr, currentBatchStartEnd, nil -} - -func genPlaceholderByType(typ querypb.Type) (string, error) { - switch typ { - case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_INT64: - return "%d", nil - case querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32, querypb.Type_UINT64: - return "%d", nil - case querypb.Type_FLOAT32, querypb.Type_FLOAT64: - return "%f", nil - // todo string - case querypb.Type_TIMESTAMP, querypb.Type_DATE, querypb.Type_TIME, querypb.Type_DATETIME, querypb.Type_YEAR, - querypb.Type_TEXT, querypb.Type_VARCHAR, querypb.Type_CHAR: - return "'%s'", nil - default: - return "", fmt.Errorf("Unsupported type: %v", typ) - } -} - -func ProcessValue(value sqltypes.Value) (any, error) { - typ := value.Type() - - switch typ { - case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_INT64: - return value.ToInt64() - case querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32, querypb.Type_UINT64: - return value.ToUint64() - case querypb.Type_FLOAT32, querypb.Type_FLOAT64: - return value.ToFloat64() - case querypb.Type_TIMESTAMP, querypb.Type_DATE, querypb.Type_TIME, querypb.Type_DATETIME, querypb.Type_YEAR, - querypb.Type_TEXT, querypb.Type_VARCHAR, querypb.Type_CHAR: - return value.ToString(), nil - default: - return nil, fmt.Errorf("Unsupported type: %v", typ) - } -} - -func genPKsGreaterThanPart(pkInfos []PKInfo, currentBatchStart []any) (string, error) { - curIdx := 0 - pksNum := len(pkInfos) - var equalStr, rst string - for curIdx < pksNum { - curPkName := pkInfos[curIdx].pkName - curPKType := pkInfos[curIdx].pkType - // mysql的浮点类型在比较时有精度损失,不适合作为拆分列 - if curPKType == querypb.Type_FLOAT32 || curPKType == querypb.Type_FLOAT64 { - return "", fmt.Errorf("unsupported type: %v", curPKType) - } - - placeholder, err := genPlaceholderByType(curPKType) - if err != nil { - return "", err - } - - if curIdx == 0 { - rst = fmt.Sprintf("( %s > %s )", curPkName, placeholder) - } else if curIdx != (pksNum - 1) { - rst += fmt.Sprintf(" OR ( %s AND %s > %s )", equalStr, curPkName, placeholder) - } else if curIdx == (pksNum - 1) { - rst += fmt.Sprintf(" OR ( %s AND %s >= %s )", equalStr, curPkName, placeholder) - } - rst = fmt.Sprintf(rst, currentBatchStart[curIdx]) - - if curIdx == 0 { - equalStr = fmt.Sprintf("%s = %s", curPkName, placeholder) - } else { - equalStr += fmt.Sprintf(" AND %s = %s", curPkName, placeholder) - } - equalStr = fmt.Sprintf(equalStr, currentBatchStart[curIdx]) - curIdx++ - } - return rst, nil -} - -func genPKsLessThanPart(pkInfos []PKInfo, currentBatchEnd []any) (string, error) { - curIdx := 0 - pksNum := len(pkInfos) - var equalStr, rst string - for curIdx < pksNum { - curPkName := pkInfos[curIdx].pkName - curPKType := pkInfos[curIdx].pkType - // mysql的浮点类型在比较时有精度损失,不适合作为拆分列 - if curPKType == querypb.Type_FLOAT32 || curPKType == querypb.Type_FLOAT64 { - return "", fmt.Errorf("unsupported type: %v", curPKType) - } - - placeholder, err := genPlaceholderByType(curPKType) - if err != nil { - return "", err - } - - if curIdx == 0 { - rst = fmt.Sprintf("( %s < %s )", curPkName, placeholder) - } else if curIdx != (pksNum - 1) { - rst += fmt.Sprintf(" OR ( %s AND %s < %s )", equalStr, curPkName, placeholder) - } else if curIdx == (pksNum - 1) { - rst += fmt.Sprintf(" OR ( %s AND %s <= %s )", equalStr, curPkName, placeholder) - } - rst = fmt.Sprintf(rst, currentBatchEnd[curIdx]) - - if curIdx == 0 { - equalStr = fmt.Sprintf("%s = %s", curPkName, placeholder) - } else { - equalStr += fmt.Sprintf(" AND %s = %s", curPkName, placeholder) - } - equalStr = fmt.Sprintf(equalStr, currentBatchEnd[curIdx]) - curIdx++ - } - return rst, nil -} - // 通知jobScheduler让它立刻开始一次调度。 func (jc *JobController) notifyJobScheduler() { if jc.schedulerNotifyChan == nil { @@ -1797,50 +1632,3 @@ func (jc *JobController) updateBatchStatus(batchTableSchema, batchTableName, sta _, err = jc.execQuery(context.Background(), batchTableSchema, query) return err } - -func getUserWhereExpr(stmt sqlparser.Statement) (expr sqlparser.Expr) { - switch s := stmt.(type) { - case *sqlparser.Update: - tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) - expr = tempAndExpr.Left - return expr - case *sqlparser.Delete: - tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) - expr = tempAndExpr.Left - return expr - default: - // the code won't reach here - return nil - } -} - -func (jc *JobController) getBatchSQLGreatThanAndLessThanExprNode(stmt sqlparser.Statement) (greatThanExpr sqlparser.Expr, lessThanExpr sqlparser.Expr) { - switch s := stmt.(type) { - case *sqlparser.Update: - // the type switch will be ok - andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) - pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr) - greatThanExpr = pkConditionExpr.Left - lessThanExpr = pkConditionExpr.Right - return greatThanExpr, lessThanExpr - case *sqlparser.Delete: - // the type switch will be ok - andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) - pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr) - greatThanExpr = pkConditionExpr.Left - lessThanExpr = pkConditionExpr.Right - return greatThanExpr, lessThanExpr - } - // the code won't reach here - return nil, nil -} - -func genExprNodeFromStr(condition string) (sqlparser.Expr, error) { - tmpSQL := fmt.Sprintf("select 1 where %s", condition) - tmpStmt, err := sqlparser.Parse(tmpSQL) - if err != nil { - return nil, err - } - tmpStmtSelect, _ := tmpStmt.(*sqlparser.Select) - return tmpStmtSelect.Where.Expr, nil -} diff --git a/go/vt/vttablet/jobcontroller/gen_sql.go b/go/vt/vttablet/jobcontroller/gen_sql.go deleted file mode 100644 index 339f9efe89..0000000000 --- a/go/vt/vttablet/jobcontroller/gen_sql.go +++ /dev/null @@ -1,156 +0,0 @@ -/* -Copyright ApeCloud, Inc. -Licensed under the Apache v2(found in the LICENSE file in the root directory). -*/ - -package jobcontroller - -import ( - "errors" - "fmt" - - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/sqlparser" -) - -func GenPKsGreaterEqualOrLessEqualStr(pkInfos []PKInfo, currentBatchStart []sqltypes.Value, greatEqual bool) (string, error) { - buf := sqlparser.NewTrackedBuffer(nil) - prefix := "" - // This loop handles the case for composite pks. For example, - // if lastpk was (1,2), and the greatEqual is true, then clause would be: - // (col1 > 1) or (col1 = 1 and col2 >= 2). - for curCol := 0; curCol <= len(pkInfos)-1; curCol++ { - buf.Myprintf("%s(", prefix) - prefix = " or " - for i, pk := range currentBatchStart[:curCol] { - buf.Myprintf("%s = ", pkInfos[i].pkName) - pk.EncodeSQL(buf) - buf.Myprintf(" and ") - } - if curCol == len(pkInfos)-1 { - if greatEqual { - buf.Myprintf("%s >= ", pkInfos[curCol].pkName) - } else { - buf.Myprintf("%s <= ", pkInfos[curCol].pkName) - } - } else { - if greatEqual { - buf.Myprintf("%s > ", pkInfos[curCol].pkName) - } else { - buf.Myprintf("%s < ", pkInfos[curCol].pkName) - } - } - currentBatchStart[curCol].EncodeSQL(buf) - buf.Myprintf(")") - } - return buf.String(), nil -} - -func GenPKConditionExprByStr(greatThanPart, lessThanPart string) (sqlparser.Expr, error) { - tmpSQL := fmt.Sprintf("select 1 where (%s) AND (%s)", greatThanPart, lessThanPart) - tmpStmt, err := sqlparser.Parse(tmpSQL) - if err != nil { - return nil, err - } - tmpStmtSelect, ok := tmpStmt.(*sqlparser.Select) - if !ok { - return nil, errors.New("genPKConditionExprByStr: tmpStmt is not *sqlparser.Select") - } - return tmpStmtSelect.Where.Expr, nil -} - -func GenSQLByReplaceWhereExprNode(stmt sqlparser.Statement, whereExpr sqlparser.Where) string { - switch s := stmt.(type) { - case *sqlparser.Update: - s.Where = &whereExpr - return sqlparser.String(s) - case *sqlparser.Delete: - s.Where = &whereExpr - return sqlparser.String(s) - case *sqlparser.Select: - // 针对batchCountSQL - s.Where = &whereExpr - return sqlparser.String(s) - default: - // the code won't reach here - return "" - } -} - -func ReplaceWhereExprNode(stmt sqlparser.Statement, whereExpr sqlparser.Where) sqlparser.Statement { - switch s := stmt.(type) { - case *sqlparser.Update: - s.Where = &whereExpr - return s - case *sqlparser.Delete: - s.Where = &whereExpr - return s - default: - // the code won't reach here - return nil - } -} - -// todo newborn22 对参数进行调整 -func GenBatchSQL(sql string, stmt sqlparser.Statement, whereExpr sqlparser.Expr, currentBatchStart, currentBatchEnd []sqltypes.Value, pkInfos []PKInfo) (batchSQL, finalWhereStr string, err error) { - // 1. 生成>=的部分 - greatThanPart, err := GenPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchStart, true) - if err != nil { - return "", "", err - } - - // 2.生成<=的部分 - lessThanPart, err := GenPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchEnd, false) - if err != nil { - return "", "", err - } - - // 3.将pk>= and pk <= 拼接起来并生成相应的condition expr ast node - pkConditionExpr, err := GenPKConditionExprByStr(greatThanPart, lessThanPart) - if err != nil { - return "", "", err - } - - // 4.将原本sql stmt中的where expr ast node用AND拼接上pkConditionExpr,作为batchSQL的where expr ast node - // 4.1先生成新的condition ast node - andExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: whereExpr, Right: pkConditionExpr}} - batchSQL = GenSQLByReplaceWhereExprNode(stmt, andExpr) - finalWhereStr = sqlparser.String(andExpr.Expr) - - return batchSQL, finalWhereStr, nil -} - -// todo newbon22 删除 -// todo newborn22 batchSQL和batchCountSQL对浮点数进行拦截,可能在获得pkInfo时就进行拦截。 -// 拆分列所支持的类型需要满足以下条件: -// 1.在sql中可以正确地使用between或>=,<=进行比较运算,且没有精度问题。 -// 2.可以转换成go中的int64,float64或string三种类型之一,且转换后,在golang中的比较规则和mysql中的比较规则相同 -func GenCountSQLOld(tableSchema, tableName, wherePart string, currentBatchStart, currentBatchEnd []sqltypes.Value, pkInfos []PKInfo) (countSQL string, err error) { - // 1. 生成>=的部分 - greatThanPart, err := GenPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchStart, true) - if err != nil { - return "", err - } - - // 2.生成<=的部分 - lessThanPart, err := GenPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchEnd, false) - if err != nil { - return "", err - } - - // 3.将各部分拼接成最终的countSQL - countSQL = fmt.Sprintf("select count(*) as count_rows from %s.%s where (%s) and ((%s) and (%s))", - tableSchema, tableName, wherePart, greatThanPart, lessThanPart) - - return countSQL, nil -} - -// todo newborn22 batchSQL和batchCountSQL对浮点数进行拦截,可能在获得pkInfo时就进行拦截。 -// 拆分列所支持的类型需要满足以下条件: -// 1.在sql中可以正确地使用between或>=,<=进行比较运算,且没有精度问题。 -// 2.可以转换成go中的int64,float64或string三种类型之一,且转换后,在golang中的比较规则和mysql中的比较规则相同 -func GenCountSQL(tableSchema, tableName, whereExpr string) (countSQL string) { - countSQL = fmt.Sprintf("select count(*) as count_rows from %s.%s where %s)", - tableSchema, tableName, whereExpr) - return countSQL -} diff --git a/go/vt/vttablet/jobcontroller/gen_sql_test.go b/go/vt/vttablet/jobcontroller/gen_sql_test.go deleted file mode 100644 index b06b7c9dff..0000000000 --- a/go/vt/vttablet/jobcontroller/gen_sql_test.go +++ /dev/null @@ -1,155 +0,0 @@ -/* -Copyright ApeCloud, Inc. -Licensed under the Apache v2(found in the LICENSE file in the root directory). -*/ - -package jobcontroller - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/sqlparser" -) - -func TestGenPKsGreaterEqualOrLessEqual(t *testing.T) { - type args struct { - pkInfos []PKInfo - currentBatchStart []sqltypes.Value - greatEqual bool - } - tests := []struct { - name string - args args - want string - }{ - { - name: "Test GenPKsGreaterEqualOrLessEqualStr, Single Int", - args: args{ - pkInfos: []PKInfo{ - {pkName: "a"}, - }, - currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1)}, - greatEqual: true, - }, - want: "(a >= 1)", - }, - { - name: "Test GenPKsGreaterEqualOrLessEqualStr, Two INTs", - args: args{ - pkInfos: []PKInfo{ - {pkName: "a"}, - {pkName: "b"}, - }, - currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, - greatEqual: true, - }, - want: "(a > 1) or (a = 1 and b >= 2)", - }, - { - name: "Test GenPKsGreaterEqualOrLessEqualStr, One INT With One String", - args: args{ - pkInfos: []PKInfo{ - {pkName: "a"}, - {pkName: "b"}, - }, - currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewTimestamp("1704630977")}, - greatEqual: false, - }, - want: "(a < 1) or (a = 1 and b <= '1704630977')", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, _ := GenPKsGreaterEqualOrLessEqualStr(tt.args.pkInfos, tt.args.currentBatchStart, tt.args.greatEqual) - assert.Equalf(t, tt.want, got, "GenPKsGreaterEqualOrLessEqualStr(%v, %v, %v)", tt.args.pkInfos, tt.args.currentBatchStart, tt.args.greatEqual) - }) - } -} - -func TestGenBatchSQL(t *testing.T) { - sql := "update t set c = 1 where 1 = 1 or 2 = 2 and 3 = 3" - stmt, _ := sqlparser.Parse(sql) - whereExpr := stmt.(*sqlparser.Update).Where - currentBatchStart := []sqltypes.Value{sqltypes.NewInt64(1)} - currentBatchEnd := []sqltypes.Value{sqltypes.NewInt64(9)} - pkInfos := []PKInfo{{pkName: "pk1"}} - batchSQL, finalWhereStr, _ := GenBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - expectedBatchSQL := "update t set c = 1 where (1 = 1 or 2 = 2 and 3 = 3) and (pk1 >= 1 and pk1 <= 9)" - expectedWhereStr := "(1 = 1 or 2 = 2 and 3 = 3) and (pk1 >= 1 and pk1 <= 9)" - assert.Equalf(t, expectedBatchSQL, batchSQL, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - assert.Equalf(t, expectedWhereStr, finalWhereStr, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - - sql = "update t set c = 1 where 1 = 1 or 2 = 2 " - stmt, _ = sqlparser.Parse(sql) - whereExpr = stmt.(*sqlparser.Update).Where - currentBatchStart = []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1)} - currentBatchEnd = []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9)} - pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}} - batchSQL, finalWhereStr, _ = GenBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - expectedBatchSQL = "update t set c = 1 where (1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" - expectedWhereStr = "(1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" - assert.Equalf(t, expectedBatchSQL, batchSQL, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - assert.Equalf(t, expectedWhereStr, finalWhereStr, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - - sql = "update t set c = 1 where 1 = 1 or 2 = 2 " - stmt, _ = sqlparser.Parse(sql) - whereExpr = stmt.(*sqlparser.Update).Where - currentBatchStart = []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1), sqltypes.NewInt64(1)} - currentBatchEnd = []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9), sqltypes.NewInt64(9)} - pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}, {pkName: "pk3"}} - batchSQL, finalWhereStr, _ = GenBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - expectedBatchSQL = "update t set c = 1 where (1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 > 1 or pk1 = 1 and pk2 = 1 and pk3 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 < 9 or pk1 = 9 and pk2 = 9 and pk3 <= 9))" - expectedWhereStr = "(1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 > 1 or pk1 = 1 and pk2 = 1 and pk3 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 < 9 or pk1 = 9 and pk2 = 9 and pk3 <= 9))" - assert.Equalf(t, expectedBatchSQL, batchSQL, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) - assert.Equalf(t, expectedWhereStr, finalWhereStr, "GenBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) -} - -//func TestGenCountSQL_old(t *testing.T) { -// type args struct { -// tableSchema string -// tableName string -// wherePart string -// currentBatchStart []sqltypes.Value -// currentBatchEnd []sqltypes.Value -// pkInfos []PKInfo -// } -// tests := []struct { -// name string -// args args -// want string -// }{ -// { -// name: "test_db", -// args: args{ -// tableSchema: "test_db", -// tableName: "test_table", -// wherePart: "id = 1", -// currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1)}, -// currentBatchEnd: []sqltypes.Value{sqltypes.NewInt64(9)}, -// pkInfos: []PKInfo{{pkName: "pk1"}}, -// }, -// want: "select count(*) as count_rows from test_db.test_table where (id = 1) and (((pk1 >= 1)) and ((pk1 <= 9)))", -// }, -// { -// name: "test_db", -// args: args{ -// tableSchema: "test_db", -// tableName: "test_table", -// wherePart: "id = 1", -// currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1)}, -// currentBatchEnd: []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9)}, -// pkInfos: []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}}, -// }, -// want: "select count(*) as count_rows from test_db.test_table where (id = 1) and (((pk1 > 1) or (pk1 = 1 and pk2 >= 1)) and ((pk1 < 9) or (pk1 = 9 and pk2 <= 9)))", -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got, _ := GenCountSQL(tt.args.tableSchema, tt.args.tableName, tt.args.wherePart, tt.args.currentBatchStart, tt.args.currentBatchEnd, tt.args.pkInfos) -// assert.Equalf(t, tt.want, got, "GenCountSQL(%v,%v,%v,%v,%v,%v)", tt.args.tableSchema, tt.args.tableName, tt.args.wherePart, tt.args.currentBatchStart, tt.args.currentBatchEnd, tt.args.pkInfos) -// }) -// } -//} diff --git a/go/vt/vttablet/jobcontroller/sql_related.go b/go/vt/vttablet/jobcontroller/sql_related.go new file mode 100644 index 0000000000..e56926d5a4 --- /dev/null +++ b/go/vt/vttablet/jobcontroller/sql_related.go @@ -0,0 +1,213 @@ +/* +Copyright ApeCloud, Inc. +Licensed under the Apache v2(found in the LICENSE file in the root directory). +*/ + +package jobcontroller + +import ( + "errors" + "fmt" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" +) + +func genPKsGreaterEqualOrLessEqualStr(pkInfos []PKInfo, currentBatchStart []sqltypes.Value, greatEqual bool) (string, error) { + buf := sqlparser.NewTrackedBuffer(nil) + prefix := "" + // This loop handles the case for composite pks. For example, + // if lastpk was (1,2), and the greatEqual is true, then clause would be: + // (col1 > 1) or (col1 = 1 and col2 >= 2). + for curCol := 0; curCol <= len(pkInfos)-1; curCol++ { + buf.Myprintf("%s(", prefix) + prefix = " or " + for i, pk := range currentBatchStart[:curCol] { + buf.Myprintf("%s = ", pkInfos[i].pkName) + pk.EncodeSQL(buf) + buf.Myprintf(" and ") + } + if curCol == len(pkInfos)-1 { + if greatEqual { + buf.Myprintf("%s >= ", pkInfos[curCol].pkName) + } else { + buf.Myprintf("%s <= ", pkInfos[curCol].pkName) + } + } else { + if greatEqual { + buf.Myprintf("%s > ", pkInfos[curCol].pkName) + } else { + buf.Myprintf("%s < ", pkInfos[curCol].pkName) + } + } + currentBatchStart[curCol].EncodeSQL(buf) + buf.Myprintf(")") + } + return buf.String(), nil +} + +func genPKConditionExprByStr(greatThanPart, lessThanPart string) (sqlparser.Expr, error) { + tmpSQL := fmt.Sprintf("select 1 where (%s) AND (%s)", greatThanPart, lessThanPart) + tmpStmt, err := sqlparser.Parse(tmpSQL) + if err != nil { + return nil, err + } + tmpStmtSelect, ok := tmpStmt.(*sqlparser.Select) + if !ok { + return nil, errors.New("genPKConditionExprByStr: tmpStmt is not *sqlparser.Select") + } + return tmpStmtSelect.Where.Expr, nil +} + +func genSQLByReplaceWhereExprNode(stmt sqlparser.Statement, whereExpr sqlparser.Where) string { + switch s := stmt.(type) { + case *sqlparser.Update: + s.Where = &whereExpr + return sqlparser.String(s) + case *sqlparser.Delete: + s.Where = &whereExpr + return sqlparser.String(s) + case *sqlparser.Select: + // 针对batchCountSQL + s.Where = &whereExpr + return sqlparser.String(s) + default: + // the code won't reach here + return "" + } +} + +// todo newborn22 对参数进行调整 +func genBatchSQL(sql string, stmt sqlparser.Statement, whereExpr sqlparser.Expr, currentBatchStart, currentBatchEnd []sqltypes.Value, pkInfos []PKInfo) (batchSQL, finalWhereStr string, err error) { + // 1. 生成>=的部分 + greatThanPart, err := genPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchStart, true) + if err != nil { + return "", "", err + } + + // 2.生成<=的部分 + lessThanPart, err := genPKsGreaterEqualOrLessEqualStr(pkInfos, currentBatchEnd, false) + if err != nil { + return "", "", err + } + + // 3.将pk>= and pk <= 拼接起来并生成相应的condition expr ast node + pkConditionExpr, err := genPKConditionExprByStr(greatThanPart, lessThanPart) + if err != nil { + return "", "", err + } + + // 4.将原本sql stmt中的where expr ast node用AND拼接上pkConditionExpr,作为batchSQL的where expr ast node + // 4.1先生成新的condition ast node + andExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: whereExpr, Right: pkConditionExpr}} + batchSQL = genSQLByReplaceWhereExprNode(stmt, andExpr) + finalWhereStr = sqlparser.String(andExpr.Expr) + + return batchSQL, finalWhereStr, nil +} + +// todo newborn22 batchSQL和batchCountSQL对浮点数进行拦截,可能在获得pkInfo时就进行拦截。 +// 拆分列所支持的类型需要满足以下条件: +// 1.在sql中可以正确地使用between或>=,<=进行比较运算,且没有精度问题。 +// 2.可以转换成go中的int64,float64或string三种类型之一,且转换后,在golang中的比较规则和mysql中的比较规则相同 +func genCountSQL(tableSchema, tableName, whereExpr string) (countSQL string) { + countSQL = fmt.Sprintf("select count(*) as count_rows from %s.%s where %s)", + tableSchema, tableName, whereExpr) + return countSQL +} + +func genBatchStartAndEndStr(currentBatchStart, currentBatchEnd []sqltypes.Value) (currentBatchStartStr string, currentBatchStartEnd string, err error) { + prefix := "" + for i := range currentBatchStart { + prefix = "," + currentBatchStartStr += prefix + currentBatchStart[i].ToString() + currentBatchStartEnd += prefix + currentBatchEnd[i].ToString() + } + return currentBatchStartStr, currentBatchStartEnd, nil +} + +func genExprNodeFromStr(condition string) (sqlparser.Expr, error) { + tmpSQL := fmt.Sprintf("select 1 where %s", condition) + tmpStmt, err := sqlparser.Parse(tmpSQL) + if err != nil { + return nil, err + } + tmpStmtSelect, _ := tmpStmt.(*sqlparser.Select) + return tmpStmtSelect.Where.Expr, nil +} +func getBatchSQLGreatThanAndLessThanExprNode(stmt sqlparser.Statement) (greatThanExpr sqlparser.Expr, lessThanExpr sqlparser.Expr) { + switch s := stmt.(type) { + case *sqlparser.Update: + // the type switch will be ok + andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) + pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr) + greatThanExpr = pkConditionExpr.Left + lessThanExpr = pkConditionExpr.Right + return greatThanExpr, lessThanExpr + case *sqlparser.Delete: + // the type switch will be ok + andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) + pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr) + greatThanExpr = pkConditionExpr.Left + lessThanExpr = pkConditionExpr.Right + return greatThanExpr, lessThanExpr + } + // the code won't reach here + return nil, nil +} + +func getUserWhereExpr(stmt sqlparser.Statement) (expr sqlparser.Expr) { + switch s := stmt.(type) { + case *sqlparser.Update: + tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) + expr = tempAndExpr.Left + return expr + case *sqlparser.Delete: + tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr) + expr = tempAndExpr.Left + return expr + default: + // the code won't reach here + return nil + } +} + +func genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLStmt sqlparser.Statement, curBatchNewEnd, newBatchStart []sqltypes.Value, pkInfos []PKInfo) (curBatchSQL, newBatchSQL, newBatchCountSQL string, err error) { + // 1) 将curBatchNewEnd和newBatchStart转换成<=和>=的字符串,然后将字符串转成expr ast node + curBatchLessThanPart, err := genPKsGreaterEqualOrLessEqualStr(pkInfos, curBatchNewEnd, false) + if err != nil { + return "", "", "", err + } + curBatchLessThanExpr, err := genExprNodeFromStr(curBatchLessThanPart) + if err != nil { + return "", "", "", err + } + + newBatchGreatThanPart, err := genPKsGreaterEqualOrLessEqualStr(pkInfos, newBatchStart, true) + if err != nil { + return "", "", "", err + } + newBatchGreatThanExpr, err := genExprNodeFromStr(newBatchGreatThanPart) + if err != nil { + return "", "", "", err + } + + // 2) 通过parser,获得原先batchSQL的greatThan和lessThan的expr ast node + curBatchGreatThanExpr, newBatchLessThanExpr := getBatchSQLGreatThanAndLessThanExprNode(batchSQLStmt) + + // 3) 生成拆batchSQL和batchCountSQL + // 3.1) 先构建curBatchSQL和newBatchSQL的where expr ast node:将用户输入的where expr与上PK Condition Expr + userWhereExpr := getUserWhereExpr(batchSQLStmt) + curBatchPKConditionExpr := sqlparser.AndExpr{Left: curBatchGreatThanExpr, Right: curBatchLessThanExpr} + newBatchPKConditionExpr := sqlparser.AndExpr{Left: newBatchGreatThanExpr, Right: newBatchLessThanExpr} + curBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: userWhereExpr, Right: &curBatchPKConditionExpr}} + newBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: userWhereExpr, Right: &newBatchPKConditionExpr}} + + // 3.2) 替换原先batchSQL的where expr来生成batchSQL + curBatchSQL = genSQLByReplaceWhereExprNode(batchSQLStmt, curBatchWhereExpr) + newBatchSQL = genSQLByReplaceWhereExprNode(batchSQLStmt, newBatchWhereExpr) + + // 3.3) 同理生成batchCountSQL + newBatchCountSQL = genSQLByReplaceWhereExprNode(batchCountSQLStmt, newBatchWhereExpr) + return curBatchSQL, newBatchSQL, newBatchCountSQL, nil +} diff --git a/go/vt/vttablet/jobcontroller/sql_related_test.go b/go/vt/vttablet/jobcontroller/sql_related_test.go new file mode 100644 index 0000000000..5b80c604b7 --- /dev/null +++ b/go/vt/vttablet/jobcontroller/sql_related_test.go @@ -0,0 +1,217 @@ +/* +Copyright ApeCloud, Inc. +Licensed under the Apache v2(found in the LICENSE file in the root directory). +*/ + +package jobcontroller + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" +) + +func TestGenPKsGreaterEqualOrLessEqual(t *testing.T) { + type args struct { + pkInfos []PKInfo + currentBatchStart []sqltypes.Value + greatEqual bool + } + tests := []struct { + name string + args args + want string + }{ + { + name: "Test genPKsGreaterEqualOrLessEqualStr, Single Int", + args: args{ + pkInfos: []PKInfo{ + {pkName: "a"}, + }, + currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1)}, + greatEqual: true, + }, + want: "(a >= 1)", + }, + { + name: "Test genPKsGreaterEqualOrLessEqualStr, Two INTs", + args: args{ + pkInfos: []PKInfo{ + {pkName: "a"}, + {pkName: "b"}, + }, + currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, + greatEqual: true, + }, + want: "(a > 1) or (a = 1 and b >= 2)", + }, + { + name: "Test genPKsGreaterEqualOrLessEqualStr, One INT With One String", + args: args{ + pkInfos: []PKInfo{ + {pkName: "a"}, + {pkName: "b"}, + }, + currentBatchStart: []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewTimestamp("1704630977")}, + greatEqual: false, + }, + want: "(a < 1) or (a = 1 and b <= '1704630977')", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := genPKsGreaterEqualOrLessEqualStr(tt.args.pkInfos, tt.args.currentBatchStart, tt.args.greatEqual) + assert.Equalf(t, tt.want, got, "genPKsGreaterEqualOrLessEqualStr(%v, %v, %v)", tt.args.pkInfos, tt.args.currentBatchStart, tt.args.greatEqual) + }) + } +} + +func TestGenBatchSQL(t *testing.T) { + sql := "update t set c = 1 where 1 = 1 or 2 = 2 and 3 = 3" + stmt, _ := sqlparser.Parse(sql) + whereExpr := stmt.(*sqlparser.Update).Where + currentBatchStart := []sqltypes.Value{sqltypes.NewInt64(1)} + currentBatchEnd := []sqltypes.Value{sqltypes.NewInt64(9)} + pkInfos := []PKInfo{{pkName: "pk1"}} + batchSQL, finalWhereStr, _ := genBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + expectedBatchSQL := "update t set c = 1 where (1 = 1 or 2 = 2 and 3 = 3) and (pk1 >= 1 and pk1 <= 9)" + expectedWhereStr := "(1 = 1 or 2 = 2 and 3 = 3) and (pk1 >= 1 and pk1 <= 9)" + assert.Equalf(t, expectedBatchSQL, batchSQL, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + assert.Equalf(t, expectedWhereStr, finalWhereStr, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + + sql = "update t set c = 1 where 1 = 1 or 2 = 2 " + stmt, _ = sqlparser.Parse(sql) + whereExpr = stmt.(*sqlparser.Update).Where + currentBatchStart = []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1)} + currentBatchEnd = []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9)} + pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}} + batchSQL, finalWhereStr, _ = genBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + expectedBatchSQL = "update t set c = 1 where (1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + expectedWhereStr = "(1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + assert.Equalf(t, expectedBatchSQL, batchSQL, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + assert.Equalf(t, expectedWhereStr, finalWhereStr, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + + sql = "update t set c = 1 where 1 = 1 or 2 = 2 " + stmt, _ = sqlparser.Parse(sql) + whereExpr = stmt.(*sqlparser.Update).Where + currentBatchStart = []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1), sqltypes.NewInt64(1)} + currentBatchEnd = []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9), sqltypes.NewInt64(9)} + pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}, {pkName: "pk3"}} + batchSQL, finalWhereStr, _ = genBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + expectedBatchSQL = "update t set c = 1 where (1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 > 1 or pk1 = 1 and pk2 = 1 and pk3 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 < 9 or pk1 = 9 and pk2 = 9 and pk3 <= 9))" + expectedWhereStr = "(1 = 1 or 2 = 2) and ((pk1 > 1 or pk1 = 1 and pk2 > 1 or pk1 = 1 and pk2 = 1 and pk3 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 < 9 or pk1 = 9 and pk2 = 9 and pk3 <= 9))" + assert.Equalf(t, expectedBatchSQL, batchSQL, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + assert.Equalf(t, expectedWhereStr, finalWhereStr, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) +} + +func TestGenNewBatchSQLsAndCountSQLsWhenSplittingBatch(t *testing.T) { + pkInfos := []PKInfo{{pkName: "pk1"}} + whereStr := "where (1 = 1 and 2 = 2) and ((pk1 >= 1) and (pk1 <= 9))" + batchSQL := fmt.Sprintf("update t set c1 = '123' %s", whereStr) + batchCountSQL := fmt.Sprintf("select count(*) from t %s", whereStr) + batchSQLStmt, _ := sqlparser.Parse(batchSQL) + batchCountSQLStmt, _ := sqlparser.Parse(batchCountSQL) + curBatchNewEnd := []sqltypes.Value{sqltypes.NewInt64(5)} + newBatchStart := []sqltypes.Value{sqltypes.NewInt64(7)} + expectedCurBatchSQL := "update t set c1 = '123' where 1 = 1 and 2 = 2 and (pk1 >= 1 and pk1 <= 5)" + expectedNewBatchSQL := "update t set c1 = '123' where 1 = 1 and 2 = 2 and (pk1 >= 7 and pk1 <= 9)" + expectedNewBatchCountSQL := "select count(*) from t where 1 = 1 and 2 = 2 and (pk1 >= 7 and pk1 <= 9)" + curBatchSQL, newBatchSQL, newBatchCountSQL, _ := genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedCurBatchSQL, curBatchSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedNewBatchSQL, newBatchSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedNewBatchCountSQL, newBatchCountSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + + pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}} + whereStr = "where (1 = 1) and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + batchSQL = fmt.Sprintf("update t set c1 = '123' %s", whereStr) + batchCountSQL = fmt.Sprintf("select count(*) from t %s", whereStr) + batchSQLStmt, _ = sqlparser.Parse(batchSQL) + batchCountSQLStmt, _ = sqlparser.Parse(batchCountSQL) + curBatchNewEnd = []sqltypes.Value{sqltypes.NewInt64(5), sqltypes.NewInt64(5)} + newBatchStart = []sqltypes.Value{sqltypes.NewInt64(7), sqltypes.NewInt64(7)} + expectedCurBatchSQL = "update t set c1 = '123' where 1 = 1 and ((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 5 or pk1 = 5 and pk2 <= 5))" + expectedNewBatchSQL = "update t set c1 = '123' where 1 = 1 and ((pk1 > 7 or pk1 = 7 and pk2 >= 7) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + expectedNewBatchCountSQL = "select count(*) from t where 1 = 1 and ((pk1 > 7 or pk1 = 7 and pk2 >= 7) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + curBatchSQL, newBatchSQL, newBatchCountSQL, _ = genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedCurBatchSQL, curBatchSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedNewBatchSQL, newBatchSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + assert.Equalf(t, expectedNewBatchCountSQL, newBatchCountSQL, "genNewBatchSQLsAndCountSQLsWhenSplittingBatch(%v,%v,%v,%v,%v)", batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos) + +} + +func TestGetUserWhereExpr(t *testing.T) { + sql := "update t set c1 = '123' where 1 = 1 and 2 = 2 and 3 = 3 and (pk1 >= 1 and pk1 <= 9)" + stmt, _ := sqlparser.Parse(sql) + userWhereExpr := getUserWhereExpr(stmt) + userWhereStr := sqlparser.String(userWhereExpr) + expectedUserWhereStr := "1 = 1 and 2 = 2 and 3 = 3" + assert.Equalf(t, expectedUserWhereStr, userWhereStr, "getUserWhereExpr(%v)", stmt) +} + +func TestBatchSQL(t *testing.T) { + userWhereStr := "1 = 1 and 2 = 2" + pkConditionStr := "(pk1 >= 1 and pk1 <= 9)" + // 1. get batchSQL by calling genBatchSQL + sql := fmt.Sprintf("update t set c = 1 where %s", userWhereStr) + stmt, _ := sqlparser.Parse(sql) + whereExpr := stmt.(*sqlparser.Update).Where + currentBatchStart := []sqltypes.Value{sqltypes.NewInt64(1)} + currentBatchEnd := []sqltypes.Value{sqltypes.NewInt64(9)} + pkInfos := []PKInfo{{pkName: "pk1"}} + batchSQL, finalWhereStr, _ := genBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + // Different from the 1st test case, here we enclose userWhereStr with parentheses because it has "or" operators. + expectedBatchSQL := fmt.Sprintf("update t set c = 1 where (%s) and %s", userWhereStr, pkConditionStr) + expectedWhereStr := fmt.Sprintf("(%s) and %s", userWhereStr, pkConditionStr) + assert.Equalf(t, expectedBatchSQL, batchSQL, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + assert.Equalf(t, expectedWhereStr, finalWhereStr, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + + batchSQLStmt, _ := sqlparser.Parse(batchSQL) + // 2. test getUserWhereExpr base on batchSQ + gotUserWhereExpr := getUserWhereExpr(batchSQLStmt) + gotUserWhereStr := sqlparser.String(gotUserWhereExpr) + expectedUserWhereStr := userWhereStr + assert.Equalf(t, expectedUserWhereStr, gotUserWhereStr, "getUserWhereExpr(%v)", stmt) + + // 3. test getBatchSQLGreatThanAndLessThanExprNode base on batchSQL + gtNode, lsNode := getBatchSQLGreatThanAndLessThanExprNode(batchSQLStmt) + gtStr := sqlparser.String(gtNode) + lsStr := sqlparser.String(lsNode) + expectedGtStr := "pk1 >= 1" + expectedLsStr := "pk1 <= 9" + assert.Equalf(t, expectedGtStr, gtStr, "getBatchSQLGreatThanAndLessThanExprNode(%v)", batchSQLStmt) + assert.Equalf(t, expectedLsStr, lsStr, "getBatchSQLGreatThanAndLessThanExprNode(%v)", batchSQLStmt) + + // repeat the same steps but with different args + userWhereStr = "1 = 1 and 2 = 2 or 3 > 2 and 1 > 3 or 3 > 4 and 1 = 1" + pkConditionStr = "((pk1 > 1 or pk1 = 1 and pk2 >= 1) and (pk1 < 9 or pk1 = 9 and pk2 <= 9))" + sql = fmt.Sprintf("update t set c = 1 where %s", userWhereStr) + stmt, _ = sqlparser.Parse(sql) + whereExpr = stmt.(*sqlparser.Update).Where + currentBatchStart = []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(1)} + currentBatchEnd = []sqltypes.Value{sqltypes.NewInt64(9), sqltypes.NewInt64(9)} + pkInfos = []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}} + batchSQL, finalWhereStr, _ = genBatchSQL(sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + expectedBatchSQL = fmt.Sprintf("update t set c = 1 where %s and %s", userWhereStr, pkConditionStr) + expectedWhereStr = fmt.Sprintf("%s and %s", userWhereStr, pkConditionStr) + assert.Equalf(t, expectedBatchSQL, batchSQL, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + assert.Equalf(t, expectedWhereStr, finalWhereStr, "genBatchSQL(%v, %v, %v,%v,%v, %v)", sql, stmt, whereExpr.Expr, currentBatchStart, currentBatchEnd, pkInfos) + batchSQLStmt, _ = sqlparser.Parse(batchSQL) + gotUserWhereExpr = getUserWhereExpr(batchSQLStmt) + gotUserWhereStr = sqlparser.String(gotUserWhereExpr) + expectedUserWhereStr = userWhereStr + assert.Equalf(t, expectedUserWhereStr, gotUserWhereStr, "getUserWhereExpr(%v)", stmt) + + // 3. test getBatchSQLGreatThanAndLessThanExprNode base on batchSQL + gtNode, lsNode = getBatchSQLGreatThanAndLessThanExprNode(batchSQLStmt) + gtStr = sqlparser.String(gtNode) + lsStr = sqlparser.String(lsNode) + expectedGtStr = "pk1 > 1 or pk1 = 1 and pk2 >= 1" + expectedLsStr = "pk1 < 9 or pk1 = 9 and pk2 <= 9" + assert.Equalf(t, expectedGtStr, gtStr, "getBatchSQLGreatThanAndLessThanExprNode(%v)", batchSQLStmt) + assert.Equalf(t, expectedLsStr, lsStr, "getBatchSQLGreatThanAndLessThanExprNode(%v)", batchSQLStmt) + +}