diff --git a/backend/x/gorgonnx/gap.go b/backend/x/gorgonnx/gap.go index d1400e65..80d9677d 100644 --- a/backend/x/gorgonnx/gap.go +++ b/backend/x/gorgonnx/gap.go @@ -2,14 +2,9 @@ package gorgonnx import ( "errors" - "fmt" - "hash" - "hash/fnv" - "github.com/chewxy/hm" "github.com/owulveryck/onnx-go" "gorgonia.org/gorgonia" - "gorgonia.org/tensor" ) func init() { @@ -32,150 +27,10 @@ func (g *gap) apply(gg *Graph, ns ...*Node) error { // Temporary, waiting for the operator to be implemented in Gorgonia // see https://github.com/gorgonia/gorgonia/pull/302 - /* - n.gorgoniaNode, err = gorgonia.GlobalAveragePool2D(children[0].gorgoniaNode) - */ - n.gorgoniaNode, err = gorgonia.ApplyOp(g, children[0].gorgoniaNode) + n.gorgoniaNode, err = gorgonia.GlobalAveragePool2D(children[0].gorgoniaNode) return err } func (*gap) init(onnx.Operation) error { return nil } - -func (g *gap) Arity() int { - return 1 -} - -func (g *gap) Type() hm.Type { - t := gorgonia.TensorType{Dims: 4, Of: hm.TypeVariable('a')} - return hm.NewFnType(t, t) -} - -func (g *gap) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) { - b, err := inputs[0].DimSize(0) - if err != nil { - return nil, err - } - c, err := inputs[0].DimSize(1) - if err != nil { - return nil, err - } - // check if the shape is correct without doing type inference - if _, err := inputs[0].DimSize(2); err != nil { - return nil, err - } - if _, err := inputs[0].DimSize(3); err != nil { - return nil, err - } - return tensor.Shape{b, c, 1, 1}, nil -} - -func (g *gap) Do(inputs ...gorgonia.Value) (gorgonia.Value, error) { - im := inputs[0] - switch im.(type) { - case tensor.Tensor: - v := im.(tensor.Tensor) - B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3] - s, err := g.InferShape(v.Shape()) - if err != nil { - return nil, err - } - output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...)) - switch v.Dtype() { - case tensor.Float64: - err = setFloat64AtTensor(v, B, C, H, W, output) - if err != nil { - return nil, err - } - case tensor.Float32: - err = setFloat32AtTensor(v, B, C, H, W, output) - if err != nil { - return nil, err - } - default: - return nil, &onnx.ErrNotImplemented{ - Operator: "Global Average Pool", - Message: fmt.Sprintf("%v not implemented", v.Dtype()), - } - } - - return output, nil - - default: - return nil, &onnx.ErrNotImplemented{ - Operator: "Global Average Pool", - Message: fmt.Sprintf("invalid input %v", inputs), - } - } -} - -func setFloat64AtTensor(v tensor.Tensor, B, C, H, W int, output tensor.Tensor) error { - for b := 0; b < B; b++ { - for c := 0; c < C; c++ { - var sum float64 - for h := 0; h < H; h++ { - for w := 0; w < W; w++ { - val, err := v.At(b, c, h, w) - if err != nil { - return err - } - sum += val.(float64) - } - } - err := output.SetAt(sum/float64(H*W), b, c, 0, 0) - if err != nil { - return err - } - } - } - return nil -} - -func setFloat32AtTensor(v tensor.Tensor, B, C, H, W int, output tensor.Tensor) error { - for b := 0; b < B; b++ { - for c := 0; c < C; c++ { - var sum float32 - for h := 0; h < H; h++ { - for w := 0; w < W; w++ { - val, err := v.At(b, c, h, w) - if err != nil { - return err - } - sum += val.(float32) - } - } - err := output.SetAt(sum/float32(H*W), b, c, 0, 0) - if err != nil { - return err - } - } - } - return nil -} - -func (g *gap) ReturnsPtr() bool { - return false -} - -func (g *gap) CallsExtern() bool { - return false -} - -func (g *gap) OverwritesInput() int { - return -1 -} - -func (g *gap) WriteHash(h hash.Hash) { - fmt.Fprintf(h, "GlobalAveragePool") -} - -func (g *gap) Hashcode() uint32 { - h := fnv.New32a() - g.WriteHash(h) - return h.Sum32() -} - -func (g *gap) String() string { - return "GlobalAveragePool" -} diff --git a/go.mod b/go.mod index ea23f44f..a63cf966 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.12 require ( github.com/awalterschulze/gographviz v0.0.0-20190522210029-fa59802746ab // indirect github.com/chewxy/hm v1.0.0 - github.com/chewxy/math32 v1.0.4 // indirect github.com/davecgh/go-spew v1.1.0 github.com/disintegration/imaging v1.6.0 github.com/gogo/protobuf v1.2.1 @@ -22,6 +21,6 @@ require ( golang.org/x/net v0.0.0-20190611141213-3f473d35a33a // indirect golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b gonum.org/v1/gonum v0.0.0-20190606121551-14af50e936aa - gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696 //indirect + gorgonia.org/gorgonia v0.9.4 //indirect gorgonia.org/tensor v0.9.0-beta ) diff --git a/go.sum b/go.sum index d5e7a997..b015e055 100644 --- a/go.sum +++ b/go.sum @@ -79,10 +79,10 @@ gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc= gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8= gorgonia.org/dawson v1.1.0 h1:o7+eJ3SKi9sheH19lpOat//tDbg0Y+M9iY/lH79VHqY= gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws= -gorgonia.org/gorgonia v0.9.3 h1:IO3/7piSdJwbmmEXHJLArvWH+TjVfzEAVhd+pW8gLRk= -gorgonia.org/gorgonia v0.9.3/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec= gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696 h1:YZ6HgxoWjk0cslQKbLZHkgQxYRswn4IlMs2bvQLuDvI= gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec= +gorgonia.org/gorgonia v0.9.4 h1:msE583U+EuthijznpLPJ1Uk6LbNqZCWX/5mgq/L/EGg= +gorgonia.org/gorgonia v0.9.4/go.mod h1:4kWgOIjKmCaY1H4JbMfhF6JXXNcbLpbCZ7m9EjVyZOY= gorgonia.org/tensor v0.9.0-beta h1:16QQufB1vbJxVbIOaB5TwkerdlBWtw+AAnZHUZ531ZE= gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0=