Skip to content

Commit

Permalink
Enhance ValuesJoin engine and improve unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Florent Poinsard <[email protected]>
  • Loading branch information
frouioui committed Jan 15, 2025
1 parent 455fe86 commit bd6d67f
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 35 deletions.
34 changes: 13 additions & 21 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type fakePrimitive struct {
allResultsInOneCall bool

async bool

useNewPrintBindVars bool
}

func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) {
Expand All @@ -72,7 +74,12 @@ func (f *fakePrimitive) GetTableName() string {
}

func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields))
if f.useNewPrintBindVars {
f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields))
} else {
f.log = append(f.log, fmt.Sprintf("Execute %v %v", deprecatedPrintBindVars(bindVars), wantfields))
}

if f.results == nil {
return nil, f.sendErr
}
Expand All @@ -87,7 +94,7 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", deprecatedPrintBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
Expand Down Expand Up @@ -171,7 +178,7 @@ func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
f.log = append(f.log, fmt.Sprintf("GetFields %v", deprecatedPrintBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
}

Expand Down
47 changes: 43 additions & 4 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ func (f *loggingVCursor) Execute(ctx context.Context, method string, query strin
case vtgatepb.CommitOrder_AUTOCOMMIT:
name = "ExecuteAutocommit"
}
f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, printBindVars(bindvars), rollbackOnError))
f.log = append(f.log, fmt.Sprintf("%s %s %v %v", name, query, deprecatedPrintBindVars(bindvars), rollbackOnError))
return f.nextResult()
}

Expand All @@ -621,7 +621,7 @@ func (f *loggingVCursor) AutocommitApproval() bool {
}

func (f *loggingVCursor) ExecuteStandalone(ctx context.Context, _ Primitive, query string, bindvars map[string]*querypb.BindVariable, rs *srvtopo.ResolvedShard, fetchLastInsertID bool) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, printBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard))
f.log = append(f.log, fmt.Sprintf("ExecuteStandalone %s %v %s %s", query, deprecatedPrintBindVars(bindvars), rs.Target.Keyspace, rs.Target.Shard))
return f.nextResult()
}

Expand Down Expand Up @@ -943,6 +943,24 @@ func expectResultAnyOrder(t *testing.T, result, want *sqltypes.Result) {
}
}

// deprecatedPrintBindVars does not print bind variables, specifically tuples, correctly.
// We should use printBindVars instead.
func deprecatedPrintBindVars(bindvars map[string]*querypb.BindVariable) string {
var keys []string
for k := range bindvars {
keys = append(keys, k)
}
sort.Strings(keys)
buf := &bytes.Buffer{}
for i, k := range keys {
if i > 0 {
fmt.Fprintf(buf, " ")
}
fmt.Fprintf(buf, "%s: %v", k, bindvars[k])
}
return buf.String()
}

func printBindVars(bindvars map[string]*querypb.BindVariable) string {
var keys []string
for k := range bindvars {
Expand All @@ -954,6 +972,27 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string {
if i > 0 {
fmt.Fprintf(buf, " ")
}

if bindvars[k].Type == querypb.Type_TUPLE {
fmt.Fprintf(buf, "%s: [", k)
for _, val := range bindvars[k].Values {
if val.Type != querypb.Type_TUPLE {
fmt.Fprintf(buf, "[%s]", val.String())
continue
}
var s []string
v := sqltypes.ProtoToValue(val)
err := v.ForEachValue(func(bv sqltypes.Value) {
s = append(s, bv.String())
})
if err != nil {
panic(err)
}
fmt.Fprintf(buf, "[%s]", strings.Join(s, " "))
}
fmt.Fprintf(buf, "]")
continue
}
fmt.Fprintf(buf, "%s: %v", k, bindvars[k])
}
return buf.String()
Expand All @@ -962,15 +1001,15 @@ func printBindVars(bindvars map[string]*querypb.BindVariable) string {
func printResolvedShardQueries(rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery) string {
buf := &bytes.Buffer{}
for i, rs := range rss {
fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, printBindVars(queries[i].BindVariables))
fmt.Fprintf(buf, "%s.%s: %s {%s} ", rs.Target.Keyspace, rs.Target.Shard, queries[i].Sql, deprecatedPrintBindVars(queries[i].BindVariables))
}
return buf.String()
}

func printResolvedShardsBindVars(rss []*srvtopo.ResolvedShard, bvs []map[string]*querypb.BindVariable) string {
buf := &bytes.Buffer{}
for i, rs := range rss {
fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, printBindVars(bvs[i]))
fmt.Fprintf(buf, "%s.%s: {%v} ", rs.Target.Keyspace, rs.Target.Shard, deprecatedPrintBindVars(bvs[i]))
}
return buf.String()
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field
fields := make([]*querypb.Field, len(cols))
for i, index := range cols {
if index < 0 {
fields[i] = lfields[-index-1]
fields[i] = lfields[-index-1].CloneVT()
continue
}
fields[i] = rfields[index-1]
fields[i] = rfields[index-1].CloneVT()
}
return fields
}
Expand Down
41 changes: 36 additions & 5 deletions go/vt/vtgate/engine/join_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vterrors"
)

