Skip to content

Commit

Permalink
Update change_ordering code.
Browse files Browse the repository at this point in the history
  • Loading branch information
gmalivenko committed Jan 13, 2020
1 parent bc3fea9 commit 24f1eaf
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 54 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2019 Grigory Malivenko
Copyright (c) 2020 Grigory Malivenko

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
17 changes: 13 additions & 4 deletions onnx2keras/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,19 @@ def onnx_to_keras(onnx_model, input_names,
for layer in conf['layers']:
if 'function' in layer['config'] and layer['config']['function'][1] is not None:
f = list(layer['config']['function'])
if len(layer['config']['function'][1][0].shape) == 4:
f[1] = (np.transpose(layer['config']['function'][1][0], [0, 2, 3, 1]), f[1][1])
elif len(layer['config']['function'][1][0].shape) == 3:
f[1] = (np.transpose(layer['config']['function'][1][0], [0, 2, 1]), f[1][1])
try:
if len(layer['config']['function'][1][0].shape) == 4:
f[1] = (np.transpose(layer['config']['function'][1][0], [0, 2, 3, 1]), f[1][1])
elif len(layer['config']['function'][1][0].shape) == 3:
f[1] = (np.transpose(layer['config']['function'][1][0], [0, 2, 1]), f[1][1])
except Exception as e:
logger.warning('Error occured in basic change ordering mode. Use fallback.')

axes = np.array(layer['config']['function'][1][0])
axes_map = np.array([0, 3, 1, 2])
axes = axes_map[axes]
f[1] = (axes, f[1][1])

layer['config']['function'] = tuple(f)

keras.backend.set_image_data_format('channels_last')
Expand Down
113 changes: 66 additions & 47 deletions onnx2keras/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,62 +270,81 @@ def convert_slice(node, params, layers, node_name, keras_name):
"""
logger = logging.getLogger('onnx2keras:slice')

logger.debug('Convert inputs to Keras/TF layers if needed.')

input_0 = ensure_tf_type(layers[node.input[0]], layers[list(layers)[0]], name="%s_const" % keras_name)
layers[node_name] = input_0

if 'axes' in params:
axes = params["axes"][0]
ends = params["ends"][0]
starts = params["starts"][0]
else:
starts = ensure_numpy_type(layers[node.input[1]])
ends = ensure_numpy_type(layers[node.input[2]])
axes = ensure_numpy_type(layers[node.input[3]])

for i in range(len(starts)):
if axes[i] != i:
assert AttributeError('Cant slice permuted axes')

if isinstance(axes, list) or isinstance(axes, np.ndarray):
def target_layer(x, axes=axes, starts=starts, ends=ends):
import tensorflow as tf
return tf.strided_slice(x, starts, ends)
if is_numpy(layers[node.input[0]]):
logger.debug('Slice numpy constants')
if 'axes' in params:
axes = params["axes"][0]
ends = params["ends"][0]
starts = params["starts"][0]
else:
raise AttributeError('Not implemented')

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
else:
if axes == 0:
def target_layer(x):
layer = x[starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
layers[node_name] = layers[node.input[0]][starts:ends]
elif axes == 1:
def target_layer(x):
layer = x[:, starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
layers[node_name] = layers[node.input[0]][:, starts:ends]
elif axes == 2:
def target_layer(x):
layer = x[:, :, starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
layers[node_name] = layers[node.input[0]][:, :, starts:ends]
elif axes == 3:
def target_layer(x):
layer = x[:, :, :, starts:ends]
return layer
layers[node_name] = layers[node.input[0]][:, :, :, starts:ends]
else:
raise AttributeError('Not implemented')
else:
logger.debug('Convert inputs to Keras/TF layers if needed.')
input_0 = ensure_tf_type(layers[node.input[0]], layers[list(layers)[0]], name="%s_const" % keras_name)
layers[node_name] = input_0

if 'axes' in params:
axes = params["axes"][0]
ends = params["ends"][0]
starts = params["starts"][0]
else:
starts = ensure_numpy_type(layers[node.input[1]])
ends = ensure_numpy_type(layers[node.input[2]])
axes = ensure_numpy_type(layers[node.input[3]])

for i in range(len(starts)):
if axes[i] != i:
assert AttributeError('Cant slice permuted axes')

if isinstance(axes, list) or isinstance(axes, np.ndarray):
def target_layer(x, axes=axes, starts=starts, ends=ends):
import tensorflow as tf
return tf.strided_slice(x, starts, ends)

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
else:
raise AttributeError('Not implemented')
if axes == 0:
def target_layer(x):
layer = x[starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
elif axes == 1:
def target_layer(x):
layer = x[:, starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
elif axes == 2:
def target_layer(x):
layer = x[:, :, starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
elif axes == 3:
def target_layer(x):
layer = x[:, :, :, starts:ends]
return layer

lambda_layer = keras.layers.Lambda(target_layer, name=keras_name)
layers[node_name] = lambda_layer(input_0)
else:
raise AttributeError('Not implemented')


def convert_squeeze(node, params, layers, node_name, keras_name):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow>=2.0
tensorflow
numpy
onnx
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


setup(name='onnx2keras',
version='0.0.17',
version='0.0.18',
description='The deep learning models convertor',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 24f1eaf

Please sign in to comment.