Skip to content

Commit

Permalink
chore: update tract to 0.21.8-pre (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Dec 3, 2024
1 parent 64cbcb3 commit 5e169bd
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 97 deletions.
51 changes: 35 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch =
"tokio-runtime",
], default-features = false, optional = true }
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
Expand Down
111 changes: 44 additions & 67 deletions examples/onnx/scatter_nd/gen.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,52 @@
import torch
import torch.nn as nn
import sys
from torch import nn
import json
import numpy as np
import tf2onnx

sys.path.append("..")

class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len

# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)

def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]

class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual

model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3

configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)

x = torch.randn(1, seq_len, pred_len)


torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})


d1 = ((x).detach().numpy()).reshape([-1]).tolist()

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model


# gather_nd in tf then export to onnx
x = in1 = Input((4, 1), dtype=tf.int32)
w = in2 = Input((4, ), dtype=tf.int32)

class MyLayer(Layer):
def call(self, x, w):
shape = tf.constant([8])
return tf.scatter_nd(x, w, shape)

x = MyLayer()(x, w)



tm = Model((in1, in2), x)
tm.summary()
tm.compile(optimizer='adam', loss='mse')

shape = [1, 4, 1]
index_shape = [1, 4]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = np.random.randint(0, 4, shape)
# w = random int tensor
w = np.random.randint(0, 4, index_shape)

spec = tf.TensorSpec(shape, tf.int32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')

model_path = "network.onnx"

tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)


d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()


data = dict(
input_data=[d1],
input_data=[d, d1],
)

# Serialize data into file:
Expand Down
17 changes: 16 additions & 1 deletion examples/onnx/scatter_nd/input.json
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
{
"input_data": [
[
0,
1,
2,
3
],
[
1,
0,
2,
1
]
]
}
Binary file modified examples/onnx/scatter_nd/network.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
Ok(output.into())
}

// assumes unique values in fullset
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
Expand All @@ -1699,6 +1700,8 @@ pub(crate) fn get_missing_set_elements<
let set_len = fullset.len();
input.flatten();

// while fullset is less than len of input concat

let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;

let mut claimed_output: ValTensor<F> = if is_assigned {
Expand Down
12 changes: 6 additions & 6 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ impl Model {

let mut symbol_values = SymbolValues::default();
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
let symbol = model.symbols.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
}
Expand Down Expand Up @@ -1199,9 +1199,9 @@ impl Model {
// Then number of columns in the circuits
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
trace!("input indices: {:?}", node.inputs());
trace!("output scales: {:?}", node.out_scales());
trace!(
"input scales: {:?}",
node.inputs()
.iter()
Expand All @@ -1220,8 +1220,8 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
trace!("output dims: {:?}", node.out_dims());
trace!(
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
Expand Down
10 changes: 5 additions & 5 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,29 +1007,29 @@ pub fn new_op_from_onnx(
op
}
"Iff" => SupportedOp::Linear(PolyOp::Iff),
"Less" => {
"<" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"LessEqual" => {
"<=" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
"Greater" => {
">" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
} else {
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
"GreaterEqual" => {
">=" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
Expand Down Expand Up @@ -1250,7 +1250,7 @@ pub fn new_op_from_onnx(
"And" => SupportedOp::Linear(PolyOp::And),
"Or" => SupportedOp::Linear(PolyOp::Or),
"Xor" => SupportedOp::Linear(PolyOp::Xor),
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
"==" => SupportedOp::Hybrid(HybridOp::Equals),
"Deconv" => {
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
Some(b) => b,
Expand Down
7 changes: 7 additions & 0 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,13 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
pub fn expand(&self, shape: &[usize]) -> Result<Self, TensorError> {
// if both have length 1 then we can just return the tensor
if self.dims().iter().product::<usize>() == 1 && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}

if self.dims().len() > shape.len() {
return Err(TensorError::DimError(format!(
"Cannot expand {:?} to the smaller shape {:?}",
Expand Down
1 change: 1 addition & 0 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();

let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok::<_, TensorError>(())
Expand Down
Loading

0 comments on commit 5e169bd

Please sign in to comment.