var _ Primitive = (*ValuesJoin)(nil)
Expand All @@ -33,9 +34,10 @@ type ValuesJoin struct {
// of the Join. They can be any primitive.
Left, Right Primitive

Vars map[string]int
Columns []string
Vars []int
RowConstructorArg string
Cols []int
ColNames []string
}

// TryExecute performs a non-streaming exec.
Expand All @@ -62,11 +64,40 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars
return jv.Right.GetFields(ctx, vcursor, bindVars)
}

for _, row := range lresult.Rows {
bv.Values = append(bv.Values, sqltypes.TupleToProto(row))
for i, row := range lresult.Rows {
newRow := make(sqltypes.Row, 0, len(jv.Vars)+1) // +1 since we always add the row ID
newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID

for _, loffset := range jv.Vars {
newRow = append(newRow, row[loffset])
}

bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow))
}

bindVars[jv.RowConstructorArg] = bv
return vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields)
rresult, err := vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields)
if err != nil {
return nil, err
}

result := &sqltypes.Result{}

result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols)
for i := range result.Fields {
result.Fields[i].Name = jv.ColNames[i]
}

for _, rrow := range rresult.Rows {
lhsRowID, err := rrow[len(rrow)-1].ToCastInt64()
if err != nil {
return nil, vterrors.VT13001("values joins cannot fetch lhs row ID: " + err.Error())
}

result.Rows = append(result.Rows, joinRows(lresult.Rows[lhsRowID], rrow, jv.Cols))
}

return result, nil
}

// TryStreamExecute performs a streaming exec.
Expand Down
101 changes: 101 additions & 0 deletions go/vt/vtgate/engine/join_values_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright 2025 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

func TestJoinValuesExecute(t *testing.T) {

/*
select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4
LHS: select col1, col2, col3 from left
RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4
*/

leftPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
),
},
}
rightPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col5|col6|id",
"varchar|varchar|int64",
),
"d|dd|0",
"e|ee|1",
"f|ff|2",
"g|gg|3",
),
},
}

bv := map[string]*querypb.BindVariable{
"a": sqltypes.Int64BindVariable(10),
}

vjn := &ValuesJoin{
Left: leftPrim,
Right: rightPrim,
Vars: []int{0},
RowConstructorArg: "v",
Cols: []int{-1, -2, -3, -1, 1, 2},
ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"},
}

r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true)
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`Execute a: type:INT64 value:"10" true`,
})
rightPrim.ExpectLog(t, []string{
`Execute a: type:INT64 value:"10" v: [[INT64(0) INT64(1)][INT64(1) INT64(2)][INT64(2) INT64(3)][INT64(3) INT64(4)]] true`,
})

result := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3|col4|col5|col6",
"int64|varchar|varchar|int64|varchar|varchar",
),
"1|a|aa|1|d|dd",
"2|b|bb|2|e|ee",
"3|c|cc|3|f|ff",
"4|d|dd|4|g|gg",
)
expectResult(t, r, result)
}

0 comments on commit bd6d67f

Please sign in to comment.