Skip to content

Commit

Permalink
fix: refactor splitBatchIntoTwo, add related testcases
Browse files Browse the repository at this point in the history
Signed-off-by: newborn22 <[email protected]>
  • Loading branch information
newborn22 committed Jan 10, 2024
1 parent 1d0845f commit ac7255a
Show file tree
Hide file tree
Showing 5 changed files with 459 additions and 552 deletions.
270 changes: 29 additions & 241 deletions go/vt/vttablet/jobcontroller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type JobController struct {
schedulerNotifyChan chan struct{} // jobScheduler每隔一段时间运行一次调度。但当它收到这个chan的消息后,会立刻开始一次调度
}

// todo newborn22 删除pktype?
type PKInfo struct {
pkName string
pkType querypb.Type
Expand Down Expand Up @@ -792,6 +793,15 @@ func (jc *JobController) execBatchAndRecord(ctx context.Context, tableSchema, ta
// 拆分的基本原理是遍历原先batch的batchCountSQL的结果集,将第batchSize条record的pk作为原先batch的PKEnd,第batchSize+1条record的pk作为新batch的PKStart
// 原先batch的PKStart和PKEnd分别成为原先batch的PKStart和新batch的PKEnd
func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, table, batchTable, batchSQL, batchCountSQL, batchID string, conn *connpool.DBConn, batchSize, expectedRow int64) (newCurrentBatchSQL string, err error) {
batchSQLStmt, err := sqlparser.Parse(batchSQL)
if err != nil {
return "", err
}
batchCountSQLStmt, err := sqlparser.Parse(batchCountSQL)
if err != nil {
return "", err
}

// 1.根据batchCountSQL生成查询pk值的select sql
// 1.1.获得PK信息
pkInfos, err := jc.getTablePkInfo(ctx, tableSchema, table)
Expand All @@ -801,11 +811,8 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab

// 1.2.根据当前batch的batchCountSQL生成select sql,用于获得拆分后batch的拆分列start和end
// 只需要将batchCountSQL的投影部分(SelectExprs)从count(*)改为拆分列即可
batchCountSQLStmt, err := sqlparser.Parse(batchCountSQL)
if err != nil {
return "", err
}
batchCountSQLStmtSelect, _ := batchCountSQLStmt.(*sqlparser.Select)
// 根据pk信息生成select exprs
var pkExprs []sqlparser.SelectExpr
for _, pkInfo := range pkInfos {
pkExprs = append(pkExprs, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName(pkInfo.pkName)})
Expand All @@ -821,85 +828,33 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab
// 拆成多个batch需要遍历完select的全部结果,这可能会导致超时

// 2.1.计算两个batch的batchPKStart和batchPKEnd。实际上,只要获得当前batch的新的PKEnd和新的batch的PKStart

// 遍历前threshold+1条,依然使用同一个连接
qr, err := conn.Exec(ctx, batchSplitSelectSQL, math.MaxInt32, true)
if err != nil {
return "", err
}

// todo literal
var curBatchNewEnd []any
var newBatchStart []any
var curBatchNewEnd []sqltypes.Value
var newBatchStart []sqltypes.Value

for rowCount, row := range qr.Named().Rows {
for rowCount, row := range qr.Rows {
// 将原本batch的PKEnd设在threshold条数处
if int64(rowCount) == batchSize-1 {
for _, pkInfo := range pkInfos {
pkName := pkInfo.pkName
keyVal, err := ProcessValue(row[pkName])
if err != nil {
return "", err
}
curBatchNewEnd = append(curBatchNewEnd, keyVal)
}
curBatchNewEnd = row
}
// 将第threshold+1条的PK作为新PK的起点
if int64(rowCount) == batchSize {
for _, pkInfo := range pkInfos {
pkName := pkInfo.pkName
keyVal, err := ProcessValue(row[pkName])
if err != nil {
return "", err
}
newBatchStart = append(newBatchStart, keyVal)
}
newBatchStart = row
break
}
}

// 2.2) 将curBatchNewEnd和newBatchStart转换成sql中where部分的<=和>=的字符串
// todo 直接从ast开始构建
curBatchLessThanPart, err := genPKsLessThanPart(pkInfos, curBatchNewEnd)
if err != nil {
return "", err
}
curBatchLessThanExpr, err := genExprNodeFromStr(curBatchLessThanPart)
if err != nil {
return "", err
}

newBatchGreatThanPart, err := genPKsGreaterThanPart(pkInfos, newBatchStart)
if err != nil {
return "", err
}
newBatchGreatThanExpr, err := genExprNodeFromStr(newBatchGreatThanPart)
// 2.2.生成新的batchSQL和新的batchCountSQL
curBatchSQL, newBatchSQL, newBatchCountSQL, err := genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLStmt, curBatchNewEnd, newBatchStart, pkInfos)
if err != nil {
return "", err
}

// 2.3) 通过parser,获得原先batchSQL的greatThan和lessThan的expr ast node
batchSQLStmt, err := sqlparser.Parse(batchSQL)
if err != nil {
return "", err
}
curBatchGreatThanExpr, newBatchLessThanExpr := jc.getBatchSQLGreatThanAndLessThanExprNode(batchSQLStmt)

// 2.4) 生成拆分后,当前batch的sql和新batch的sql
// 2.4.1) 先获得curBatchSQL和newBatchSQL的where expr ast node,需要将用户输入的where expr与上PK Condition Expr
oldBatchSQLUserWhereExpr := getUserWhereExpr(batchSQLStmt)
curBatchPKConditionExpr := sqlparser.AndExpr{Left: curBatchGreatThanExpr, Right: curBatchLessThanExpr}
newBatchPKConditionExpr := sqlparser.AndExpr{Left: newBatchGreatThanExpr, Right: newBatchLessThanExpr}
curBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: oldBatchSQLUserWhereExpr, Right: &curBatchPKConditionExpr}}
newBatchWhereExpr := sqlparser.Where{Expr: &sqlparser.AndExpr{Left: oldBatchSQLUserWhereExpr, Right: &newBatchPKConditionExpr}}

