diff --git a/encode.go b/encode.go index db1d775..8ace28b 100644 --- a/encode.go +++ b/encode.go @@ -52,7 +52,7 @@ func marshal(v interface{}, flags encodeFlags) (types.AttributeValue, error) { if err != nil { return nil, err } - enc, err := def.encodeType(rt, flags) + enc, err := def.encodeType(rt, flags, nil) if err != nil { return nil, err } @@ -90,7 +90,9 @@ func encodeItem(fields []structField, rv reflect.Value) (Item, error) { continue } } - + if field.enc == nil { + continue + } av, err := field.enc(fv, field.flags) if err != nil { return nil, err @@ -165,27 +167,11 @@ func isZeroIface[T any](rt reflect.Type, isZero func(v T) bool) func(rv reflect. } func (def *typedef) isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool { - fields, err := def.structFields(rt, false) - if err != nil { - return nil - } - return func(rv reflect.Value) bool { - for _, info := range *fields { - if info.isZero == nil { - continue - } - - field := dig(rv, info.index) - if !field.IsValid() { - return true - } - - if !info.isZero(field) { - return false - } - } - return true + if fn := def.info.findZero(rt); fn != nil { + return fn } + child, _ := def.structInfo(rt, def.info) + return child.isZero } func (def *typedef) isZeroArray(rt reflect.Type) func(reflect.Value) bool { diff --git a/encode_test.go b/encode_test.go index fe95dc2..0ceb94d 100644 --- a/encode_test.go +++ b/encode_test.go @@ -193,73 +193,3 @@ func TestMarshalItemBypass(t *testing.T) { t.Error("bad unmarshal") } } - -func TestMarshalRecursive(t *testing.T) { - t.SkipNow() - - type Person struct { - Spouse *Person - Children []Person - Name string - } - type Friend struct { - ID int - Person Person - Nickname string - } - children := []Person{ - {Name: "Bobby"}, - } - - hank := Person{ - Spouse: &Person{ - Name: "Peggy", - Children: children, - }, - Children: children, - Name: "Hank", - } - - t.Run("self-recursive", func(t *testing.T) { - - want := map[string]types.AttributeValue{ - "Name": &types.AttributeValueMemberS{Value: "Hank"}, - "Spouse": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ - "Name": &types.AttributeValueMemberS{Value: "Peggy"}, - "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ - &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ - "Name": &types.AttributeValueMemberS{Value: "Bobby"}, - "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, - }}, - }, - }, - }}, - "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{ - &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ - "Name": &types.AttributeValueMemberS{Value: "Bobby"}, - "Children": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, - }}, - }}, - } - - got, err := MarshalItem(hank) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, want) { - t.Error("bad", got) - } - }) - - t.Run("field is recursive", func(t *testing.T) { - friend := Friend{ - Person: hank, - Nickname: "H-Dawg", - } - got, err := MarshalItem(friend) - if err != nil { - t.Fatal(err) - } - t.Fatal(got) - }) -} diff --git a/encodefunc.go b/encodefunc.go index fad4c2f..a1a793c 100644 --- a/encodefunc.go +++ b/encodefunc.go @@ -13,7 +13,12 @@ import ( type encodeFunc func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) -func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { + encKey := encodeKey{rt, flags} + if fn := info.findEncoder(encKey); fn != nil { + return fn, nil + } + try := rt for { switch try { @@ -122,7 +127,7 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, switch rt.Kind() { case reflect.Pointer: - return def.encodePtr(rt, flags) + return def.encodePtr(rt, flags, info) // BOOL case reflect.Bool: @@ -152,7 +157,7 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, return encodeSet(rt, flags) } // lists (L) - return def.encodeList(rt, flags) + return def.encodeList(rt, flags, info) case reflect.Map: // sets (NS, SS, BS) @@ -160,11 +165,11 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, return encodeSet(rt, flags) } // M - return def.encodeMapM(rt, flags) + return def.encodeMapM(rt, flags, info) // M case reflect.Struct: - return def.encodeStruct(rt) + return def.encodeStruct(rt, flags, info) case reflect.Interface: if rt.NumMethod() == 0 { @@ -174,8 +179,8 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, return nil, fmt.Errorf("dynamo marshal: unsupported type %s", rt.String()) } -func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { - elem, err := def.encodeType(rt.Elem(), flags) +func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { + elem, err := def.encodeType(rt.Elem(), flags, info) if err != nil { return nil, err } @@ -278,23 +283,19 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { } } -func (def *typedef) encodeStruct(rt reflect.Type) (encodeFunc, error) { - var fields *[]structField - var err error - if def.sameAsRoot(rt) { - fields, err = def.structFields(rt, false) - } else { - var subdef *typedef - subdef, err = typedefOf(rt) - if subdef != nil { - fields = &subdef.fields - } - } +func (def *typedef) encodeStruct(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { + info2, err := def.structInfo(rt, info) if err != nil { return nil, err } + + var fields []structField + for _, field := range info2.fields { + fields = append(fields, *field) + } + return func(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { - item, err := encodeItem(*fields, rv) + item, err := encodeItem(fields, rv) if err != nil { return nil, err } @@ -381,7 +382,7 @@ func encodeSliceBS(rv reflect.Value, flags encodeFlags) (types.AttributeValue, e return &types.AttributeValueMemberBS{Value: bs}, nil } -func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { keyString := encodeMapKeyFunc(rt) if keyString == nil { return nil, fmt.Errorf("dynamo marshal: map key type must be string or encoding.TextMarshaler, have %v", rt) @@ -397,7 +398,7 @@ func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, subflags |= flagOmitEmpty } - valueEnc, err := def.encodeType(rt.Elem(), subflags) + valueEnc, err := def.encodeType(rt.Elem(), subflags, info) if err != nil { return nil, err } @@ -594,7 +595,7 @@ func encodeSet(rt /* []T | map[T]bool | map[T]struct{} */ reflect.Type, flags en return nil, fmt.Errorf("dynamo: marshal: invalid type for set %s", rt.String()) } -func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { +func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) { // lists CAN be empty subflags := flagNone if flags&flagOmitEmptyElem == 0 { @@ -608,7 +609,7 @@ func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, subflags |= flagAllowEmptyElem } - valueEnc, err := def.encodeType(rt.Elem(), subflags) + valueEnc, err := def.encodeType(rt.Elem(), subflags, info) if err != nil { return nil, err } @@ -645,7 +646,7 @@ func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (types.Attrib } return nil, nil } - enc, err := def.encodeType(rv.Elem().Type(), flags) + enc, err := def.encodeType(rv.Elem().Type(), flags, nil) if err != nil { return nil, err } diff --git a/encoding.go b/encoding.go index 599dd08..5077b0c 100644 --- a/encoding.go +++ b/encoding.go @@ -16,6 +16,7 @@ type typedef struct { decoders map[unmarshalKey]decodeFunc fields []structField root reflect.Type + info *structInfo } func newTypedef(rt reflect.Type) (*typedef, error) { @@ -45,11 +46,15 @@ func (def *typedef) init(rt reflect.Type) error { return nil } - fieldptr, err := def.structFields(rt, true) - if fieldptr != nil { - def.fields = *fieldptr + info, err := def.structInfo(rt, nil) + if err != nil { + return err + } + for _, field := range info.fields { + def.fields = append(def.fields, *field) } - return err + def.info = info + return nil } func registerTypedef(gotype reflect.Type, def *typedef) *typedef { @@ -98,7 +103,7 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) { case reflect.Struct: return encodeItem(def.fields, rv) case reflect.Map: - enc, err := def.encodeMapM(rv.Type(), flagNone) + enc, err := def.encodeMapM(rv.Type(), flagNone, def.info) if err != nil { return nil, err } @@ -448,47 +453,6 @@ type structField struct { isZero func(reflect.Value) bool } -// type encodeKey struct { -// rt reflect.Type -// flags encodeFlags -// } - -func (def *typedef) sameAsRoot(rt reflect.Type) bool { - switch { - case rt == def.root: - return true - case def.root.Kind() == reflect.Pointer && rt.Kind() != reflect.Pointer: - return def.root.Elem() == rt - case def.root.Kind() != reflect.Pointer && rt.Kind() == reflect.Pointer: - return rt.Elem() == def.root - } - return false -} - -func (def *typedef) structFields(rt reflect.Type, isRoot bool) (*[]structField, error) { - if !isRoot && def.sameAsRoot(rt) { - return &def.fields, nil - } - - var fields []structField - err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error { - enc, err := def.encodeType(vt, flags) - if err != nil { - return err - } - field := structField{ - index: index, - name: name, - flags: flags, - enc: enc, - isZero: def.isZeroFunc(vt), - } - fields = append(fields, field) - return nil - }) - return &fields, err -} - var ( nullAV = &types.AttributeValueMemberNULL{Value: true} emptyB = &types.AttributeValueMemberB{Value: []byte("")} diff --git a/encoding_test.go b/encoding_test.go index f266f1e..d07c512 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -779,6 +779,78 @@ var itemEncodingTests = []struct { }}, }, }, + { + name: "mega recursion A -> B -> *A -> B", + in: MegaRecursiveA{ + ID: 123, + Text: "hello", + Child: MegaRecursiveB{ + ID: "test", + Blah: 555, + Child: &MegaRecursiveA{ + ID: 222, + Text: "help", + Child: MegaRecursiveB{ + ID: "why", + Blah: 1337, + }, + Friends: []MegaRecursiveA{}, + Enemies: []MegaRecursiveB{}, + }, + }, + Friends: []MegaRecursiveA{ + {ID: 1, Text: "suffering", Child: MegaRecursiveB{ID: "pain"}, Friends: []MegaRecursiveA{}, Enemies: []MegaRecursiveB{}}, + {ID: 2, Text: "love", Child: MegaRecursiveB{ID: "understanding"}, Friends: []MegaRecursiveA{}, Enemies: []MegaRecursiveB{}}, + }, + Enemies: []MegaRecursiveB{ + {ID: "recursion", Blah: 30}, + }, + }, + out: Item{ + "ID": &types.AttributeValueMemberN{Value: "123"}, + "Text": &types.AttributeValueMemberS{Value: "hello"}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "1"}, + "Text": &types.AttributeValueMemberS{Value: "suffering"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "pain"}, + }}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "2"}, + "Text": &types.AttributeValueMemberS{Value: "love"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "understanding"}, + }}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + }}, + }}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "recursion"}, + "Blah": &types.AttributeValueMemberN{Value: "30"}, + }}, + }}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "test"}, + "Blah": &types.AttributeValueMemberN{Value: "555"}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberN{Value: "222"}, + "Text": &types.AttributeValueMemberS{Value: "help"}, + "Friends": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Enemies": &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + "Child": &types.AttributeValueMemberM{Value: Item{ + "ID": &types.AttributeValueMemberS{Value: "why"}, + "Blah": &types.AttributeValueMemberN{Value: "1337"}, + }}, + }}, + }}, + }, + }, } type embedded struct { @@ -879,6 +951,20 @@ type Friend struct { Nickname string } +type MegaRecursiveA struct { + ID int + Child MegaRecursiveB + Text string + Friends []MegaRecursiveA + Enemies []MegaRecursiveB +} + +type MegaRecursiveB struct { + ID string + Child *MegaRecursiveA + Blah int `dynamo:",omitempty"` +} + func byteSlicePtr(a []byte) *[]byte { return &a } diff --git a/query.go b/query.go index 3fbebf1..a2eacbc 100644 --- a/query.go +++ b/query.go @@ -187,10 +187,10 @@ func (q *Query) SearchLimit(limit int) *Query { } // RequestLimit specifies the maximum amount of requests to make against DynamoDB's API. -func (q *Query) RequestLimit(limit int) *Query { - q.reqLimit = limit - return q -} +// func (q *Query) RequestLimit(limit int) *Query { +// q.reqLimit = limit +// return q +// } // Order specifies the desired result order. // Requires a range key (a.k.a. partition key) to be specified. diff --git a/reflect.go b/reflect.go index 64debe8..cb3bc36 100644 --- a/reflect.go +++ b/reflect.go @@ -194,6 +194,201 @@ func visitFields(item map[string]types.AttributeValue, rv reflect.Value, seen ma return nil } +type encodeKey struct { + rt reflect.Type + flags encodeFlags +} + +type structInfo struct { + root reflect.Type + fields map[string]*structField // by name + refs map[encodeKey][]*structField + types map[encodeKey]encodeFunc + zeros map[reflect.Type]func(reflect.Value) bool + parent *structInfo + + seen map[encodeKey]struct{} + queue []encodeKey +} + +func (info *structInfo) encode(rv reflect.Value, flags encodeFlags) (types.AttributeValue, error) { + item := make(Item, len(info.fields)) + for _, field := range info.fields { + fv := dig(rv, field.index) + if !fv.IsValid() { + // TODO: encode NULL? + continue + } + + if field.flags&flagOmitEmpty != 0 && field.isZero != nil { + if field.isZero(fv) { + continue + } + } + + av, err := field.enc(fv, field.flags) + if err != nil { + return nil, err + } + if av == nil { + if field.flags&flagNull != 0 { + item[field.name] = nullAV + } + continue + } + item[field.name] = av + } + return &types.AttributeValueMemberM{Value: item}, nil +} + +func (info *structInfo) isZero(rv reflect.Value) bool { + if info == nil { + return false + } + for _, field := range info.fields { + fv := dig(rv, field.index) + if !fv.IsValid() { + // TODO: encode NULL? + continue + } + if !field.isZero(fv) { + return false + } + } + return true +} + +func (info *structInfo) findEncoder(key encodeKey) encodeFunc { + if info == nil { + return nil + } + if key.rt == info.root { + return info.encode + } + if enc, ok := info.types[key]; ok { + return enc + } + return info.parent.findEncoder(key) +} + +func (info *structInfo) findZero(rt reflect.Type) func(reflect.Value) bool { + if info == nil { + return nil + } + if rt == info.root { + return info.isZero + } + if isZero, ok := info.zeros[rt]; ok { + return isZero + } + return info.parent.findZero(rt) +} + +func (def *typedef) structInfo(rt reflect.Type, parent *structInfo) (*structInfo, error) { + rti := rt + for rti.Kind() == reflect.Pointer { + rti = rti.Elem() + } + if rti.Kind() != reflect.Struct { + return nil, nil + } + + info := &structInfo{ + root: rt, + parent: parent, + fields: make(map[string]*structField), + refs: make(map[encodeKey][]*structField), + types: make(map[encodeKey]encodeFunc), + zeros: make(map[reflect.Type]func(reflect.Value) bool), + seen: make(map[encodeKey]struct{}), + } + + collectTypes(rt, info, nil) + + for _, key := range info.queue { + fn, err := def.encodeType(key.rt, key.flags, info) + if err != nil { + return info, err + } + isZero := info.findZero(key.rt) + if isZero == nil { + isZero = def.isZeroFunc(key.rt) + } + for _, sf := range info.refs[key] { + sf.enc = fn + sf.isZero = isZero + } + info.types[key] = fn + info.zeros[key.rt] = isZero + } + + // don't need these anymore + info.queue = nil + info.seen = nil + + return info, nil +} + +func collectTypes(rt reflect.Type, info *structInfo, trail []int) *structInfo { + for rt.Kind() == reflect.Pointer { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + panic("not a struct") + } + + // fields := make(map[string]reflect.Value) + for i := 0; i < rt.NumField(); i++ { + field := rt.Field(i) + ft := field.Type + isPtr := ft.Kind() == reflect.Ptr + + name, flags := fieldInfo(field) + if name == "-" { + // skip + continue + } + + key := encodeKey{ + rt: ft, + flags: flags, + } + + idx := field.Index + if len(trail) > 0 { + idx = append(trail, idx...) + } + + sf := &structField{ + index: idx, + name: name, + flags: flags, + } + public := field.IsExported() + if _, ok := info.fields[name]; !ok { + if public { + info.fields[name] = sf + } + info.refs[key] = append(info.refs[key], sf) + } + + // embed anonymous structs, they could be pointers so test that too + if (ft.Kind() == reflect.Struct || isPtr && ft.Elem().Kind() == reflect.Struct) && field.Anonymous { + collectTypes(ft, info, idx) + continue + } + + if !public { + continue + } + if _, ok := info.seen[key]; ok { + continue + } + info.queue = append(info.queue, key) + } + return info +} + func visitTypeFields(rt reflect.Type, seen map[string]struct{}, trail []int, fn func(name string, index []int, flags encodeFlags, vt reflect.Type) error) error { for rt.Kind() == reflect.Pointer { rt = rt.Elem() @@ -263,7 +458,7 @@ func reallocMap(v reflect.Value, size int) { type decodeKeyFunc func(reflect.Value, string) error func decodeMapKeyFunc(rt reflect.Type) decodeKeyFunc { - if reflect.PtrTo(rt.Key()).Implements(rtypeTextUnmarshaler) { + if reflect.PointerTo(rt.Key()).Implements(rtypeTextUnmarshaler) { return func(kv reflect.Value, s string) error { tm := kv.Interface().(encoding.TextUnmarshaler) if err := tm.UnmarshalText([]byte(s)); err != nil { diff --git a/scan.go b/scan.go index 796b2c5..58b1655 100644 --- a/scan.go +++ b/scan.go @@ -138,10 +138,10 @@ func (s *Scan) SearchLimit(limit int) *Scan { } // RequestLimit specifies the maximum amount of requests to make against DynamoDB's API. -func (s *Scan) RequestLimit(limit int) *Scan { - s.reqLimit = limit - return s -} +// func (s *Scan) RequestLimit(limit int) *Scan { +// s.reqLimit = limit +// return s +// } // ConsumedCapacity will measure the throughput capacity consumed by this operation and add it to cc. func (s *Scan) ConsumedCapacity(cc *ConsumedCapacity) *Scan {