From 83e7beba6e5fa5f5a6538ef1c93fb90d858bc88b Mon Sep 17 00:00:00 2001 From: newborn22 <953950914@qq.com> Date: Wed, 10 Jan 2024 13:18:41 +0800 Subject: [PATCH] fix: split large functions into smaller ones Signed-off-by: newborn22 <953950914@qq.com> --- .../jobcontroller/batch_sql_related.go | 79 +++ .../jobcontroller/batch_sql_related_test.go | 11 + go/vt/vttablet/jobcontroller/controller.go | 487 +++++------------- go/vt/vttablet/jobcontroller/sqls.go | 6 +- go/vt/vttablet/jobcontroller/util.go | 153 ++++++ go/vt/vttablet/jobcontroller/util_test.go | 58 +++ 6 files changed, 439 insertions(+), 355 deletions(-) create mode 100644 go/vt/vttablet/jobcontroller/util_test.go diff --git a/go/vt/vttablet/jobcontroller/batch_sql_related.go b/go/vt/vttablet/jobcontroller/batch_sql_related.go index c39847a254..75fa740f46 100644 --- a/go/vt/vttablet/jobcontroller/batch_sql_related.go +++ b/go/vt/vttablet/jobcontroller/batch_sql_related.go @@ -6,8 +6,12 @@ Licensed under the Apache v2(found in the LICENSE file in the root directory). package jobcontroller import ( + "context" "errors" "fmt" + "math" + + "vitess.io/vitess/go/vt/vttablet/tabletserver/connpool" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -211,3 +215,78 @@ func genNewBatchSQLsAndCountSQLsWhenSplittingBatch(batchSQLStmt, batchCountSQLSt newBatchCountSQL = genSQLByReplaceWhereExprNode(batchCountSQLStmt, newBatchWhereExpr) return curBatchSQL, newBatchSQL, newBatchCountSQL, nil } + +// replace selectExprs in batchCountSQLStmt with PK cols to generate selectPKsSQL +// the function will not change the original batchCountSQLStmt +func genSelectPKsSQL(batchCountSQLStmt sqlparser.Statement, pkInfos []PKInfo) string { + batchCountSQLStmtSelect, _ := batchCountSQLStmt.(*sqlparser.Select) + // 根据pk信息生成select exprs + var pkExprs []sqlparser.SelectExpr + for _, pkInfo := range pkInfos { + pkExprs = append(pkExprs, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName(pkInfo.pkName)}) + } + oldBatchCountSQLStmtSelectExprs := batchCountSQLStmtSelect.SelectExprs + batchCountSQLStmtSelect.SelectExprs = pkExprs + batchSplitSelectSQL := sqlparser.String(batchCountSQLStmtSelect) + // undo the change of select exprs + batchCountSQLStmtSelect.SelectExprs = oldBatchCountSQLStmtSelectExprs + return batchSplitSelectSQL +} + +// get the begin and end fields of the batches newly created during splitting +func getNewBatchesBeginAndEndStr(ctx context.Context, conn *connpool.DBConn, batchTable, batchID string, curBatchNewEnd, newBatchStart []sqltypes.Value) (currentBatchNewBeginStr, currentBatchNewEndStr, newBatchBeginStr, newBatchEndStr string, err error) { + getBatchBeginAndEndSQL := fmt.Sprintf(sqlTemplateGetBatchBeginAndEnd, batchTable) + getBatchBeginAndEndQuery, err := sqlparser.ParseAndBind(getBatchBeginAndEndSQL, sqltypes.StringBindVariable(batchID)) + if err != nil { + return "", "", "", "", err + } + qr, err := conn.Exec(ctx, getBatchBeginAndEndQuery, math.MaxInt32, true) + if err != nil { + return "", "", "", "", err + } + if len(qr.Named().Rows) != 1 { + return "", "", "", "", errors.New("can not get batch begin and end") + } + currentBatchNewBeginStr = qr.Named().Rows[0]["batch_begin"].ToString() + newBatchEndStr = qr.Named().Rows[0]["batch_end"].ToString() + currentBatchNewEndStr, newBatchBeginStr, err = genBatchStartAndEndStr(curBatchNewEnd, newBatchStart) + if err != nil { + return "", "", "", "", err + } + return currentBatchNewBeginStr, currentBatchNewEndStr, newBatchBeginStr, newBatchEndStr, nil +} +func updateBatchInfoTableEntry(ctx context.Context, conn *connpool.DBConn, batchTable string, curBatchSQL, currentBatchNewBeginStr, currentBatchNewEndStr, batchID string) (err error) { + sqlUpdateBatchInfoTableEntry := fmt.Sprintf(sqlTemplateUpdateBatchSQL, batchTable) + queryUpdateBatchInfoTableEntry, err := sqlparser.ParseAndBind(sqlUpdateBatchInfoTableEntry, + sqltypes.StringBindVariable(curBatchSQL), + sqltypes.StringBindVariable(currentBatchNewBeginStr), + sqltypes.StringBindVariable(currentBatchNewEndStr), + sqltypes.StringBindVariable(batchID)) + if err != nil { + return err + } + _, err = conn.Exec(ctx, queryUpdateBatchInfoTableEntry, math.MaxInt32, false) + if err != nil { + return err + } + return nil +} + +func insertBatchInfoTableEntry(ctx context.Context, conn *connpool.DBConn, batchTable, nextBatchID, newBatchSQL, newBatchCountSQL, newBatchBeginStr, newBatchEndStr string, newBatchSize int64) (err error) { + sqlInsertBatchInfoTableEntry := fmt.Sprintf(sqlTemplateInsertBatchEntry, batchTable) + queryInsertBatchInfoTableEntry, err := sqlparser.ParseAndBind(sqlInsertBatchInfoTableEntry, + sqltypes.StringBindVariable(nextBatchID), + sqltypes.StringBindVariable(newBatchSQL), + sqltypes.StringBindVariable(newBatchCountSQL), + sqltypes.Int64BindVariable(newBatchSize), + sqltypes.StringBindVariable(newBatchBeginStr), + sqltypes.StringBindVariable(newBatchEndStr)) + if err != nil { + return err + } + _, err = conn.Exec(ctx, queryInsertBatchInfoTableEntry, math.MaxInt32, false) + if err != nil { + return err + } + return nil +} diff --git a/go/vt/vttablet/jobcontroller/batch_sql_related_test.go b/go/vt/vttablet/jobcontroller/batch_sql_related_test.go index 5f8065b99b..1a1d8d0f2f 100644 --- a/go/vt/vttablet/jobcontroller/batch_sql_related_test.go +++ b/go/vt/vttablet/jobcontroller/batch_sql_related_test.go @@ -250,3 +250,14 @@ func TestStripComments(t *testing.T) { }) } } + +func TestGenSelectPKsSQL(t *testing.T) { + batchCountSQL := "select count(*) from t where 1 = 1" + batchCountSQLStmt, _ := sqlparser.Parse(batchCountSQL) + pkInfos := []PKInfo{{pkName: "pk1"}, {pkName: "pk2"}} + gotSQL := genSelectPKsSQL(batchCountSQLStmt, pkInfos) + expectedSQL := "select pk1, pk2 from t where 1 = 1" + assert.Equalf(t, expectedSQL, gotSQL, "genSelectPKsSQL(%v,%v)", batchCountSQLStmt, pkInfos) + // batchCountSQLStmt should not be changed + assert.Equalf(t, batchCountSQL, sqlparser.String(batchCountSQLStmt), "genSelectPKsSQL(%v,%v)", batchCountSQLStmt, pkInfos) +} diff --git a/go/vt/vttablet/jobcontroller/controller.go b/go/vt/vttablet/jobcontroller/controller.go index fb0ee2e997..8db7fa7b38 100644 --- a/go/vt/vttablet/jobcontroller/controller.go +++ b/go/vt/vttablet/jobcontroller/controller.go @@ -111,6 +111,16 @@ type PKInfo struct { pkType querypb.Type } +type JobRunnerArgs struct { + uuid, table, tableSchema, batchInfoTable, failPolicy string + batchInterval, batchSize int64 + timePeriodStart, timePeriodEnd *time.Time +} + +type JobHealthCheckArgs struct { + uuid, tableSchema, batchInfoTable, statusSetTime string +} + func (jc *JobController) Open() error { jc.initMutex.Lock() defer jc.initMutex.Unlock() @@ -178,8 +188,6 @@ func (jc *JobController) SubmitJob(sql, tableSchema, runningTimePeriodStart, run jc.tableMutex.Lock() defer jc.tableMutex.Unlock() - ctx := context.Background() - jobUUID, err := schema.CreateUUIDWithDelimiter("-") if err != nil { return nil, err @@ -191,22 +199,15 @@ func (jc *JobController) SubmitJob(sql, tableSchema, runningTimePeriodStart, run if userBatchSize == 0 { userBatchSize = int64(defaultBatchSize) } - // 取用户输入的batchSize和程序的threshold的最小值作为每个batch最终的batchSize - var batchSize int64 - if userBatchSize < batchSizeThreshold { - batchSize = userBatchSize - } else { - batchSize = batchSizeThreshold - } // 创建batchInfo表 - tableName, batchInfoTable, batchSize, err := jc.createJobBatches(jobUUID, sql, tableSchema, batchSize) - batchInfoTableSchema := tableSchema + tableName, batchInfoTable, batchSize, err := jc.createJobBatches(jobUUID, sql, tableSchema, userBatchSize) if err != nil { return &sqltypes.Result{}, err } if batchInfoTable == "" { return &sqltypes.Result{}, errors.New("this DML sql won't affect any rows") } + batchInfoTableSchema := tableSchema jobStatus := queuedStatus if postponeLaunch { @@ -236,32 +237,15 @@ func (jc *JobController) SubmitJob(sql, tableSchema, runningTimePeriodStart, run } } - submitQuery, err := sqlparser.ParseAndBind(sqlDMLJobSubmit, - sqltypes.StringBindVariable(jobUUID), - sqltypes.StringBindVariable(sql), - sqltypes.StringBindVariable(tableSchema), - sqltypes.StringBindVariable(tableName), - sqltypes.StringBindVariable(batchInfoTableSchema), - sqltypes.StringBindVariable(batchInfoTable), - sqltypes.Int64BindVariable(timeGapInMs), - sqltypes.Int64BindVariable(batchSize), - sqltypes.StringBindVariable(jobStatus), - sqltypes.StringBindVariable(statusSetTime), - sqltypes.StringBindVariable(failPolicy), - sqltypes.StringBindVariable(runningTimePeriodStart), - sqltypes.StringBindVariable(runningTimePeriodEnd)) - if err != nil { - return nil, err - } - - _, err = jc.execQuery(ctx, "", submitQuery) + err = jc.insertJobEntry(jobUUID, sql, tableSchema, tableName, batchInfoTableSchema, batchInfoTable, + jobStatus, statusSetTime, failPolicy, runningTimePeriodStart, runningTimePeriodEnd, timeGapInMs, batchSize) if err != nil { return &sqltypes.Result{}, err } jc.notifyJobScheduler() - return jc.buildJobSubmitResult(jobUUID, batchInfoTable, timeGapInMs, userBatchSize, postponeLaunch, failPolicy), nil + return jc.buildJobSubmitResult(jobUUID, batchInfoTable, timeGapInMs, batchSize, postponeLaunch, failPolicy), nil } // 和cancel的区别:1.pasue不会删除元数据 2.cancel状态的job在经过一段时间后会被后台协程回收 @@ -315,18 +299,12 @@ func (jc *JobController) ResumeJob(uuid string) (*sqltypes.Result, error) { return emptyResult, errors.New("the len of qr of querying job info by uuid is not 1") } row := rst.Named().Rows[0] - tableSchema := row["table_schema"].ToString() - table := row["table_name"].ToString() - jobBatchTable := row["batch_info_table_name"].ToString() - batchInterval, _ := row["batch_interval_in_ms"].ToInt64() - batchSize, _ := row["batch_size"].ToInt64() - runningTimePeriodStart := row["running_time_period_start"].ToString() - runningTimePeriodEnd := row["running_time_period_end"].ToString() - periodStartTimePtr, periodEndTimePtr := getRunningPeriodTime(runningTimePeriodStart, runningTimePeriodEnd) - failPolicy := row["fail_policy"].ToString() + + runnerArgs := JobRunnerArgs{} + runnerArgs.initArgsByQueryResult(row) // 拉起runner协程,协程内会将状态改为running - go jc.dmlJobBatchRunner(uuid, table, tableSchema, jobBatchTable, failPolicy, batchInterval, batchSize, periodStartTimePtr, periodEndTimePtr) + go jc.dmlJobBatchRunner(runnerArgs.uuid, runnerArgs.table, runnerArgs.tableSchema, runnerArgs.batchInfoTable, runnerArgs.failPolicy, runnerArgs.batchInterval, runnerArgs.batchSize, runnerArgs.timePeriodStart, runnerArgs.timePeriodEnd) emptyResult.RowsAffected = 1 return emptyResult, nil } @@ -423,21 +401,13 @@ func (jc *JobController) jobScheduler(checkBeforeSchedule chan struct{}) { if qr != nil { for _, row := range qr.Named().Rows { status := row["status"].ToString() - schema := row["table_schema"].ToString() - table := row["table_name"].ToString() - uuid := row["job_uuid"].ToString() - jobBatchTable := row["batch_info_table_name"].ToString() - batchInterval, _ := row["batch_interval_in_ms"].ToInt64() - batchSize, _ := row["batch_size"].ToInt64() - runningTimePeriodStart := row["running_time_period_start"].ToString() - runningTimePeriodEnd := row["running_time_period_end"].ToString() - periodStartTimePtr, periodEndTimePtr := getRunningPeriodTime(runningTimePeriodStart, runningTimePeriodEnd) - failPolicy := row["fail_policy"].ToString() - - if jc.checkDmlJobRunnable(uuid, status, table, periodStartTimePtr, periodEndTimePtr) { + runnerArgs := JobRunnerArgs{} + runnerArgs.initArgsByQueryResult(row) + + if jc.checkDmlJobRunnable(runnerArgs.uuid, status, runnerArgs.table, runnerArgs.timePeriodStart, runnerArgs.timePeriodEnd) { // 初始化Job在内存中的元数据,防止在dmlJobBatchRunner修改表中的状态前,scheduler多次启动同一个job - jc.initDMLJobRunningMeta(uuid, table) - go jc.dmlJobBatchRunner(uuid, table, schema, jobBatchTable, failPolicy, batchInterval, batchSize, periodStartTimePtr, periodEndTimePtr) + jc.initDMLJobRunningMeta(runnerArgs.uuid, runnerArgs.table) + go jc.dmlJobBatchRunner(runnerArgs.uuid, runnerArgs.table, runnerArgs.tableSchema, runnerArgs.batchInfoTable, runnerArgs.failPolicy, runnerArgs.batchInterval, runnerArgs.batchSize, runnerArgs.timePeriodStart, runnerArgs.timePeriodEnd) } } } @@ -616,41 +586,26 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab return "", err } - // 1.根据batchCountSQL生成查询pk值的select sql - // 1.1.获得PK信息 + // 1.根据batchCountSQL生成select PKs SQL pkInfos, err := jc.getTablePkInfo(ctx, tableSchema, table) if err != nil { return "", err } - - // 1.2.根据当前batch的batchCountSQL生成select sql,用于获得拆分后batch的拆分列start和end - // 只需要将batchCountSQL的投影部分(SelectExprs)从count(*)改为拆分列即可 - batchCountSQLStmtSelect, _ := batchCountSQLStmt.(*sqlparser.Select) - // 根据pk信息生成select exprs - var pkExprs []sqlparser.SelectExpr - for _, pkInfo := range pkInfos { - pkExprs = append(pkExprs, &sqlparser.AliasedExpr{Expr: sqlparser.NewColName(pkInfo.pkName)}) - } - oldBatchCountSQLStmtSelectExprs := batchCountSQLStmtSelect.SelectExprs - batchCountSQLStmtSelect.SelectExprs = pkExprs - batchSplitSelectSQL := sqlparser.String(batchCountSQLStmtSelect) - // batchCountSQLStmt在后续生成newBatchCountSQL时还需用到,因此这里将其恢复原样 - batchCountSQLStmtSelect.SelectExprs = oldBatchCountSQLStmtSelectExprs + selectPKsSQL := genSelectPKsSQL(batchCountSQLStmt, pkInfos) // 2.根据select sql将batch拆分,生成两个新的batch。 - //这里每次只将超过threshold的batch拆成两个batch而不是多个小于等于threshold的batch的原因是: + // 这里每次只将超过threshold的batch拆成两个batch而不是多个小于等于threshold的batch的原因是: // 拆成多个batch需要遍历完select的全部结果,这可能会导致超时 // 2.1.计算两个batch的batchPKStart和batchPKEnd。实际上,只要获得当前batch的新的PKEnd和新的batch的PKStart // 遍历前threshold+1条,依然使用同一个连接 - qr, err := conn.Exec(ctx, batchSplitSelectSQL, math.MaxInt32, true) - if err != nil { - return "", err - } - var curBatchNewEnd []sqltypes.Value var newBatchStart []sqltypes.Value + qr, err := conn.Exec(ctx, selectPKsSQL, math.MaxInt32, true) + if err != nil { + return "", err + } for rowCount, row := range qr.Rows { // 将原本batch的PKEnd设在threshold条数处 if int64(rowCount) == batchSize-1 { @@ -669,62 +624,29 @@ func (jc *JobController) splitBatchIntoTwo(ctx context.Context, tableSchema, tab } // 2.3.计算两个batch的batch start和end字段 - getBatchBeginAndEndSQL := fmt.Sprintf(sqlTemplateGetBatchBeginAndEnd, batchTable) - getBatchBeginAndEndQuery, err := sqlparser.ParseAndBind(getBatchBeginAndEndSQL, sqltypes.StringBindVariable(batchID)) - if err != nil { - return "", err - } - qr, err = conn.Exec(ctx, getBatchBeginAndEndQuery, math.MaxInt32, true) - if err != nil { - return "", err - } - if len(qr.Named().Rows) != 1 { - return "", errors.New("can not get batch begin and end") - } - currentBatchNewBeginStr := qr.Named().Rows[0]["batch_begin"].ToString() - newBatchEndStr := qr.Named().Rows[0]["batch_end"].ToString() - currentBatchNewEndStr, newBatchBeginStr, err := genBatchStartAndEndStr(curBatchNewEnd, newBatchStart) + currentBatchNewBeginStr, currentBatchNewEndStr, newBatchBeginStr, newBatchEndStr, err := getNewBatchesBeginAndEndStr(ctx, conn, batchTable, batchID, curBatchNewEnd, newBatchStart) if err != nil { return "", err } - // 3 将结果记录在表中:在batch表中更改旧的条目的sql,并插入新batch条目 - // 3.1.在表中更改旧的sql - updateBatchSQL := fmt.Sprintf(sqlTemplateUpdateBatchSQL, batchTable) - updateBatchSQLQuery, err := sqlparser.ParseAndBind(updateBatchSQL, - sqltypes.StringBindVariable(curBatchSQL), - sqltypes.StringBindVariable(currentBatchNewBeginStr), - sqltypes.StringBindVariable(currentBatchNewEndStr), - sqltypes.StringBindVariable(batchID)) - if err != nil { - return "", err - } - _, err = conn.Exec(ctx, updateBatchSQLQuery, math.MaxInt32, false) + // 3 将结果记录在表中:在batch表中更改原本batch的条目的sql,并插入新batch条目 + // 更改原本batch条目 + err = updateBatchInfoTableEntry(ctx, conn, batchTable, curBatchSQL, currentBatchNewBeginStr, currentBatchNewEndStr, batchID) if err != nil { return "", err } - // 3.2.插入新batch条目 - newCurrentBatchSQL = curBatchSQL + // 插入新batch条目 nextBatchID, err := genNewBatchID(batchID) if err != nil { return "", err } newBatchSize := expectedRow - batchSize - insertBatchSQL := fmt.Sprintf(sqlTemplateInsertBatchEntry, batchTable) - insertBatchSQLQuery, err := sqlparser.ParseAndBind(insertBatchSQL, - sqltypes.StringBindVariable(nextBatchID), - sqltypes.StringBindVariable(newBatchSQL), - sqltypes.StringBindVariable(newBatchCountSQL), - sqltypes.Int64BindVariable(newBatchSize), - sqltypes.StringBindVariable(newBatchBeginStr), - sqltypes.StringBindVariable(newBatchEndStr)) - if err != nil { - return "", err - } - _, err = conn.Exec(ctx, insertBatchSQLQuery, math.MaxInt32, false) + err = insertBatchInfoTableEntry(ctx, conn, batchTable, nextBatchID, newBatchSQL, newBatchCountSQL, newBatchBeginStr, newBatchEndStr, newBatchSize) if err != nil { return "", err } + + newCurrentBatchSQL = curBatchSQL return newCurrentBatchSQL, nil } @@ -760,7 +682,7 @@ func (jc *JobController) dmlJobBatchRunner(uuid, table, tableSchema, batchTable, } // 检查是否在运维窗口内 - // todo,增加时区支持,以及是否可能由于脑裂问题导致错误fail掉job? + // todo feat 增加时区支持,以及是否可能由于脑裂问题导致错误fail掉job? if timePeriodStart != nil && timePeriodEnd != nil { currentTime := time.Now() if !(currentTime.After(*timePeriodStart) && currentTime.Before(*timePeriodEnd)) { @@ -802,12 +724,12 @@ func (jc *JobController) dmlJobBatchRunner(uuid, table, tableSchema, batchTable, err = jc.execBatchAndRecord(ctx, tableSchema, table, batchSQL, batchCountSQL, uuid, batchTable, batchIDToExec, batchSize) // 如果执行batch时失败,则根据failPolicy决定处理策略 if err != nil { + // todo feat 支持batch并行时需要重新考虑逻辑 switch failPolicy { case failPolicyAbort: jc.FailJob(ctx, uuid, err.Error(), table) return case failPolicySkip: - // todo,由于目前batch是串行执行,不存在多个协程同时访问batch表的情况,因此暂时不用加锁。 _ = jc.updateBatchStatus(tableSchema, batchTable, failPolicySkip, batchIDToExec, err.Error()) continue case failPolicyPause: @@ -831,121 +753,40 @@ func (jc *JobController) deleteDMLJobRunningMeta(uuid, table string) { delete(jc.workingTables, table) } -func (jc *JobController) execSubtaskAndRecord(ctx context.Context, tableSchema, subtaskSQL, uuid string) (affectedRows int64, err error) { - defer jc.env.LogError() - - var setting pools.Setting - if tableSchema != "" { - setting.SetWithoutDBName(false) - setting.SetQuery(fmt.Sprintf("use %s", tableSchema)) - setting.SetResetQuery(fmt.Sprintf("use %s", jc.env.Config().DB.DBName)) - } - conn, err := jc.pool.Get(ctx, &setting) - defer conn.Recycle() - if err != nil { - return 0, err - } - - _, err = conn.Exec(ctx, "start transaction", math.MaxInt32, false) - if err != nil { - return 0, err - } - - qr, err := conn.Exec(ctx, subtaskSQL, math.MaxInt32, true) - affectedRows = int64(qr.RowsAffected) - - jc.tableMutex.Lock() - defer jc.tableMutex.Unlock() - recordRstSQL, err := sqlparser.ParseAndBind(sqlDMLJobUpdateAffectedRows, - sqltypes.Int64BindVariable(affectedRows), - sqltypes.StringBindVariable(uuid)) - _, err = conn.Exec(ctx, recordRstSQL, math.MaxInt32, false) - if err != nil { - return 0, err - } - _, err = conn.Exec(ctx, "commit", math.MaxInt32, false) - if err != nil { - return 0, err - } - - return affectedRows, nil -} - func (jc *JobController) jobHealthCheck(checkBeforeSchedule chan struct{}) { ctx := context.Background() - // 1.启动时,先检查是否有处于"running"或"paused"的job,并恢复它们在内存的状态 - qr, _ := jc.execQuery(ctx, "", sqlDMLJobGetAllJobs) - if qr != nil { - - jc.workingTablesMutex.Lock() - jc.tableMutex.Lock() - - for _, row := range qr.Named().Rows { - status := row["status"].ToString() - tableSchema := row["table_schema"].ToString() - table := row["table_name"].ToString() - jobBatchTable := row["batch_info_table_name"].ToString() - uuid := row["job_uuid"].ToString() - batchInterval, _ := row["batch_interval_in_ms"].ToInt64() - batchSize, _ := row["batch_size"].ToInt64() - runningTimePeriodStart := row["running_time_period_start"].ToString() - runningTimePeriodEnd := row["running_time_period_end"].ToString() - periodStartTimePtr, periodEndTimePtr := getRunningPeriodTime(runningTimePeriodStart, runningTimePeriodEnd) - failPolicy := row["fail_policy"].ToString() - - switch status { - case runningStatus: - jc.initDMLJobRunningMeta(uuid, table) - go jc.dmlJobBatchRunner(uuid, table, tableSchema, jobBatchTable, failPolicy, batchInterval, batchSize, periodStartTimePtr, periodEndTimePtr) - case pausedStatus: - jc.initDMLJobRunningMeta(uuid, table) - } - } - - jc.workingTablesMutex.Unlock() - jc.tableMutex.Unlock() - } + jc.recoverJobsMetadata(ctx) - log.Info("check of running and paused done \n") + log.Info("jobHealthCheck: metadata of all running and paused jobs are restored to memory\n") // 内存状态恢复完毕后,唤醒Job调度协程 checkBeforeSchedule <- struct{}{} // 2.每隔一段时间轮询一次,根据job的状态进行不同的处理 timer := time.NewTicker(healthCheckInterval) defer timer.Stop() - for range timer.C { // 防止vttablet不再是primary时该协程继续执行 if jc.tabletTypeFunc() != topodatapb.TabletType_PRIMARY { return } - jc.tableMutex.Lock() + qr, _ := jc.execQuery(ctx, "", sqlDMLJobGetAllJobs) if qr != nil { for _, row := range qr.Named().Rows { status := row["status"].ToString() - statusSetTime := row["status_set_time"].ToString() - uuid := row["job_uuid"].ToString() - jobBatchTable := row["batch_info_table_name"].ToString() - tableSchema := row["table_schema"].ToString() + args := JobHealthCheckArgs{} + args.initArgsByQueryResult(row) switch status { + // 清理已经运行结束的job的表及条目 case canceledStatus, failedStatus, completedStatus: - statusSetTimeObj, err := time.Parse(time.RFC3339, statusSetTime) + err := jc.tableGC(ctx, args.uuid, args.tableSchema, args.batchInfoTable, args.statusSetTime) if err != nil { + log.Errorf("jobHealthCheck: tableGC failed, %s", err) continue } - if time.Now().After(statusSetTimeObj.Add(tableEntryGCInterval)) { - deleteJobSQL, err := sqlparser.ParseAndBind(sqlDMLJobDeleteJob, - sqltypes.StringBindVariable(uuid)) - if err != nil { - continue - } - _, _ = jc.execQuery(ctx, "", deleteJobSQL) - _, _ = jc.execQuery(ctx, tableSchema, fmt.Sprintf(sqlTemplateDropTable, jobBatchTable)) - } case runningStatus: // todo feat 增加对长时间未增加rows的running job的处理 } @@ -956,14 +797,70 @@ func (jc *JobController) jobHealthCheck(checkBeforeSchedule chan struct{}) { } } +func (jc *JobController) recoverJobsMetadata(ctx context.Context) { + qr, _ := jc.execQuery(ctx, "", sqlDMLJobGetAllJobs) + if qr != nil { + jc.workingTablesMutex.Lock() + jc.tableMutex.Lock() + + for _, row := range qr.Named().Rows { + status := row["status"].ToString() + runnerArgs := JobRunnerArgs{} + runnerArgs.initArgsByQueryResult(row) + + switch status { + case runningStatus: + jc.initDMLJobRunningMeta(runnerArgs.uuid, runnerArgs.table) + go jc.dmlJobBatchRunner(runnerArgs.uuid, runnerArgs.table, runnerArgs.tableSchema, runnerArgs.batchInfoTable, runnerArgs.failPolicy, runnerArgs.batchInterval, runnerArgs.batchSize, runnerArgs.timePeriodStart, runnerArgs.timePeriodEnd) + case pausedStatus: + jc.initDMLJobRunningMeta(runnerArgs.uuid, runnerArgs.table) + } + } + + jc.workingTablesMutex.Unlock() + jc.tableMutex.Unlock() + } +} + +func (jc *JobController) tableGC(ctx context.Context, uuid, tableSchema, batchInfoTable, statusSetTime string) error { + statusSetTimeObj, err := time.Parse(time.RFC3339, statusSetTime) + if err != nil { + return err + } + // 如果Job设置结束状态的时间距离当前已经超过了一定的时间间隔,则删除该Job在表中的条目,并将其batch表删除 + if time.Now().After(statusSetTimeObj.Add(tableEntryGCInterval)) { + deleteJobSQL, err := sqlparser.ParseAndBind(sqlDMLJobDeleteJob, + sqltypes.StringBindVariable(uuid)) + if err != nil { + return err + } + _, _ = jc.execQuery(ctx, "", deleteJobSQL) + _, _ = jc.execQuery(ctx, tableSchema, fmt.Sprintf(sqlTemplateDropTable, batchInfoTable)) + } + return nil +} + func (jc *JobController) createJobBatches(jobUUID, sql, tableSchema string, userBatchSize int64) (tableName, batchTableName string, batchSize int64, err error) { - // 1.解析用户提交的DML sql,返回DML的各个部分。其中selectSQL用于确定每一个batch的pk范围,生成每一个batch所要执行的batch sql - selectSQL, tableName, wherePart, pkPart, whereExpr, pkInfos, stmt, err := jc.parseDML(sql, tableSchema) + // 1.对用户提交的DML sql进行合法性检验和解析 + tableName, whereExpr, stmt, err := parseDML(sql) + if err != nil { + return "", "", 0, err + } + // 2.检查PK列类型的合法性 + ctx := context.Background() + pkInfos, err := jc.getTablePkInfo(ctx, tableSchema, tableName) if err != nil { return "", "", 0, err } + if existUnSupportedPK(pkInfos) { + return "", "", 0, errors.New("the table has unsupported PK type") + } + // 3.拼接生成selectSQL,用于生成batch表 + pkCols := getPKColsStr(pkInfos) + selectSQL := fmt.Sprintf("select %s from %s.%s where %s order by %s", + pkCols, tableSchema, tableName, sqlparser.String(whereExpr), pkCols) - // 2.利用selectSQL为该job生成batch表,在此之前生成每个batch的batchSize + // 4.计算每个batch的batchSize // batchSize = min(userBatchSize, batchSizeThreshold / 每个表的index数量 * ratioOfBatchSizeThreshold) indexCount, err := jc.getIndexCount(tableSchema, tableName) if err != nil { @@ -975,105 +872,13 @@ func (jc *JobController) createJobBatches(jobUUID, sql, tableSchema string, user } else { batchSize = actualThreshold } - // 3.创建batchTable表,并在表中记录每个batch所要执行的sql - batchTableName, err = jc.createBatchTable(jobUUID, selectSQL, tableSchema, sql, tableName, wherePart, pkPart, whereExpr, stmt, pkInfos, batchSize) + // 5.基于selectSQL生成batch表 + batchTableName, err = jc.createBatchTable(jobUUID, selectSQL, tableSchema, sql, tableName, whereExpr, stmt, pkInfos, batchSize) return tableName, batchTableName, batchSize, err } -func (jc *JobController) parseDML(sql, tableSchema string) (selectSQL, tableName, wherePart, pkPart string, whereExpr sqlparser.Expr, pkInfos []PKInfo, stmt sqlparser.Statement, err error) { - stmt, err = sqlparser.Parse(sql) - if err != nil { - return "", "", "", "", nil, nil, nil, err - } - // 根据stmt,分析DML SQL的各个部分,包括涉及的表,where条件 - switch s := stmt.(type) { - case *sqlparser.Delete: - if len(s.TableExprs) != 1 { - return "", "", "", "", nil, nil, nil, errors.New("the number of table is more than one") - } - tableExpr, ok := s.TableExprs[0].(*sqlparser.AliasedTableExpr) - // todo feat 目前暂不支持join和多表 - if !ok { - return "", "", "", "", nil, nil, nil, errors.New("don't support join table now") - } - tableName = sqlparser.String(tableExpr) - wherePart = sqlparser.String(s.Where) - if wherePart == "" { - return "", "", "", "", nil, nil, nil, errors.New("the sql doesn't have where condition") - } - // 将where字符串中的"where"字符串删除,便于对真正的条件部分增加括号 - wherePart = wherePart[strings.Index(wherePart, "where")+5:] - whereExpr = s.Where.Expr - - limitPart := sqlparser.String(s.Limit) - if limitPart != "" { - return "", "", "", "", nil, nil, nil, errors.New("the SQL should not have limit clause") - } - orderByPart := sqlparser.String(s.OrderBy) - if orderByPart != "" { - return "", "", "", "", nil, nil, nil, errors.New("the SQL should not have order by clause") - } - - case *sqlparser.Update: - if len(s.TableExprs) != 1 { - return "", "", "", "", nil, nil, nil, errors.New("the number of table is more than one") - } - tableExpr, ok := s.TableExprs[0].(*sqlparser.AliasedTableExpr) - if !ok { - return "", "", "", "", nil, nil, nil, errors.New("don't support join table now") - } - tableName = sqlparser.String(tableExpr) - wherePart = sqlparser.String(s.Where) - if wherePart == "" { - return "", "", "", "", nil, nil, nil, errors.New("the sql doesn't have where condition") - } - // 将where字符串中的"where"字符串删除,便于对真正的条件部分增加括号 - wherePart = wherePart[strings.Index(wherePart, "where")+5:] - whereExpr = s.Where.Expr - - limitPart := sqlparser.String(s.Limit) - if limitPart != "" { - return "", "", "", "", nil, nil, nil, errors.New("the SQL should not have limit clause") - } - orderByPart := sqlparser.String(s.OrderBy) - if orderByPart != "" { - return "", "", "", "", nil, nil, nil, errors.New("the SQL should not have order by clause") - } - - default: - // todo feat support select...into, replace...into - return "", "", "", "", nil, nil, nil, errors.New("the type of sql is not supported") - } - - // 获得该DML所相关表的PK信息,将其中的PK列组成字符串pkPart,形如"PKCol1,PKCol2,PKCol3" - ctx := context.Background() - pkInfos, err = jc.getTablePkInfo(ctx, tableSchema, tableName) - if existUnSupportedPK(pkInfos) { - return "", "", "", "", nil, nil, nil, errors.New("the table has unsupported PK type") - } - if err != nil { - return "", "", "", "", nil, nil, nil, err - } - pkPart = "" - firstPK := true - for _, pkInfo := range pkInfos { - if !firstPK { - pkPart += "," - } - pkPart += pkInfo.pkName - firstPK = false - } - - // 将该DML的各部分信息组成batch select语句,用于生成每一个batch的pk范围 - selectSQL = fmt.Sprintf("select %s from %s.%s where %s order by %s", - pkPart, tableSchema, tableName, wherePart, pkPart) - - return selectSQL, tableName, wherePart, pkPart, whereExpr, pkInfos, stmt, err -} - -func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, tableName, wherePart, pkPart string, whereExpr sqlparser.Expr, stmt sqlparser.Statement, pkInfos []PKInfo, batchSize int64) (string, error) { +func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, tableName string, whereExpr sqlparser.Expr, stmt sqlparser.Statement, pkInfos []PKInfo, batchSize int64) (string, error) { ctx := context.Background() - // 执行selectSQL,获得有序的pk值结果集,以生成每一个batch要执行的batch SQL qr, err := jc.execQuery(ctx, "", selectSQL) if err != nil { @@ -1083,11 +888,10 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, return "", nil } + // todo feat 删除batchSQL,batchCountSQL,字段,在内存中生成具体的sql, mysql generate col 或者 go代码实现 // 为每一个DML job创建一张batch表,保存着该job被拆分成batches的具体信息。 // healthCheck协程会定时对处于结束状态(completed,canceled,failed)的job的batch表进行回收 batchTableName := "_vt_BATCH_" + strings.Replace(jobUUID, "-", "_", -1) - - // todo feat 删除batchSQL,batchCountSQL,字段,在内存中生成具体的sql, mysql generate col 或者 go代码实现 createTableSQL := fmt.Sprintf(sqlTemplateCreateBatchTable, batchTableName) _, err = jc.execQuery(ctx, tableSchema, createTableSQL) if err != nil { @@ -1100,7 +904,6 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, var currentBatchStart []sqltypes.Value var currentBatchEnd []sqltypes.Value currentBatchID := "1" - insertBatchSQLWithTableName := fmt.Sprintf(sqlTemplateInsertBatchEntry, batchTableName) for _, values := range qr.Rows { if currentBatchSize == 0 { @@ -1108,32 +911,12 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, } currentBatchEnd = values currentBatchSize++ - if currentBatchSize == batchSize { - batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) - if err != nil { - return "", err - } - countSQL := genCountSQL(tableSchema, tableName, finalWhereStr) - if err != nil { - return "", err - } - batchStartStr, batchEndStr, err := genBatchStartAndEndStr(currentBatchStart, currentBatchEnd) - if err != nil { - return "", err - } - currentBatchSize = 0 - insertBatchSQLQuery, err := sqlparser.ParseAndBind(insertBatchSQLWithTableName, - sqltypes.StringBindVariable(currentBatchID), - sqltypes.StringBindVariable(batchSQL), - sqltypes.StringBindVariable(countSQL), - sqltypes.Int64BindVariable(batchSize), - sqltypes.StringBindVariable(batchStartStr), - sqltypes.StringBindVariable(batchEndStr)) + batchSQL, countSQL, batchStartStr, batchEndStr, err := createBatchInfoTableEntry(tableSchema, tableName, sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) if err != nil { return "", err } - _, err = jc.execQuery(ctx, tableSchema, insertBatchSQLQuery) + err = jc.insertBatchInfoTableEntry(ctx, tableSchema, batchTableName, currentBatchID, batchSQL, countSQL, batchStartStr, batchEndStr, currentBatchSize) if err != nil { return "", err } @@ -1141,33 +924,16 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, if err != nil { return "", err } + currentBatchSize = 0 } } // 最后一个batch的行数不一定是batchSize,在循环结束时要将剩余的行数划分到最后一个batch中 if currentBatchSize != 0 { - batchSQL, finalWhereStr, err := genBatchSQL(sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) - if err != nil { - return "", err - } - countSQL := genCountSQL(tableSchema, tableName, finalWhereStr) - if err != nil { - return "", err - } - batchStartStr, batchEndStr, err := genBatchStartAndEndStr(currentBatchStart, currentBatchEnd) - if err != nil { - return "", err - } - insertBatchSQLQuery, err := sqlparser.ParseAndBind(insertBatchSQLWithTableName, - sqltypes.StringBindVariable(currentBatchID), - sqltypes.StringBindVariable(batchSQL), - sqltypes.StringBindVariable(countSQL), - sqltypes.Int64BindVariable(currentBatchSize), - sqltypes.StringBindVariable(batchStartStr), - sqltypes.StringBindVariable(batchEndStr)) + batchSQL, countSQL, batchStartStr, batchEndStr, err := createBatchInfoTableEntry(tableSchema, tableName, sql, stmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) if err != nil { return "", err } - _, err = jc.execQuery(ctx, tableSchema, insertBatchSQLQuery) + err = jc.insertBatchInfoTableEntry(ctx, tableSchema, batchTableName, currentBatchID, batchSQL, countSQL, batchStartStr, batchEndStr, currentBatchSize) if err != nil { return "", err } @@ -1175,6 +941,23 @@ func (jc *JobController) createBatchTable(jobUUID, selectSQL, tableSchema, sql, return batchTableName, nil } +func createBatchInfoTableEntry(tableSchema, tableName, sql string, sqlStmt sqlparser.Statement, whereExpr sqlparser.Expr, + currentBatchStart, currentBatchEnd []sqltypes.Value, pkInfos []PKInfo) (batchSQL, countSQL, batchStartStr, batchEndStr string, err error) { + batchSQL, finalWhereStr, err := genBatchSQL(sql, sqlStmt, whereExpr, currentBatchStart, currentBatchEnd, pkInfos) + if err != nil { + return "", "", "", "", err + } + countSQL = genCountSQL(tableSchema, tableName, finalWhereStr) + if err != nil { + return "", "", "", "", err + } + batchStartStr, batchEndStr, err = genBatchStartAndEndStr(currentBatchStart, currentBatchEnd) + if err != nil { + return "", "", "", "", err + } + return batchSQL, countSQL, batchStartStr, batchEndStr, nil +} + // 通知jobScheduler让它立刻开始一次调度。 func (jc *JobController) notifyJobScheduler() { if jc.schedulerNotifyChan == nil { diff --git a/go/vt/vttablet/jobcontroller/sqls.go b/go/vt/vttablet/jobcontroller/sqls.go index 42da7b3686..4e776a6ffe 100644 --- a/go/vt/vttablet/jobcontroller/sqls.go +++ b/go/vt/vttablet/jobcontroller/sqls.go @@ -31,13 +31,13 @@ const ( table_name, batch_info_table_schema, batch_info_table_name, - batch_interval_in_ms, - batch_size, status, status_set_time, fail_policy, running_time_period_start, - running_time_period_end) values(%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a)` + running_time_period_end, + batch_interval_in_ms, + batch_size) values(%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a,%a)` sqlDMLJobUpdateMessage = `update mysql.non_transactional_dml_jobs set message = %a diff --git a/go/vt/vttablet/jobcontroller/util.go b/go/vt/vttablet/jobcontroller/util.go index 585511e29c..f67dbce071 100644 --- a/go/vt/vttablet/jobcontroller/util.go +++ b/go/vt/vttablet/jobcontroller/util.go @@ -72,6 +72,87 @@ func (jc *JobController) execQuery(ctx context.Context, targetString, query stri } +func parseDML(sql string) (tableName string, whereExpr sqlparser.Expr, stmt sqlparser.Statement, err error) { + stmt, err = sqlparser.Parse(sql) + if err != nil { + return "", nil, nil, err + } + // 根据stmt,分析DML SQL的各个部分,包括涉及的表,where条件 + switch s := stmt.(type) { + case *sqlparser.Delete: + if len(s.TableExprs) != 1 { + return "", nil, nil, errors.New("the number of table is more than one") + } + tableExpr, ok := s.TableExprs[0].(*sqlparser.AliasedTableExpr) + // todo feat 目前暂不支持join和多表 + if !ok { + return "", nil, nil, errors.New("don't support join table now") + } + tableName = sqlparser.String(tableExpr) + // the sql should have where clause + if s.Where == nil { + return "", nil, nil, errors.New("the SQL should have where clause") + } + whereExpr = s.Where.Expr + // the sql should not have limit clause and order by clause + limitPart := sqlparser.String(s.Limit) + if limitPart != "" { + return "", nil, nil, errors.New("the SQL should not have limit clause") + } + orderByPart := sqlparser.String(s.OrderBy) + if orderByPart != "" { + return "", nil, nil, errors.New("the SQL should not have order by clause") + } + + case *sqlparser.Update: + if len(s.TableExprs) != 1 { + return "", nil, nil, errors.New("the number of table is more than one") + } + tableExpr, ok := s.TableExprs[0].(*sqlparser.AliasedTableExpr) + if !ok { + return "", nil, nil, errors.New("don't support join table now") + } + tableName = sqlparser.String(tableExpr) + // the sql should have where clause + if s.Where == nil { + return "", nil, nil, errors.New("the SQL should have where clause") + } + whereExpr = s.Where.Expr + // the sql should not have limit clause and order by clause + limitPart := sqlparser.String(s.Limit) + if limitPart != "" { + return "", nil, nil, errors.New("the SQL should not have limit clause") + } + orderByPart := sqlparser.String(s.OrderBy) + if orderByPart != "" { + return "", nil, nil, errors.New("the SQL should not have order by clause") + } + + default: + // todo feat support select...into, replace...into + return "", nil, nil, errors.New("the type of sql is not supported") + } + + if err != nil { + return "", nil, nil, err + } + + return tableName, whereExpr, stmt, err +} + +func getPKColsStr(pkInfos []PKInfo) string { + pkCols := "" + firstPK := true + for _, pkInfo := range pkInfos { + if !firstPK { + pkCols += "," + } + pkCols += pkInfo.pkName + firstPK = false + } + return pkCols +} + // 该函数拿锁 func (jc *JobController) updateJobMessage(ctx context.Context, uuid, message string) error { jc.tableMutex.Lock() @@ -351,3 +432,75 @@ func (jc *JobController) updateBatchStatus(batchTableSchema, batchTableName, sta _, err = jc.execQuery(context.Background(), batchTableSchema, query) return err } + +func (args *JobRunnerArgs) initArgsByQueryResult(row sqltypes.RowNamedValues) { + args.uuid = row["job_uuid"].ToString() + args.tableSchema = row["table_schema"].ToString() + args.table = row["table_name"].ToString() + args.batchInfoTable = row["batch_info_table_name"].ToString() + args.failPolicy = row["fail_policy"].ToString() + + batchInterval, _ := row["batch_interval_in_ms"].ToInt64() + batchSize, _ := row["batch_size"].ToInt64() + args.batchInterval = batchInterval + args.batchSize = batchSize + + runningTimePeriodStart := row["running_time_period_start"].ToString() + runningTimePeriodEnd := row["running_time_period_end"].ToString() + args.timePeriodStart, args.timePeriodEnd = getRunningPeriodTime(runningTimePeriodStart, runningTimePeriodEnd) + +} + +func (args *JobHealthCheckArgs) initArgsByQueryResult(row sqltypes.RowNamedValues) { + args.uuid = row["job_uuid"].ToString() + args.tableSchema = row["table_schema"].ToString() + args.batchInfoTable = row["batch_info_table_name"].ToString() + args.statusSetTime = row["status_set_time"].ToString() +} + +func (jc *JobController) insertBatchInfoTableEntry(ctx context.Context, tableSchema, batchTableName, currentBatchID, batchSQL, countSQL, batchStartStr, batchEndStr string, batchSize int64) (err error) { + insertBatchSQLWithTableName := fmt.Sprintf(sqlTemplateInsertBatchEntry, batchTableName) + insertBatchSQLQuery, err := sqlparser.ParseAndBind(insertBatchSQLWithTableName, + sqltypes.StringBindVariable(currentBatchID), + sqltypes.StringBindVariable(batchSQL), + sqltypes.StringBindVariable(countSQL), + sqltypes.Int64BindVariable(batchSize), + sqltypes.StringBindVariable(batchStartStr), + sqltypes.StringBindVariable(batchEndStr)) + if err != nil { + return err + } + _, err = jc.execQuery(ctx, tableSchema, insertBatchSQLQuery) + if err != nil { + return err + } + return nil +} + +func (jc *JobController) insertJobEntry(jobUUID, sql, tableSchema, tableName, batchInfoTableSchema, + batchInfoTable, jobStatus, statusSetTime, failPolicy, runningTimePeriodStart, runningTimePeriodEnd string, + timeGapInMs, batchSize int64) (err error) { + ctx := context.Background() + submitQuery, err := sqlparser.ParseAndBind(sqlDMLJobSubmit, + sqltypes.StringBindVariable(jobUUID), + sqltypes.StringBindVariable(sql), + sqltypes.StringBindVariable(tableSchema), + sqltypes.StringBindVariable(tableName), + sqltypes.StringBindVariable(batchInfoTableSchema), + sqltypes.StringBindVariable(batchInfoTable), + sqltypes.StringBindVariable(jobStatus), + sqltypes.StringBindVariable(statusSetTime), + sqltypes.StringBindVariable(failPolicy), + sqltypes.StringBindVariable(runningTimePeriodStart), + sqltypes.StringBindVariable(runningTimePeriodEnd), + sqltypes.Int64BindVariable(timeGapInMs), + sqltypes.Int64BindVariable(batchSize)) + if err != nil { + return err + } + _, err = jc.execQuery(ctx, "", submitQuery) + if err != nil { + return err + } + return nil +} diff --git a/go/vt/vttablet/jobcontroller/util_test.go b/go/vt/vttablet/jobcontroller/util_test.go new file mode 100644 index 0000000000..be551d669a --- /dev/null +++ b/go/vt/vttablet/jobcontroller/util_test.go @@ -0,0 +1,58 @@ +/* +Copyright ApeCloud, Inc. +Licensed under the Apache v2(found in the LICENSE file in the root directory). +*/ + +package jobcontroller + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/vt/sqlparser" +) + +func TestParseDML(t *testing.T) { + // DELETE + tableName, whereExpr, _, err := parseDML("delete from t1 where id=1") + whereExprStr := sqlparser.String(whereExpr) + expectedTableName := "t1" + expectedWhereExpr := "id = 1" + assert.Equalf(t, expectedTableName, tableName, "table name") + assert.Equalf(t, expectedWhereExpr, whereExprStr, "where expr") + assert.Equal(t, nil, err) + + // UPDATE + tableName, whereExpr, _, err = parseDML("update t2 set c1 = '123' where id=2") + whereExprStr = sqlparser.String(whereExpr) + expectedTableName = "t2" + expectedWhereExpr = "id = 2" + assert.Equalf(t, expectedTableName, tableName, "table name") + assert.Equalf(t, expectedWhereExpr, whereExprStr, "where expr") + assert.Equal(t, nil, err) + + // error: the type of sql is not supported + _, _, _, err = parseDML("select * from t1") + assert.Equalf(t, "the type of sql is not supported", err.Error(), "error message: %s", err.Error()) + + // error: don't support join table now + _, _, _, err = parseDML("update t1 join t2 on t1.id = t2.id set t1.c1 = '123' where t1.id = 1") + assert.Equalf(t, "don't support join table now", err.Error(), "error message: %s", err.Error()) + + // support alias + _, _, _, err = parseDML("update t1 as mytable set mytable.c1 = '123' where mytable.id = 1") + assert.Equal(t, err, nil) + + // error: the SQL should have where clause + _, _, _, err = parseDML("delete from t1") + assert.Equalf(t, "the SQL should have where clause", err.Error(), "error message: %s", err.Error()) + + // error: the SQL should not have limit clause + _, _, _, err = parseDML("delete from t1 where id=1 limit 1") + assert.Equalf(t, "the SQL should not have limit clause", err.Error(), "error message: %s", err.Error()) + + // error: the SQL should not have order clause + _, _, _, err = parseDML("delete from t1 where id=1 order by c1") + assert.Equalf(t, "the SQL should not have order by clause", err.Error(), "error message: %s", err.Error()) +}