// 2.4.2) 替换原先batchSQL和batchCountSQL的where expr来生成新的sql
curBatchSQL := GenSQLByReplaceWhereExprNode(batchSQLStmt, curBatchWhereExpr)
newBatchSQL := GenSQLByReplaceWhereExprNode(batchSQLStmt, newBatchWhereExpr)

// 2.4.3) 生成新batch的batchCountSQL,原理同上
newBatchCountSQL := GenSQLByReplaceWhereExprNode(batchCountSQLStmt, newBatchWhereExpr)

// 构建当前batch新的batch begin及end字段以及新batch的begin及end字段
// 2.3.计算两个batch的batch start和end字段
getBatchBeginAndEndSQL := fmt.Sprintf(sqlTemplateGetBatchBeginAndEnd, batchTable)
getBatchBeginAndEndQuery, err := sqlparser.ParseAndBind(getBatchBeginAndEndSQL, sqltypes.StringBindVariable(batchID))
if err != nil {
Expand All @@ -914,14 +869,13 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab
}
currentBatchNewBeginStr := qr.Named().Rows[0]["batch_begin"].ToString()
newBatchEndStr := qr.Named().Rows[0]["batch_end"].ToString()
// todo newborn22 next
currentBatchNewEndStr, newBatchBegintStr, err := genBatchStartAndEndStr(nil, nil)
currentBatchNewEndStr, newBatchBeginStr, err := genBatchStartAndEndStr(curBatchNewEnd, newBatchStart)
if err != nil {
return "", err
}

