Skip to content

Commit

Permalink
feat: gather_nd/scatter_nd support (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Mar 11, 2024
1 parent ea535e2 commit a1450f8
Show file tree
Hide file tree
Showing 16 changed files with 945 additions and 16 deletions.
48 changes: 48 additions & 0 deletions examples/onnx/gather_nd/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx


import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model


# gather_nd in tf then export to onnx




x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')

shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)

spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')

model_path = "network.onnx"

tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)


d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()


data = dict(
input_data=[d, d1],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/gather_nd/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[0.07409840936878616, 0.026765300859598586, 0.0016244073056684407, 0.0441650068276063, 0.07194419125310977, 0.06692585614676133, 0.09608227643863203, 0.06954370939385973, 0.08529500855856466, 0.027427555452308128, 0.09236151161537542, 0.011363978906049666, 0.04123485316293706, 0.052260053362345765, 0.07718738744438916, 0.08752350940581384, 0.014004981152922381, 0.010884932350631815, 0.005834255835543156, 0.023425642525204883, 0.02094299377523633, 0.0834802433553593, 0.09072429232993884, 0.08080799984467026, 0.08846240688137623, 0.005972032477683776, 0.09543121848890006, 0.009593637052989368, 0.011043162248260897, 0.07626296942645529, 0.08665096065055262, 0.05531860137087041, 0.08555080157962193, 0.06305421152360385, 0.044457749781945825, 0.03346308199928279, 0.08409485819282941, 0.08373187170722414, 0.09211645983725941, 0.04638676434612774, 0.04353660804787102, 0.02722920399751614, 0.00748668892398231, 0.05697824644237888, 0.03159872068355862, 0.042243840841574114, 0.0026578916765197524, 0.08241319697741795, 0.0011032522545437184, 0.02945550874511177, 0.03093435828119301, 0.021394384259703938, 0.01850611460160917, 0.01177980853456815, 0.0723251150078536, 0.09937202215615878, 0.07289788627537903, 0.016733957884392702, 0.09930083413669373, 0.04661307600925346, 0.034026939589230254, 0.0003860672352996697, 0.0009392416775760083, 0.04915507878759171, 0.005487484619969374, 0.027356480324101973, 0.004414615409163947, 0.005906496351017521, 0.03184624273247968, 0.03609610320817154, 0.02320452948484344, 0.0027679693675671693, 0.0999650222289449, 0.052656740685879555, 0.05811185074201652, 0.05893880051789002, 0.07099004703280616, 0.06140757104198268, 0.018656956890861054, 0.054860839759383696, 0.06859210259110152, 0.09826135863295309, 0.031221736672843795, 0.09801241172361014, 0.026099888760803925, 0.04535476360755811, 0.050322856130120525, 0.06421022948155175, 0.0693960218732169, 0.009933220605766869, 0.055240649637579, 0.09025087147436488, 0.0826740468672511, 0.08954953520373564, 0.04366470140397538, 0.008152418001653838, 0.07090095627388146, 0.05197436173795772, 0.02482902257077856, 0.02618333398385272, 0.00808403000729875, 0.06676702934215438, 0.0843839009603915, 0.04802880934097036, 0.028976376986277964, 0.05874038769751201, 0.06665917890346157, 0.08203598439224069, 0.030536369233358885, 0.08267264088368732, 0.03922695410226166, 0.017195715209570507, 0.07568728185339207, 0.010491948396854479, 0.07385285358026732, 0.033894790999440386, 0.037953270098610783, 0.09923025906124551, 0.030385707233350257, 0.05797656904214617, 0.03525801977611266, 0.06513307621541974, 0.055129691186989405, 0.06580208273455689, 0.05137272614163535, 0.036238086027517616, 0.04219439532052521, 0.0824490458502624, 0.0703967104204315, 0.08984124427156798, 0.039774971852938504, 0.01927020503597058, 0.0946327772890066, 0.08487201812772689, 0.09499811520198619, 0.04274917648390278, 0.09758553691237543, 0.09775122269151526, 0.015484284179940178, 0.00850299125383659, 0.05254198276402216, 0.03147624747001413, 0.05176079373282756, 0.029932922941023787, 0.05371068766425129, 0.05991055411320946, 0.04111917907889184, 0.005826076585261153, 0.006330736671178972, 0.07113532288996224, 0.05159552878432429, 0.07815869464893865, 0.018250884851669934, 0.03033538781990235, 0.08192114816681907, 0.0484394631760716, 0.08915656585376198, 0.057618037780829835, 0.08752603445545322, 0.024668608257092762, 0.05045358839836907, 0.08931253842621169, 0.002945524391599808, 0.007837583888122347, 0.024846839309055315, 0.06441353986205248, 0.039224816681025135, 0.029435747598787657, 0.07538516146684937, 0.051500724663495014, 0.09511367364928426, 0.07106042164186015, 0.026381734649478496, 0.04443156955747423, 0.04113957693690739, 0.009484913458110235, 0.03184215092612154, 0.0017834825596509951, 0.06451846572723677, 0.04121236074344475, 0.021126598414098863, 0.06443367043778918, 0.0038598957838039127, 0.05305618162284708, 0.030368222589400563, 0.03565353906284881, 0.07637705305532676, 0.07707681534499795, 0.046216264886290964, 0.07697258678509858, 0.08242934820015058, 0.05176302219659787, 0.043731487554577314, 0.013667058385289733, 0.05976389427244843, 0.06137920920119046, 0.0953044124521614, 0.014962285636962626, 0.08650210747754121, 0.07706347189367195, 0.06254120844587761, 0.07853864262402863, 0.09585377108693328, 0.05321118065283568, 0.07937354027356201, 0.04627424023748361, 0.0972385713632545, 0.025431124558871854, 0.060954848544736574, 0.0006064704940715982, 0.020168505086630983, 0.044439944739528406, 0.07428067636165084, 0.04515464279676548, 0.05894266295600956, 0.09802000717221132, 0.06266213767521421, 0.058502538947961494, 0.08980067812482097, 0.022917072940121665, 0.0562944410229134, 0.036419183934740675, 0.02985327712260656, 0.017925908520906333, 0.0620962168740018, 0.07725534932564043, 0.09273055155983036, 0.028056400984576124, 0.019224871514768704, 0.0758506668871381, 0.030115241411931185, 0.047380573821017506, 0.04996063932082685, 0.03430418629392269, 0.06402179753678135, 0.009087491329278176, 0.01179911973847121, 0.03609770346453922, 0.01170299478820912, 0.017048258668025718, 0.0020372346275399746, 0.08326540595055604, 0.07335490234268104, 0.07612654114257544, 0.04767277205444751, 0.06795876090635204, 0.0243758721930891, 0.01962491019071133, 0.09645366951075646, 0.025512697117130247, 0.06092853249806779, 0.022104090609437313, 0.07707209916730826, 0.0054029130552852855, 0.09195796692479014, 0.009668102602996732, 0.05605458805872368, 0.0774723677146252, 0.019456273043242045, 0.06440237440992606, 0.0072201142789202095, 0.07797957601222233, 0.03955274740052564, 0.007534793496305226, 0.04666461808237964, 0.044356613093418644, 0.06702610540427431, 0.09397051666390746, 0.02492758406138401, 0.021545571290444446], [5, 2, 1, 5, 8, 6, 9, 8, 2, 3, 8, 5, 9, 5, 3]]}
Binary file added examples/onnx/gather_nd/network.onnx
Binary file not shown.
76 changes: 76 additions & 0 deletions examples/onnx/scatter_nd/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json

sys.path.append("..")

class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len

# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)

def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]

class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual

model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3

configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)

x = torch.randn(1, seq_len, pred_len)


torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})


d1 = ((x).detach().numpy()).reshape([-1]).tolist()


data = dict(
input_data=[d1],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/scatter_nd/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
Binary file added examples/onnx/scatter_nd/network.onnx
Binary file not shown.
Loading

0 comments on commit a1450f8

Please sign in to comment.