Skip to content

Commit

Permalink
skip scanning struct fields for ItemMarshaler structs (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed May 4, 2024
1 parent 0b7cb6b commit 67f288f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
54 changes: 54 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,57 @@ func TestMarshalItemAsymmetric(t *testing.T) {
})
}
}

type isValue_Kind interface {
isValue_Kind()
}

type myStruct struct {
OK bool
Value isValue_Kind
}

func (ms *myStruct) MarshalDynamoItem() (map[string]*dynamodb.AttributeValue, error) {
world := "world"
return map[string]*dynamodb.AttributeValue{
"hello": {S: &world},
}, nil
}

func (ms *myStruct) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error {
hello := item["hello"]
if hello == nil || hello.S == nil || *hello.S != "world" {
ms.OK = false
} else {
ms.OK = true
}
return nil
}

var _ ItemMarshaler = &myStruct{}
var _ ItemUnmarshaler = &myStruct{}

func TestMarshalItemBypass(t *testing.T) {
something := &myStruct{}
got, err := MarshalItem(something)
if err != nil {
t.Fatal(err)
}

world := "world"
expect := map[string]*dynamodb.AttributeValue{
"hello": {S: &world},
}
if !reflect.DeepEqual(got, expect) {
t.Error("bad marshal. want:", expect, "got:", got)
}

var dec myStruct
err = UnmarshalItem(got, &dec)
if err != nil {
t.Fatal(err)
}
if !dec.OK {
t.Error("bad unmarshal")
}
}
11 changes: 9 additions & 2 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import (
var typeCache sync.Map // unmarshalKey → *typedef

type typedef struct {
decoders map[unmarshalKey]decodeFunc
fields []structField
decoders map[unmarshalKey]decodeFunc
fields []structField
marshaler bool
}

func newTypedef(rt reflect.Type) (*typedef, error) {
Expand All @@ -27,6 +28,7 @@ func newTypedef(rt reflect.Type) (*typedef, error) {
}

func (def *typedef) init(rt reflect.Type) error {
rt0 := rt
for rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
Expand All @@ -37,6 +39,11 @@ func (def *typedef) init(rt reflect.Type) error {
return nil
}

// skip visiting struct fields if encoding will be bypassed by a custom marshaler
if shouldBypassEncodeItem(rt0) || shouldBypassEncodeItem(rt) {
return nil
}

var err error
def.fields, err = structFields(rt)
return err
Expand Down

0 comments on commit 67f288f

Please sign in to comment.