Skip to content

Commit

Permalink
support A -> B -> A recursive types (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed May 5, 2024
1 parent d5fb452 commit b2544f3
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 164 deletions.
30 changes: 8 additions & 22 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func marshal(v interface{}, flags encodeFlags) (*dynamodb.AttributeValue, error)
if err != nil {
return nil, err
}
enc, err := def.encodeType(rt, flags)
enc, err := def.encodeType(rt, flags, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -90,7 +90,9 @@ func encodeItem(fields []structField, rv reflect.Value) (Item, error) {
continue
}
}

if field.enc == nil {
continue
}
av, err := field.enc(fv, field.flags)
if err != nil {
return nil, err
Expand Down Expand Up @@ -165,27 +167,11 @@ func isZeroIface[T any](rt reflect.Type, isZero func(v T) bool) func(rv reflect.
}

func (def *typedef) isZeroStruct(rt reflect.Type) func(rv reflect.Value) bool {
fields, err := def.structFields(rt, false)
if err != nil {
return nil
}
return func(rv reflect.Value) bool {
for _, info := range *fields {
if info.isZero == nil {
continue
}

field := dig(rv, info.index)
if !field.IsValid() {
return true
}

if !info.isZero(field) {
return false
}
}
return true
if fn := def.info.findZero(rt); fn != nil {
return fn
}
child, _ := def.structInfo(rt, def.info)
return child.isZero
}

func (def *typedef) isZeroArray(rt reflect.Type) func(reflect.Value) bool {
Expand Down
70 changes: 0 additions & 70 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,73 +187,3 @@ func TestMarshalItemBypass(t *testing.T) {
t.Error("bad unmarshal")
}
}

func TestMarshalRecursive(t *testing.T) {
t.SkipNow()

type Person struct {
Spouse *Person
Children []Person
Name string
}
type Friend struct {
ID int
Person Person
Nickname string
}
children := []Person{
{Name: "Bobby"},
}

hank := Person{
Spouse: &Person{
Name: "Peggy",
Children: children,
},
Children: children,
Name: "Hank",
}

t.Run("self-recursive", func(t *testing.T) {

want := map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Hank")},
"Spouse": {M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Peggy")},
"Children": {L: []*dynamodb.AttributeValue{
{M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Bobby")},
"Children": {L: []*dynamodb.AttributeValue{}},
}},
},
},
}},
"Children": {L: []*dynamodb.AttributeValue{
{M: map[string]*dynamodb.AttributeValue{
"Name": {S: aws.String("Bobby")},
"Children": {L: []*dynamodb.AttributeValue{}},
}},
}},
}

got, err := MarshalItem(hank)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Error("bad", got)
}
})

t.Run("field is recursive", func(t *testing.T) {
friend := Friend{
Person: hank,
Nickname: "H-Dawg",
}
got, err := MarshalItem(friend)
if err != nil {
t.Fatal(err)
}
t.Fatal(got)
})
}
50 changes: 25 additions & 25 deletions encodefunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ import (

type encodeFunc func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error)

