Skip to content

Commit

Permalink
Flatten Support for ONNX(#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshmiKumar23 authored and kiritigowda committed Apr 12, 2019
1 parent f55c266 commit f05574d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion model_compiler/python/nnir.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def updateLocals(self):
elif node.type in ['reshape']:
input = self.tensor_dict[node.inputs[0]]
param = node.attr.get('shape')
if not param:
param = input.shape
icount = 1
ocount = 1
out_shape = [0,0,0,0]
Expand Down Expand Up @@ -468,8 +470,12 @@ def updateLocals(self):
elif node.type in ['flatten']:
input = self.tensor_dict[node.inputs[0]]
axis = node.attr.get("axis")
if axis == 1:
if axis == 0:
shape = [1, input.shape[0]*input.shape[1]*input.shape[2]*input.shape[3], 1, 1]
elif axis == 1:
shape = [input.shape[0], input.shape[1]*input.shape[2]*input.shape[3], 1, 1]
else:
raise ValueError("Flatten: unsupoorted flatten: " + str(axis))
local = IrTensor()
local.setName(output)
local.setInfo(input.type, shape)
Expand Down
3 changes: 2 additions & 1 deletion model_compiler/python/onnx_to_nnir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
'GlobalAveragePool' : 'global_avg_pool',
'Softmax' : 'softmax',
'Reshape' : 'reshape',
'Transpose' : 'transpose'
'Transpose' : 'transpose',
'Flatten' : 'flatten'
}

onnx2ir_data_type = [
Expand Down

0 comments on commit f05574d

Please sign in to comment.