Skip to content

Commit

Permalink
allownil: Allocate 0 length slices (#336)
Browse files Browse the repository at this point in the history
* allownil: Allocate 0 length slices

When `allownil` is enabled, always allocate zero length slices.

This ensures roundtrips with 0-length slices are not returned as nil.

Replaces #304

Adds tests. Bonus: Don't shell out to test issue 94.
  • Loading branch information
klauspost authored Feb 10, 2024
1 parent cabc832 commit e00f9b0
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 53 deletions.
57 changes: 57 additions & 0 deletions _generated/allownil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestAllownil(t *testing.T) {
tt := &NamedStructAN{
A: []string{},
B: nil,
}
var buf bytes.Buffer

err := msgp.Encode(&buf, tt)
if err != nil {
t.Fatal(err)
}
in := buf.Bytes()

for _, tnew := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
err = msgp.Decode(bytes.NewBuffer(in), tnew)
if err != nil {
t.Error(err)
}

if !reflect.DeepEqual(tt, tnew) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tnew)
t.Fatal("objects not equal")
}
}

in, err = tt.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
for _, tanother := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
var left []byte
left, err = tanother.UnmarshalMsg(in)
if err != nil {
t.Error(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left", len(left))
}

if !reflect.DeepEqual(tt, tanother) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tanother)
t.Fatal("objects not equal")
}
}
}
4 changes: 2 additions & 2 deletions _generated/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func Test1EncodeDecode(t *testing.T) {
}

if !tt.Equal(tnew) {
t.Logf("in: %v", tt)
t.Logf("out: %v", tnew)
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tnew)
t.Fatal("objects not equal")
}

