diff --git a/functions.go b/functions.go index e9770e8..924ff73 100644 --- a/functions.go +++ b/functions.go @@ -361,7 +361,7 @@ func (a *argSpec) typeCheck(arg interface{}) error { return nil } case jpObject: - if _, ok := arg.(map[string]interface{}); ok { + if isObject(arg) { return nil } case jpArrayNumber: @@ -412,7 +412,7 @@ func jpfLength(arguments []interface{}) (interface{}, error) { } else if isSliceType(arg) { v := reflect.ValueOf(arg) return float64(v.Len()), nil - } else if c, ok := arg.(map[string]interface{}); ok { + } else if c := toObject(arg); c != nil { return float64(len(c)), nil } return nil, errors.New("could not compute length()") @@ -516,7 +516,7 @@ func jpfMax(arguments []interface{}) (interface{}, error) { func jpfMerge(arguments []interface{}) (interface{}, error) { final := make(map[string]interface{}) for _, m := range arguments { - mapped := m.(map[string]interface{}) + mapped := toObject(m) for key, value := range mapped { final[key] = value } @@ -696,7 +696,7 @@ func jpfType(arguments []interface{}) (interface{}, error) { return nil, errors.New("unknown type") } func jpfKeys(arguments []interface{}) (interface{}, error) { - arg := arguments[0].(map[string]interface{}) + arg := toObject(arguments[0]) collected := make([]interface{}, 0, len(arg)) for key := range arg { collected = append(collected, key) @@ -704,7 +704,7 @@ func jpfKeys(arguments []interface{}) (interface{}, error) { return collected, nil } func jpfValues(arguments []interface{}) (interface{}, error) { - arg := arguments[0].(map[string]interface{}) + arg := toObject(arguments[0]) collected := make([]interface{}, 0, len(arg)) for _, value := range arg { collected = append(collected, value) diff --git a/interpreter.go b/interpreter.go index 13c7460..ee571d6 100644 --- a/interpreter.go +++ b/interpreter.go @@ -3,8 +3,6 @@ package jmespath import ( "errors" "reflect" - "unicode" - "unicode/utf8" ) /* This is a tree based interpreter. It walks the AST and directly @@ -76,11 +74,11 @@ func (intr *treeInterpreter) Execute(node ASTNode, value interface{}) (interface } return intr.fCall.CallFunction(node.value.(string), resolvedArgs, intr) case ASTField: - if m, ok := value.(map[string]interface{}); ok { + if m := toObject(value); m != nil { key := node.value.(string) return m[key], nil } - return intr.fieldFromStruct(node.value.(string), value) + return nil, nil case ASTFilterProjection: left, err := intr.Execute(node.children[0], value) if err != nil { @@ -291,8 +289,8 @@ func (intr *treeInterpreter) Execute(node ASTNode, value interface{}) (interface if err != nil { return nil, nil } - mapType, ok := left.(map[string]interface{}) - if !ok { + mapType := toObject(left) + if mapType == nil { return nil, nil } values := make([]interface{}, len(mapType)) @@ -314,31 +312,6 @@ func (intr *treeInterpreter) Execute(node ASTNode, value interface{}) (interface return nil, errors.New("Unknown AST node: " + node.nodeType.String()) } -func (intr *treeInterpreter) fieldFromStruct(key string, value interface{}) (interface{}, error) { - rv := reflect.ValueOf(value) - first, n := utf8.DecodeRuneInString(key) - fieldName := string(unicode.ToUpper(first)) + key[n:] - if rv.Kind() == reflect.Struct { - v := rv.FieldByName(fieldName) - if !v.IsValid() { - return nil, nil - } - return v.Interface(), nil - } else if rv.Kind() == reflect.Ptr { - // Handle multiple levels of indirection? - if rv.IsNil() { - return nil, nil - } - rv = rv.Elem() - v := rv.FieldByName(fieldName) - if !v.IsValid() { - return nil, nil - } - return v.Interface(), nil - } - return nil, nil -} - func (intr *treeInterpreter) flattenWithReflection(value interface{}) (interface{}, error) { v := reflect.ValueOf(value) flattened := []interface{}{} diff --git a/interpreter_test.go b/interpreter_test.go index 76e0c5d..1fde97e 100644 --- a/interpreter_test.go +++ b/interpreter_test.go @@ -115,25 +115,6 @@ func TestCanSupportStructWithSlicePointer(t *testing.T) { assert.Equal("correct", result) } -func TestWillAutomaticallyCapitalizeFieldNames(t *testing.T) { - assert := assert.New(t) - s := scalars{Foo: "one", Bar: "bar"} - // Note that there's a lower cased "foo" instead of "Foo", - // but it should still correspond to the Foo field in the - // scalars struct - result, err := Search("foo", &s) - assert.Nil(err) - assert.Equal("one", result) -} - -func TestCanSupportStructWithSliceLowerCased(t *testing.T) { - assert := assert.New(t) - data := sliceType{A: "foo", B: []scalars{{"f1", "b1"}, {"correct", "b2"}}} - result, err := Search("b[-1].foo", data) - assert.Nil(err) - assert.Equal("correct", result) -} - func TestCanSupportStructWithNestedPointers(t *testing.T) { assert := assert.New(t) data := struct{ A *struct{ B int } }{} diff --git a/object.go b/object.go new file mode 100644 index 0000000..db33858 --- /dev/null +++ b/object.go @@ -0,0 +1,99 @@ +package jmespath + +import ( + "reflect" + "strings" +) + +type objectKind int + +const ( + objectKindNone objectKind = iota + objectKindStruct + objectKindMapStringInterface + objectKindMapStringOther +) + +func getObjectKind(value interface{}) (objectKind, reflect.Value) { + rv := reflect.Indirect(reflect.ValueOf(value)) + if rv.Kind() == reflect.Struct { + return objectKindStruct, rv + } + if rv.Kind() == reflect.Map { + rt := rv.Type() + if rt.Key().Kind() == reflect.String { + if rt.Elem().Kind() == reflect.Interface { + return objectKindMapStringInterface, rv + } + return objectKindMapStringOther, rv + } + } + return objectKindNone, rv +} + +func isObject(value interface{}) bool { + kind, _ := getObjectKind(value) + return kind != objectKindNone +} + +func toObject(value interface{}) map[string]interface{} { + kind, rv := getObjectKind(value) + switch kind { + case objectKindStruct: + // This does not flatten fields from anonymous embedded structs into the top-level struct + // the way the encoding/json package does, as this is quite complicated. These fields can + // still be accessed by specifying the full path to the embedded field. See the typeFields() + // function in https://go.dev/src/encoding/json/encode.go if you feel the need to do add + // flattening functionality. + ret := make(map[string]interface{}) + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + f := rt.Field(i) + if f.IsExported() { + key := f.Name + if t, ok := f.Tag.Lookup("jmes"); ok { + switch t { + case "": + // Leave the key set to the field name + break + case "-": + // Skip this field + continue + default: + // Set the key to the tag value + key = t + } + } else if t, ok := f.Tag.Lookup("json"); ok { + switch t { + case "", "-": + // Leave the key set to the field name + break + default: + if i := strings.IndexByte(t, ','); i >= 0 { + if i != 0 { + // Set the key to the tag value up to the comma + key = t[:i] + } // else leave the key set to the field name + } else { + // Set the key to the tag value + key = t + } + } + } + ret[key] = rv.Field(i).Interface() + } + } + return ret + case objectKindMapStringInterface: + return value.(map[string]interface{}) + case objectKindMapStringOther: + ret := make(map[string]interface{}) + iter := rv.MapRange() + for iter.Next() { + ret[iter.Key().String()] = iter.Value().Interface() + } + return ret + default: + return nil + } +}