func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) {
encKey := encodeKey{rt, flags}
if fn := info.findEncoder(encKey); fn != nil {
return fn, nil
}

try := rt
for {
switch try {
Expand Down Expand Up @@ -53,7 +58,7 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc,

switch rt.Kind() {
case reflect.Pointer:
return def.encodePtr(rt, flags)
return def.encodePtr(rt, flags, info)

// BOOL
case reflect.Bool:
Expand Down Expand Up @@ -83,19 +88,19 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc,
return encodeSet(rt, flags)
}
// lists (L)
return def.encodeList(rt, flags)
return def.encodeList(rt, flags, info)

case reflect.Map:
// sets (NS, SS, BS)
if flags&flagSet != 0 {
return encodeSet(rt, flags)
}
// M
return def.encodeMapM(rt, flags)
return def.encodeMapM(rt, flags, info)

// M
case reflect.Struct:
return def.encodeStruct(rt)
return def.encodeStruct(rt, flags, info)

case reflect.Interface:
if rt.NumMethod() == 0 {
Expand All @@ -105,8 +110,8 @@ func (def *typedef) encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc,
return nil, fmt.Errorf("dynamo marshal: unsupported type %s", rt.String())
}

func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
elem, err := def.encodeType(rt.Elem(), flags)
func (def *typedef) encodePtr(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) {
elem, err := def.encodeType(rt.Elem(), flags, info)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -209,24 +214,19 @@ func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc {
}
}

func (def *typedef) encodeStruct(rt reflect.Type) (encodeFunc, error) {
var fields *[]structField
var err error
if def.sameAsRoot(rt) {
fields, err = def.structFields(rt, false)
} else {
var subdef *typedef
subdef, err = typedefOf(rt)
if subdef != nil {
fields = &subdef.fields
}
}
func (def *typedef) encodeStruct(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) {
info2, err := def.structInfo(rt, info)
if err != nil {
return nil, err
}

var fields []structField
for _, field := range info2.fields {
fields = append(fields, *field)
}

return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) {
item, err := encodeItem(*fields, rv)
item, err := encodeItem(fields, rv)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,7 +313,7 @@ func encodeSliceBS(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValu
return &dynamodb.AttributeValue{BS: bs}, nil
}

func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) {
keyString := encodeMapKeyFunc(rt)
if keyString == nil {
return nil, fmt.Errorf("dynamo marshal: map key type must be string or encoding.TextMarshaler, have %v", rt)
Expand All @@ -329,7 +329,7 @@ func (def *typedef) encodeMapM(rt reflect.Type, flags encodeFlags) (encodeFunc,
subflags |= flagOmitEmpty
}

valueEnc, err := def.encodeType(rt.Elem(), subflags)
valueEnc, err := def.encodeType(rt.Elem(), subflags, info)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -526,7 +526,7 @@ func encodeSet(rt /* []T | map[T]bool | map[T]struct{} */ reflect.Type, flags en
return nil, fmt.Errorf("dynamo: marshal: invalid type for set %s", rt.String())
}

func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc, error) {
func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags, info *structInfo) (encodeFunc, error) {
// lists CAN be empty
subflags := flagNone
if flags&flagOmitEmptyElem == 0 {
Expand All @@ -540,7 +540,7 @@ func (def *typedef) encodeList(rt reflect.Type, flags encodeFlags) (encodeFunc,
subflags |= flagAllowEmptyElem
}

valueEnc, err := def.encodeType(rt.Elem(), subflags)
valueEnc, err := def.encodeType(rt.Elem(), subflags, info)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -577,7 +577,7 @@ func (def *typedef) encodeAny(rv reflect.Value, flags encodeFlags) (*dynamodb.At
}
return nil, nil
}
enc, err := def.encodeType(rv.Elem().Type(), flags)
enc, err := def.encodeType(rv.Elem().Type(), flags, nil)
if err != nil {
return nil, err
}
Expand Down
56 changes: 10 additions & 46 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type typedef struct {
decoders map[unmarshalKey]decodeFunc
fields []structField
root reflect.Type
info *structInfo
}

func newTypedef(rt reflect.Type) (*typedef, error) {
Expand Down Expand Up @@ -46,11 +47,15 @@ func (def *typedef) init(rt reflect.Type) error {
return nil
}

fieldptr, err := def.structFields(rt, true)
if fieldptr != nil {
def.fields = *fieldptr
info, err := def.structInfo(rt, nil)
if err != nil {
return err
}
for _, field := range info.fields {
def.fields = append(def.fields, *field)
}
return err
def.info = info
return nil
}

func registerTypedef(gotype reflect.Type, def *typedef) *typedef {
Expand Down Expand Up @@ -99,7 +104,7 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) {
case reflect.Struct:
return encodeItem(def.fields, rv)
case reflect.Map:
enc, err := def.encodeMapM(rv.Type(), flagNone)
enc, err := def.encodeMapM(rv.Type(), flagNone, def.info)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -399,47 +404,6 @@ type structField struct {
isZero func(reflect.Value) bool
}

// type encodeKey struct {
// rt reflect.Type
// flags encodeFlags
// }

func (def *typedef) sameAsRoot(rt reflect.Type) bool {
switch {
case rt == def.root:
return true
case def.root.Kind() == reflect.Pointer && rt.Kind() != reflect.Pointer:
return def.root.Elem() == rt
case def.root.Kind() != reflect.Pointer && rt.Kind() == reflect.Pointer:
return rt.Elem() == def.root
}
return false
}

func (def *typedef) structFields(rt reflect.Type, isRoot bool) (*[]structField, error) {
if !isRoot && def.sameAsRoot(rt) {
return &def.fields, nil
}

var fields []structField
err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error {
enc, err := def.encodeType(vt, flags)
if err != nil {
return err
}
field := structField{
index: index,
name: name,
flags: flags,
enc: enc,
isZero: def.isZeroFunc(vt),
}
fields = append(fields, field)
return nil
})
return &fields, err
}

var (
nullAV = &dynamodb.AttributeValue{NULL: aws.Bool(true)}
emptyB = &dynamodb.AttributeValue{B: []byte("")}
Expand Down
Loading

0 comments on commit b2544f3

Please sign in to comment.