Expand Down
10 changes: 0 additions & 10 deletions _generated/issue94.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@ import (

//go:generate msgp

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.

//go:generate ./search.sh $GOFILE timetostr

//msgp:shim time.Time as:string using:timetostr/strtotime
type T struct {
T time.Time
Expand Down
25 changes: 25 additions & 0 deletions _generated/issue94_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package _generated

import (
"bytes"
"os"
"testing"
)

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.
func TestIssue94(t *testing.T) {
b, err := os.ReadFile("issue94_gen.go")
if err != nil {
t.Fatal(err)
}
const want = "timetostr"
if !bytes.Contains(b, []byte(want)) {
t.Errorf("generated code did not contain %q", want)
}
}
12 changes: 0 additions & 12 deletions _generated/search.sh

This file was deleted.

20 changes: 14 additions & 6 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ func (d *decodeGen) structAsTuple(s *Struct) {
if !d.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
d.p.print("\nif dc.IsNil() {")
d.p.print("\nerr = dc.ReadNil()")
d.p.wrapErrCheck(d.ctx.ArgsStr())
d.p.printf("\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
}
SetIsAllowNil(fieldElem, anField)
d.ctx.PushString(s.Fields[i].FieldName)
next(d, s.Fields[i].FieldElem)
next(d, fieldElem)
d.ctx.Pop()
if anField {
d.p.printf("\n}") // close if statement
Expand All @@ -112,14 +114,16 @@ func (d *decodeGen) structAsMap(s *Struct) {
for i := range s.Fields {
d.ctx.PushString(s.Fields[i].FieldName)
d.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag)
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
d.p.print("\nif dc.IsNil() {")
d.p.print("\nerr = dc.ReadNil()")
d.p.wrapErrCheck(d.ctx.ArgsStr())
d.p.printf("\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
d.p.printf("\n%s = nil\n} else {", fieldElem.Varname())
}
next(d, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(d, fieldElem)
d.ctx.Pop()
if !d.p.ok() {
return
Expand Down Expand Up @@ -215,7 +219,11 @@ func (d *decodeGen) gSlice(s *Slice) {
sz := randIdent()
d.p.declare(sz, u32)
d.assignAndCheck(sz, arrayHeader)
d.p.resizeSlice(sz, s)
if s.isAllowNil {
d.p.resizeSliceNoNil(sz, s)
} else {
d.p.resizeSlice(sz, s)
}
d.p.rangeBlock(d.ctx, s.Index, s.Varname(), d, s.Els)
}

Expand Down
28 changes: 23 additions & 5 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,10 @@ func (a *Array) IfZeroExpr() string { return "" }
// Map is a map[string]Elem
type Map struct {
common
Keyidx string // key variable name
Validx string // value variable name
Value Elem // value element
Keyidx string // key variable name
Validx string // value variable name
Value Elem // value element
isAllowNil bool
}

func (m *Map) SetVarname(s string) {
Expand Down Expand Up @@ -304,10 +305,14 @@ func (m *Map) IfZeroExpr() string { return m.Varname() + " == nil" }
// AllowNil is true for maps.
func (m *Map) AllowNil() bool { return true }

// SetIsAllowNil sets whether the map is allowed to be nil.
func (m *Map) SetIsAllowNil(b bool) { m.isAllowNil = b }

type Slice struct {
common
Index string
Els Elem // The type of each element
Index string
isAllowNil bool
Els Elem // The type of each element
}

func (s *Slice) SetVarname(a string) {
Expand Down Expand Up @@ -348,6 +353,19 @@ func (s *Slice) IfZeroExpr() string { return s.Varname() + " == nil" }
// AllowNil is true for slices.
func (s *Slice) AllowNil() bool { return true }

// SetIsAllowNil sets whether the slice is allowed to be nil.
func (s *Slice) SetIsAllowNil(b bool) { s.isAllowNil = b }

// SetIsAllowNil will set whether the element is allowed to be nil.
func SetIsAllowNil(e Elem, b bool) {
type i interface {
SetIsAllowNil(b bool)
}
if x, ok := e.(i); ok {
x.SetIsAllowNil(b)
}
}

type Ptr struct {
common
Value Elem
Expand Down
11 changes: 7 additions & 4 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ func (e *encodeGen) tuple(s *Struct) {
if !e.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
e.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
e.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
e.p.printf("\nerr = en.WriteNil(); if err != nil { return; }")
e.p.printf("\n} else {")
}
SetIsAllowNil(fieldElem, anField)
e.ctx.PushString(s.Fields[i].FieldName)
next(e, s.Fields[i].FieldElem)
e.ctx.Pop()
Expand Down Expand Up @@ -197,13 +199,14 @@ func (e *encodeGen) structmap(s *Struct) {
e.p.printf("\n// write %q", s.Fields[i].FieldTag)
e.Fuse(data)
e.fuseHook()

anField := !oeField && s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := !oeField && s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
e.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
e.p.printf("\nerr = en.WriteNil(); if err != nil { return; }")
e.p.printf("\n} else {")
}
SetIsAllowNil(fieldElem, anField)

e.ctx.PushString(s.Fields[i].FieldName)
next(e, s.Fields[i].FieldElem)
Expand Down
17 changes: 10 additions & 7 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,16 @@ func (m *marshalGen) tuple(s *Struct) {
if !m.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
m.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
m.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
m.p.printf("\no = msgp.AppendNil(o)")
m.p.printf("\n} else {")
}
m.ctx.PushString(s.Fields[i].FieldName)
next(m, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(m, fieldElem)
m.ctx.Pop()
if anField {
m.p.printf("\n}") // close if statement
Expand Down Expand Up @@ -194,15 +196,16 @@ func (m *marshalGen) mapstruct(s *Struct) {
m.Fuse(data)
m.fuseHook()

anField := !oeField && s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := !oeField && s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
m.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
m.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
m.p.printf("\no = msgp.AppendNil(o)")
m.p.printf("\n} else {")
}

m.ctx.PushString(s.Fields[i].FieldName)
next(m, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(m, fieldElem)
m.ctx.Pop()

if oeField || anField {
Expand Down
7 changes: 7 additions & 0 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ func (p *printer) resizeSlice(size string, s *Slice) {
p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

// resizeSliceNoNil will resize a slice and will not allow nil slices.
func (p *printer) resizeSliceNoNil(size string, s *Slice) {
p.printf("\nif %[1]s != nil && cap(%[1]s) >= int(%[2]s) {", s.Varname(), size)
p.printf("\n%[1]s = (%[1]s)[:%[2]s]", s.Varname(), size)
p.printf("\n} else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

func (p *printer) arrayCheck(want string, got string) {
p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
}
Expand Down
22 changes: 15 additions & 7 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ func (u *unmarshalGen) tuple(s *Struct) {
return
}
u.ctx.PushString(s.Fields[i].FieldName)
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname())
}
next(u, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if anField {
u.p.printf("\n}")
Expand All @@ -113,11 +115,13 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag)
u.ctx.PushString(s.Fields[i].FieldName)

anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname())
}
next(u, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if anField {
u.p.printf("\n}")
Expand Down Expand Up @@ -193,7 +197,11 @@ func (u *unmarshalGen) gSlice(s *Slice) {
sz := randIdent()
u.p.declare(sz, u32)
u.assignAndCheck(sz, arrayHeader)
u.p.resizeSlice(sz, s)
if s.isAllowNil {
u.p.resizeSliceNoNil(sz, s)
} else {
u.p.resizeSlice(sz, s)
}
u.p.rangeBlock(u.ctx, s.Index, s.Varname(), u, s.Els)
}

Expand Down

0 comments on commit e00f9b0

Please sign in to comment.