diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 8746ddebaec94..3fe8f86cd379a 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -1838,7 +1838,7 @@ func (b *builtinCastTimeAsDurationSig) evalDuration(ctx EvalContext, row chunk.R if err != nil { return res, false, err } - res, err = res.RoundFrac(b.tp.GetDecimal(), location(ctx)) + res, err = res.RoundFrac(b.tp.GetDecimal()) return res, false, err } @@ -1857,7 +1857,7 @@ func (b *builtinCastDurationAsDurationSig) evalDuration(ctx EvalContext, row chu if isNull || err != nil { return res, isNull, err } - res, err = res.RoundFrac(b.tp.GetDecimal(), location(ctx)) + res, err = res.RoundFrac(b.tp.GetDecimal()) return res, false, err } @@ -1881,7 +1881,7 @@ func (b *builtinCastDurationAsIntSig) evalInt(ctx EvalContext, row chunk.Row) (r res, err = val.ConvertToYear(typeCtx(ctx)) } else { var dur types.Duration - dur, err = val.RoundFrac(types.DefaultFsp, location(ctx)) + dur, err = val.RoundFrac(types.DefaultFsp) if err != nil { return res, false, err } @@ -2207,7 +2207,7 @@ func (b *builtinCastJSONAsDurationSig) evalDuration(ctx EvalContext, row chunk.R if err != nil { return res, false, err } - res, err = res.RoundFrac(b.tp.GetDecimal(), location(ctx)) + res, err = res.RoundFrac(b.tp.GetDecimal()) return res, isNull, err case types.JSONTypeCodeDuration: res = val.GetDuration() diff --git a/pkg/expression/builtin_cast_vec.go b/pkg/expression/builtin_cast_vec.go index c73314d12ee8b..f56683a11bd72 100644 --- a/pkg/expression/builtin_cast_vec.go +++ b/pkg/expression/builtin_cast_vec.go @@ -350,7 +350,7 @@ func (b *builtinCastDurationAsIntSig) vecEvalInt(ctx EvalContext, input *chunk.C i64s[i], err = duration.ConvertToYear(tc) } else { var dur types.Duration - dur, err = duration.RoundFrac(types.DefaultFsp, location(ctx)) + dur, err = duration.RoundFrac(types.DefaultFsp) if err != nil { return err } @@ -1346,7 +1346,7 @@ func (b *builtinCastTimeAsDurationSig) vecEvalDuration(ctx EvalContext, input *c if err != nil { return err } - d, err = d.RoundFrac(b.tp.GetDecimal(), location(ctx)) + d, err = d.RoundFrac(b.tp.GetDecimal()) if err != nil { return err } @@ -1375,7 +1375,7 @@ func (b *builtinCastDurationAsDurationSig) vecEvalDuration(ctx EvalContext, inpu continue } dur.Duration = v - rd, err = dur.RoundFrac(b.tp.GetDecimal(), location(ctx)) + rd, err = dur.RoundFrac(b.tp.GetDecimal()) if err != nil { return err } @@ -1981,7 +1981,7 @@ func (b *builtinCastJSONAsDurationSig) vecEvalDuration(ctx EvalContext, input *c if err != nil { return err } - d, err = d.RoundFrac(b.tp.GetDecimal(), location(ctx)) + d, err = d.RoundFrac(b.tp.GetDecimal()) if err != nil { return err } diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index 4100f705d63ad..12c2b1497fee4 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -2060,7 +2060,7 @@ func (b *builtinSysDateWithFspSig) evalTime(ctx EvalContext, row chunk.Row) (val loc := location(ctx) now := time.Now().In(loc) - result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeHalfUp) + result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeTruncate) if err != nil { return types.ZeroTime, true, err } @@ -2082,7 +2082,7 @@ func (b *builtinSysDateWithoutFspSig) Clone() builtinFunc { func (b *builtinSysDateWithoutFspSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { tz := location(ctx) now := time.Now().In(tz) - result, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp) + result, err := convertTimeToMysqlTime(now, 0, types.ModeTruncate) if err != nil { return types.ZeroTime, true, err } @@ -2181,7 +2181,7 @@ func (b *builtinCurrentTime0ArgSig) evalDuration(ctx EvalContext, row chunk.Row) return types.Duration{}, true, err } dur := nowTs.In(tz).Format(types.TimeFormat) - res, _, err := types.ParseDuration(typeCtx(ctx), dur, types.MinFsp) + res, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), dur, types.MinFsp) if err != nil { return types.Duration{}, true, err } @@ -2209,8 +2209,7 @@ func (b *builtinCurrentTime1ArgSig) evalDuration(ctx EvalContext, row chunk.Row) return types.Duration{}, true, err } dur := nowTs.In(tz).Format(types.TimeFSPFormat) - tc := typeCtx(ctx) - res, _, err := types.ParseDuration(tc, dur, int(fsp)) + res, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), dur, int(fsp)) if err != nil { return types.Duration{}, true, err } @@ -2404,7 +2403,7 @@ func evalUTCTimestampWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) if err != nil { return types.ZeroTime, true, err } - result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfUp) + result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeTruncate) if err != nil { return types.ZeroTime, true, err } @@ -6492,7 +6491,7 @@ func (b *builtinUTCTimeWithoutArgSig) evalDuration(ctx EvalContext, row chunk.Ro if err != nil { return types.Duration{}, true, err } - v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), 0) + v, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), 0) return v, false, err } @@ -6523,7 +6522,7 @@ func (b *builtinUTCTimeWithArgSig) evalDuration(ctx EvalContext, row chunk.Row) if err != nil { return types.Duration{}, true, err } - v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) + v, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) return v, false, err } diff --git a/pkg/expression/builtin_time_vec.go b/pkg/expression/builtin_time_vec.go index b965952a7108e..f93e642ba264d 100644 --- a/pkg/expression/builtin_time_vec.go +++ b/pkg/expression/builtin_time_vec.go @@ -171,7 +171,7 @@ func (b *builtinSysDateWithoutFspSig) vecEvalTime(ctx EvalContext, input *chunk. result.ResizeTime(n, false) times := result.Times() - t, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp) + t, err := convertTimeToMysqlTime(now, 0, types.ModeTruncate) if err != nil { return err } @@ -411,7 +411,7 @@ func (b *builtinUTCTimeWithArgSig) vecEvalDuration(ctx EvalContext, input *chunk if fsp < int64(types.MinFsp) { return errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) } - res, _, err := types.ParseDuration(tc, utc, int(fsp)) + res, _, err := types.ParseDurationTruncateFsp(tc, utc, int(fsp)) if err != nil { return err } @@ -765,7 +765,7 @@ func (b *builtinSysDateWithFspSig) vecEvalTime(ctx EvalContext, input *chunk.Chu if result.IsNull(i) { continue } - t, err := convertTimeToMysqlTime(now, int(ds[i]), types.ModeHalfUp) + t, err := convertTimeToMysqlTime(now, int(ds[i]), types.ModeTruncate) if err != nil { return err } @@ -1959,7 +1959,7 @@ func (b *builtinUTCTimeWithoutArgSig) vecEvalDuration(ctx EvalContext, input *ch if err != nil { return err } - res, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), types.DefaultFsp) + res, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), types.DefaultFsp) if err != nil { return err } @@ -2362,7 +2362,7 @@ func (b *builtinCurrentTime0ArgSig) vecEvalDuration(ctx EvalContext, input *chun } tz := location(ctx) dur := nowTs.In(tz).Format(types.TimeFormat) - res, _, err := types.ParseDuration(typeCtx(ctx), dur, types.MinFsp) + res, _, err := types.ParseDurationTruncateFsp(typeCtx(ctx), dur, types.MinFsp) if err != nil { return err } @@ -2556,7 +2556,7 @@ func (b *builtinCurrentTime1ArgSig) vecEvalDuration(ctx EvalContext, input *chun result.ResizeGoDuration(n, false) durations := result.GoDurations() for i := 0; i < n; i++ { - res, _, err := types.ParseDuration(tc, dur, int(i64s[i])) + res, _, err := types.ParseDurationTruncateFsp(tc, dur, int(i64s[i])) if err != nil { return err } diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 64dd807171fbd..0433880da51a9 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -1421,30 +1421,79 @@ func TestIssue9710(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) - getSAndMS := func(str string) (int, int) { + getSAndMS := func(str string) (string, string) { results := strings.Split(str, ":") SAndMS := strings.Split(results[len(results)-1], ".") - var s, ms int - s, _ = strconv.Atoi(SAndMS[0]) if len(SAndMS) > 1 { - ms, _ = strconv.Atoi(SAndMS[1]) + return SAndMS[0], SAndMS[1] } - return s, ms + return SAndMS[0], "" } for { - rs := tk.MustQuery("select now(), now(6), unix_timestamp(), unix_timestamp(now())") - s, ms := getSAndMS(rs.Rows()[0][1].(string)) - if ms < 500000 { - time.Sleep(time.Second / 10) - continue + rs := tk.MustQuery("select now(), now(4), now(6), unix_timestamp(), unix_timestamp(now()), unix_timestamp(now(5)), unix_timestamp(now(6)), utc_timestamp(), utc_timestamp(3), utc_timestamp(6), sysdate(), sysdate(2), sysdate(6), curtime(), curtime(1), curtime(6), utc_time(), utc_time(5), utc_time(6)") + n0, nms0 := getSAndMS(rs.Rows()[0][0].(string)) + n4, nms4 := getSAndMS(rs.Rows()[0][1].(string)) + n6, nms6 := getSAndMS(rs.Rows()[0][2].(string)) + + unix0, unixms0 := getSAndMS(rs.Rows()[0][3].(string)) + unixn0, unixnms0 := getSAndMS(rs.Rows()[0][4].(string)) + unixn5, unixnms5 := getSAndMS(rs.Rows()[0][5].(string)) + unixn6, unixnms6 := getSAndMS(rs.Rows()[0][6].(string)) + + utc0, utcms0 := getSAndMS(rs.Rows()[0][7].(string)) + utc3, utcms3 := getSAndMS(rs.Rows()[0][8].(string)) + utc6, utcms6 := getSAndMS(rs.Rows()[0][9].(string)) + + sysDate0, sysDatems0 := getSAndMS(rs.Rows()[0][10].(string)) + sysDate2, sysDatems2 := getSAndMS(rs.Rows()[0][11].(string)) + sysDate6, sysDatems6 := getSAndMS(rs.Rows()[0][12].(string)) + + curTime0, curTimems0 := getSAndMS(rs.Rows()[0][13].(string)) + curTime1, curTimems1 := getSAndMS(rs.Rows()[0][14].(string)) + curTime6, curTimems6 := getSAndMS(rs.Rows()[0][15].(string)) + + utcT0, utcTms0 := getSAndMS(rs.Rows()[0][16].(string)) + utcT5, utcTms5 := getSAndMS(rs.Rows()[0][17].(string)) + utcT6, utcTms6 := getSAndMS(rs.Rows()[0][18].(string)) + + require.Equal(t, n0, n4) // now() will truncate the result instead of rounding it + require.Equal(t, n0, n6) // now() will truncate the result instead of rounding it + require.Equal(t, nms0, "") + require.LessOrEqual(t, nms4, nms6) + + require.Equal(t, unix0, unixn0) + require.Equal(t, unix0, unixn5) + require.Equal(t, unix0, unixn6) + require.Equal(t, unixms0, "") + require.Equal(t, unixnms0, "") + require.LessOrEqual(t, unixnms5, unixnms6) + + require.Equal(t, utc0, utc3) + require.Equal(t, utc0, utc6) + require.Equal(t, utcms0, "") + require.LessOrEqual(t, utcms3, utcms6) + + require.Equal(t, sysDate0, sysDate2) + require.Equal(t, sysDate0, sysDate6) + require.Equal(t, sysDatems0, "") + require.LessOrEqual(t, sysDatems2, sysDatems6) + + require.Equal(t, curTime0, curTime1) + require.Equal(t, curTime0, curTime6) + require.Equal(t, curTimems0, "") + require.LessOrEqual(t, curTimems1, curTimems6) + + require.Equal(t, utcT0, utcT5) + require.Equal(t, utcT0, utcT6) + require.Equal(t, utcTms0, "") + require.LessOrEqual(t, utcTms5, utcTms6) + + // We really want to test truncate when fsp >= .5 + if nms6 >= "500000" { + break } - - s1, _ := getSAndMS(rs.Rows()[0][0].(string)) - require.Equal(t, s, s1) // now() will truncate the result instead of rounding it - - require.Equal(t, rs.Rows()[0][2], rs.Rows()[0][3]) // unix_timestamp() will truncate the result - break + time.Sleep(time.Second / 10) } } diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 5eb5bb873aea1..6d8150f47f6ef 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -1501,13 +1501,13 @@ func (d *Datum) convertToMysqlDuration(typeCtx Context, target *FieldType) (Datu ret.SetMysqlDuration(dur) return ret, errors.Trace(err) } - dur, err = dur.RoundFrac(fsp, typeCtx.Location()) + dur, err = dur.RoundFrac(fsp) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) } case KindMysqlDuration: - dur, err := d.GetMysqlDuration().RoundFrac(fsp, typeCtx.Location()) + dur, err := d.GetMysqlDuration().RoundFrac(fsp) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) @@ -2037,7 +2037,7 @@ func (d *Datum) toSignedInteger(ctx Context, tp byte) (int64, error) { case KindMysqlDuration: // 11:11:11.999999 -> 111112 // 11:59:59.999999 -> 120000 - dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp, ctx.Location()) + dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/types/time.go b/pkg/types/time.go index e2d915df86489..55d1cf7db39a2 100644 --- a/pkg/types/time.go +++ b/pkg/types/time.go @@ -1510,7 +1510,7 @@ func (d Duration) ConvertToYear(ctx Context) (int64, error) { func (d Duration) ConvertToYearFromNow(ctx Context, now gotime.Time) (int64, error) { if ctx.Flags().CastTimeToYearThroughConcat() { // this error will never happen, because we always give a valid FSP - dur, _ := d.RoundFrac(DefaultFsp, ctx.Location()) + dur, _ := d.RoundFrac(DefaultFsp) // the range of a duration will never exceed the range of `mysql.TypeLonglong` ival, _ := dur.ToNumber().ToInt() @@ -1524,29 +1524,28 @@ func (d Duration) ConvertToYearFromNow(ctx Context, now gotime.Time) (int64, err return AdjustYear(int64(datePart.Year()), false) } +// TruncateFrac truncates the fractional second precision +// and returns a new Duration. +// so 10:10:10.999999 truncate 0 -> 10:10:10 +// and 10:10:10.66666 truncate 3 -> 10:10:10.666 +func (d Duration) TruncateFrac(fsp int) (Duration, error) { + fsp, err := CheckFsp(fsp) + if err != nil { + return d, errors.Trace(err) + } + return Duration{Duration: d.Truncate(gotime.Duration(math.Pow10(9 - fsp))), Fsp: fsp}, nil +} + // RoundFrac rounds fractional seconds precision with new fsp and returns a new one. // We will use the “round half up” rule, e.g, >= 0.5 -> 1, < 0.5 -> 0, // so 10:10:10.999999 round 0 -> 10:10:11 // and 10:10:10.000000 round 0 -> 10:10:10 -func (d Duration) RoundFrac(fsp int, loc *gotime.Location) (Duration, error) { - tz := loc - if tz == nil { - logutil.BgLogger().Warn("use gotime.local because sc.timezone is nil") - tz = gotime.Local - } - +func (d Duration) RoundFrac(fsp int) (Duration, error) { fsp, err := CheckFsp(fsp) if err != nil { return d, errors.Trace(err) } - - if fsp == d.Fsp { - return d, nil - } - - n := gotime.Date(0, 0, 0, 0, 0, 0, 0, tz) - nd := n.Add(d.Duration).Round(gotime.Duration(math.Pow10(9-fsp)) * gotime.Nanosecond).Sub(n) //nolint:durationcheck - return Duration{Duration: nd, Fsp: fsp}, nil + return Duration{Duration: d.Round(gotime.Duration(math.Pow10(9 - fsp))), Fsp: fsp}, nil } // Compare returns an integer comparing the Duration instant t to o. @@ -1813,10 +1812,7 @@ func canFallbackToDateTime(str string) bool { return len(rest) > 0 && (rest[0] == ' ' || rest[0] == 'T') } -// ParseDuration parses the time form a formatted string with a fractional seconds part, -// returns the duration type Time value and bool to indicate whether the result is null. -// See http://dev.mysql.com/doc/refman/5.7/en/fractional-seconds.html -func ParseDuration(ctx Context, str string, fsp int) (Duration, bool, error) { +func parseDurationNoRound(ctx Context, str string, fsp int) (Duration, bool, error) { rest := strings.TrimSpace(str) d, isNull, err := matchDuration(rest, fsp) if err == nil { @@ -1836,7 +1832,31 @@ func ParseDuration(ctx Context, str string, fsp int) (Duration, bool, error) { return ZeroDuration, true, ErrTruncatedWrongVal.GenWithStackByArgs("time", str) } - d, err = d.RoundFrac(fsp, ctx.Location()) + return d, false, nil +} + +// ParseDuration parses the time form a formatted string with a fractional seconds part, +// returns the duration type Time value and bool to indicate whether the result is null. +// See http://dev.mysql.com/doc/refman/5.7/en/fractional-seconds.html +func ParseDuration(ctx Context, str string, fsp int) (Duration, bool, error) { + d, isNull, err := parseDurationNoRound(ctx, str, fsp) + if err != nil { + return d, isNull, err + } + d, err = d.RoundFrac(fsp) + return d, false, err +} + +// ParseDurationTruncateFsp parses the time form a formatted string with a +// fractional seconds part, returns the duration type Time value and bool to indicate +// whether the result is null. It also truncates FSP part instead of rounding it as +// ParseDuration above does. +func ParseDurationTruncateFsp(ctx Context, str string, fsp int) (Duration, bool, error) { + d, isNull, err := parseDurationNoRound(ctx, str, fsp) + if err != nil { + return d, isNull, err + } + d, err = d.TruncateFrac(fsp) return d, false, err } diff --git a/pkg/types/time_test.go b/pkg/types/time_test.go index d015f5fb209b3..8889cd70cbaf1 100644 --- a/pkg/types/time_test.go +++ b/pkg/types/time_test.go @@ -912,7 +912,7 @@ func TestRoundFrac(t *testing.T) { for _, tt := range tbl { v, _, err := types.ParseDuration(typeCtx, tt.Input, types.MaxFsp) require.NoError(t, err) - nv, err := v.RoundFrac(tt.Fsp, typeCtx.Location()) + nv, err := v.RoundFrac(tt.Fsp) require.NoError(t, err) require.Equal(t, tt.Except, nv.String()) } @@ -1961,6 +1961,29 @@ func TestTruncateFrac(t *testing.T) { require.Equal(t, col.output.Second(), res.Second()) require.NoError(t, err) } + + // Copied from TestRoundFrac and adjusted! + tbl := []struct { + Input string + Fsp int + Except string + }{ + {"11:30:45.123456", 4, "11:30:45.1234"}, + {"11:30:45.123456", 6, "11:30:45.123456"}, + {"11:30:45.123456", 0, "11:30:45"}, + {"1 11:30:45.123456", 1, "35:30:45.1"}, + {"1 11:30:45.999999", 4, "35:30:45.9999"}, + {"-1 11:30:45.999999", 0, "-35:30:45"}, + } + + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) + for _, tt := range tbl { + v, _, err := types.ParseDuration(typeCtx, tt.Input, types.MaxFsp) + require.NoError(t, err) + nv, err := v.TruncateFrac(tt.Fsp) + require.NoError(t, err) + require.Equal(t, tt.Except, nv.String()) + } } func TestTimeSub(t *testing.T) {