Skip to content

Commit

Permalink
Merge pull request #142 from disktnk/feature/where
Browse files Browse the repository at this point in the history
Add Where op converter
  • Loading branch information
disktnk authored Apr 2, 2019
2 parents 942034e + d34da6b commit 305b3ea
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
25 changes: 21 additions & 4 deletions onnx_chainer/context.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
import chainer

from onnx_chainer import onnx_helper


class Context(object):
"""Context of converter
This context shares names during exporting.
Attributes:
name_list (dict): list of being exported as ONNX node name keyed by
instance ID. When the target variable is ``chainer.Variable`` or
``chainer.Parameter``, instance ID of ``ndarray`` held by the
variable is also put as key, because some functions like
``F.where`` internally unwrap variable.
"""

def __init__(self, model):
self.name_list = dict()
for name, param in model.namedparams():
onnx_name = onnx_helper.cleanse_param_name(name)
self.name_list[str(id(param))] = onnx_name
self.set_name(param, onnx_name)

def get_name(self, variable):
str_id = str(id(variable))
str_id = id(variable)
if str_id in self.name_list:
return self.name_list[str_id]
else:
new_name = 'v{}'.format(len(self.name_list))
self.name_list[str_id] = new_name
self.set_name(variable, new_name)
return new_name

def set_name(self, variable, name):
str_id = str(id(variable))
str_id = id(variable)
self.name_list[str_id] = name
if isinstance(variable, (chainer.Variable, chainer.Parameter)):
array_id = id(variable.array)
self.name_list[array_id] = name
1 change: 1 addition & 0 deletions onnx_chainer/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onnx_chainer.functions.array import convert_Squeeze # NOQA
from onnx_chainer.functions.array import convert_Tile # NOQA
from onnx_chainer.functions.array import convert_Transpose # NOQA
from onnx_chainer.functions.array import convert_Where # NOQA

from onnx_chainer.functions.connection import convert_Convolution2DFunction # NOQA
from onnx_chainer.functions.connection import convert_ConvolutionND # NOQA
Expand Down
7 changes: 7 additions & 0 deletions onnx_chainer/functions/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,10 @@ def convert_ExpandDims(func, opset_version, input_names,

return onnx_helper.make_node(
'Unsqueeze', input_names, num_outputs, axes=[axis]),


@support((9,))
def convert_Where(func, opset_version, input_names, num_outputs, context,
parameters):
input_names.insert(0, context.get_name(func.condition))
return onnx_helper.make_node('Where', input_names, num_outputs),
1 change: 1 addition & 0 deletions onnx_chainer/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'Squeeze',
'Tile',
'Transpose',
'Where',

# Connection
'Convolution2DFunction',
Expand Down
12 changes: 12 additions & 0 deletions tests/functions_tests/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,15 @@ def __call__(self, x1, x2):

def test_output(self):
self.expect(self.model, (self.x1, self.x2))


class TestWhere(ONNXModelTest):

def test_output(self):
model = chainer.Sequential(
F.where
)
cond = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.bool)
x = input_generator.increasing(2, 3)
y = np.zeros((2, 3), np.float32)
self.expect(model, (cond, x, y), skip_opset_version=[7, 8])

0 comments on commit 305b3ea

Please sign in to comment.