Skip to content

Introduce jsontext.RawToken API for better number parsing #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions arshal_any.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package json
import (
"cmp"
"reflect"
"strconv"

"github.com/go-json-experiment/json/internal"
"github.com/go-json-experiment/json/internal/jsonflags"
Expand Down Expand Up @@ -83,9 +82,9 @@ func unmarshalValueAny(dec *jsontext.Decoder, uo *jsonopts.Struct) (any, error)
if uo.Flags.Get(jsonflags.UnmarshalAnyWithRawNumber) {
return internal.RawNumberOf(val), nil
}
fv, ok := jsonwire.ParseFloat(val, 64)
if !ok {
return fv, newUnmarshalErrorAfterWithValue(dec, float64Type, strconv.ErrRange)
fv, err := jsonwire.ParseFloat(val, 64)
if err != nil {
return nil, newUnmarshalErrorAfterWithValue(dec, float64Type, err)
}
return fv, nil
default:
Expand Down Expand Up @@ -196,13 +195,13 @@ func unmarshalObjectAny(dec *jsontext.Decoder, uo *jsonopts.Struct) (map[string]
}

val, err := unmarshalValueAny(dec, uo)
obj[name] = val
if err != nil {
if isFatalError(err, uo.Flags) {
return obj, err
}
errUnmarshal = cmp.Or(err, errUnmarshal)
}
obj[name] = val
}
if _, err := dec.ReadToken(); err != nil {
return obj, err
Expand Down Expand Up @@ -266,13 +265,13 @@ func unmarshalArrayAny(dec *jsontext.Decoder, uo *jsonopts.Struct) ([]any, error
var errUnmarshal error
for dec.PeekKind() != ']' {
val, err := unmarshalValueAny(dec, uo)
arr = append(arr, val)
if err != nil {
if isFatalError(err, uo.Flags) {
return arr, err
}
errUnmarshal = cmp.Or(errUnmarshal, err)
}
arr = append(arr, val)
}
if _, err := dec.ReadToken(); err != nil {
return arr, err
Expand Down
64 changes: 27 additions & 37 deletions arshal_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ var (

bytesType = reflect.TypeFor[[]byte]()
emptyStructType = reflect.TypeFor[struct{}]()

nanString = jsontext.String("NaN")
pinfString = jsontext.String("Infinity")
ninfString = jsontext.String("-Infinity")
)

const startDetectingCyclesAfter = 1000
Expand Down Expand Up @@ -479,28 +483,11 @@ func makeIntArshaler(t reflect.Type) *arshaler {
if stringify && k == '0' {
break
}
var negOffset int
neg := len(val) > 0 && val[0] == '-'
if neg {
negOffset = 1
}
n, ok := jsonwire.ParseUint(val[negOffset:])
maxInt := uint64(1) << (bits - 1)
overflow := (neg && n > maxInt) || (!neg && n > maxInt-1)
if !ok {
if n != math.MaxUint64 {
return newUnmarshalErrorAfterWithValue(dec, t, strconv.ErrSyntax)
}
overflow = true
}
if overflow {
return newUnmarshalErrorAfterWithValue(dec, t, strconv.ErrRange)
}
if neg {
va.SetInt(int64(-n))
} else {
va.SetInt(int64(+n))
n, err := jsonwire.ParseInt(val, bits)
if err != nil {
return newUnmarshalErrorAfterWithValue(dec, t, err)
}
va.SetInt(n)
return nil
}
return newUnmarshalErrorAfter(dec, t, nil)
Expand Down Expand Up @@ -566,17 +553,9 @@ func makeUintArshaler(t reflect.Type) *arshaler {
if stringify && k == '0' {
break
}
n, ok := jsonwire.ParseUint(val)
maxUint := uint64(1) << bits
overflow := n > maxUint-1
if !ok {
if n != math.MaxUint64 {
return newUnmarshalErrorAfterWithValue(dec, t, strconv.ErrSyntax)
}
overflow = true
}
if overflow {
return newUnmarshalErrorAfterWithValue(dec, t, strconv.ErrRange)
n, err := jsonwire.ParseUint(val, bits)
if err != nil {
return newUnmarshalErrorAfterWithValue(dec, t, err)
}
va.SetUint(n)
return nil
Expand Down Expand Up @@ -606,7 +585,18 @@ func makeFloatArshaler(t reflect.Type) *arshaler {
err := fmt.Errorf("unsupported value: %v", fv)
return newMarshalErrorBefore(enc, t, err)
}
return enc.WriteToken(jsontext.Float(fv))
var token jsontext.Token
switch {
case math.IsInf(fv, 1):
token = pinfString
case math.IsInf(fv, -1):
token = ninfString
case math.IsNaN(fv):
token = nanString
default:
panic("unreachable")
}
return enc.WriteToken(token)
}

// Optimize for marshaling without preceding whitespace or string escaping.
Expand Down Expand Up @@ -679,11 +669,11 @@ func makeFloatArshaler(t reflect.Type) *arshaler {
if stringify && k == '0' {
break
}
fv, ok := jsonwire.ParseFloat(val, bits)
va.SetFloat(fv)
if !ok {
return newUnmarshalErrorAfterWithValue(dec, t, strconv.ErrRange)
fv, err := jsonwire.ParseFloat(val, bits)
if err != nil {
return newUnmarshalErrorAfterWithValue(dec, t, err)
}
va.SetFloat(fv)
return nil
}
return newUnmarshalErrorAfter(dec, t, nil)
Expand Down
6 changes: 3 additions & 3 deletions arshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5434,7 +5434,7 @@ func TestUnmarshal(t *testing.T) {
name: jsontest.Name("Floats/Float32/Overflow"),
inBuf: `-1e1000`,
inVal: addr(float32(32.32)),
want: addr(float32(-math.MaxFloat32)),
want: addr(float32(32.32)),
wantErr: EU(strconv.ErrRange).withVal(`-1e1000`).withType('0', T[float32]()),
}, {
name: jsontest.Name("Floats/Float64/Pi"),
Expand All @@ -5450,13 +5450,13 @@ func TestUnmarshal(t *testing.T) {
name: jsontest.Name("Floats/Float64/Overflow"),
inBuf: `-1e1000`,
inVal: addr(float64(64.64)),
want: addr(float64(-math.MaxFloat64)),
want: addr(float64(64.64)),
wantErr: EU(strconv.ErrRange).withVal(`-1e1000`).withType('0', T[float64]()),
}, {
name: jsontest.Name("Floats/Any/Overflow"),
inBuf: `1e1000`,
inVal: new(any),
want: addr(any(float64(math.MaxFloat64))),
want: new(any),
wantErr: EU(strconv.ErrRange).withVal(`1e1000`).withType('0', T[float64]()),
}, {
name: jsontest.Name("Floats/Named"),
Expand Down
45 changes: 24 additions & 21 deletions arshal_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,18 +419,19 @@ func appendDurationBase10(b []byte, d time.Duration, pow10 uint64) []byte {
func parseDurationBase10(b []byte, pow10 uint64) (time.Duration, error) {
suffix, neg := consumeSign(b) // consume sign
wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
whole, okWhole := jsonwire.ParseUint(wholeBytes) // parse whole field; may overflow
whole, err := jsonwire.ParseUint(wholeBytes, 64) // parse whole field; may overflow
frac, okFrac := parseFracBase10(fracBytes, pow10) // parse frac field
hi, lo := bits.Mul64(whole, uint64(pow10)) // overflow if hi > 0
sum, co := bits.Add64(lo, uint64(frac), 0) // overflow if co > 0
switch d := mayApplyDurationSign(sum, neg); { // overflow if neg != (d < 0)
case (!okWhole && whole != math.MaxUint64) || !okFrac:
return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrSyntax)
case !okWhole || hi > 0 || co > 0 || neg != (d < 0):
return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrRange)
default:
case !okFrac:
err = strconv.ErrSyntax
case hi > 0 || co > 0 || neg != (d < 0):
err = strconv.ErrRange
case err == nil:
return d, nil
}
return 0, fmt.Errorf("invalid duration %q: %w", b, err)
}

// appendDurationBase60 appends d formatted with H:MM:SS.SSS notation.
Expand All @@ -455,21 +456,22 @@ func parseDurationBase60(b []byte) (time.Duration, error) {
hourBytes, suffix := bytesCutByte(suffix, ':', false) // consume hour field
minBytes, suffix := bytesCutByte(suffix, ':', false) // consume min field
secBytes, nsecBytes := bytesCutByte(suffix, '.', true) // consume sec and nsec fields
hour, okHour := jsonwire.ParseUint(hourBytes) // parse hour field; may overflow
hour, err := jsonwire.ParseUint(hourBytes, 64) // parse hour field; may overflow
min := parseDec2(minBytes) // parse min field
sec := parseDec2(secBytes) // parse sec field
nsec, okNsec := parseFracBase10(nsecBytes, 1e9) // parse nsec field
n := uint64(min)*60*1e9 + uint64(sec)*1e9 + uint64(nsec) // cannot overflow
hi, lo := bits.Mul64(hour, 60*60*1e9) // overflow if hi > 0
sum, co := bits.Add64(lo, n, 0) // overflow if co > 0
switch d := mayApplyDurationSign(sum, neg); { // overflow if neg != (d < 0)
case (!okHour && hour != math.MaxUint64) || !checkBase60(minBytes) || !checkBase60(secBytes) || !okNsec:
return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrSyntax)
case !okHour || hi > 0 || co > 0 || neg != (d < 0):
return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrRange)
default:
case !checkBase60(minBytes) || !checkBase60(secBytes) || !okNsec:
err = strconv.ErrSyntax
case hi > 0 || co > 0 || neg != (d < 0):
err = strconv.ErrRange
case err == nil:
return d, nil
}
return 0, fmt.Errorf("invalid duration %q: %w", b, err)
}

// mayAppendDurationSign appends a negative sign if n is negative.
Expand Down Expand Up @@ -517,19 +519,19 @@ func appendTimeUnix(b []byte, t time.Time, pow10 uint64) []byte {
func parseTimeUnix(b []byte, pow10 uint64) (time.Time, error) {
suffix, neg := consumeSign(b) // consume sign
wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
whole, okWhole := jsonwire.ParseUint(wholeBytes) // parse whole field; may overflow
whole, err := jsonwire.ParseUint(wholeBytes, 64) // parse whole field; may overflow
frac, okFrac := parseFracBase10(fracBytes, 1e9/pow10) // parse frac field
var sec, nsec int64
switch {
case pow10 == 1e0: // fast case where units is in seconds
sec = int64(whole) // check overflow later after negation
nsec = int64(frac) // cannot overflow
case okWhole: // intermediate case where units is not seconds, but no overflow
case err == nil: // intermediate case where units is not seconds, but no overflow
sec = int64(whole / pow10) // check overflow later after negation
nsec = int64((whole%pow10)*(1e9/pow10) + frac) // cannot overflow
case !okWhole && whole == math.MaxUint64: // slow case where units is not seconds and overflow occurred
case err == strconv.ErrRange: // slow case where units is not seconds and overflow occurred
width := int(math.Log10(float64(pow10))) // compute len(strconv.Itoa(pow10-1))
whole, okWhole = jsonwire.ParseUint(wholeBytes[:len(wholeBytes)-width]) // parse the upper whole field
whole, err = jsonwire.ParseUint(wholeBytes[:len(wholeBytes)-width], 64) // parse the upper whole field
mid, _ := parsePaddedBase10(wholeBytes[len(wholeBytes)-width:], pow10) // parse the lower whole field
sec = int64(whole) // check overflow later after negation
nsec = int64(mid*(1e9/pow10) + frac) // cannot overflow
Expand All @@ -538,13 +540,14 @@ func parseTimeUnix(b []byte, pow10 uint64) (time.Time, error) {
sec, nsec = negateSecNano(sec, nsec)
}
switch t := time.Unix(sec, nsec).UTC(); {
case (!okWhole && whole != math.MaxUint64) || !okFrac:
return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrSyntax)
case !okWhole || neg != (t.Unix() < 0):
return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrRange)
default:
case !okFrac:
err = strconv.ErrSyntax
case neg != (t.Unix() < 0):
err = strconv.ErrRange
case err == nil:
return t, nil
}
return time.Time{}, fmt.Errorf("invalid time %q: %w", b, err)
}

// negateSecNano negates a Unix timestamp, where nsec must be within [0, 1e9).
Expand Down
8 changes: 6 additions & 2 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,13 @@ func mustDecodeTokens(t testing.TB, data []byte) []jsontext.Token {
case '"':
tokens = append(tokens, jsontext.String(tok.String()))
case '0':
tokens = append(tokens, jsontext.Float(tok.Float()))
v, err := tok.ParseFloat(64)
if err != nil {
t.Fatalf("ParseFloat error: %v", err)
}
tokens = append(tokens, jsontext.Float(v))
default:
tokens = append(tokens, tok.Clone())
tokens = append(tokens, jsontext.Raw(tok.Clone()))
}
}
return tokens
Expand Down
71 changes: 45 additions & 26 deletions internal/jsonwire/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package jsonwire

import (
"io"
"math"
"slices"
"strconv"
"unicode/utf16"
Expand Down Expand Up @@ -586,42 +585,62 @@ func parseHexUint16[Bytes ~[]byte | ~string](b Bytes) (v uint16, ok bool) {

// ParseUint parses b as a decimal unsigned integer according to
// a strict subset of the JSON number grammar, returning the value if valid.
// It returns (0, false) if there is a syntax error and
// returns (math.MaxUint64, false) if there is an overflow.
func ParseUint(b []byte) (v uint64, ok bool) {
// It returns (0, strconv.ErrSyntax) if there is a syntax error and
// returns (max, strconv.ErrRange) if there is an overflow.
func ParseUint(b []byte, bits int) (uint64, error) {
const unsafeWidth = 20 // len(fmt.Sprint(uint64(math.MaxUint64)))
var n int
var v uint64
for ; len(b) > n && ('0' <= b[n] && b[n] <= '9'); n++ {
v = 10*v + uint64(b[n]-'0')
}

max := uint64(1)<<uint(bits) - 1
switch {
case n == 0 || len(b) != n || (b[0] == '0' && string(b) != "0"):
return 0, false
return 0, strconv.ErrSyntax
case n >= unsafeWidth && (b[0] != '1' || v < 1e19 || n > unsafeWidth):
return math.MaxUint64, false
return max, strconv.ErrRange
case v > max:
return max, strconv.ErrRange
}
return v, nil
}

func ParseInt(b []byte, bits int) (int64, error) {
negOffset := 0
neg := len(b) > 0 && b[0] == '-'
if neg {
negOffset = 1
}
n, err := ParseUint(b[negOffset:], bits)
if err != nil && n == 0 {
return 0, err
}

maxInt := uint64(1) << (bits - 1)
if neg && n > maxInt {
return -int64(maxInt), strconv.ErrRange
} else if !neg && n > maxInt-1 {
return int64(maxInt - 1), strconv.ErrRange
}

if neg {
return int64(-n), nil
} else {
return int64(+n), nil
}
return v, true
}

// ParseFloat parses a floating point number according to the Go float grammar.
// Note that the JSON number grammar is a strict subset.
//
// If the number overflows the finite representation of a float,
// then we return MaxFloat since any finite value will always be infinitely
// more accurate at representing another finite value than an infinite value.
func ParseFloat(b []byte, bits int) (v float64, ok bool) {
fv, err := strconv.ParseFloat(string(b), bits)
if math.IsInf(fv, 0) {
switch {
case bits == 32 && math.IsInf(fv, +1):
fv = +math.MaxFloat32
case bits == 64 && math.IsInf(fv, +1):
fv = +math.MaxFloat64
case bits == 32 && math.IsInf(fv, -1):
fv = -math.MaxFloat32
case bits == 64 && math.IsInf(fv, -1):
fv = -math.MaxFloat64
}
func ParseFloat(b []byte, bits int) (float64, error) {
// Note that the JSON number grammar is a strict subset.
// We have ensured the input is a valid json number in [ConsumeNumberResumable],
// So we may take advantage of the simpler grammar and
// replace this with a more efficient implementation in the future.
v, err := strconv.ParseFloat(string(b), bits)
if err != nil {
err = err.(*strconv.NumError).Err
}
return fv, err == nil
return v, err
}
Loading
Loading