Skip to content

Commit

Permalink
npy: properly handle empty slices
Browse files Browse the repository at this point in the history
Fixes #15.
  • Loading branch information
sbinet committed Jan 26, 2022
1 parent 77aa330 commit 3f83362
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
3 changes: 3 additions & 0 deletions npy/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ func shapeFrom(rv reflect.Value) ([]int, error) {
rt := rv.Type()
switch rt.Kind() {
case reflect.Array, reflect.Slice:
if rv.Len() == 0 {
return []int{0}, nil
}
eshape, err := shapeFrom(rv.Index(0))
if err != nil {
return nil, err
Expand Down
109 changes: 109 additions & 0 deletions npy/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package npy

import (
"bytes"
"fmt"
"math"
"reflect"
"testing"
Expand Down Expand Up @@ -115,3 +116,111 @@ func TestWriterNaNsInf(t *testing.T) {
}
}
}

func TestShapeFrom(t *testing.T) {
for _, tc := range []struct {
v interface{}
want []int
err error
}{
{
v: "hello",
want: nil,
},
{
v: 1,
want: nil,
},
{
v: 0.1,
want: nil,
},
{
v: [0]int{},
want: []int{0},
},
{
v: []int{},
want: []int{0},
},
{
v: []int{1},
want: []int{1},
},
{
v: [3][]int{nil, nil, nil},
want: []int{3, 0},
},
{
v: [3][]int{{}, {}, {}},
want: []int{3, 0},
},
{
v: [3][]int{{1, 2}, {3, 4}, {5, 6}},
want: []int{3, 2},
},
{
v: [][]int{nil, nil, nil},
want: []int{3, 0},
},
{
v: [][]int{{}, {}, {}},
want: []int{3, 0},
},
{
v: [][]int{{1, 2}, {3, 4}, {5, 6}},
want: []int{3, 2},
},
{
v: [][][]int{{{1}, {2}}, {{3}, {4}}, {{5}, {6}}},
want: []int{3, 2, 1},
},
{
v: [][]float64{{1, 2}, {3, 4}, {5, 6}},
want: []int{3, 2},
},
{
v: mat.NewDense(2, 3, []float64{1, 2, 3, 4, 5, 6}),
want: nil, // shapeFrom takes a deref-iface
},
{
v: *mat.NewDense(2, 3, []float64{1, 2, 3, 4, 5, 6}),
want: []int{2, 3},
},
{
v: make(map[int]int),
err: fmt.Errorf("npy: type map[int]int not supported"),
},
{
v: make(chan int),
err: fmt.Errorf("npy: type chan int not supported"),
},
{
v: struct{}{},
err: fmt.Errorf("npy: type struct {} not supported"),
},
} {
t.Run("", func(t *testing.T) {
got, err := shapeFrom(reflect.ValueOf(tc.v))
switch {
case err != nil && tc.err != nil:
if err.Error() != tc.err.Error() {
t.Fatalf("invalid error:\ngot= %+v\nwant=%+v",
err, tc.err,
)
}
return
case err != nil && tc.err == nil:
t.Fatalf("unexpected error: %+v", err)
case err == nil && tc.err != nil:
t.Fatalf("expected an error")
case err == nil && tc.err == nil:
// ok.
}

if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("invalid shape.\ngot= %+v\nwant=%+v", got, tc.want)
}
})
}
}

0 comments on commit 3f83362

Please sign in to comment.