diff --git a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go index 3c2a1800e31..6cf88a48f75 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go +++ b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan.go @@ -617,9 +617,10 @@ func valsEqual(v1, v2 sqltypes.Value) bool { // on the source: sum/count for aggregation queries, for example. func (tp *TablePlan) appendFromRow(buf *bytes2.Buffer, row *querypb.Row) error { bindLocations := tp.BulkInsertValues.BindLocations() - if len(tp.Fields) < len(bindLocations) { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "wrong number of fields: got %d fields for %d bind locations ", - len(tp.Fields), len(bindLocations)) + usedFieldCnt := len(tp.Fields) - len(tp.FieldsToSkip) + if usedFieldCnt != len(bindLocations) { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "wrong number of fields: got %d fields for %d bind locations", + usedFieldCnt, len(bindLocations)) } // Bind field values to locations. diff --git a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan_test.go b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan_test.go index 644b4585914..3b46a08a5cd 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/replicator_plan_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/replicator_plan_test.go @@ -21,17 +21,18 @@ import ( "strings" "testing" - vttablet "vitess.io/vitess/go/vt/vttablet/common" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/bytes2" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/binlog/binlogplayer" "vitess.io/vitess/go/vt/sqlparser" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" + vttablet "vitess.io/vitess/go/vt/vttablet/common" ) type TestReplicatorPlan struct { @@ -829,3 +830,114 @@ func TestBuildPlayerPlanExclude(t *testing.T) { wantPlan, _ := json.Marshal(want) assert.Equal(t, string(gotPlan), string(wantPlan)) } + +func TestAppendFromRow(t *testing.T) { + testCases := []struct { + name string + tp *TablePlan + row *querypb.Row + want string + wantErr string + }{ + { + name: "simple", + tp: &TablePlan{ + BulkInsertValues: sqlparser.BuildParsedQuery("values (%a, %a, %a)", + ":c1", ":c2", ":c3", + ), + Fields: []*querypb.Field{ + {Name: "c1", Type: querypb.Type_INT32}, + {Name: "c2", Type: querypb.Type_INT32}, + {Name: "c3", Type: querypb.Type_INT32}, + }, + }, + row: sqltypes.RowToProto3( + []sqltypes.Value{ + sqltypes.NewInt64(1), + sqltypes.NewInt64(2), + sqltypes.NewInt64(3), + }, + ), + want: "values (1, 2, 3)", + }, + { + name: "too few fields", + tp: &TablePlan{ + BulkInsertValues: sqlparser.BuildParsedQuery("values (%a, %a, %a)", + ":c1", ":c2", ":c3", + ), + Fields: []*querypb.Field{ + {Name: "c1", Type: querypb.Type_INT32}, + {Name: "c2", Type: querypb.Type_INT32}, + }, + }, + wantErr: "wrong number of fields: got 2 fields for 3 bind locations", + }, + { + name: "too few non-skipped fields", + tp: &TablePlan{ + BulkInsertValues: sqlparser.BuildParsedQuery("values (%a, %a, %a)", + ":c1", ":c2", ":c3", + ), + Fields: []*querypb.Field{ + {Name: "c1", Type: querypb.Type_INT32}, + {Name: "c2", Type: querypb.Type_INT32}, + {Name: "c3", Type: querypb.Type_INT32}, + {Name: "c4", Type: querypb.Type_INT32}, + }, + FieldsToSkip: map[string]bool{ + "c3": true, + "c4": true, + }, + }, + wantErr: "wrong number of fields: got 2 fields for 3 bind locations", + }, + { + name: "lots o skippin", + tp: &TablePlan{ + BulkInsertValues: sqlparser.BuildParsedQuery("values (%a, %a, %a)", + ":c1", ":c2", ":c4", + ), + Fields: []*querypb.Field{ + {Name: "c1", Type: querypb.Type_INT32}, + {Name: "c2", Type: querypb.Type_INT32}, + {Name: "c3", Type: querypb.Type_INT32}, + {Name: "c4", Type: querypb.Type_INT32}, + {Name: "c5", Type: querypb.Type_INT32}, + {Name: "c6", Type: querypb.Type_INT32}, + {Name: "c7", Type: querypb.Type_INT32}, + }, + FieldsToSkip: map[string]bool{ + "c3": true, + "c5": true, + "c6": true, + "c7": true, + }, + }, + row: sqltypes.RowToProto3( + []sqltypes.Value{ + sqltypes.NewInt64(1), + sqltypes.NewInt64(2), + sqltypes.NewInt64(3), + sqltypes.NewInt64(4), + sqltypes.NewInt64(5), + sqltypes.NewInt64(6), + sqltypes.NewInt64(7), + }, + ), + want: "values (1, 2, 4)", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bb := &bytes2.Buffer{} + err := tc.tp.appendFromRow(bb, tc.row) + if tc.wantErr != "" { + require.EqualError(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.want, bb.String()) + } + }) + } +}