From bd6d67f870a5b74c8e80c0df03485a663c61df8a Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 15 Jan 2025 09:35:59 -0600 Subject: [PATCH] Enhance ValuesJoin engine and improve unit test Signed-off-by: Florent Poinsard --- go/vt/vtgate/engine/cached_size.go | 34 +++---- go/vt/vtgate/engine/fake_primitive_test.go | 13 ++- go/vt/vtgate/engine/fake_vcursor_test.go | 47 +++++++++- go/vt/vtgate/engine/join.go | 4 +- go/vt/vtgate/engine/join_values.go | 41 ++++++++- go/vt/vtgate/engine/join_values_test.go | 101 +++++++++++++++++++++ 6 files changed, 205 insertions(+), 35 deletions(-) create mode 100644 go/vt/vtgate/engine/join_values_test.go diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index d176b99a839..4aec4b70ecc 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -1475,15 +1475,13 @@ func (cached *VStream) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Position))) return size } - -//go:nocheckptr func (cached *ValuesJoin) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) } size := int64(0) if alloc { - size += int64(80) + size += int64(128) } // field Left vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Left.(cachedObject); ok { @@ -1493,29 +1491,23 @@ func (cached *ValuesJoin) CachedSize(alloc bool) int64 { if cc, ok := cached.Right.(cachedObject); ok { size += cc.CachedSize(true) } - // field Vars map[string]int - if cached.Vars != nil { - size += int64(48) - hmap := reflect.ValueOf(cached.Vars) - numBuckets := int(math.Pow(2, float64((*(*uint8)(unsafe.Pointer(hmap.Pointer() + uintptr(9))))))) - numOldBuckets := (*(*uint16)(unsafe.Pointer(hmap.Pointer() + uintptr(10)))) - size += hack.RuntimeAllocSize(int64(numOldBuckets * 208)) - if len(cached.Vars) > 0 || numBuckets > 1 { - size += hack.RuntimeAllocSize(int64(numBuckets * 208)) - } - for k := range cached.Vars { - size += hack.RuntimeAllocSize(int64(len(k))) - } + // field Vars []int + { + size += hack.RuntimeAllocSize(int64(cap(cached.Vars)) * int64(8)) + } + // field RowConstructorArg string + size += hack.RuntimeAllocSize(int64(len(cached.RowConstructorArg))) + // field Cols []int + { + size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(8)) } - // field Columns []string + // field ColNames []string { - size += hack.RuntimeAllocSize(int64(cap(cached.Columns)) * int64(16)) - for _, elem := range cached.Columns { + size += hack.RuntimeAllocSize(int64(cap(cached.ColNames)) * int64(16)) + for _, elem := range cached.ColNames { size += hack.RuntimeAllocSize(int64(len(elem))) } } - // field RowConstructorArg string - size += hack.RuntimeAllocSize(int64(len(cached.RowConstructorArg))) return size } func (cached *Verify) CachedSize(alloc bool) int64 { diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index f3ab5ad5336..bddbca87664 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -46,6 +46,8 @@ type fakePrimitive struct { allResultsInOneCall bool async bool + + useNewPrintBindVars bool } func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) { @@ -72,7 +74,12 @@ func (f *fakePrimitive) GetTableName() string { } func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) + if f.useNewPrintBindVars { + f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) + } else { + f.log = append(f.log, fmt.Sprintf("Execute %v %v", deprecatedPrintBindVars(bindVars), wantfields)) + } + if f.results == nil { return nil, f.sendErr } @@ -87,7 +94,7 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { if !f.noLog { - f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", deprecatedPrintBindVars(bindVars), wantfields)) } if f.results == nil { return f.sendErr @@ -171,7 +178,7 @@ func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error { } func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars))) + f.log = append(f.log, fmt.Sprintf("GetFields %v", deprecatedPrintBindVars(bindVars))) return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */) } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index aac3e9b584c..3ac62ddffd9 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -597,7 +597,7 @@ func (f *loggingVCursor) Execute(ctx context.Context, method string, query strin case vtgatepb.CommitOrder_AUTOCOMMIT: name = "ExecuteAutocommit" } - f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, printBindVars(bindvars), rollbackOnError)) + f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, deprecatedPrintBindVars(bindvars), rollbackOnError)) return f.nextResult() } @@ -621,7 +621,7 @@ func (f *loggingVCursor) AutocommitApproval() bool { } func (f *loggingVCursor) ExecuteStandalone(ctx context.Context, _ Primitive, query string, bindvars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard, fetchLastInsertID bool) (*sqltypes.Result, error) { - f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, printBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard)) + f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, deprecatedPrintBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard)) return f.nextResult() } @@ -943,6 +943,24 @@ func expectResultAnyOrder(t *testing.T, result, want *sqltypes.Result) { } } +// deprecatedPrintBindVars does not print bind variables, specifically tuples, correctly. +// We should use printBindVars instead. +func deprecatedPrintBindVars(bindvars map[string]*querypb.BindVariable) string { + var keys []string + for k := range bindvars { + keys = append(keys, k) + } + sort.Strings(keys) + buf := &bytes.Buffer{} + for i, k := range keys { + if i > 0 { + fmt.Fprintf(buf, " ") + } + fmt.Fprintf(buf, "%s: %v", k, bindvars[k]) + } + return buf.String() +} + func printBindVars(bindvars map[string]*querypb.BindVariable) string { var keys []string for k := range bindvars { @@ -954,6 +972,27 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string { if i > 0 { fmt.Fprintf(buf, " ") } + + if bindvars[k].Type == querypb.Type_TUPLE { + fmt.Fprintf(buf, "%s: [", k) + for _, val := range bindvars[k].Values { + if val.Type != querypb.Type_TUPLE { + fmt.Fprintf(buf, "[%s]", val.String()) + continue + } + var s []string + v := sqltypes.ProtoToValue(val) + err := v.ForEachValue(func(bv sqltypes.Value) { + s = append(s, bv.String()) + }) + if err != nil { + panic(err) + } + fmt.Fprintf(buf, "[%s]", strings.Join(s, " ")) + } + fmt.Fprintf(buf, "]") + continue + } fmt.Fprintf(buf, "%s: %v", k, bindvars[k]) } return buf.String() @@ -962,7 +1001,7 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string { func printResolvedShardQueries(rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery) string { buf := &bytes.Buffer{} for i, rs := range rss { - fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, printBindVars(queries[i].BindVariables)) + fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, deprecatedPrintBindVars(queries[i].BindVariables)) } return buf.String() } @@ -970,7 +1009,7 @@ func printResolvedShardQueries(rss []*srvtopo.ResolvedShard, queries []*querypb. func printResolvedShardsBindVars(rss []*srvtopo.ResolvedShard, bvs []map[string]*querypb.BindVariable) string { buf := &bytes.Buffer{} for i, rs := range rss { - fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, printBindVars(bvs[i])) + fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, deprecatedPrintBindVars(bvs[i])) } return buf.String() } diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 51976396cba..8134d78ff4a 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -220,10 +220,10 @@ func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field fields := make([]*querypb.Field, len(cols)) for i, index := range cols { if index < 0 { - fields[i] = lfields[-index-1] + fields[i] = lfields[-index-1].CloneVT() continue } - fields[i] = rfields[index-1] + fields[i] = rfields[index-1].CloneVT() } return fields } diff --git a/go/vt/vtgate/engine/join_values.go b/go/vt/vtgate/engine/join_values.go index e35addd8c09..0d341c362df 100644 --- a/go/vt/vtgate/engine/join_values.go +++ b/go/vt/vtgate/engine/join_values.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vterrors" ) var _ Primitive = (*ValuesJoin)(nil) @@ -33,9 +34,10 @@ type ValuesJoin struct { // of the Join. They can be any primitive. Left, Right Primitive - Vars map[string]int - Columns []string + Vars []int RowConstructorArg string + Cols []int + ColNames []string } // TryExecute performs a non-streaming exec. @@ -62,11 +64,40 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars return jv.Right.GetFields(ctx, vcursor, bindVars) } - for _, row := range lresult.Rows { - bv.Values = append(bv.Values, sqltypes.TupleToProto(row)) + for i, row := range lresult.Rows { + newRow := make(sqltypes.Row, 0, len(jv.Vars)+1) // +1 since we always add the row ID + newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID + + for _, loffset := range jv.Vars { + newRow = append(newRow, row[loffset]) + } + + bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow)) } + bindVars[jv.RowConstructorArg] = bv - return vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields) + rresult, err := vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields) + if err != nil { + return nil, err + } + + result := &sqltypes.Result{} + + result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols) + for i := range result.Fields { + result.Fields[i].Name = jv.ColNames[i] + } + + for _, rrow := range rresult.Rows { + lhsRowID, err := rrow[len(rrow)-1].ToCastInt64() + if err != nil { + return nil, vterrors.VT13001("values joins cannot fetch lhs row ID: " + err.Error()) + } + + result.Rows = append(result.Rows, joinRows(lresult.Rows[lhsRowID], rrow, jv.Cols)) + } + + return result, nil } // TryStreamExecute performs a streaming exec. diff --git a/go/vt/vtgate/engine/join_values_test.go b/go/vt/vtgate/engine/join_values_test.go new file mode 100644 index 00000000000..068259a4e3e --- /dev/null +++ b/go/vt/vtgate/engine/join_values_test.go @@ -0,0 +1,101 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func TestJoinValuesExecute(t *testing.T) { + + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 + */ + + leftPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + ), + }, + } + rightPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col5|col6|id", + "varchar|varchar|int64", + ), + "d|dd|0", + "e|ee|1", + "f|ff|2", + "g|gg|3", + ), + }, + } + + bv := map[string]*querypb.BindVariable{ + "a": sqltypes.Int64BindVariable(10), + } + + vjn := &ValuesJoin{ + Left: leftPrim, + Right: rightPrim, + Vars: []int{0}, + RowConstructorArg: "v", + Cols: []int{-1, -2, -3, -1, 1, 2}, + ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, + } + + r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute a: type:INT64 value:"10" true`, + }) + rightPrim.ExpectLog(t, []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(0) INT64(1)][INT64(1) INT64(2)][INT64(2) INT64(3)][INT64(3) INT64(4)]] true`, + }) + + result := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ) + expectResult(t, r, result) +}