From f54cc031aff35fd4240f6649ac30a37b04edc961 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 25 Oct 2024 10:30:58 +0200 Subject: [PATCH] wip --- .../planbuilder/operators/SQL_builder.go | 13 +-- .../vtgate/planbuilder/operators/ast_to_op.go | 9 +- .../planbuilder/operators/expressions.go | 5 +- go/vt/vtgate/planbuilder/plan_test.go | 14 +++ .../plancontext/planning_context.go | 94 +++++++++++++++++-- .../testdata/foreignkey_cases.json | 6 +- .../planbuilder/testdata/from_cases.json | 4 +- .../vtgate/planbuilder/testdata/onecase.json | 75 ++++++++++++++- .../testdata/postprocess_cases.json | 2 +- .../planbuilder/testdata/reference_cases.json | 2 +- .../planbuilder/testdata/select_cases.json | 4 +- .../planbuilder/testdata/tpcc_cases.json | 4 +- 12 files changed, 196 insertions(+), 36 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index cd56fed05b2..b8456b5ec1f 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -88,7 +88,7 @@ func (qb *queryBuilder) addTableExpr( } func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { - if _, toBeSkipped := qb.ctx.SkipPredicates[expr]; toBeSkipped { + if qb.ctx.ShouldSkip(expr) { // This is a predicate that was added to the RHS of an ApplyJoin. // The original predicate will be added, so we don't have to add this here return @@ -566,21 +566,16 @@ func buildProjection(op *Projection, qb *queryBuilder) error { func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error { predicates := slice.Map(op.JoinPredicates, func(jc JoinColumn) sqlparser.Expr { // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done - qb.ctx.SkipPredicates[jc.RHSExpr] = nil - + qb.ctx.SkipJoinPredicates(jc.Original.Expr) return jc.Original.Expr }) + pred := sqlparser.AndExpressions(predicates...) err := buildQuery(op.LHS, qb) if err != nil { return err } - // If we are going to add the predicate used in join here - // We should not add the predicate's copy of when it was split into - // two parts. To avoid this, we use the SkipPredicates map. - for _, pred := range op.JoinPredicates { - qb.ctx.SkipPredicates[pred.RHSExpr] = nil - } + qbR := &queryBuilder{ctx: qb.ctx} err = buildQuery(op.RHS, qbR) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index e7628edacc5..328833683a1 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -213,7 +213,14 @@ func createOpFromStmt(ctx *plancontext.PlanningContext, stmt sqlparser.Statement newCtx.VerifyAllFKs = verifyAllFKs newCtx.ParentFKToIgnore = fkToIgnore - return PlanQuery(newCtx, stmt) + query, err := PlanQuery(newCtx, stmt) + if err != nil { + return nil, err + } + + ctx.KeepPredicateInfo(newCtx) + + return query, err } func getOperatorFromTableExpr(ctx *plancontext.PlanningContext, tableExpr sqlparser.TableExpr, onlyTable bool) (ops.Operator, error) { diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 4c03490317e..defe2e506ad 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -52,10 +52,7 @@ func BreakExpressionInLHSandRHS( cursor.Replace(arg) }, nil).(sqlparser.Expr) - if err != nil { - return JoinColumn{}, err - } - ctx.JoinPredicates[expr] = append(ctx.JoinPredicates[expr], rewrittenExpr) + ctx.AddJoinPredicates(expr, rewrittenExpr) col.RHSExpr = rewrittenExpr return } diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 579c271c2ca..7bec5f2b239 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -128,6 +128,20 @@ func (s *planTestSuite) TestForeignKeyPlanning() { s.testFile("foreignkey_cases.json", vschemaWrapper, false) } +// TestForeignKeyPlanning tests the planning of foreign keys in a managed mode by Vitess. +func (s *planTestSuite) TestForeignKeyPlanningOne() { + closer := oprewriters.EnableDebugPrinting() + defer closer() + vschema := loadSchema(s.T(), "vschemas/schema.json", true) + s.setFks(vschema) + vschemaWrapper := &vschemawrapper.VSchemaWrapper{ + V: vschema, + TestBuilder: TestBuilder, + } + + s.testFile("onecase.json", vschemaWrapper, false) +} + func (s *planTestSuite) setFks(vschema *vindexes.VSchema) { if vschema.Keyspaces["sharded_fk_allow"] != nil { // FK from multicol_tbl2 referencing multicol_tbl1 that is shard scoped. diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index d090a593a39..d9ebb95ff85 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -19,6 +19,7 @@ package plancontext import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -27,12 +28,16 @@ type PlanningContext struct { SemTable *semantics.SemTable VSchema VSchema - // here we add all predicates that were created because of a join condition - // e.g. [FROM tblA JOIN tblB ON a.colA = b.colB] will be rewritten to [FROM tblB WHERE :a_colA = b.colB], - // if we assume that tblB is on the RHS of the join. This last predicate in the WHERE clause is added to the - // map below - JoinPredicates map[sqlparser.Expr][]sqlparser.Expr - SkipPredicates map[sqlparser.Expr]any + // joinPredicates maps each original join predicate (key) to a slice of + // variations of the RHS predicates (value). This map is used to handle + // different scenarios in join planning, where the RHS predicates are + // modified to accommodate dependencies + joinPredicates map[sqlparser.Expr][]sqlparser.Expr + + // skipPredicates tracks predicates that should be skipped, typically when + // a join predicate is reverted to its original form during planning. + skipPredicates map[sqlparser.Expr]any + PlannerVersion querypb.ExecuteOptions_PlannerVersion // If we during planning have turned this expression into an argument name, @@ -79,8 +84,8 @@ func CreatePlanningContext(stmt sqlparser.Statement, ReservedVars: reservedVars, SemTable: semTable, VSchema: vschema, - JoinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, - SkipPredicates: map[sqlparser.Expr]any{}, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, PlannerVersion: version, ReservedArguments: map[sqlparser.Expr]string{}, }, nil @@ -116,3 +121,76 @@ func (ctx *PlanningContext) GetArgumentFor(expr sqlparser.Expr, f func() string) ctx.ReservedArguments[expr] = bvName return bvName } + +// ShouldSkip determines if a given expression should be ignored in the SQL output building. +// It checks against expressions that have been marked to be excluded from further processing. +func (ctx *PlanningContext) ShouldSkip(expr sqlparser.Expr) bool { + for k := range ctx.skipPredicates { + if ctx.SemTable.EqualsExpr(expr, k) { + return true + } + } + return false +} + +// AddJoinPredicates associates additional RHS predicates with an existing join predicate. +// This is used to dynamically adjust the RHS predicates based on evolving join conditions. +func (ctx *PlanningContext) AddJoinPredicates(joinPred sqlparser.Expr, predicates ...sqlparser.Expr) { + fn := func(original sqlparser.Expr, rhsExprs []sqlparser.Expr) { + ctx.joinPredicates[original] = append(rhsExprs, predicates...) + } + if ctx.execOnJoinPredicateEqual(joinPred, fn) { + return + } + + // we didn't find an existing entry + ctx.joinPredicates[joinPred] = predicates +} + +// SkipJoinPredicates marks the predicates related to a specific join predicate as irrelevant +// for the current planning stage. This is used when a join has been pushed under a route and +// the original predicate will be used. +func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error { + fn := func(_ sqlparser.Expr, rhsExprs []sqlparser.Expr) { + ctx.skipThesePredicates(rhsExprs...) + } + if ctx.execOnJoinPredicateEqual(joinPred, fn) { + return nil + } + return vterrors.VT13001("predicate does not exist: " + sqlparser.String(joinPred)) +} + +// KeepPredicateInfo transfers join predicate information from another context. +// This is useful when nesting queries, ensuring consistent predicate handling across contexts. +func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) { + for k, v := range other.joinPredicates { + ctx.AddJoinPredicates(k, v...) + } + for expr := range other.skipPredicates { + ctx.skipThesePredicates(expr) + } +} + +// skipThesePredicates is a utility function to exclude certain predicates from SQL building +func (ctx *PlanningContext) skipThesePredicates(preds ...sqlparser.Expr) { +outer: + for _, expr := range preds { + for k := range ctx.skipPredicates { + if ctx.SemTable.EqualsExpr(expr, k) { + // already skipped + continue outer + } + } + ctx.skipPredicates[expr] = nil + } +} + +func (ctx *PlanningContext) execOnJoinPredicateEqual(joinPred sqlparser.Expr, fn func(original sqlparser.Expr, rhsExprs []sqlparser.Expr)) bool { + for key, values := range ctx.joinPredicates { + if ctx.SemTable.EqualsExpr(joinPred, key) { + fn(key, values) + return true + } + } + return false +} diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index c9c0acb3cc7..02d0b047735 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -1132,7 +1132,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where u_tbl9.col9 is null and (u_tbl8.col8) in ::fkc_vals limit 1 lock in share mode", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where u_tbl9.col9 is null and (u_tbl8.col8) in ::fkc_vals and :u_tbl9_col9 = 'foo' limit 1 lock in share mode", "Table": "u_tbl8, u_tbl9" }, { @@ -1208,7 +1208,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals and :u_tbl3_col3 = 'foo' limit 1 lock in share mode", "Table": "u_tbl3, u_tbl4" }, { @@ -1297,7 +1297,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where u_tbl3.col3 is null and (u_tbl4.col4) in ::fkc_vals and :u_tbl3_col3 = :v1 limit 1 lock in share mode", "Table": "u_tbl3, u_tbl4" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 9e668bd68a2..eb5feecf7f6 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -1598,7 +1598,7 @@ "Sharded": true }, "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t, user_extra where 1 != 1", - "Query": "select t.id from (select id from `user` where id = 5 and id = :user_extra_user_id) as t, user_extra where t.id = user_extra.user_id", + "Query": "select t.id from (select id from `user` where id = 5) as t, user_extra where t.id = user_extra.user_id", "Table": "`user`, user_extra", "Values": [ "INT64(5)" @@ -1736,7 +1736,7 @@ "Sharded": true }, "FieldQuery": "select t.id from (select id, textcol1 as baz from `user` as route1 where 1 != 1) as t, (select id, textcol1 + textcol1 as baz from `user` where 1 != 1) as s where 1 != 1", - "Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3' and id = :t_id) as s where t.id = s.id", + "Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3') as s where t.id = s.id", "Table": "`user`" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..12ac8a0afef 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,9 +1,78 @@ [ { - "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "comment": "Update in a table with shard-scoped foreign keys with cascade that requires a validation of a different parent foreign key", + "query": "update u_tbl6 set col6 = 'foo'", "plan": { - + "QueryType": "UPDATE", + "Original": "update u_tbl6 set col6 = 'foo'", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select u_tbl6.col6 from u_tbl6 where 1 != 1", + "Query": "select u_tbl6.col6 from u_tbl6 for update", + "Table": "u_tbl6" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FKVerify", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Table": "u_tbl8, u_tbl9" + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl8 set col8 = 'foo' where (col8) in ::fkc_vals", + "Table": "u_tbl8" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl6 set col6 = 'foo'", + "Table": "u_tbl6" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl6", + "unsharded_fk_allow.u_tbl8", + "unsharded_fk_allow.u_tbl9" + ] } } ] \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 7f573227e65..a36ad580dc4 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -1052,7 +1052,7 @@ "Sharded": true }, "FieldQuery": "select * from (select user_id from user_extra where 1 != 1) as eu, `user` as u where 1 != 1", - "Query": "select * from (select user_id from user_extra where user_id = 5 and user_id = :u_id) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", + "Query": "select * from (select user_id from user_extra where user_id = 5) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", "Table": "`user`, user_extra", "Values": [ "INT64(5)" diff --git a/go/vt/vtgate/planbuilder/testdata/reference_cases.json b/go/vt/vtgate/planbuilder/testdata/reference_cases.json index 140ce9f7849..e55eed68678 100644 --- a/go/vt/vtgate/planbuilder/testdata/reference_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/reference_cases.json @@ -937,7 +937,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where 1 != 1", - "Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where sr.foo = :ue_foo and rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo", + "Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo", "Table": "`user`, ref, ref_with_source, user_extra" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 6c255fd9d89..2412fbaadcf 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4110,7 +4110,7 @@ "Sharded": true }, "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", - "Query": "select music.id from (select id from music where music.user_id = 5 and id = :music_id) as other, music where other.id = music.id", + "Query": "select music.id from (select id from music where music.user_id = 5) as other, music where other.id = music.id", "Table": "music", "Values": [ "INT64(5)" @@ -4136,7 +4136,7 @@ "Sharded": true }, "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", - "Query": "select music.id from (select id from music where music.user_id in ::__vals and id = :music_id) as other, music where other.id = music.id", + "Query": "select music.id from (select id from music where music.user_id in ::__vals) as other, music where other.id = music.id", "Table": "music", "Values": [ "(INT64(5), INT64(6), INT64(7))" diff --git a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json index fa823b0ae59..2677deb2cab 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json @@ -13,7 +13,7 @@ "Sharded": true }, "FieldQuery": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where 1 != 1", - "Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_w_id = :w_id and c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id", + "Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id", "Table": "customer1, warehouse1", "Values": [ "INT64(1)" @@ -947,7 +947,7 @@ "Sharded": true }, "FieldQuery": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where 1 != 1 group by o_c_id, o_d_id, o_w_id) as t, orders1 as o where 1 != 1", - "Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 and o_w_id = :o_o_w_id and o_d_id = :o_o_d_id and o_c_id = :o_o_c_id group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", + "Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", "Table": "orders1", "Values": [ "INT64(1)"