diff --git a/endtoend/framework/onlineddl_utils.go b/endtoend/framework/onlineddl_utils.go deleted file mode 100644 index 83af4753b6..0000000000 --- a/endtoend/framework/onlineddl_utils.go +++ /dev/null @@ -1 +0,0 @@ -package framework diff --git a/endtoend/framework/onlineddl_utils/onlineddl_utils.go b/endtoend/framework/onlineddl_utils/onlineddl_utils.go new file mode 100644 index 0000000000..144842f3d5 --- /dev/null +++ b/endtoend/framework/onlineddl_utils/onlineddl_utils.go @@ -0,0 +1,47 @@ +package onlineddl_utils + +import ( + "context" + "database/sql" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func VtgateExecDDL(t *testing.T, db *sql.DB, ddlStrategy string, ddl string, expectError string) string { + t.Helper() + + ctx := context.Background() + conn, err := db.Conn(ctx) + assert.NoError(t, err) + defer conn.Close() + + // Read original DDL strategy + var originalStrategy string + err = conn.QueryRowContext(ctx, "SELECT @@ddl_strategy").Scan(&originalStrategy) + assert.NoError(t, err) + + // Set new DDL strategy + _, err = conn.ExecContext(ctx, fmt.Sprintf("SET @@ddl_strategy='%s'", ddlStrategy)) + assert.NoError(t, err) + + // Ensure strategy is reset after execution + defer func() { + _, err := conn.ExecContext(ctx, fmt.Sprintf("SET @@ddl_strategy='%s'", originalStrategy)) + assert.NoError(t, err) + }() + + // Execute DDL + var uuid string + err = conn.QueryRowContext(ctx, ddl).Scan(&uuid) + + // Handle expected errors + if expectError == "" { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.Contains(t, err.Error(), expectError) + } + + return uuid +} diff --git a/endtoend/framework/sql_utils.go b/endtoend/framework/sql_utils.go index f336bc9bd8..57691c3f33 100644 --- a/endtoend/framework/sql_utils.go +++ b/endtoend/framework/sql_utils.go @@ -20,19 +20,26 @@ func ExecNoError(t *testing.T, db *sql.DB, sql string, args ...any) { assert.NoError(t, err) } -func QueryNoError(t *testing.T, db *sql.DB, sql string, args ...any) *sql.Rows { +func ExecWithErrorContains(t *testing.T, db *sql.DB, contains string, sql string, args ...any) { + t.Helper() + log.Println(sql) + _, err := db.Exec(sql, args...) + assert.ErrorContains(t, err, contains) +} + +func Query(t *testing.T, db *sql.DB, sql string, args ...any) (*sql.Rows, error) { t.Helper() log.Println(sql) rows, err := db.Query(sql, args...) - assert.NoError(t, err) - return rows + return rows, err } -func ExecWithErrorContains(t *testing.T, db *sql.DB, contains string, sql string, args ...any) { +func QueryNoError(t *testing.T, db *sql.DB, sql string, args ...any) *sql.Rows { t.Helper() log.Println(sql) - _, err := db.Exec(sql, args...) - assert.ErrorContains(t, err, contains) + rows, err := db.Query(sql, args...) + assert.NoError(t, err) + return rows } func QueryWithErrorContains(t *testing.T, db *sql.DB, contains string, sql string, args ...any) { diff --git a/go/test/endtoend/onlineddl/scheduler/onlineddl_scheduler_test.go b/go/test/endtoend/onlineddl/scheduler/onlineddl_scheduler_test.go index f5eed2a6de..5c05d36046 100644 --- a/go/test/endtoend/onlineddl/scheduler/onlineddl_scheduler_test.go +++ b/go/test/endtoend/onlineddl/scheduler/onlineddl_scheduler_test.go @@ -82,7 +82,6 @@ var ( DBName = "test" cell = "zone1" schemaChangeDirectory = "" - overrideVtctlParams *cluster.VtctlClientParams ) func parseTableName(t *testing.T, sql string) (tableName string) { @@ -259,18 +258,6 @@ func checkArtifactsOfMigration(t *testing.T, uuid, expectArtifacts string) { assert.Equal(t, expectArtifacts, actualArtifacts) } -func checkStateOfVreplication(t *testing.T, uuid, expectState string) { - query, err := sqlparser.ParseAndBind("select state from mysql.vreplication where workflow=%a", - sqltypes.StringBindVariable(uuid), - ) - require.NoError(t, err) - rs := onlineddl.VtgateExecQuery(t, &vtParams, query, "") - require.NotNil(t, rs) - - assert.Equal(t, 1, len(rs.Named().Rows)) - assert.Equal(t, expectState, rs.Named().Rows[0].AsString("state", "")) -} - func WaitForVreplicationState(t *testing.T, vtParams *mysql.ConnParams, uuid string, timeout time.Duration, expectStates ...string) string { query, err := sqlparser.ParseAndBind("select state from mysql.vreplication where workflow=%a", sqltypes.StringBindVariable(uuid),