diff --git a/validator/vars.go b/validator/vars.go index f20eaadc..d98e37e8 100644 --- a/validator/vars.go +++ b/validator/vars.go @@ -1,6 +1,7 @@ package validator import ( + "encoding/json" "fmt" "reflect" "strconv" @@ -29,6 +30,7 @@ func VariableValues(schema *ast.Schema, op *ast.OperationDefinition, variables m } val, hasValue := variables[v.Variable] + if !hasValue { if v.DefaultValue != nil { var err error @@ -50,6 +52,24 @@ func VariableValues(schema *ast.Schema, op *ast.OperationDefinition, variables m coercedVars[v.Variable] = nil } else { rv := reflect.ValueOf(val) + + jsonNumber, isJsonNumber := val.(json.Number) + if isJsonNumber { + if v.Type.NamedType == "Int" { + n, err := jsonNumber.Int64() + if err != nil { + return nil, gqlerror.ErrorPathf(validator.path, "cannot use value %s as %s", n, v.Type.NamedType) + } + rv = reflect.ValueOf(n) + } else if v.Type.NamedType == "Float" { + f, err := jsonNumber.Float64() + if err != nil { + return nil, gqlerror.ErrorPathf(validator.path, "cannot use value %f as %s", f, v.Type.NamedType) + } + rv = reflect.ValueOf(f) + + } + } if rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { rv = rv.Elem() } diff --git a/validator/vars_test.go b/validator/vars_test.go index 5836678d..1c27b019 100644 --- a/validator/vars_test.go +++ b/validator/vars_test.go @@ -1,7 +1,6 @@ package validator_test import ( - "encoding/json" "io/ioutil" "testing" @@ -285,10 +284,19 @@ func TestValidateVars(t *testing.T) { t.Run("Json Number -> Int", func(t *testing.T) { q := gqlparser.MustLoadQuery(schema, `query foo($var: Int) { optionalIntArg(i: $var) }`) vars, gerr := validator.VariableValues(schema, q.Operations.ForName(""), map[string]interface{}{ - "var": json.Number("10"), + "var": 10, }) require.Nil(t, gerr) - require.Equal(t, json.Number("10"), vars["var"]) + require.Equal(t, 10, vars["var"]) + }) + + t.Run("Json Number -> Float", func(t *testing.T) { + q := gqlparser.MustLoadQuery(schema, `query foo($var: Float!) { floatArg(i: $var) }`) + vars, gerr := validator.VariableValues(schema, q.Operations.ForName(""), map[string]interface{}{ + "var": 10.2, + }) + require.Nil(t, gerr) + require.Equal(t, 10.2, vars["var"]) }) t.Run("Nil -> Int", func(t *testing.T) {