From a59ef6bb76bcd42fc2288000c844df9aea7d248e Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 7 Feb 2022 02:22:07 +0900 Subject: [PATCH] improve embedded struct decoding, fixing aux unmarshaling Fixes #181. Aux decoding wasn't working because we were attempting to recreate the embedded (aux) struct iself, instead of following its pointer. This caused us to skip it. Now we zero out each field of the struct as needed. --- decode.go | 24 ++++--- decode_aux_test.go | 161 +++++++++++++++++++++++++++++++++++++++++++++ decode_test.go | 65 ++++++++++++++++++ 3 files changed, 242 insertions(+), 8 deletions(-) create mode 100644 decode_aux_test.go diff --git a/decode.go b/decode.go index 3e44984..660de03 100644 --- a/decode.go +++ b/decode.go @@ -344,14 +344,19 @@ func fieldsInStruct(rv reflect.Value) map[string]reflect.Value { // embed anonymous structs, they could be pointers so test that too if (fv.Type().Kind() == reflect.Struct || isPtr && fv.Type().Elem().Kind() == reflect.Struct) && field.Anonymous { if isPtr { - // need to protect from setting unexported pointers because it will panic - if !fv.CanSet() { - continue + if fv.CanSet() { + // set zero value for pointer + zero := reflect.New(fv.Type().Elem()) + fv.Set(zero) + fv = zero + } else { + fv = reflect.Indirect(fv) } - // set zero value for pointer - zero := reflect.New(fv.Type().Elem()) - fv.Set(zero) - fv = zero + } + + if !fv.IsValid() { + // inaccessible + continue } innerFields := fieldsInStruct(fv) @@ -394,13 +399,16 @@ func unmarshalItem(item map[string]*dynamodb.AttributeValue, out interface{}) er return unmarshalItem(item, rv.Elem().Interface()) case reflect.Struct: var err error - rv.Elem().Set(reflect.Zero(rv.Type().Elem())) fields := fieldsInStruct(rv.Elem()) for name, fv := range fields { if av, ok := item[name]; ok { if innerErr := unmarshalReflect(av, fv); innerErr != nil { err = innerErr } + } else { + // we need to zero-out omitted fields to avoid weird data sticking around + // when iterating by unmarshaling to the same object over and over + fv.Set(reflect.Zero(fv.Type())) } } return err diff --git a/decode_aux_test.go b/decode_aux_test.go new file mode 100644 index 0000000..8671ff5 --- /dev/null +++ b/decode_aux_test.go @@ -0,0 +1,161 @@ +package dynamo_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + + "github.com/guregu/dynamo" +) + +type Coffee struct { + Name string +} + +func TestEncodingAux(t *testing.T) { + // This tests behavior of embedded anonymous (unexported) structs + // using the "aux" unmarshaling trick. + // See: https://github.com/guregu/dynamo/issues/181 + + in := map[string]*dynamodb.AttributeValue{ + "ID": {S: aws.String("intenso")}, + "Name": {S: aws.String("Intenso 12")}, + } + + type coffeeItemDefault struct { + ID string + Coffee + } + + tests := []struct { + name string + out interface{} + }{ + {name: "no custom unmrashalling", out: coffeeItemDefault{ID: "intenso", Coffee: Coffee{Name: "Intenso 12"}}}, + {name: "AWS SDK pointer", out: coffeeItemSDKEmbeddedPointer{ID: "intenso", Coffee: &Coffee{Name: "Intenso 12"}}}, + {name: "flat", out: coffeeItemFlat{ID: "intenso", Name: "Intenso 12"}}, + {name: "flat (invalid)", out: coffeeItemInvalid{}}, // want to make sure this doesn't panic + {name: "embedded", out: coffeeItemEmbedded{ID: "intenso", Coffee: Coffee{Name: "Intenso 12"}}}, + {name: "embedded pointer", out: coffeeItemEmbeddedPointer{ID: "intenso", Coffee: &Coffee{Name: "Intenso 12"}}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + out := reflect.New(reflect.TypeOf(test.out)).Interface() + if err := dynamo.UnmarshalItem(in, out); err != nil { + t.Fatal(err) + } + got := reflect.ValueOf(out).Elem().Interface() + if !reflect.DeepEqual(test.out, got) { + t.Error("bad value. want:", test.out, "got:", got) + } + }) + } +} + +type coffeeItemFlat struct { + ID string + Name string +} + +func (c *coffeeItemFlat) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + type alias coffeeItemFlat + aux := struct { + *alias + }{ + alias: (*alias)(c), + } + if err := dynamo.UnmarshalItem(item, &aux); err != nil { + return err + } + return nil +} + +type coffeeItemInvalid struct { + ID string + Name string +} + +func (c *coffeeItemInvalid) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + type alias coffeeItemInvalid + aux := struct { + *alias + }{ + alias: (*alias)(nil), + } + if err := dynamo.UnmarshalItem(item, &aux); err != nil { + return err + } + return nil +} + +type coffeeItemEmbedded struct { + ID string + Coffee +} + +func (c *coffeeItemEmbedded) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + type alias coffeeItemEmbedded + aux := struct { + *alias + }{ + alias: (*alias)(c), + } + if err := dynamo.UnmarshalItem(item, &aux); err != nil { + return err + } + return nil +} + +type coffeeItemEmbeddedPointer struct { + ID string + *Coffee +} + +func (c *coffeeItemEmbeddedPointer) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + type alias coffeeItemEmbeddedPointer + aux := struct { + *alias + }{ + alias: (*alias)(c), + } + if err := dynamo.UnmarshalItem(item, &aux); err != nil { + return err + } + return nil +} + +func (c *coffeeItemEmbeddedPointer) UnmarshalJSON(data []byte) error { + type alias coffeeItemEmbeddedPointer + aux := struct { + *alias + }{ + alias: (*alias)(c), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + return nil +} + +type coffeeItemSDKEmbeddedPointer struct { + ID string + *Coffee +} + +func (c *coffeeItemSDKEmbeddedPointer) UnmarshalDynamoItem(item map[string]*dynamodb.AttributeValue) error { + type alias coffeeItemEmbeddedPointer + aux := struct { + *alias + }{ + alias: (*alias)(c), + } + if err := dynamodbattribute.UnmarshalMap(item, &aux); err != nil { + return err + } + return nil +} diff --git a/decode_test.go b/decode_test.go index 48363df..0686dfc 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3,6 +3,7 @@ package dynamo import ( "reflect" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" @@ -184,3 +185,67 @@ func TestUnmarshalNULL(t *testing.T) { t.Error("unmarshal null: bad result:", result, "≠", resultType{}) } } + +func TestUnmarshalMissing(t *testing.T) { + // This test makes sure we're zeroing out fields of structs even if the given data doesn't contain them + + type widget2 struct { + widget + Inner struct { + Blarg string + } + Foo *struct { + Bar int + } + } + + w := widget2{ + widget: widget{ + UserID: 111, + Time: time.Now().UTC(), + Msg: "hello", + }, + } + w.Inner.Blarg = "AHH" + w.Foo = &struct{ Bar int }{Bar: 1337} + + want := widget2{ + widget: widget{ + UserID: 112, + }, + } + + replace := map[string]*dynamodb.AttributeValue{ + "UserID": {N: aws.String("112")}, + } + + if err := UnmarshalItem(replace, &w); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(want, w) { + t.Error("bad unmarshal missing. want:", want, "got:", w) + } + + replace2 := map[string]*dynamodb.AttributeValue{ + "UserID": {N: aws.String("113")}, + "Foo": {M: map[string]*dynamodb.AttributeValue{ + "Bar": {N: aws.String("1338")}, + }}, + } + + want = widget2{ + widget: widget{ + UserID: 113, + }, + Foo: &struct{ Bar int }{Bar: 1338}, + } + + if err := UnmarshalItem(replace2, &w); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(want, w) { + t.Error("bad unmarshal missing. want:", want, "got:", w) + } +}