Skip to content

Commit

Permalink
Deal with ints and uints
Browse files Browse the repository at this point in the history
  • Loading branch information
morris-kelly committed Jul 12, 2024
1 parent 84014b3 commit b4e7840
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
20 changes: 19 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node)
if !v.Type().ConvertibleTo(typ) {

// Special case for "strings -> floats" aka scientific notation
// If the destination type is a float and the source type is a string, check if we can
// If the destination type is a float and the source type is a string, check if we can
// use strconv.ParseFloat to convert the string to a float.
if (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) &&
v.Type().Kind() == reflect.String {
Expand Down Expand Up @@ -892,6 +892,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
dst.SetInt(int64(vv))
return nil
}
case string: // handle scientific notation
if i, err := strconv.ParseFloat(vv, 64); err == nil {
if 0 <= i && i <= math.MaxUint64 && !dst.OverflowInt(int64(i)) {
dst.SetInt(int64(i))
return nil
}
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
Expand All @@ -914,6 +923,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
dst.SetUint(uint64(vv))
return nil
}
case string: // handle scientific notation
if i, err := strconv.ParseFloat(vv, 64); err == nil {
if 0 <= i && i <= math.MaxUint64 && !dst.OverflowUint(uint64(i)) {
dst.SetUint(uint64(i))
return nil
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
Expand Down
75 changes: 75 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ func TestDecoder(t *testing.T) {
"v: 4294967295",
map[string]uint{"v": math.MaxUint32},
},
{
"v: 1e3",
map[string]uint{"v": 1000},
},

// uint64
{
Expand All @@ -271,6 +275,10 @@ func TestDecoder(t *testing.T) {
"v: 9223372036854775807",
map[string]uint64{"v": math.MaxInt64},
},
{
"v: 1e3",
map[string]uint64{"v": 1000},
},

// float32
{
Expand Down Expand Up @@ -1101,6 +1109,73 @@ c:
}
}

func TestDecoder_ScientificNotation(t *testing.T) {
tests := []struct {
source string
value interface{}
}{
{
"v: 1e3",
map[string]uint{"v": 1000},
},
{
"v: 1e-3",
map[string]uint{"v": 0},
},
{
"v: 1e3",
map[string]int{"v": 1000},
},
{
"v: 1e-3",
map[string]int{"v": 0},
},
{
"v: 1e3",
map[string]float32{"v": 1000},
},
{
"v: 1.0e3",
map[string]float64{"v": 1000},
},
{
"v: 1e-3",
map[string]float64{"v": 0.001},
},
{
"v: 1.0e-3",
map[string]float64{"v": 0.001},
},
{
"v: 1.0e+3",
map[string]float64{"v": 1000},
},
{
"v: 1.0e+3",
map[string]float64{"v": 1000},
},
}
for _, test := range tests {
t.Run(test.source, func(t *testing.T) {
buf := bytes.NewBufferString(test.source)
dec := yaml.NewDecoder(buf)
typ := reflect.ValueOf(test.value).Type()
value := reflect.New(typ)
if err := dec.Decode(value.Interface()); err != nil {
if err == io.EOF {
return
}
t.Fatalf("%s: %+v", test.source, err)
}
actual := fmt.Sprintf("%+v", value.Elem().Interface())
expect := fmt.Sprintf("%+v", test.value)
if actual != expect {
t.Fatalf("failed to test [%s], actual=[%s], expect=[%s]", test.source, actual, expect)
}
})
}
}

func TestDecoder_TypeConversionError(t *testing.T) {
t.Run("type conversion for struct", func(t *testing.T) {
type T struct {
Expand Down

0 comments on commit b4e7840

Please sign in to comment.