Skip to content

Commit

Permalink
chore: GAP is now implemented in Gorgonia's master branch (#167)
Browse files Browse the repository at this point in the history
* chore: GAP is now implemented in Gorgonia's master branch

* chore: go-modules..................... :sigh:
  • Loading branch information
owulveryck authored Nov 7, 2019
1 parent fb68709 commit 7767174
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 150 deletions.
147 changes: 1 addition & 146 deletions backend/x/gorgonnx/gap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@ package gorgonnx

import (
"errors"
"fmt"
"hash"
"hash/fnv"

"github.com/chewxy/hm"
"github.com/owulveryck/onnx-go"
"gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)

func init() {
Expand All @@ -32,150 +27,10 @@ func (g *gap) apply(gg *Graph, ns ...*Node) error {

// Temporary, waiting for the operator to be implemented in Gorgonia
// see https://github.com/gorgonia/gorgonia/pull/302
/*
n.gorgoniaNode, err = gorgonia.GlobalAveragePool2D(children[0].gorgoniaNode)
*/
n.gorgoniaNode, err = gorgonia.ApplyOp(g, children[0].gorgoniaNode)
n.gorgoniaNode, err = gorgonia.GlobalAveragePool2D(children[0].gorgoniaNode)
return err
}

func (*gap) init(onnx.Operation) error {
return nil
}

func (g *gap) Arity() int {
return 1
}

func (g *gap) Type() hm.Type {
t := gorgonia.TensorType{Dims: 4, Of: hm.TypeVariable('a')}
return hm.NewFnType(t, t)
}

func (g *gap) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
b, err := inputs[0].DimSize(0)
if err != nil {
return nil, err
}
c, err := inputs[0].DimSize(1)
if err != nil {
return nil, err
}
// check if the shape is correct without doing type inference
if _, err := inputs[0].DimSize(2); err != nil {
return nil, err
}
if _, err := inputs[0].DimSize(3); err != nil {
return nil, err
}
return tensor.Shape{b, c, 1, 1}, nil
}

func (g *gap) Do(inputs ...gorgonia.Value) (gorgonia.Value, error) {
im := inputs[0]
switch im.(type) {
case tensor.Tensor:
v := im.(tensor.Tensor)
B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3]
s, err := g.InferShape(v.Shape())
if err != nil {
return nil, err
}
output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...))
switch v.Dtype() {
case tensor.Float64:
err = setFloat64AtTensor(v, B, C, H, W, output)
if err != nil {
return nil, err
}
case tensor.Float32:
err = setFloat32AtTensor(v, B, C, H, W, output)
if err != nil {
return nil, err
}
default:
return nil, &onnx.ErrNotImplemented{
Operator: "Global Average Pool",
Message: fmt.Sprintf("%v not implemented", v.Dtype()),
}
}

return output, nil

default:
return nil, &onnx.ErrNotImplemented{
Operator: "Global Average Pool",
Message: fmt.Sprintf("invalid input %v", inputs),
}
}
}

func setFloat64AtTensor(v tensor.Tensor, B, C, H, W int, output tensor.Tensor) error {
for b := 0; b < B; b++ {
for c := 0; c < C; c++ {
var sum float64
for h := 0; h < H; h++ {
for w := 0; w < W; w++ {
val, err := v.At(b, c, h, w)
if err != nil {
return err
}
sum += val.(float64)
}
}
err := output.SetAt(sum/float64(H*W), b, c, 0, 0)
if err != nil {
return err
}
}
}
return nil
}

func setFloat32AtTensor(v tensor.Tensor, B, C, H, W int, output tensor.Tensor) error {
for b := 0; b < B; b++ {
for c := 0; c < C; c++ {
var sum float32
for h := 0; h < H; h++ {
for w := 0; w < W; w++ {
val, err := v.At(b, c, h, w)
if err != nil {
return err
}
sum += val.(float32)
}
}
err := output.SetAt(sum/float32(H*W), b, c, 0, 0)
if err != nil {
return err
}
}
}
return nil
}

func (g *gap) ReturnsPtr() bool {
return false
}

func (g *gap) CallsExtern() bool {
return false
}

func (g *gap) OverwritesInput() int {
return -1
}

func (g *gap) WriteHash(h hash.Hash) {
fmt.Fprintf(h, "GlobalAveragePool")
}

func (g *gap) Hashcode() uint32 {
h := fnv.New32a()
g.WriteHash(h)
return h.Sum32()
}

func (g *gap) String() string {
return "GlobalAveragePool"
}
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.12
require (
github.com/awalterschulze/gographviz v0.0.0-20190522210029-fa59802746ab // indirect
github.com/chewxy/hm v1.0.0
github.com/chewxy/math32 v1.0.4 // indirect
github.com/davecgh/go-spew v1.1.0
github.com/disintegration/imaging v1.6.0
github.com/gogo/protobuf v1.2.1
Expand All @@ -22,6 +21,6 @@ require (
golang.org/x/net v0.0.0-20190611141213-3f473d35a33a // indirect
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b
gonum.org/v1/gonum v0.0.0-20190606121551-14af50e936aa
gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696 //indirect
gorgonia.org/gorgonia v0.9.4 //indirect
gorgonia.org/tensor v0.9.0-beta
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc=
gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8=
gorgonia.org/dawson v1.1.0 h1:o7+eJ3SKi9sheH19lpOat//tDbg0Y+M9iY/lH79VHqY=
gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
gorgonia.org/gorgonia v0.9.3 h1:IO3/7piSdJwbmmEXHJLArvWH+TjVfzEAVhd+pW8gLRk=
gorgonia.org/gorgonia v0.9.3/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec=
gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696 h1:YZ6HgxoWjk0cslQKbLZHkgQxYRswn4IlMs2bvQLuDvI=
gorgonia.org/gorgonia v0.9.4-0.20191013102522-a6f8db2f4696/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec=
gorgonia.org/gorgonia v0.9.4 h1:msE583U+EuthijznpLPJ1Uk6LbNqZCWX/5mgq/L/EGg=
gorgonia.org/gorgonia v0.9.4/go.mod h1:4kWgOIjKmCaY1H4JbMfhF6JXXNcbLpbCZ7m9EjVyZOY=
gorgonia.org/tensor v0.9.0-beta h1:16QQufB1vbJxVbIOaB5TwkerdlBWtw+AAnZHUZ531ZE=
gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w=
gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0=
Expand Down

0 comments on commit 7767174

Please sign in to comment.