From 27f2d736463364ead0bea817c0e73ed79be47883 Mon Sep 17 00:00:00 2001 From: Noble Mittal <62551163+beingnoble03@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:59:37 +0530 Subject: [PATCH] evalengine: Implement `PERIOD_ADD` (#16492) --- go/mysql/datetime/mydate.go | 45 ++++++++++++++++ go/mysql/datetime/mydate_test.go | 54 +++++++++++++++++++ go/vt/vtgate/evalengine/cached_size.go | 12 +++++ go/vt/vtgate/evalengine/compiler_asm.go | 20 +++++++ go/vt/vtgate/evalengine/fn_time.go | 57 ++++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 22 ++++++++ go/vt/vtgate/evalengine/testcases/inputs.go | 5 ++ go/vt/vtgate/evalengine/translate_builtin.go | 7 +++ 8 files changed, 222 insertions(+) diff --git a/go/mysql/datetime/mydate.go b/go/mysql/datetime/mydate.go index 62cbb3f2524..5b77082055a 100644 --- a/go/mysql/datetime/mydate.go +++ b/go/mysql/datetime/mydate.go @@ -89,3 +89,48 @@ func DateFromDayNumber(daynr int) Date { d.year, d.month, d.day = mysqlDateFromDayNumber(daynr) return d } + +// ValidatePeriod validates the MySQL period. +// Returns false if period is non-positive or contains incorrect month value. +func ValidatePeriod(period int64) bool { + if period <= 0 { + return false + } + month := period % 100 + if month == 0 || month > 12 { + return false + } + return true +} + +// PeriodToMonths converts a MySQL period into number of months. +// This is an algorithm that has been reverse engineered from MySQL. +func PeriodToMonths(period int64) int64 { + p := uint64(period) + if p == 0 { + return 0 + } + y := p / 100 + if y < 70 { + y += 2000 + } else if y < 100 { + y += 1900 + } + return int64(y*12 + p%100 - 1) +} + +// MonthsToPeriod converts number of months into MySQL period. +// This is an algorithm that has been reverse engineered from MySQL. +func MonthsToPeriod(months int64) int64 { + m := uint64(months) + if m == 0 { + return 0 + } + y := m / 12 + if y < 70 { + y += 2000 + } else if y < 100 { + y += 1900 + } + return int64(y*100 + m%12 + 1) +} diff --git a/go/mysql/datetime/mydate_test.go b/go/mysql/datetime/mydate_test.go index bb5073b8ff8..a743db60709 100644 --- a/go/mysql/datetime/mydate_test.go +++ b/go/mysql/datetime/mydate_test.go @@ -65,3 +65,57 @@ func TestDayNumberFields(t *testing.T) { assert.Equal(t, wantDate, got) } } + +func TestValidatePeriod(t *testing.T) { + testCases := []struct { + period int64 + want bool + }{ + {110112, true}, + {101122, false}, + {-1112212, false}, + {7110, true}, + } + + for _, tc := range testCases { + got := ValidatePeriod(tc.period) + assert.Equal(t, tc.want, got) + } +} + +func TestPeriodToMonths(t *testing.T) { + testCases := []struct { + period int64 + want int64 + }{ + {0, 0}, + {110112, 13223}, + {100112, 12023}, + {7112, 23663}, + {200112, 24023}, + {112, 24023}, + } + + for _, tc := range testCases { + got := PeriodToMonths(tc.period) + assert.Equal(t, tc.want, got) + } +} + +func TestMonthsToPeriod(t *testing.T) { + testCases := []struct { + months int64 + want int64 + }{ + {0, 0}, + {13223, 110112}, + {12023, 100112}, + {23663, 197112}, + {24023, 200112}, + } + + for _, tc := range testCases { + got := MonthsToPeriod(tc.months) + assert.Equal(t, tc.want, got) + } +} diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 6f447f0d1c1..9009b069f5a 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1397,6 +1397,18 @@ func (cached *builtinPad) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinPeriodAdd) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinPi) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 0cac66d9e5e..93781ed077b 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -4299,6 +4299,26 @@ func (asm *assembler) Fn_YEARWEEK() { }, "FN YEARWEEK DATE(SP-1)") } +func (asm *assembler) Fn_PERIOD_ADD() { + asm.adjustStack(-1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-2] == nil { + env.vm.sp-- + return 1 + } + period := env.vm.stack[env.vm.sp-2].(*evalInt64).i + months := env.vm.stack[env.vm.sp-1].(*evalInt64).i + res, err := periodAdd(period, months) + if err != nil { + env.vm.err = err + return 0 + } + env.vm.stack[env.vm.sp-2] = res + env.vm.sp-- + return 1 + }, "FN PERIOD_ADD INT64(SP-2) INT64(SP-1)") +} + func (asm *assembler) Interval(l int) { asm.adjustStack(-l) asm.emit(func(env *ExpressionEnv) int { diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 8d920e9e135..90fcda2c32a 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -26,6 +26,9 @@ import ( "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vterrors" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) var SystemTime = time.Now @@ -174,6 +177,10 @@ type ( CallExpr } + builtinPeriodAdd struct { + CallExpr + } + builtinDateMath struct { CallExpr sub bool @@ -214,6 +221,7 @@ var _ IR = (*builtinWeekDay)(nil) var _ IR = (*builtinWeekOfYear)(nil) var _ IR = (*builtinYear)(nil) var _ IR = (*builtinYearWeek)(nil) +var _ IR = (*builtinPeriodAdd)(nil) func (call *builtinNow) eval(env *ExpressionEnv) (eval, error) { now := env.time(call.utc) @@ -1964,6 +1972,55 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag | flagNullable}, nil } +func periodAdd(period, months int64) (*evalInt64, error) { + if !datetime.ValidatePeriod(period) { + return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongArguments, "Incorrect arguments to period_add") + } + return newEvalInt64(datetime.MonthsToPeriod(datetime.PeriodToMonths(period) + months)), nil +} + +func (b *builtinPeriodAdd) eval(env *ExpressionEnv) (eval, error) { + p, m, err := b.arg2(env) + if err != nil { + return nil, err + } + if p == nil || m == nil { + return nil, nil + } + period := evalToInt64(p) + months := evalToInt64(m) + return periodAdd(period.i, months.i) +} + +func (call *builtinPeriodAdd) compile(c *compiler) (ctype, error) { + period, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + months, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + skip := c.compileNullCheck2(period, months) + + switch period.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(2) + } + + switch months.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(1) + } + + c.asm.Fn_PERIOD_ADD() + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.Int64, Flag: period.Flag | months.Flag | flagNullable}, nil +} + func evalToInterval(itv eval, unit datetime.IntervalType, negate bool) *datetime.Interval { switch itv := itv.(type) { case *evalBytes: diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 003eb45c0a3..7d5305b21f7 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -155,6 +155,7 @@ var Cases = []TestCase{ {Run: FnWeekOfYear}, {Run: FnYear}, {Run: FnYearWeek}, + {Run: FnPeriodAdd}, {Run: FnInetAton}, {Run: FnInetNtoa}, {Run: FnInet6Aton}, @@ -2223,6 +2224,27 @@ func FnYearWeek(yield Query) { } } +func FnPeriodAdd(yield Query) { + for _, p := range inputBitwise { + for _, m := range inputBitwise { + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + } + } + for _, p := range inputPeriods { + for _, m := range inputBitwise { + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + } + } + + mysqlDocSamples := []string{ + `PERIOD_ADD(200801,2)`, + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnInetAton(yield Query) { for _, d := range ipInputs { yield(fmt.Sprintf("INET_ATON(%s)", d), nil) diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index eb94235d9b4..ac23281fd54 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -59,6 +59,11 @@ var inputBitwise = []string{ "64", "'64'", "_binary '64'", "X'40'", "_binary X'40'", } +var inputPeriods = []string{ + "110192", "'119812'", "2703", "7111", "200103", "200309", "0309", "-110102", "0", + "'032'", "223", "'-119812'", "-2703", "99999999999999999999999911", "'-0309'", +} + var radianInputs = []string{ "0", "1", diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index d4c6bcdae5a..2c4d887ff19 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -528,6 +528,13 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { default: return nil, argError(method) } + case "period_add": + switch len(args) { + case 2: + return &builtinPeriodAdd{CallExpr: call}, nil + default: + return nil, argError(method) + } case "inet_aton": if len(args) != 1 { return nil, argError(method)