Skip to content

Commit

Permalink
feat: support more loose type-casting (#294)
Browse files Browse the repository at this point in the history
* feat: support more losing type cast

* test: add loose casting tests

* format

* fmt: add license

* fmt: add comments
  • Loading branch information
AsterDY authored Sep 8, 2022
1 parent a48cad8 commit 2138136
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 66 deletions.
63 changes: 63 additions & 0 deletions ast/compat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2022 ByteDance Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ast

import (
`testing`

jsoniter `github.com/json-iterator/go`
`github.com/stretchr/testify/require`
`github.com/tidwall/gjson`
)

func TestNotFoud(t *testing.T) {
data := `{}`

ia := jsoniter.Get([]byte(data), "b")
require.Error(t, ia.LastError())
require.Equal(t, false, ia.ToBool())

ga := gjson.GetBytes([]byte(data), "b")
require.True(t, ga.Type == gjson.Null)
require.Equal(t, false, ga.Bool())

sa, err := NewSearcher(data).GetByPath("b")
require.True(t, sa.Type() == V_NONE)
require.Error(t, err)
sv, err := sa.Bool()
require.Error(t, err)
require.Equal(t, false, sv)
}

func TestNull(t *testing.T) {
data := `{"b": null}`

ia := jsoniter.Get([]byte(data), "b")
require.NoError(t, ia.LastError())
require.Equal(t, false, ia.ToBool())

ga := gjson.GetBytes([]byte(data), "b")
require.True(t, ga.Type == gjson.Null)
require.Equal(t, false, ga.Bool())

sa, err := NewSearcher(data).GetByPath("b")
require.True(t, sa.Type() == V_NULL)
require.NoError(t, err)
sv, err := sa.Bool()
require.NoError(t, err)
require.Equal(t, false, sv)
}
229 changes: 179 additions & 50 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package ast
import (
`encoding/json`
`fmt`
`strconv`
`unsafe`

`github.com/bytedance/sonic/decoder`
`github.com/bytedance/sonic/internal/native/types`
`github.com/bytedance/sonic/internal/rt`
Expand All @@ -37,7 +38,7 @@ const (

const (
_V_NONE types.ValueType = 0
_V_NODE_BASE types.ValueType = 1<<5
_V_NODE_BASE types.ValueType = 1 << 5
_V_LAZY types.ValueType = 1 << 7
_V_RAW types.ValueType = 1 << 8
_V_NUMBER = _V_NODE_BASE + 1
Expand Down Expand Up @@ -165,11 +166,9 @@ func (self *Node) checkRaw() error {
return nil
}

// Bool returns bool value represented by this node
//
// If node type is not types.V_TRUE or types.V_FALSE,
// V_RAW (must be a bool json value), or V_ANY (must be a bool type)
// it will return error
// Bool returns bool value represented by this node,
// including types.V_TRUE|V_FALSE|V_NUMBER|V_STRING|V_ANY|V_NULL,
// V_NONE will return error
func (self *Node) Bool() (bool, error) {
if err := self.checkRaw(); err != nil {
return false, err
Expand All @@ -178,40 +177,97 @@ func (self *Node) Bool() (bool, error) {
case types.V_TRUE : return true , nil
case types.V_FALSE : return false, nil
case types.V_NULL : return false, nil
case _V_ANY :
if v, ok := self.packAny().(bool); ok {
return v, nil
case _V_NUMBER :
if i, err := numberToInt64(self); err == nil {
return i != 0, nil
} else if f, err := numberToFloat64(self); err == nil {
return f != 0, nil
} else {
return false, ErrUnsupportType
return false, err
}
case types.V_STRING: return strconv.ParseBool(addr2str(self.p, self.v))
case _V_ANY :
any := self.packAny()
switch v := any.(type) {
case bool : return v, nil
case int : return v != 0, nil
case int8 : return v != 0, nil
case int16 : return v != 0, nil
case int32 : return v != 0, nil
case int64 : return v != 0, nil
case uint : return v != 0, nil
case uint8 : return v != 0, nil
case uint16 : return v != 0, nil
case uint32 : return v != 0, nil
case uint64 : return v != 0, nil
case float32: return v != 0, nil
case float64: return v != 0, nil
case string : return strconv.ParseBool(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i != 0, nil
} else if f, err := v.Float64(); err == nil {
return f != 0, nil
} else {
return false, err
}
default: return false, ErrUnsupportType
}
default : return false, ErrUnsupportType
}
}

// Int64 casts the node to int64 value, including V_NUMBER, V_TRUE, V_FALSE, V_ANY,
// V_STRING of invalid digits
// Int64 casts the node to int64 value,
// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING
// V_NONE it will return error
func (self *Node) Int64() (int64, error) {
if err := self.checkRaw(); err != nil {
return 0, err
}
switch self.t {
case _V_NUMBER, types.V_STRING : return numberToInt64(self)
case _V_NUMBER, types.V_STRING :
if i, err := numberToInt64(self); err == nil {
return i, nil
} else if f, err := numberToFloat64(self); err == nil {
return int64(f), nil
} else {
return 0, err
}
case types.V_TRUE : return 1, nil
case types.V_FALSE : return 0, nil
case types.V_NULL : return 0, nil
case _V_ANY :
any := self.packAny()
switch v := any.(type) {
case int : return int64(v), nil
case int8 : return int64(v), nil
case int16 : return int64(v), nil
case int32 : return int64(v), nil
case int64 : return int64(v), nil
case uint : return int64(v), nil
case uint8 : return int64(v), nil
case uint16: return int64(v), nil
case uint32: return int64(v), nil
case uint64: return int64(v), nil
case bool : if v { return 1, nil } else { return 0, nil }
case int : return int64(v), nil
case int8 : return int64(v), nil
case int16 : return int64(v), nil
case int32 : return int64(v), nil
case int64 : return int64(v), nil
case uint : return int64(v), nil
case uint8 : return int64(v), nil
case uint16 : return int64(v), nil
case uint32 : return int64(v), nil
case uint64 : return int64(v), nil
case float32: return int64(v), nil
case float64: return int64(v), nil
case string :
if i, err := strconv.ParseInt(v, 10, 64); err == nil {
return i, nil
} else if f, err := strconv.ParseFloat(v, 64); err == nil {
return int64(f), nil
} else {
return 0, err
}
case json.Number:
if i, err := v.Int64(); err == nil {
return i, nil
} else if f, err := v.Float64(); err == nil {
return int64(f), nil
} else {
return 0, err
}
default: return 0, ErrUnsupportType
}
default : return 0, ErrUnsupportType
Expand All @@ -238,14 +294,29 @@ func (self *Node) StrictInt64() (int64, error) {
case uint16: return int64(v), nil
case uint32: return int64(v), nil
case uint64: return int64(v), nil
case json.Number:
if i, err := v.Int64(); err == nil {
return i, nil
} else {
return 0, err
}
default: return 0, ErrUnsupportType
}
default : return 0, ErrUnsupportType
}
}

// Number casts node to float64, including V_NUMBER, V_TRUE, V_FALSE, V_ANY of json.Number,
// V_STRING of invalid digits
func castNumber(v bool) json.Number {
if v {
return json.Number("1")
} else {
return json.Number("0")
}
}

// Number casts node to float64,
// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL,
// V_NONE it will return error
func (self *Node) Number() (json.Number, error) {
if err := self.checkRaw(); err != nil {
return json.Number(""), err
Expand All @@ -264,10 +335,29 @@ func (self *Node) Number() (json.Number, error) {
case types.V_FALSE : return json.Number("0"), nil
case types.V_NULL : return json.Number("0"), nil
case _V_ANY :
if v, ok := self.packAny().(json.Number); ok {
return v, nil
} else {
return json.Number(""), ErrUnsupportType
any := self.packAny()
switch v := any.(type) {
case bool : return castNumber(v), nil
case int : return castNumber(v != 0), nil
case int8 : return castNumber(v != 0), nil
case int16 : return castNumber(v != 0), nil
case int32 : return castNumber(v != 0), nil
case int64 : return castNumber(v != 0), nil
case uint : return castNumber(v != 0), nil
case uint8 : return castNumber(v != 0), nil
case uint16 : return castNumber(v != 0), nil
case uint32 : return castNumber(v != 0), nil
case uint64 : return castNumber(v != 0), nil
case float32: return castNumber(v != 0), nil
case float64: return castNumber(v != 0), nil
case string :
if _, err := strconv.ParseFloat(v, 64); err == nil {
return json.Number(v), nil
} else {
return json.Number(""), err
}
case json.Number: return v, nil
default: return json.Number(""), ErrUnsupportType
}
default : return json.Number(""), ErrUnsupportType
}
Expand All @@ -290,29 +380,38 @@ func (self *Node) StrictNumber() (json.Number, error) {
}
}

// String returns raw string value if node type is V_STRING.
// Or return the string representation of other types:
// V_NULL => "",
// V_TRUE => "true",
// V_FALSE => "false",
// V_NUMBER => "[0-9\.]*"
// V_ANY => interface{}.(string)
// String cast node to string,
// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL,
// V_NONE it will return error
func (self *Node) String() (string, error) {
if err := self.checkRaw(); err != nil {
return "", err
}
switch self.t {
case _V_NUMBER : return toNumber(self).String(), nil
case types.V_NULL : return "" , nil
case types.V_TRUE : return "true" , nil
case types.V_FALSE : return "false", nil
case types.V_STRING : return addr2str(self.p, self.v), nil
case _V_ANY :
if v, ok := self.packAny().(string); ok {
return v, nil
} else {
return "", ErrUnsupportType
}
case types.V_STRING, _V_NUMBER : return addr2str(self.p, self.v), nil
case _V_ANY :
any := self.packAny()
switch v := any.(type) {
case bool : return strconv.FormatBool(v), nil
case int : return strconv.Itoa(v), nil
case int8 : return strconv.Itoa(int(v)), nil
case int16 : return strconv.Itoa(int(v)), nil
case int32 : return strconv.Itoa(int(v)), nil
case int64 : return strconv.Itoa(int(v)), nil
case uint : return strconv.Itoa(int(v)), nil
case uint8 : return strconv.Itoa(int(v)), nil
case uint16 : return strconv.Itoa(int(v)), nil
case uint32 : return strconv.Itoa(int(v)), nil
case uint64 : return strconv.Itoa(int(v)), nil
case float32: return strconv.FormatFloat(float64(v), 'g', -1, 64), nil
case float64: return strconv.FormatFloat(float64(v), 'g', -1, 64), nil
case string : return v, nil
case json.Number: return v.String(), nil
default: return "", ErrUnsupportType
}
default : return "" , ErrUnsupportType
}
}
Expand All @@ -335,7 +434,9 @@ func (self *Node) StrictString() (string, error) {
}
}

// Float64 casts node to float64, includeing V_NUMBER, V_TRUE, V_FALSE, V_ANY
// Float64 cast node to float64,
// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL,
// V_NONE it will return error
func (self *Node) Float64() (float64, error) {
if err := self.checkRaw(); err != nil {
return 0.0, err
Expand All @@ -348,11 +449,39 @@ func (self *Node) Float64() (float64, error) {
case _V_ANY :
any := self.packAny()
switch v := any.(type) {
case float32 : return float64(v), nil
case float64 : return float64(v), nil
default : return 0, ErrUnsupportType
case bool :
if v {
return 1.0, nil
} else {
return 0.0, nil
}
case int : return float64(v), nil
case int8 : return float64(v), nil
case int16 : return float64(v), nil
case int32 : return float64(v), nil
case int64 : return float64(v), nil
case uint : return float64(v), nil
case uint8 : return float64(v), nil
case uint16 : return float64(v), nil
case uint32 : return float64(v), nil
case uint64 : return float64(v), nil
case float32: return float64(v), nil
case float64: return float64(v), nil
case string :
if f, err := strconv.ParseFloat(v, 64); err == nil {
return float64(f), nil
} else {
return 0, err
}
case json.Number:
if f, err := v.Float64(); err == nil {
return float64(f), nil
} else {
return 0, err
}
default : return 0, ErrUnsupportType
}
default : return 0.0, ErrUnsupportType
default : return 0.0, ErrUnsupportType
}
}

Expand Down
Loading

0 comments on commit 2138136

Please sign in to comment.