// 2.5) 在batch表中更改旧的条目的sql,并插入新batch条目
// 在表中更改旧的sql
// 3 将结果记录在表中:在batch表中更改旧的条目的sql,并插入新batch条目
// 3.1.在表中更改旧的sql
updateBatchSQL := fmt.Sprintf(sqlTemplateUpdateBatchSQL, batchTable)
updateBatchSQLQuery, err := sqlparser.ParseAndBind(updateBatchSQL,
sqltypes.StringBindVariable(curBatchSQL),
Expand All @@ -935,7 +889,7 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab
if err != nil {
return "", err
}
// 插入新batch条目
// 3.2.插入新batch条目
newCurrentBatchSQL = curBatchSQL
// todo 1-1 -> 1-2开始
nextBatchID, err := genNewBatchID(batchID)
Expand All @@ -949,7 +903,7 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab
sqltypes.StringBindVariable(newBatchSQL),
sqltypes.StringBindVariable(newBatchCountSQL),
sqltypes.Int64BindVariable(newBatchSize),
sqltypes.StringBindVariable(newBatchBegintStr),
sqltypes.StringBindVariable(newBatchBeginStr),
sqltypes.StringBindVariable(newBatchEndStr))
if err != nil {
return "", err
Expand Down Expand Up @@ -1538,11 +1492,11 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql,
currentBatchSize++

if currentBatchSize == batchSize {
batchSQL, finalWhereStr, err := GenBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos)
batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos)
if err != nil {
return "", err
}
countSQL := GenCountSQL(tableSchema, tableName, finalWhereStr)
countSQL := genCountSQL(tableSchema, tableName, finalWhereStr)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -1573,11 +1527,11 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql,
}
// 最后一个batch的行数不一定是batchSize,在循环结束时要将剩余的行数划分到最后一个batch中
if currentBatchSize != 0 {
batchSQL, finalWhereStr, err := GenBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos)
batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos)
if err != nil {
return "", err
}
countSQL := GenCountSQL(tableSchema, tableName, finalWhereStr)
countSQL := genCountSQL(tableSchema, tableName, finalWhereStr)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -1615,125 +1569,6 @@ func currentBatchIDInc(currentBatchID string) (string, error) {
return strconv.FormatInt(currentBatchIDInt64, 10), nil
}

func genBatchStartAndEndStr(currentBatchStart, currentBatchEnd []sqltypes.Value) (currentBatchStartStr string, currentBatchStartEnd string, err error) {
prefix := ""
for i := range currentBatchStart {
prefix = ","
currentBatchStartStr += prefix + currentBatchStart[i].ToString()
currentBatchStartEnd += prefix + currentBatchEnd[i].ToString()
}
return currentBatchStartStr, currentBatchStartEnd, nil
}

func genPlaceholderByType(typ querypb.Type) (string, error) {
switch typ {
case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_INT64:
return "%d", nil
case querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32, querypb.Type_UINT64:
return "%d", nil
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
return "%f", nil
// todo string
case querypb.Type_TIMESTAMP, querypb.Type_DATE, querypb.Type_TIME, querypb.Type_DATETIME, querypb.Type_YEAR,
querypb.Type_TEXT, querypb.Type_VARCHAR, querypb.Type_CHAR:
return "'%s'", nil
default:
return "", fmt.Errorf("Unsupported type: %v", typ)
}
}

func ProcessValue(value sqltypes.Value) (any, error) {
typ := value.Type()

switch typ {
case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_INT64:
return value.ToInt64()
case querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32, querypb.Type_UINT64:
return value.ToUint64()
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
return value.ToFloat64()
case querypb.Type_TIMESTAMP, querypb.Type_DATE, querypb.Type_TIME, querypb.Type_DATETIME, querypb.Type_YEAR,
querypb.Type_TEXT, querypb.Type_VARCHAR, querypb.Type_CHAR:
return value.ToString(), nil
default:
return nil, fmt.Errorf("Unsupported type: %v", typ)
}
}

func genPKsGreaterThanPart(pkInfos []PKInfo, currentBatchStart []any) (string, error) {
curIdx := 0
pksNum := len(pkInfos)
var equalStr, rst string
for curIdx < pksNum {
curPkName := pkInfos[curIdx].pkName
curPKType := pkInfos[curIdx].pkType
// mysql的浮点类型在比较时有精度损失,不适合作为拆分列
if curPKType == querypb.Type_FLOAT32 || curPKType == querypb.Type_FLOAT64 {
return "", fmt.Errorf("unsupported type: %v", curPKType)
}

placeholder, err := genPlaceholderByType(curPKType)
if err != nil {
return "", err
}

if curIdx == 0 {
rst = fmt.Sprintf("( %s > %s )", curPkName, placeholder)
} else if curIdx != (pksNum - 1) {
rst += fmt.Sprintf(" OR ( %s AND %s > %s )", equalStr, curPkName, placeholder)
} else if curIdx == (pksNum - 1) {
rst += fmt.Sprintf(" OR ( %s AND %s >= %s )", equalStr, curPkName, placeholder)
}
rst = fmt.Sprintf(rst, currentBatchStart[curIdx])

if curIdx == 0 {
equalStr = fmt.Sprintf("%s = %s", curPkName, placeholder)
} else {
equalStr += fmt.Sprintf(" AND %s = %s", curPkName, placeholder)
}
equalStr = fmt.Sprintf(equalStr, currentBatchStart[curIdx])
curIdx++
}
return rst, nil
}

