diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index bf810a5a319..9874d368dcd 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -170,6 +170,17 @@ func (shard *Shard) PrimaryTablet() *Vttablet { return shard.Vttablets[0] } +// FindPrimaryTablet finds the primary tablet in the shard. +func (shard *Shard) FindPrimaryTablet() *Vttablet { + for _, vttablet := range shard.Vttablets { + tabletType := vttablet.VttabletProcess.GetTabletType() + if tabletType == "primary" { + return vttablet + } + } + return nil +} + // Rdonly get the last tablet which is rdonly func (shard *Shard) Rdonly() *Vttablet { for idx, tablet := range shard.Vttablets { diff --git a/go/test/endtoend/cluster/reshard.go b/go/test/endtoend/cluster/reshard.go index af36d4543c8..3cec8c14e4d 100644 --- a/go/test/endtoend/cluster/reshard.go +++ b/go/test/endtoend/cluster/reshard.go @@ -95,7 +95,7 @@ func (rw *ReshardWorkflow) WaitForVreplCatchup(timeToWait time.Duration) { if !slices.Contains(targetShards, shard.Name) { continue } - vttablet := shard.PrimaryTablet().VttabletProcess + vttablet := shard.FindPrimaryTablet().VttabletProcess vttablet.WaitForVReplicationToCatchup(rw.t, rw.workflowName, fmt.Sprintf("vt_%s", vttablet.Keyspace), "", timeToWait) } } diff --git a/go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go b/go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go index c908a99e631..75bc46bacab 100644 --- a/go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go +++ b/go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path" + "slices" "strconv" "strings" "sync" @@ -126,7 +127,18 @@ func TestTwoPCFuzzTest(t *testing.T) { fz.start(t) // Wait for the timeForTesting so that the threads continue to run. - time.Sleep(tt.timeForTesting) + timeout := time.After(tt.timeForTesting) + loop := true + for loop { + select { + case <-timeout: + loop = false + case <-time.After(1 * time.Second): + if t.Failed() { + loop = false + } + } + } // Signal the fuzzer to stop. fz.stop() @@ -302,9 +314,11 @@ func (fz *fuzzer) generateAndExecuteTransaction(threadId int) { // for each update set ordered by the auto increment column will not be true. // That assertion depends on all the transactions running updates first to ensure that for any given update set, // no two transactions are running the insert queries. - queries := []string{"begin"} + var queries []string queries = append(queries, fz.generateUpdateQueries(updateSetVal, incrementVal)...) queries = append(queries, fz.generateInsertQueries(updateSetVal, threadId)...) + queries = fz.addRandomSavePoints(queries) + queries = append([]string{"begin"}, queries...) finalCommand := "commit" for _, query := range queries { _, err := conn.ExecuteFetch(query, 0, false) @@ -377,6 +391,45 @@ func (fz *fuzzer) runClusterDisruption(t *testing.T) { } } +// addRandomSavePoints will add random savepoints and queries to the list of queries. +// It still ensures that all the new queries added are rolledback so that the assertions of queries +// don't change. +func (fz *fuzzer) addRandomSavePoints(queries []string) []string { + savePointCount := 1 + for { + shouldAddSavePoint := rand.Intn(2) + if shouldAddSavePoint == 0 { + return queries + } + + savePointQueries := []string{"SAVEPOINT sp" + strconv.Itoa(savePointCount)} + randomDmlCount := rand.Intn(2) + 1 + for i := 0; i < randomDmlCount; i++ { + savePointQueries = append(savePointQueries, fz.randomDML()) + } + savePointQueries = append(savePointQueries, "ROLLBACK TO sp"+strconv.Itoa(savePointCount)) + savePointCount++ + + savePointPosition := rand.Intn(len(queries)) + newQueries := slices.Clone(queries[:savePointPosition]) + newQueries = append(newQueries, savePointQueries...) + newQueries = append(newQueries, queries[savePointPosition:]...) + queries = newQueries + } +} + +// randomDML generates a random DML to be used. +func (fz *fuzzer) randomDML() string { + queryType := rand.Intn(2) + if queryType == 0 { + // Generate INSERT + return fmt.Sprintf(insertIntoFuzzInsert, updateRowBaseVals[rand.Intn(len(updateRowBaseVals))], rand.Intn(fz.updateSets), rand.Intn(fz.threads)) + } + // Generate UPDATE + updateId := fz.updateRowsVals[rand.Intn(len(fz.updateRowsVals))][rand.Intn(len(updateRowBaseVals))] + return fmt.Sprintf(updateFuzzUpdate, rand.Intn(100000), updateId) +} + /* Cluster Level Disruptions for the fuzzer */ diff --git a/go/test/endtoend/transaction/twopc/main_test.go b/go/test/endtoend/transaction/twopc/main_test.go index 9a46562d1c7..eaf835e678e 100644 --- a/go/test/endtoend/transaction/twopc/main_test.go +++ b/go/test/endtoend/transaction/twopc/main_test.go @@ -24,16 +24,19 @@ import ( "fmt" "io" "os" + "strings" "sync" "testing" "time" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/endtoend/cluster" - "vitess.io/vitess/go/test/endtoend/transaction/twopc/utils" + twopcutil "vitess.io/vitess/go/test/endtoend/transaction/twopc/utils" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -42,6 +45,7 @@ import ( var ( clusterInstance *cluster.LocalProcessCluster + mysqlParams mysql.ConnParams vtParams mysql.ConnParams vtgateGrpcAddress string keyspaceName = "ks" @@ -81,6 +85,8 @@ func TestMain(m *testing.M) { "--twopc_enable", "--twopc_abandon_age", "1", "--queryserver-config-transaction-cap", "3", + "--queryserver-config-transaction-timeout", "400s", + "--queryserver-config-query-timeout", "9000s", ) // Start keyspace @@ -102,6 +108,15 @@ func TestMain(m *testing.M) { vtParams = clusterInstance.GetVTParams(keyspaceName) vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort) + // create mysql instance and connection parameters + conn, closer, err := utils.NewMySQL(clusterInstance, keyspaceName, SchemaSQL) + if err != nil { + fmt.Println(err) + return 1 + } + defer closer() + mysqlParams = conn + return m.Run() }() os.Exit(exitcode) @@ -121,8 +136,29 @@ func start(t *testing.T) (*mysql.Conn, func()) { func cleanup(t *testing.T) { cluster.PanicHandler(t) - utils.ClearOutTable(t, vtParams, "twopc_user") - utils.ClearOutTable(t, vtParams, "twopc_t1") + twopcutil.ClearOutTable(t, vtParams, "twopc_user") + twopcutil.ClearOutTable(t, vtParams, "twopc_t1") + sm.reset() +} + +func startWithMySQL(t *testing.T) (utils.MySQLCompare, func()) { + mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams) + require.NoError(t, err) + + deleteAll := func() { + tables := []string{"twopc_user"} + for _, table := range tables { + _, _ = mcmp.ExecAndIgnore("delete from " + table) + } + } + + deleteAll() + + return mcmp, func() { + deleteAll() + mcmp.Close() + cluster.PanicHandler(t) + } } type extractInterestingValues func(dtidMap map[string]string, vals []sqltypes.Value) []sqltypes.Value @@ -147,7 +183,8 @@ var tables = map[string]extractInterestingValues{ }, "ks.redo_statement": func(dtidMap map[string]string, vals []sqltypes.Value) (out []sqltypes.Value) { dtid := getDTID(dtidMap, vals[0].ToString()) - out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1:]...) + stmt := getStatement(vals[2].ToString()) + out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1], sqltypes.TestValue(sqltypes.Blob, stmt)) return }, "ks.twopc_user": func(_ map[string]string, vals []sqltypes.Value) []sqltypes.Value { return vals }, @@ -167,6 +204,28 @@ func getDTID(dtidMap map[string]string, dtKey string) string { return dtid } +func getStatement(stmt string) string { + var sKey string + var prefix string + switch { + case strings.HasPrefix(stmt, "savepoint"): + prefix = "savepoint-" + sKey = stmt[9:] + case strings.HasPrefix(stmt, "rollback to"): + prefix = "rollback-" + sKey = stmt[11:] + default: + return stmt + } + + sid, exists := sm.stmt[sKey] + if !exists { + sid = fmt.Sprintf("%d", len(sm.stmt)+1) + sm.stmt[sKey] = sid + } + return prefix + sid +} + func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, vtgateConn *vtgateconn.VTGateConn) { vgtid := &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{ @@ -272,3 +331,13 @@ func prettyPrint(v interface{}) string { } return string(b) } + +type stmtMapper struct { + stmt map[string]string +} + +var sm = &stmtMapper{stmt: make(map[string]string)} + +func (sm *stmtMapper) reset() { + sm.stmt = make(map[string]string) +} diff --git a/go/test/endtoend/transaction/twopc/twopc_test.go b/go/test/endtoend/transaction/twopc/twopc_test.go index cdb6b61f91a..dda655613de 100644 --- a/go/test/endtoend/transaction/twopc/twopc_test.go +++ b/go/test/endtoend/transaction/twopc/twopc_test.go @@ -47,7 +47,7 @@ func TestDTCommit(t *testing.T) { conn, closer := start(t) defer closer() - vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -183,7 +183,7 @@ func TestDTRollback(t *testing.T) { utils.Exec(t, conn, "insert into twopc_user(id, name) values(7,'foo'), (8,'bar')") // run vstream to stream binlogs - vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -232,7 +232,7 @@ func TestDTCommitDMLOnlyOnMM(t *testing.T) { conn, closer := start(t) defer closer() - vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -327,7 +327,7 @@ func TestDTCommitDMLOnlyOnRM(t *testing.T) { conn, closer := start(t) defer closer() - vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -435,7 +435,7 @@ func TestDTPrepareFailOnRM(t *testing.T) { conn, closer := start(t) defer closer() - vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "") + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -637,6 +637,8 @@ func TestDTResolveAfterMMCommit(t *testing.T) { // TestDTResolveAfterRMPrepare tests that transaction is rolled back on recovery // failure after RM prepare and before MM commit. func TestDTResolveAfterRMPrepare(t *testing.T) { + defer cleanup(t) + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -697,6 +699,8 @@ func TestDTResolveAfterRMPrepare(t *testing.T) { // TestDTResolveDuringRMPrepare tests that transaction is rolled back on recovery // failure after semi RM prepare. func TestDTResolveDuringRMPrepare(t *testing.T) { + defer cleanup(t) + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -844,6 +848,8 @@ func TestDTResolveDuringRMCommit(t *testing.T) { // TestDTResolveAfterTransactionRecord tests that transaction is rolled back on recovery // failure after TR created and before RM prepare. func TestDTResolveAfterTransactionRecord(t *testing.T) { + defer cleanup(t) + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") require.NoError(t, err) defer vtgateConn.Close() @@ -959,6 +965,8 @@ func testWarningAndTransactionStatus(t *testing.T, conn *vtgateconn.VTGateSessio // TestReadingUnresolvedTransactions tests the reading of unresolved transactions func TestReadingUnresolvedTransactions(t *testing.T) { + defer cleanup(t) + testcases := []struct { name string queries []string @@ -1023,6 +1031,322 @@ func TestReadingUnresolvedTransactions(t *testing.T) { } } +// TestDTSavepointWithVanilaMySQL ensures that distributed transactions should work with savepoint as with vanila MySQL +func TestDTSavepointWithVanilaMySQL(t *testing.T) { + mcmp, closer := startWithMySQL(t) + defer closer() + + // internal savepoint + mcmp.Exec("begin") + mcmp.Exec("insert into twopc_user(id, name) values(7,'foo'), (8,'bar')") + mcmp.Exec("commit") + mcmp.Exec("select * from twopc_user order by id") + + // external savepoint, single shard transaction. + mcmp.Exec("begin") + mcmp.Exec("savepoint a") + mcmp.Exec("insert into twopc_user(id, name) values(9,'baz')") + mcmp.Exec("savepoint b") + mcmp.Exec("rollback to b") + mcmp.Exec("commit") + mcmp.Exec("select * from twopc_user order by id") + + // external savepoint, multi-shard transaction. + mcmp.Exec("begin") + mcmp.Exec("savepoint a") + mcmp.Exec("insert into twopc_user(id, name) values(10,'apa')") + mcmp.Exec("savepoint b") + mcmp.Exec("update twopc_user set name = 'temp' where id = 7") + mcmp.Exec("rollback to a") + mcmp.Exec("commit") + mcmp.Exec("select * from twopc_user order by id") + + // external savepoint, multi-shard transaction. + mcmp.Exec("begin") + mcmp.Exec("savepoint a") + mcmp.Exec("insert into twopc_user(id, name) values(10,'apa')") + mcmp.Exec("savepoint b") + mcmp.Exec("update twopc_user set name = 'temp' where id = 7") + mcmp.Exec("rollback to b") + mcmp.Exec("commit") + mcmp.Exec("select * from twopc_user order by id") + + // external savepoint, multi-shard transaction. + mcmp.Exec("begin") + mcmp.Exec("update twopc_user set name = 'temp1' where id = 10") + mcmp.Exec("savepoint b") + mcmp.Exec("update twopc_user set name = 'temp2' where id = 7") + mcmp.Exec("commit") + mcmp.Exec("select * from twopc_user order by id") +} + +// TestDTSavepoint tests distributed transaction should work with savepoint. +func TestDTSavepoint(t *testing.T) { + defer cleanup(t) + + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") + require.NoError(t, err) + defer vtgateConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan *binlogdatapb.VEvent) + runVStream(t, ctx, ch, vtgateConn) + + ss := vtgateConn.Session("", nil) + + // internal savepoint + execute(ctx, t, ss, "begin") + execute(ctx, t, ss, "insert into twopc_user(id, name) values(7,'foo'), (8,'bar')") + execute(ctx, t, ss, "commit") + + tableMap := make(map[string][]*querypb.Field) + dtMap := make(map[string]string) + logTable := retrieveTransitions(t, ch, tableMap, dtMap) + expectations := map[string][]string{ + "ks.dt_participant:40-80": { + "insert:[VARCHAR(\"dtid-1\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + }, + "ks.dt_state:40-80": { + "insert:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + "update:[VARCHAR(\"dtid-1\") VARCHAR(\"COMMIT\")]", + "delete:[VARCHAR(\"dtid-1\") VARCHAR(\"COMMIT\")]", + }, + "ks.redo_state:80-": { + "insert:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + "delete:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + }, + "ks.redo_statement:80-": { + "insert:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (7, 'foo')\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (7, 'foo')\")]", + }, + "ks.twopc_user:40-80": {"insert:[INT64(8) VARCHAR(\"bar\")]"}, + "ks.twopc_user:80-": {"insert:[INT64(7) VARCHAR(\"foo\")]"}, + } + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) + + // external savepoint, single shard transaction. + execute(ctx, t, ss, "begin") + execute(ctx, t, ss, "savepoint a") + execute(ctx, t, ss, "insert into twopc_user(id, name) values(9,'baz')") + execute(ctx, t, ss, "savepoint b") + execute(ctx, t, ss, "rollback to b") + execute(ctx, t, ss, "commit") + + logTable = retrieveTransitions(t, ch, tableMap, dtMap) + expectations = map[string][]string{ + "ks.twopc_user:80-": {"insert:[INT64(9) VARCHAR(\"baz\")]"}} + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) + + // external savepoint, multi-shard transaction - rollback to a savepoint that leaves no change. + execute(ctx, t, ss, "begin") + execute(ctx, t, ss, "savepoint a") + execute(ctx, t, ss, "insert into twopc_user(id, name) values(10,'apa')") + execute(ctx, t, ss, "savepoint b") + execute(ctx, t, ss, "update twopc_user set name = 'temp' where id = 7") + execute(ctx, t, ss, "rollback to a") + execute(ctx, t, ss, "commit") + + logTable = retrieveTransitions(t, ch, tableMap, dtMap) + expectations = map[string][]string{ + "ks.dt_participant:-40": { + "insert:[VARCHAR(\"dtid-2\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + "delete:[VARCHAR(\"dtid-2\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + }, + "ks.dt_state:-40": { + "insert:[VARCHAR(\"dtid-2\") VARCHAR(\"PREPARE\")]", + "update:[VARCHAR(\"dtid-2\") VARCHAR(\"COMMIT\")]", + "delete:[VARCHAR(\"dtid-2\") VARCHAR(\"COMMIT\")]", + }, + } + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) + + // external savepoint, multi-shard transaction - rollback to a savepoint that leaves a change. + execute(ctx, t, ss, "begin") + execute(ctx, t, ss, "savepoint a") + execute(ctx, t, ss, "insert into twopc_user(id, name) values(10,'apa')") + execute(ctx, t, ss, "savepoint b") + execute(ctx, t, ss, "update twopc_user set name = 'temp' where id = 7") + execute(ctx, t, ss, "rollback to b") + execute(ctx, t, ss, "commit") + + logTable = retrieveTransitions(t, ch, tableMap, dtMap) + expectations = map[string][]string{ + "ks.dt_participant:-40": { + "insert:[VARCHAR(\"dtid-3\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + "delete:[VARCHAR(\"dtid-3\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + }, + "ks.dt_state:-40": { + "insert:[VARCHAR(\"dtid-3\") VARCHAR(\"PREPARE\")]", + "update:[VARCHAR(\"dtid-3\") VARCHAR(\"COMMIT\")]", + "delete:[VARCHAR(\"dtid-3\") VARCHAR(\"COMMIT\")]", + }, + "ks.twopc_user:-40": {"insert:[INT64(10) VARCHAR(\"apa\")]"}, + } + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) + + // external savepoint, multi-shard transaction - savepoint added later and rollback to it. + execute(ctx, t, ss, "begin") + execute(ctx, t, ss, "update twopc_user set name = 'temp1' where id = 7") + execute(ctx, t, ss, "savepoint c") + execute(ctx, t, ss, "update twopc_user set name = 'temp2' where id = 8") + execute(ctx, t, ss, "rollback to c") + execute(ctx, t, ss, "commit") + + logTable = retrieveTransitions(t, ch, tableMap, dtMap) + expectations = map[string][]string{ + "ks.dt_participant:40-80": { + "insert:[VARCHAR(\"dtid-4\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + "delete:[VARCHAR(\"dtid-4\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + }, + "ks.dt_state:40-80": { + "insert:[VARCHAR(\"dtid-4\") VARCHAR(\"PREPARE\")]", + "update:[VARCHAR(\"dtid-4\") VARCHAR(\"COMMIT\")]", + "delete:[VARCHAR(\"dtid-4\") VARCHAR(\"COMMIT\")]", + }, + "ks.redo_state:80-": { + "insert:[VARCHAR(\"dtid-4\") VARCHAR(\"PREPARE\")]", + "delete:[VARCHAR(\"dtid-4\") VARCHAR(\"PREPARE\")]", + }, + "ks.redo_statement:80-": { + "insert:[VARCHAR(\"dtid-4\") INT64(1) BLOB(\"update twopc_user set `name` = 'temp1' where id = 7 limit 10001 /* INT64 */\")]", + "delete:[VARCHAR(\"dtid-4\") INT64(1) BLOB(\"update twopc_user set `name` = 'temp1' where id = 7 limit 10001 /* INT64 */\")]", + }, + "ks.twopc_user:80-": {"update:[INT64(7) VARCHAR(\"temp1\")]"}, + } + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) +} + +func execute(ctx context.Context, t *testing.T, ss *vtgateconn.VTGateSession, sql string) { + t.Helper() + + err := executeReturnError(ctx, t, ss, sql) + require.NoError(t, err) +} + +func executeReturnError(ctx context.Context, t *testing.T, ss *vtgateconn.VTGateSession, sql string) error { + t.Helper() + + if sql == "commit" { + // sort by shard + sortShard(ss) + } + _, err := ss.Execute(ctx, sql, nil) + return err +} + +func sortShard(ss *vtgateconn.VTGateSession) { + sort.Slice(ss.SessionPb().ShardSessions, func(i, j int) bool { + return ss.SessionPb().ShardSessions[i].Target.Shard < ss.SessionPb().ShardSessions[j].Target.Shard + }) +} + +// TestDTSavepointResolveAfterMMCommit tests that transaction is committed on recovery +// failure after MM commit involving savepoint. +func TestDTSavepointResolveAfterMMCommit(t *testing.T) { + defer cleanup(t) + + vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "dt_user", "") + require.NoError(t, err) + defer vtgateConn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan *binlogdatapb.VEvent) + runVStream(t, ctx, ch, vtgateConn) + + conn := vtgateConn.Session("", nil) + qCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // initial insert + for i := 1; i <= 100; i++ { + execute(qCtx, t, conn, fmt.Sprintf("insert into twopc_user(id, name) values(%d,'foo')", 10*i)) + } + + // ignore initial change + tableMap := make(map[string][]*querypb.Field) + dtMap := make(map[string]string) + _ = retrieveTransitionsWithTimeout(t, ch, tableMap, dtMap, 2*time.Second) + + // Insert into multiple shards + execute(qCtx, t, conn, "begin") + execute(qCtx, t, conn, "insert into twopc_user(id, name) values(7,'foo'),(8,'bar')") + execute(qCtx, t, conn, "savepoint a") + for i := 1; i <= 100; i++ { + execute(qCtx, t, conn, fmt.Sprintf("insert ignore into twopc_user(id, name) values(%d,'baz')", 12+i)) + } + execute(qCtx, t, conn, "savepoint b") + execute(qCtx, t, conn, "insert into twopc_user(id, name) values(11,'apa')") + execute(qCtx, t, conn, "rollback to a") + + // The caller ID is used to simulate the failure at the desired point. + newCtx := callerid.NewContext(qCtx, callerid.NewEffectiveCallerID("MMCommitted_FailNow", "", ""), nil) + err = executeReturnError(newCtx, t, conn, "commit") + require.ErrorContains(t, err, "Fail After MM commit") + + testWarningAndTransactionStatus(t, conn, + "distributed transaction ID failed during metadata manager commit; transaction will be committed/rollbacked based on the state on recovery", + false, "COMMIT", "ks:40-80,ks:80-") + + // 2nd session to write something on different primary key, this should continue to work. + conn2 := vtgateConn.Session("", nil) + execute(qCtx, t, conn2, "insert into twopc_user(id, name) values(190001,'mysession')") + execute(qCtx, t, conn2, "insert into twopc_user(id, name) values(290001,'mysession')") + + // Below check ensures that the transaction is resolved by the resolver on receiving unresolved transaction signal from MM. + logTable := retrieveTransitionsWithTimeout(t, ch, tableMap, dtMap, 2*time.Second) + expectations := map[string][]string{ + "ks.dt_participant:-40": { + "insert:[VARCHAR(\"dtid-1\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"40-80\")]", + "insert:[VARCHAR(\"dtid-1\") INT64(2) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(1) VARCHAR(\"ks\") VARCHAR(\"40-80\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(2) VARCHAR(\"ks\") VARCHAR(\"80-\")]", + }, + "ks.dt_state:-40": { + "insert:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + "update:[VARCHAR(\"dtid-1\") VARCHAR(\"COMMIT\")]", + "delete:[VARCHAR(\"dtid-1\") VARCHAR(\"COMMIT\")]", + }, + "ks.redo_state:40-80": { + "insert:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + "delete:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + }, + "ks.redo_state:80-": { + "insert:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + "delete:[VARCHAR(\"dtid-1\") VARCHAR(\"PREPARE\")]", + }, + "ks.redo_statement:40-80": { + "insert:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (8, 'bar')\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (8, 'bar')\")]", + }, + "ks.redo_statement:80-": { + "insert:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (7, 'foo')\")]", + "delete:[VARCHAR(\"dtid-1\") INT64(1) BLOB(\"insert into twopc_user(id, `name`) values (7, 'foo')\")]", + }, + "ks.twopc_user:-40": { + "insert:[INT64(290001) VARCHAR(\"mysession\")]", + }, + "ks.twopc_user:40-80": { + "insert:[INT64(190001) VARCHAR(\"mysession\")]", + "insert:[INT64(8) VARCHAR(\"bar\")]", + }, + "ks.twopc_user:80-": { + "insert:[INT64(7) VARCHAR(\"foo\")]", + }, + } + assert.Equal(t, expectations, logTable, + "mismatch expected: \n got: %s, want: %s", prettyPrint(logTable), prettyPrint(expectations)) +} + // TestSemiSyncRequiredWithTwoPC tests that semi-sync is required when using two-phase commit. func TestSemiSyncRequiredWithTwoPC(t *testing.T) { // cleanup all the old data. diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 1cd58b31cc3..7151bf7b834 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -2807,17 +2807,6 @@ func TestExecutorRejectTwoPC(t *testing.T) { "1|2|0"), }, expErr: "VT12001: unsupported: atomic distributed transaction commit with consistent lookup vindex", - }, { - sqls: []string{ - `savepoint x`, - `insert into user_extra(user_id) values (1)`, - `insert into user_extra(user_id) values (3)`, - }, - testRes: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|unq_col|unchanged", "int64|int64|int64"), - "1|2|0"), - }, - expErr: "VT12001: unsupported: atomic distributed transaction commit with savepoint", }, } diff --git a/go/vt/vtgate/tx_conn.go b/go/vt/vtgate/tx_conn.go index 372c3fc6164..f7dc472df2e 100644 --- a/go/vt/vtgate/tx_conn.go +++ b/go/vt/vtgate/tx_conn.go @@ -280,9 +280,6 @@ func (txc *TxConn) checkValidCondition(session *SafeSession) error { if len(session.PreSessions) != 0 || len(session.PostSessions) != 0 { return vterrors.VT12001("atomic distributed transaction commit with consistent lookup vindex") } - if len(session.GetSavepoints()) != 0 { - return vterrors.VT12001("atomic distributed transaction commit with savepoint") - } if session.GetInReservedConn() { return vterrors.VT12001("atomic distributed transaction commit with system settings") } diff --git a/go/vt/vttablet/endtoend/savepoint_test.go b/go/vt/vttablet/endtoend/savepoint_test.go index 74572f2376f..90cf8fbd547 100644 --- a/go/vt/vttablet/endtoend/savepoint_test.go +++ b/go/vt/vttablet/endtoend/savepoint_test.go @@ -103,7 +103,7 @@ func TestSavepointInTransactionWithRelease(t *testing.T) { diff int }{{ tag: "Queries/Histograms/Savepoint/Count", - diff: 1, + diff: 2, // savepoint a (post-begin) and savepoint b }, { tag: "Queries/Histograms/Release/Count", diff: 1, diff --git a/go/vt/vttablet/tabletserver/dt_executor.go b/go/vt/vttablet/tabletserver/dt_executor.go index 1aaf75edc9e..823751df638 100644 --- a/go/vt/vttablet/tabletserver/dt_executor.go +++ b/go/vt/vttablet/tabletserver/dt_executor.go @@ -69,7 +69,8 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { } // If no queries were executed, we just rollback. - if len(conn.TxProperties().Queries) == 0 { + queries := conn.TxProperties().GetQueries() + if len(queries) == 0 { dte.te.txPool.RollbackAndRelease(dte.ctx, conn) return nil } @@ -90,7 +91,7 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { // Fail Prepare if any query rule disallows it. // This could be due to ongoing cutover happening in vreplication workflow // regarding OnlineDDL or MoveTables. - for _, query := range conn.TxProperties().Queries { + for _, query := range queries { qr := dte.qe.queryRuleSources.FilterByPlan(query.Sql, 0, query.Tables...) if qr != nil { act, _, _, _ := qr.GetAction("", "", nil, sqlparser.MarginComments{}) @@ -110,7 +111,7 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { // Recheck the rules. As some prepare transaction could have passed the first check. // If they are put in the prepared pool, then vreplication workflow waits. // This check helps reject the prepare that came later. - for _, query := range conn.TxProperties().Queries { + for _, query := range queries { qr := dte.qe.queryRuleSources.FilterByPlan(query.Sql, 0, query.Tables...) if qr != nil { act, _, _, _ := qr.GetAction("", "", nil, sqlparser.MarginComments{}) @@ -130,7 +131,7 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { } return dte.inTransaction(func(localConn *StatefulConnection) error { - return dte.te.twoPC.SaveRedo(dte.ctx, localConn, dtid, conn.TxProperties().Queries) + return dte.te.twoPC.SaveRedo(dte.ctx, localConn, dtid, queries) }) } @@ -312,7 +313,7 @@ func (dte *DTExecutor) ReadTwopcInflight() (distributed []*tx.DistributedTx, pre } func (dte *DTExecutor) inTransaction(f func(*StatefulConnection) error) error { - conn, _, _, err := dte.te.txPool.Begin(dte.ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := dte.te.txPool.Begin(dte.ctx, &querypb.ExecuteOptions{}, false, 0, nil) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/planbuilder/plan.go b/go/vt/vttablet/tabletserver/planbuilder/plan.go index f18ea59a714..db17500ae19 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/plan.go +++ b/go/vt/vttablet/tabletserver/planbuilder/plan.go @@ -205,10 +205,7 @@ func (plan *Plan) TableNames() (names []string) { func Build(env *vtenv.Environment, statement sqlparser.Statement, tables map[string]*schema.Table, dbName string, viewsEnabled bool) (plan *Plan, err error) { switch stmt := statement.(type) { case *sqlparser.Union: - plan, err = &Plan{ - PlanID: PlanSelect, - FullQuery: GenerateLimitQuery(stmt), - }, nil + plan = &Plan{PlanID: PlanSelect, FullQuery: GenerateLimitQuery(stmt)} case *sqlparser.Select: plan, err = analyzeSelect(env, stmt, tables) case *sqlparser.Insert: @@ -218,39 +215,39 @@ func Build(env *vtenv.Environment, statement sqlparser.Statement, tables map[str case *sqlparser.Delete: plan, err = analyzeDelete(stmt, tables) case *sqlparser.Set: - plan, err = analyzeSet(stmt), nil + plan = analyzeSet(stmt) case sqlparser.DDLStatement: plan, err = analyzeDDL(stmt) case *sqlparser.AlterMigration: - plan, err = &Plan{PlanID: PlanAlterMigration, FullStmt: stmt}, nil + plan = &Plan{PlanID: PlanAlterMigration, FullStmt: stmt} case *sqlparser.RevertMigration: - plan, err = &Plan{PlanID: PlanRevertMigration, FullStmt: stmt}, nil + plan = &Plan{PlanID: PlanRevertMigration, FullStmt: stmt} case *sqlparser.ShowMigrationLogs: - plan, err = &Plan{PlanID: PlanShowMigrationLogs, FullStmt: stmt}, nil + plan = &Plan{PlanID: PlanShowMigrationLogs, FullStmt: stmt} case *sqlparser.ShowThrottledApps: - plan, err = &Plan{PlanID: PlanShowThrottledApps, FullStmt: stmt}, nil + plan = &Plan{PlanID: PlanShowThrottledApps, FullStmt: stmt} case *sqlparser.ShowThrottlerStatus: - plan, err = &Plan{PlanID: PlanShowThrottlerStatus, FullStmt: stmt}, nil + plan = &Plan{PlanID: PlanShowThrottlerStatus, FullStmt: stmt} case *sqlparser.Show: plan, err = analyzeShow(stmt, dbName) case *sqlparser.Analyze, sqlparser.Explain: - plan, err = &Plan{PlanID: PlanOtherRead}, nil + plan = &Plan{PlanID: PlanOtherRead} case *sqlparser.OtherAdmin: - plan, err = &Plan{PlanID: PlanOtherAdmin}, nil + plan = &Plan{PlanID: PlanOtherAdmin} case *sqlparser.Savepoint: - plan, err = &Plan{PlanID: PlanSavepoint}, nil + plan = &Plan{PlanID: PlanSavepoint, FullStmt: stmt} case *sqlparser.Release: - plan, err = &Plan{PlanID: PlanRelease}, nil + plan = &Plan{PlanID: PlanRelease} case *sqlparser.SRollback: - plan, err = &Plan{PlanID: PlanSRollback}, nil + plan = &Plan{PlanID: PlanSRollback, FullStmt: stmt} case *sqlparser.Load: - plan, err = &Plan{PlanID: PlanLoad}, nil + plan = &Plan{PlanID: PlanLoad} case *sqlparser.Flush: plan, err = analyzeFlush(stmt, tables) case *sqlparser.UnlockTables: - plan, err = &Plan{PlanID: PlanUnlockTables}, nil + plan = &Plan{PlanID: PlanUnlockTables} case *sqlparser.CallProc: - plan, err = &Plan{PlanID: PlanCallProc, FullQuery: GenerateFullQuery(stmt)}, nil + plan = &Plan{PlanID: PlanCallProc, FullQuery: GenerateFullQuery(stmt)} default: return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "invalid SQL") } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 1318f2b90ab..d06953b3241 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -235,7 +235,7 @@ func (qre *QueryExecutor) execAutocommit(f func(conn *StatefulConnection) (*sqlt return nil, errTxThrottled } - conn, _, _, err := qre.tsv.te.txPool.Begin(qre.ctx, qre.options, false, 0, nil, qre.setting) + conn, _, _, err := qre.tsv.te.txPool.Begin(qre.ctx, qre.options, false, 0, qre.setting) if err != nil { return nil, err @@ -249,7 +249,7 @@ func (qre *QueryExecutor) execAsTransaction(f func(conn *StatefulConnection) (*s if qre.tsv.txThrottler.Throttle(qre.tsv.getPriorityFromOptions(qre.options), qre.options.GetWorkloadName()) { return nil, errTxThrottled } - conn, beginSQL, _, err := qre.tsv.te.txPool.Begin(qre.ctx, qre.options, false, 0, nil, qre.setting) + conn, beginSQL, _, err := qre.tsv.te.txPool.Begin(qre.ctx, qre.options, false, 0, qre.setting) if err != nil { return nil, err } @@ -287,8 +287,12 @@ func (qre *QueryExecutor) txConnExec(conn *StatefulConnection) (*sqltypes.Result return qre.execDMLLimit(conn) case p.PlanOtherRead, p.PlanOtherAdmin, p.PlanFlush, p.PlanUnlockTables: return qre.execStatefulConn(conn, qre.query, true) - case p.PlanSavepoint, p.PlanRelease, p.PlanSRollback: - return qre.execStatefulConn(conn, qre.query, true) + case p.PlanSavepoint: + return qre.execSavepointQuery(conn, qre.query, qre.plan.FullStmt) + case p.PlanSRollback: + return qre.execRollbackToSavepoint(conn, qre.query, qre.plan.FullStmt) + case p.PlanRelease: + return qre.execTxQuery(conn, qre.query, false) case p.PlanSelect, p.PlanSelectImpossible, p.PlanShow, p.PlanSelectLockFunc: maxrows := qre.getSelectLimit() qre.bindVars["#maxLimit"] = sqltypes.Int64BindVariable(maxrows + 1) @@ -790,6 +794,11 @@ func (qre *QueryExecutor) txFetch(conn *StatefulConnection, record bool) (*sqlty if err != nil { return nil, err } + return qre.execTxQuery(conn, sql, record) +} + +// execTxQuery executes the query provided and record in Tx Property if record is true. +func (qre *QueryExecutor) execTxQuery(conn *StatefulConnection, sql string, record bool) (*sqltypes.Result, error) { qr, err := qre.execStatefulConn(conn, sql, true) if err != nil { return nil, err @@ -801,6 +810,40 @@ func (qre *QueryExecutor) txFetch(conn *StatefulConnection, record bool) (*sqlty return qr, nil } +// execTxQuery executes the query provided and record in Tx Property if record is true. +func (qre *QueryExecutor) execSavepointQuery(conn *StatefulConnection, sql string, ast sqlparser.Statement) (*sqltypes.Result, error) { + qr, err := qre.execStatefulConn(conn, sql, true) + if err != nil { + return nil, err + } + + // Only record successful queries. + sp, ok := ast.(*sqlparser.Savepoint) + if !ok { + return nil, vterrors.VT13001("expected to get a savepoint statement") + } + conn.TxProperties().RecordSavePointDetail(sp.Name.String()) + + return qr, nil +} + +// execTxQuery executes the query provided and record in Tx Property if record is true. +func (qre *QueryExecutor) execRollbackToSavepoint(conn *StatefulConnection, sql string, ast sqlparser.Statement) (*sqltypes.Result, error) { + qr, err := qre.execStatefulConn(conn, sql, true) + if err != nil { + return nil, err + } + + // Only record successful queries. + sp, ok := ast.(*sqlparser.SRollback) + if !ok { + return nil, vterrors.VT13001("expected to get a rollback statement") + } + + _ = conn.TxProperties().RollbackToSavepoint(sp.Name.String()) + return qr, nil +} + func (qre *QueryExecutor) generateFinalSQL(parsedQuery *sqlparser.ParsedQuery, bindVars map[string]*querypb.BindVariable) (string, string, error) { query, err := parsedQuery.GenerateQuery(bindVars, nil) if err != nil { diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 75cd5dec08e..ad65f61cbfc 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -510,7 +510,14 @@ func (tsv *TabletServer) Begin(ctx context.Context, target *querypb.Target, opti return tsv.begin(ctx, target, nil, 0, nil, options) } -func (tsv *TabletServer) begin(ctx context.Context, target *querypb.Target, savepointQueries []string, reservedID int64, settings []string, options *querypb.ExecuteOptions) (state queryservice.TransactionState, err error) { +func (tsv *TabletServer) begin( + ctx context.Context, + target *querypb.Target, + postBeginQueries []string, + reservedID int64, + settings []string, + options *querypb.ExecuteOptions, +) (state queryservice.TransactionState, err error) { state.TabletAlias = tsv.alias err = tsv.execRequest( ctx, tsv.loadQueryTimeoutWithOptions(options), @@ -528,12 +535,43 @@ func (tsv *TabletServer) begin(ctx context.Context, target *querypb.Target, save return err } } - transactionID, beginSQL, sessionStateChanges, err := tsv.te.Begin(ctx, savepointQueries, reservedID, connSetting, options) + transactionID, beginSQL, sessionStateChanges, err := tsv.te.Begin(ctx, reservedID, connSetting, options) state.TransactionID = transactionID state.SessionStateChanges = sessionStateChanges logStats.TransactionID = transactionID logStats.ReservedID = reservedID + if err != nil { + return err + } + + targetType, err := tsv.resolveTargetType(ctx, target) + if err != nil { + return err + } + for _, query := range postBeginQueries { + plan, err := tsv.qe.GetPlan(ctx, logStats, query, true) + if err != nil { + return err + } + + qre := &QueryExecutor{ + ctx: ctx, + query: query, + connID: transactionID, + options: options, + plan: plan, + logStats: logStats, + tsv: tsv, + targetTabletType: targetType, + setting: connSetting, + } + _, err = qre.Execute() + if err != nil { + return err + } + } + // Record the actual statements that were executed in the logStats. // If nothing was actually executed, don't count the operation in // the tablet metrics, and clear out the logStats Method so that @@ -552,7 +590,7 @@ func (tsv *TabletServer) begin(ctx context.Context, target *querypb.Target, save return err }, ) - return state, err + return } func (tsv *TabletServer) getPriorityFromOptions(options *querypb.ExecuteOptions) int { @@ -978,7 +1016,7 @@ func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Targ } // BeginExecute combines Begin and Execute. -func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { +func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Target, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { // Disable hot row protection in case of reserve connection. if tsv.enableHotRowProtection && reservedID == 0 { @@ -991,7 +1029,7 @@ func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Targe } } - state, err := tsv.begin(ctx, target, preQueries, reservedID, nil, options) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, options) if err != nil { return state, nil, err } @@ -1004,14 +1042,14 @@ func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Targe func (tsv *TabletServer) BeginStreamExecute( ctx context.Context, target *querypb.Target, - preQueries []string, + postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (queryservice.TransactionState, error) { - state, err := tsv.begin(ctx, target, preQueries, reservedID, nil, options) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, options) if err != nil { return state, err } @@ -1237,8 +1275,8 @@ func (tsv *TabletServer) VStreamResults(ctx context.Context, target *querypb.Tar } // ReserveBeginExecute implements the QueryService interface -func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { - state, result, err = tsv.beginExecuteWithSettings(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) +func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { + state, result, err = tsv.beginExecuteWithSettings(ctx, target, settings, postBeginQueries, sql, bindVariables, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. // This is specially for get_lock function call from vtgate that needs a reserved connection. @@ -1266,12 +1304,35 @@ func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *queryp return err } defer tsv.stats.QueryTimingsByTabletType.Record(targetType.String(), time.Now()) - connID, sessionStateChanges, err = tsv.te.ReserveBegin(ctx, options, preQueries, postBeginQueries) + connID, sessionStateChanges, err = tsv.te.ReserveBegin(ctx, options, settings) + logStats.TransactionID = connID + logStats.ReservedID = connID if err != nil { return err } - logStats.TransactionID = connID - logStats.ReservedID = connID + + for _, query := range postBeginQueries { + plan, err := tsv.qe.GetPlan(ctx, logStats, query, true) + if err != nil { + return err + } + + qre := &QueryExecutor{ + ctx: ctx, + query: query, + connID: connID, + options: options, + plan: plan, + logStats: logStats, + tsv: tsv, + targetTabletType: targetType, + } + _, err = qre.Execute() + if err != nil { + return err + } + } + return nil }, ) @@ -1292,13 +1353,13 @@ func (tsv *TabletServer) ReserveBeginStreamExecute( ctx context.Context, target *querypb.Target, settings []string, - savepointQueries []string, + postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedTransactionState, err error) { - txState, err := tsv.begin(ctx, target, savepointQueries, 0, settings, options) + txState, err := tsv.begin(ctx, target, postBeginQueries, 0, settings, options) if err != nil { return txToReserveState(txState), err } @@ -1308,9 +1369,9 @@ func (tsv *TabletServer) ReserveBeginStreamExecute( } // ReserveExecute implements the QueryService interface -func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state queryservice.ReservedState, result *sqltypes.Result, err error) { +func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Target, settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state queryservice.ReservedState, result *sqltypes.Result, err error) { - result, err = tsv.executeWithSettings(ctx, target, preQueries, sql, bindVariables, transactionID, options) + result, err = tsv.executeWithSettings(ctx, target, settings, sql, bindVariables, transactionID, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. // This is specially for get_lock function call from vtgate that needs a reserved connection. @@ -1335,7 +1396,7 @@ func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Tar return err } defer tsv.stats.QueryTimingsByTabletType.Record(targetType.String(), time.Now()) - state.ReservedID, err = tsv.te.Reserve(ctx, options, transactionID, preQueries) + state.ReservedID, err = tsv.te.Reserve(ctx, options, transactionID, settings) if err != nil { return err } @@ -1357,14 +1418,14 @@ func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Tar func (tsv *TabletServer) ReserveStreamExecute( ctx context.Context, target *querypb.Target, - preQueries []string, + settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedState, err error) { - return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, preQueries, options, callback) + return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, settings, options, callback) } // Release implements the QueryService interface @@ -1404,8 +1465,8 @@ func (tsv *TabletServer) executeWithSettings(ctx context.Context, target *queryp return tsv.execute(ctx, target, sql, bindVariables, transactionID, 0, settings, options) } -func (tsv *TabletServer) beginExecuteWithSettings(ctx context.Context, target *querypb.Target, settings []string, savepointQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { - txState, err := tsv.begin(ctx, target, savepointQueries, 0, settings, options) +func (tsv *TabletServer) beginExecuteWithSettings(ctx context.Context, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { + txState, err := tsv.begin(ctx, target, postBeginQueries, 0, settings, options) if err != nil { return txToReserveState(txState), nil, err } diff --git a/go/vt/vttablet/tabletserver/tx/api.go b/go/vt/vttablet/tabletserver/tx/api.go index 48a1cc1107a..a7bc4389b89 100644 --- a/go/vt/vttablet/tabletserver/tx/api.go +++ b/go/vt/vttablet/tabletserver/tx/api.go @@ -21,19 +21,19 @@ import ( "strings" "time" + "vitess.io/vitess/go/slice" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" ) type ( // ConnID as type int64 ConnID = int64 - // DTID as type string DTID = string - // EngineStateMachine is used to control the state the transactional engine - // whether new connections and/or transactions are allowed or not. EngineStateMachine interface { @@ -42,10 +42,8 @@ type ( AcceptReadOnly() error StopGently() } - // ReleaseReason as type int ReleaseReason int - // Properties contains all information that is related to the currently running // transaction on the connection Properties struct { @@ -60,12 +58,17 @@ type ( Stats *servenv.TimingsWrapper } -) -type Query struct { - Sql string - Tables []string -} + // Query contains the query and involved tables executed inside transaction. + // A savepoint is represented by having only the Savepoint field set. + // This is used to rollback to a specific savepoint. + // The query log on commit, does not need to store the savepoint. + Query struct { + Savepoint string + Sql string + Tables []string + } +) const ( // TxClose - connection released on close. @@ -130,6 +133,30 @@ func (p *Properties) RecordQueryDetail(query string, tables []string) { }) } +// RecordQueryDetail records the query and tables against this transaction. +func (p *Properties) RecordSavePointDetail(savepoint string) { + if p == nil { + return + } + p.Queries = append(p.Queries, Query{ + Savepoint: savepoint, + }) +} + +func (p *Properties) RollbackToSavepoint(savepoint string) error { + if p == nil { + return nil + } + for i, query := range p.Queries { + if query.Savepoint == savepoint { + p.Queries = p.Queries[:i] + return nil + } + } + + return vterrors.VT13001(fmt.Sprintf("savepoint %s not found", savepoint)) +} + // RecordQuery records the query and extract tables against this transaction. func (p *Properties) RecordQuery(query string, parser *sqlparser.Parser) { if p == nil { @@ -181,3 +208,12 @@ func (p *Properties) String(sanitize bool, parser *sqlparser.Parser) string { printQueries(), ) } + +func (p *Properties) GetQueries() []Query { + if p == nil { + return nil + } + return slice.Filter(p.Queries, func(q Query) bool { + return q.Sql != "" + }) +} diff --git a/go/vt/vttablet/tabletserver/tx/api_test.go b/go/vt/vttablet/tabletserver/tx/api_test.go new file mode 100644 index 00000000000..cefb04c0391 --- /dev/null +++ b/go/vt/vttablet/tabletserver/tx/api_test.go @@ -0,0 +1,62 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tx + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/test/utils" +) + +/* + TestRollbackToSavePointQueryDetails tests the rollback to savepoint query details + +s1 +q1 +s2 +r1 +q2 +q3 +s3 +s4 +q4 +q5 +r2 -- error +r4 +*/ +func TestRollbackToSavePointQueryDetails(t *testing.T) { + p := &Properties{} + p.RecordSavePointDetail("s1") + p.RecordQueryDetail("select 1", nil) + p.RecordSavePointDetail("s2") + require.NoError(t, p.RollbackToSavepoint("s1")) + p.RecordQueryDetail("select 2", nil) + p.RecordQueryDetail("select 3", nil) + p.RecordSavePointDetail("s3") + p.RecordSavePointDetail("s4") + p.RecordQueryDetail("select 4", nil) + p.RecordQueryDetail("select 5", nil) + require.ErrorContains(t, p.RollbackToSavepoint("s2"), "savepoint s2 not found") + require.NoError(t, p.RollbackToSavepoint("s4")) + + utils.MustMatch(t, p.GetQueries(), []Query{ + {Sql: "select 2"}, + {Sql: "select 3"}, + }) +} diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index 42bec29dfa3..d581fb79ae4 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -270,7 +270,7 @@ func (te *TxEngine) isTxPoolAvailable(addToWaitGroup func(int)) error { // statement(s) used to execute the begin (if any). // // Subsequent statements can access the connection through the transaction id. -func (te *TxEngine) Begin(ctx context.Context, savepointQueries []string, reservedID int64, setting *smartconnpool.Setting, options *querypb.ExecuteOptions) (int64, string, string, error) { +func (te *TxEngine) Begin(ctx context.Context, reservedID int64, setting *smartconnpool.Setting, options *querypb.ExecuteOptions) (int64, string, string, error) { span, ctx := trace.NewSpan(ctx, "TxEngine.Begin") defer span.Finish() @@ -285,7 +285,7 @@ func (te *TxEngine) Begin(ctx context.Context, savepointQueries []string, reserv } defer te.beginRequests.Done() - conn, beginSQL, sessionStateChanges, err := te.txPool.Begin(ctx, options, te.state == AcceptingReadOnly, reservedID, savepointQueries, setting) + conn, beginSQL, sessionStateChanges, err := te.txPool.Begin(ctx, options, te.state == AcceptingReadOnly, reservedID, setting) if err != nil { return 0, "", "", err } @@ -516,7 +516,7 @@ func (te *TxEngine) checkErrorAndMarkFailed(ctx context.Context, dtid string, re // Update the state of the transaction in the redo log. // Retryable Error: Update the message with error message. // Non-retryable Error: Along with message, update the state as RedoStateFailed. - conn, _, _, err := te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) if err != nil { log.Errorf("markFailed: Begin failed for dtid %s: %v", dtid, err) return @@ -608,7 +608,7 @@ func (te *TxEngine) stopTransactionWatcher() { } // ReserveBegin creates a reserved connection, and in it opens a transaction -func (te *TxEngine) ReserveBegin(ctx context.Context, options *querypb.ExecuteOptions, preQueries []string, savepointQueries []string) (int64, string, error) { +func (te *TxEngine) ReserveBegin(ctx context.Context, options *querypb.ExecuteOptions, preQueries []string) (int64, string, error) { span, ctx := trace.NewSpan(ctx, "TxEngine.ReserveBegin") defer span.Finish() err := te.isTxPoolAvailable(te.beginRequests.Add) @@ -622,7 +622,7 @@ func (te *TxEngine) ReserveBegin(ctx context.Context, options *querypb.ExecuteOp return 0, "", err } defer conn.UnlockUpdateTime() - _, sessionStateChanges, err := te.txPool.begin(ctx, options, te.state == AcceptingReadOnly, conn, savepointQueries) + _, sessionStateChanges, err := te.txPool.begin(ctx, options, te.state == AcceptingReadOnly, conn) if err != nil { conn.Close() conn.Release(tx.ConnInitFail) @@ -720,7 +720,7 @@ func (te *TxEngine) beginNewDbaConnection(ctx context.Context) (*StatefulConnect env: te.env, } - _, _, err = te.txPool.begin(ctx, nil, false, sc, nil) + _, _, err = te.txPool.begin(ctx, nil, false, sc) return sc, err } diff --git a/go/vt/vttablet/tabletserver/tx_engine_test.go b/go/vt/vttablet/tabletserver/tx_engine_test.go index a9958525587..be2531f1a41 100644 --- a/go/vt/vttablet/tabletserver/tx_engine_test.go +++ b/go/vt/vttablet/tabletserver/tx_engine_test.go @@ -62,11 +62,11 @@ func TestTxEngineClose(t *testing.T) { // Normal close with timeout wait. te.AcceptReadWrite() - c, beginSQL, _, err := te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, beginSQL, _, err := te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) require.Equal(t, "begin", beginSQL) c.Unlock() - c, beginSQL, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, beginSQL, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) require.Equal(t, "begin", beginSQL) c.Unlock() @@ -78,7 +78,7 @@ func TestTxEngineClose(t *testing.T) { // Immediate close. te.AcceptReadOnly() - c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) if err != nil { t.Fatal(err) } @@ -90,7 +90,7 @@ func TestTxEngineClose(t *testing.T) { // Normal close with short grace period. te.shutdownGracePeriod = 25 * time.Millisecond te.AcceptReadWrite() - c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) c.Unlock() start = time.Now() @@ -101,7 +101,7 @@ func TestTxEngineClose(t *testing.T) { // Normal close with short grace period, but pool gets empty early. te.shutdownGracePeriod = 25 * time.Millisecond te.AcceptReadWrite() - c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) c.Unlock() go func() { @@ -117,7 +117,7 @@ func TestTxEngineClose(t *testing.T) { // Immediate close, but connection is in use. te.AcceptReadOnly() - c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + c, _, _, err = te.txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) go func() { time.Sleep(100 * time.Millisecond) @@ -138,7 +138,7 @@ func TestTxEngineClose(t *testing.T) { te.AcceptReadWrite() _, err = te.Reserve(ctx, &querypb.ExecuteOptions{}, 0, nil) require.NoError(t, err) - _, _, err = te.ReserveBegin(ctx, &querypb.ExecuteOptions{}, nil, nil) + _, _, err = te.ReserveBegin(ctx, &querypb.ExecuteOptions{}, nil) require.NoError(t, err) start = time.Now() te.Close() @@ -159,11 +159,11 @@ func TestTxEngineBegin(t *testing.T) { for _, exec := range []func() (int64, string, error){ func() (int64, string, error) { - tx, _, schemaStateChanges, err := te.Begin(ctx, nil, 0, nil, &querypb.ExecuteOptions{}) + tx, _, schemaStateChanges, err := te.Begin(ctx, 0, nil, &querypb.ExecuteOptions{}) return tx, schemaStateChanges, err }, func() (int64, string, error) { - return te.ReserveBegin(ctx, &querypb.ExecuteOptions{}, nil, nil) + return te.ReserveBegin(ctx, &querypb.ExecuteOptions{}, nil) }, } { te.AcceptReadOnly() @@ -204,7 +204,7 @@ func TestTxEngineRenewFails(t *testing.T) { te := NewTxEngine(tabletenv.NewEnv(vtenv.NewTestEnv(), cfg, "TabletServerTest"), nil) te.AcceptReadOnly() options := &querypb.ExecuteOptions{} - connID, _, err := te.ReserveBegin(ctx, options, nil, nil) + connID, _, err := te.ReserveBegin(ctx, options, nil) require.NoError(t, err) conn, err := te.txPool.GetAndLock(connID, "for test") @@ -559,7 +559,7 @@ func startTx(te *TxEngine, writeTransaction bool) error { } else { options.TransactionIsolation = querypb.ExecuteOptions_CONSISTENT_SNAPSHOT_READ_ONLY } - _, _, _, err := te.Begin(context.Background(), nil, 0, nil, options) + _, _, _, err := te.Begin(context.Background(), 0, nil, options) return err } @@ -577,7 +577,7 @@ func TestTxEngineFailReserve(t *testing.T) { _, err := te.Reserve(ctx, options, 0, nil) assert.EqualError(t, err, "tx engine can't accept new connections in state NotServing") - _, _, err = te.ReserveBegin(ctx, options, nil, nil) + _, _, err = te.ReserveBegin(ctx, options, nil) assert.EqualError(t, err, "tx engine can't accept new connections in state NotServing") te.AcceptReadOnly() @@ -586,14 +586,14 @@ func TestTxEngineFailReserve(t *testing.T) { _, err = te.Reserve(ctx, options, 0, []string{"dummy_query"}) assert.EqualError(t, err, "unknown error: failed executing dummy_query (errno 1105) (sqlstate HY000) during query: dummy_query") - _, _, err = te.ReserveBegin(ctx, options, []string{"dummy_query"}, nil) + _, _, err = te.ReserveBegin(ctx, options, []string{"dummy_query"}) assert.EqualError(t, err, "unknown error: failed executing dummy_query (errno 1105) (sqlstate HY000) during query: dummy_query") nonExistingID := int64(42) _, err = te.Reserve(ctx, options, nonExistingID, nil) assert.EqualError(t, err, "transaction 42: not found (potential transaction timeout)") - txID, _, _, err := te.Begin(ctx, nil, 0, nil, options) + txID, _, _, err := te.Begin(ctx, 0, nil, options) require.NoError(t, err) conn, err := te.txPool.GetAndLock(txID, "for test") require.NoError(t, err) diff --git a/go/vt/vttablet/tabletserver/tx_pool.go b/go/vt/vttablet/tabletserver/tx_pool.go index 52f356e0cca..6d1f1dec3c2 100644 --- a/go/vt/vttablet/tabletserver/tx_pool.go +++ b/go/vt/vttablet/tabletserver/tx_pool.go @@ -230,7 +230,7 @@ func (tp *TxPool) Rollback(ctx context.Context, txConn *StatefulConnection) erro // the statements (if any) executed to initiate the transaction. In autocommit // mode the statement will be "". // The connection returned is locked for the callee and its responsibility is to unlock the connection. -func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, reservedID int64, savepointQueries []string, setting *smartconnpool.Setting) (*StatefulConnection, string, string, error) { +func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, reservedID int64, setting *smartconnpool.Setting) (*StatefulConnection, string, string, error) { span, ctx := trace.NewSpan(ctx, "TxPool.Begin") defer span.Finish() @@ -262,7 +262,7 @@ func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, re if err != nil { return nil, "", "", err } - sql, sessionStateChanges, err := tp.begin(ctx, options, readOnly, conn, savepointQueries) + sql, sessionStateChanges, err := tp.begin(ctx, options, readOnly, conn) if err != nil { conn.Close() conn.Release(tx.ConnInitFail) @@ -271,16 +271,14 @@ func (tp *TxPool) Begin(ctx context.Context, options *querypb.ExecuteOptions, re return conn, sql, sessionStateChanges, nil } -func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, conn *StatefulConnection, savepointQueries []string) (string, string, error) { +func (tp *TxPool) begin(ctx context.Context, options *querypb.ExecuteOptions, readOnly bool, conn *StatefulConnection) (string, string, error) { immediateCaller := callerid.ImmediateCallerIDFromContext(ctx) effectiveCaller := callerid.EffectiveCallerIDFromContext(ctx) - beginQueries, autocommit, sessionStateChanges, err := createTransaction(ctx, options, conn, readOnly, savepointQueries) + beginQueries, autocommit, sessionStateChanges, err := createTransaction(ctx, options, conn, readOnly) if err != nil { return "", "", err } - conn.txProps = tp.NewTxProps(immediateCaller, effectiveCaller, autocommit) - return beginQueries, sessionStateChanges, nil } @@ -306,7 +304,6 @@ func createTransaction( options *querypb.ExecuteOptions, conn *StatefulConnection, readOnly bool, - savepointQueries []string, ) (beginQueries string, autocommitTransaction bool, sessionStateChanges string, err error) { switch options.GetTransactionIsolation() { case querypb.ExecuteOptions_CONSISTENT_SNAPSHOT_READ_ONLY: @@ -344,12 +341,6 @@ func createTransaction( default: return "", false, "", vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] don't know how to open a transaction of this type: %v", options.GetTransactionIsolation()) } - - for _, savepoint := range savepointQueries { - if _, err = conn.Exec(ctx, savepoint, 1, false); err != nil { - return "", false, "", err - } - } return } diff --git a/go/vt/vttablet/tabletserver/tx_pool_test.go b/go/vt/vttablet/tabletserver/tx_pool_test.go index e80f1edb17f..c03cac92878 100644 --- a/go/vt/vttablet/tabletserver/tx_pool_test.go +++ b/go/vt/vttablet/tabletserver/tx_pool_test.go @@ -48,7 +48,7 @@ func TestTxPoolExecuteCommit(t *testing.T) { sql := "select 'this is a query'" // begin a transaction and then return the connection - conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) id := conn.ReservedID() @@ -83,7 +83,7 @@ func TestTxPoolExecuteRollback(t *testing.T) { db, txPool, _, closer := setup(t) defer closer() - conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) defer conn.Release(tx.TxRollback) @@ -104,7 +104,7 @@ func TestTxPoolExecuteRollbackOnClosedConn(t *testing.T) { db, txPool, _, closer := setup(t) defer closer() - conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) defer conn.Release(tx.TxRollback) @@ -125,9 +125,9 @@ func TestTxPoolRollbackNonBusy(t *testing.T) { defer closer() // start two transactions, and mark one of them as unused - conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) - conn2, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn2, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) conn2.Unlock() // this marks conn2 as NonBusy @@ -154,7 +154,7 @@ func TestTxPoolTransactionIsolation(t *testing.T) { db, txPool, _, closer := setup(t) defer closer() - c2, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{TransactionIsolation: querypb.ExecuteOptions_READ_COMMITTED}, false, 0, nil, nil) + c2, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{TransactionIsolation: querypb.ExecuteOptions_READ_COMMITTED}, false, 0, nil) require.NoError(t, err) c2.Release(tx.TxClose) @@ -172,7 +172,7 @@ func TestTxPoolAutocommit(t *testing.T) { // to mysql. // This test is meaningful because if txPool.Begin were to send a BEGIN statement to the connection, it will fatal // because is not in the list of expected queries (i.e db.AddQuery hasn't been called). - conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{TransactionIsolation: querypb.ExecuteOptions_AUTOCOMMIT}, false, 0, nil, nil) + conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{TransactionIsolation: querypb.ExecuteOptions_AUTOCOMMIT}, false, 0, nil) require.NoError(t, err) // run a query to see it in the query log @@ -204,7 +204,7 @@ func TestTxPoolBeginWithPoolConnectionError_Errno2006_Transient(t *testing.T) { err := db.WaitForClose(2 * time.Second) require.NoError(t, err) - txConn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + txConn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err, "Begin should have succeeded after the retry in DBConn.Exec()") txConn.Release(tx.TxCommit) } @@ -225,7 +225,7 @@ func primeTxPoolWithConnection(t *testing.T, ctx context.Context) (*fakesqldb.DB // reused by subsequent transactions. db.AddQuery("begin", &sqltypes.Result{}) db.AddQuery("rollback", &sqltypes.Result{}) - txConn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + txConn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) txConn.Release(tx.TxCommit) @@ -248,7 +248,7 @@ func TestTxPoolBeginWithError(t *testing.T) { } ctxWithCallerID := callerid.NewContext(ctx, ef, im) - _, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil, nil) + _, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil) require.Error(t, err) require.Contains(t, err.Error(), "error: rejected") require.Equal(t, vtrpcpb.Code_UNKNOWN, vterrors.Code(err), "wrong error code for Begin error") @@ -270,19 +270,6 @@ func TestTxPoolBeginWithError(t *testing.T) { }, limiter.Actions()) } -func TestTxPoolBeginWithPreQueryError(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - db, txPool, _, closer := setup(t) - defer closer() - db.AddRejectedQuery("pre_query", errRejected) - _, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, []string{"pre_query"}, nil) - require.Error(t, err) - require.Contains(t, err.Error(), "error: rejected") - require.Equal(t, vtrpcpb.Code_UNKNOWN, vterrors.Code(err), "wrong error code for Begin error") -} - func TestTxPoolCancelledContextError(t *testing.T) { // given db, txPool, _, closer := setup(t) @@ -291,7 +278,7 @@ func TestTxPoolCancelledContextError(t *testing.T) { cancel() // when - _, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + _, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) // then require.Error(t, err) @@ -312,12 +299,12 @@ func TestTxPoolWaitTimeoutError(t *testing.T) { defer closer() // lock the only connection in the pool. - conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) defer conn.Unlock() // try locking one more connection. - _, _, _, err = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + _, _, _, err = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) // then require.Error(t, err) @@ -337,7 +324,7 @@ func TestTxPoolRollbackFailIsPassedThrough(t *testing.T) { defer closer() db.AddRejectedQuery("rollback", errRejected) - conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn1, _, _, err := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) _, err = conn1.Exec(ctx, sql, 1, true) @@ -357,7 +344,7 @@ func TestTxPoolGetConnRecentlyRemovedTransaction(t *testing.T) { db, txPool, _, _ := setup(t) defer db.Close() - conn1, _, _, _ := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn1, _, _, _ := txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) id := conn1.ReservedID() conn1.Unlock() txPool.Close() @@ -380,7 +367,7 @@ func TestTxPoolGetConnRecentlyRemovedTransaction(t *testing.T) { params := dbconfigs.New(db.ConnParams()) txPool.Open(params, params, params) - conn1, _, _, _ = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn1, _, _, _ = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) id = conn1.ReservedID() _, err := txPool.Commit(ctx, conn1) require.NoError(t, err) @@ -396,7 +383,7 @@ func TestTxPoolGetConnRecentlyRemovedTransaction(t *testing.T) { txPool.Open(params, params, params) defer txPool.Close() - conn1, _, _, err = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn1, _, _, err = txPool.Begin(ctx, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err, "unable to start transaction: %v", err) conn1.Unlock() id = conn1.ReservedID() @@ -412,7 +399,7 @@ func TestTxPoolCloseKillsStrayTransactions(t *testing.T) { startingStray := txPool.env.Stats().InternalErrors.Counts()["StrayTransactions"] // Start stray transaction. - conn, _, _, err := txPool.Begin(context.Background(), &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(context.Background(), &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) conn.Unlock() @@ -443,7 +430,7 @@ func TestTxTimeoutKillsTransactions(t *testing.T) { ctxWithCallerID := callerid.NewContext(ctx, ef, im) // Start transaction. - conn, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) conn.Unlock() @@ -491,7 +478,7 @@ func TestTxTimeoutDoesNotKillShortLivedTransactions(t *testing.T) { ctxWithCallerID := callerid.NewContext(ctx, ef, im) // Start transaction. - conn, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) conn.Unlock() @@ -526,7 +513,7 @@ func TestTxTimeoutKillsOlapTransactions(t *testing.T) { // Start transaction. conn, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{ Workload: querypb.ExecuteOptions_OLAP, - }, false, 0, nil, nil) + }, false, 0, nil) require.NoError(t, err) conn.Unlock() @@ -561,11 +548,11 @@ func TestTxTimeoutNotEnforcedForZeroLengthTimeouts(t *testing.T) { ctxWithCallerID := callerid.NewContext(ctx, ef, im) // Start transactions. - conn0, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil, nil) + conn0, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, 0, nil) require.NoError(t, err) conn1, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{ Workload: querypb.ExecuteOptions_OLAP, - }, false, 0, nil, nil) + }, false, 0, nil) require.NoError(t, err) conn0.Unlock() conn1.Unlock() @@ -606,7 +593,7 @@ func TestTxTimeoutReservedConn(t *testing.T) { // Start OLAP transaction and return it to pool right away. conn0, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{ Workload: querypb.ExecuteOptions_OLAP, - }, false, 0, nil, nil) + }, false, 0, nil) require.NoError(t, err) // Taint the connection. conn0.Taint(ctxWithCallerID, nil) @@ -648,14 +635,14 @@ func TestTxTimeoutReusedReservedConn(t *testing.T) { // Start OLAP transaction and return it to pool right away. conn0, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{ Workload: querypb.ExecuteOptions_OLAP, - }, false, 0, nil, nil) + }, false, 0, nil) require.NoError(t, err) // Taint the connection. conn0.Taint(ctxWithCallerID, nil) conn0.Unlock() // Reuse underlying connection in an OLTP transaction. - conn1, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, conn0.ReservedID(), nil, nil) + conn1, _, _, err := txPool.Begin(ctxWithCallerID, &querypb.ExecuteOptions{}, false, conn0.ReservedID(), nil) require.NoError(t, err) require.Equal(t, conn1.ReservedID(), conn0.ReservedID()) conn1.Unlock() @@ -786,7 +773,7 @@ func TestTxPoolBeginStatements(t *testing.T) { TransactionIsolation: tc.txIsolationLevel, TransactionAccessMode: tc.txAccessModes, } - conn, beginSQL, _, err := txPool.Begin(ctx, options, tc.readOnly, 0, nil, nil) + conn, beginSQL, _, err := txPool.Begin(ctx, options, tc.readOnly, 0, nil) if tc.expErr != "" { require.Error(t, err) require.Contains(t, err.Error(), tc.expErr)