func genPKsLessThanPart(pkInfos []PKInfo, currentBatchEnd []any) (string, error) {
curIdx := 0
pksNum := len(pkInfos)
var equalStr, rst string
for curIdx < pksNum {
curPkName := pkInfos[curIdx].pkName
curPKType := pkInfos[curIdx].pkType
// mysql的浮点类型在比较时有精度损失,不适合作为拆分列
if curPKType == querypb.Type_FLOAT32 || curPKType == querypb.Type_FLOAT64 {
return "", fmt.Errorf("unsupported type: %v", curPKType)
}

placeholder, err := genPlaceholderByType(curPKType)
if err != nil {
return "", err
}

if curIdx == 0 {
rst = fmt.Sprintf("( %s < %s )", curPkName, placeholder)
} else if curIdx != (pksNum - 1) {
rst += fmt.Sprintf(" OR ( %s AND %s < %s )", equalStr, curPkName, placeholder)
} else if curIdx == (pksNum - 1) {
rst += fmt.Sprintf(" OR ( %s AND %s <= %s )", equalStr, curPkName, placeholder)
}
rst = fmt.Sprintf(rst, currentBatchEnd[curIdx])

if curIdx == 0 {
equalStr = fmt.Sprintf("%s = %s", curPkName, placeholder)
} else {
equalStr += fmt.Sprintf(" AND %s = %s", curPkName, placeholder)
}
equalStr = fmt.Sprintf(equalStr, currentBatchEnd[curIdx])
curIdx++
}
return rst, nil
}

// 通知jobScheduler让它立刻开始一次调度。
func (jc *JobController) notifyJobScheduler() {
if jc.schedulerNotifyChan == nil {
Expand Down Expand Up @@ -1797,50 +1632,3 @@ func (jc *JobController) updateBatchStatus(batchTableSchema, batchTableName, sta
_, err = jc.execQuery(context.Background(), batchTableSchema, query)
return err
}

func getUserWhereExpr(stmt sqlparser.Statement) (expr sqlparser.Expr) {
switch s := stmt.(type) {
case *sqlparser.Update:
tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr)
expr = tempAndExpr.Left
return expr
case *sqlparser.Delete:
tempAndExpr, _ := s.Where.Expr.(*sqlparser.AndExpr)
expr = tempAndExpr.Left
return expr
default:
// the code won't reach here
return nil
}
}

func (jc *JobController) getBatchSQLGreatThanAndLessThanExprNode(stmt sqlparser.Statement) (greatThanExpr sqlparser.Expr, lessThanExpr sqlparser.Expr) {
switch s := stmt.(type) {
case *sqlparser.Update:
// the type switch will be ok
andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr)
pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr)
greatThanExpr = pkConditionExpr.Left
lessThanExpr = pkConditionExpr.Right
return greatThanExpr, lessThanExpr
case *sqlparser.Delete:
// the type switch will be ok
andExpr, _ := s.Where.Expr.(*sqlparser.AndExpr)
pkConditionExpr, _ := andExpr.Right.(*sqlparser.AndExpr)
greatThanExpr = pkConditionExpr.Left
lessThanExpr = pkConditionExpr.Right
return greatThanExpr, lessThanExpr
}
// the code won't reach here
return nil, nil
}

func genExprNodeFromStr(condition string) (sqlparser.Expr, error) {
tmpSQL := fmt.Sprintf("select 1 where %s", condition)
tmpStmt, err := sqlparser.Parse(tmpSQL)
if err != nil {
return nil, err
}
tmpStmtSelect, _ := tmpStmt.(*sqlparser.Select)
return tmpStmtSelect.Where.Expr, nil
}
Loading

0 comments on commit ac7255a

Please sign in to comment.