Skip to content

Commit

Permalink
implement if, and pad reflect mode (huggingface#2251)
Browse files Browse the repository at this point in the history
* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
shua and LaurentMazare authored Jun 6, 2024
1 parent f65e90e commit b9fac7e
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 54 deletions.
101 changes: 98 additions & 3 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};

Expand Down Expand Up @@ -56,6 +56,15 @@ impl Attr for str {
}
}

impl Attr for GraphProto {
const TYPE: AttributeType = AttributeType::Graph;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
attr.g
.as_ref()
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
}
}

impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
Expand Down Expand Up @@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
simple_eval_(graph, &mut inputs)
}

fn simple_eval_(
graph: &onnx::GraphProto,
values: &mut HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
Expand Down Expand Up @@ -958,6 +973,86 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
"If" => {
// protobuf encodes boolean false as 0 and true as 1
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
let attr_name = if cond != 0 {
"then_branch"
} else {
"else_branch"
};
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
if sub_graph.output.len() != node.output.len() {
bail!(
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
node.name,
sub_graph.output.len(),
node.output.len()
);
}
let branch_out = simple_eval_(sub_graph, values)?;
for (i, out) in node.output.iter().enumerate() {
values.insert(
out.clone(),
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
);
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
"Pad" => {
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
if pads.dim(0).unwrap() != 2 * data.rank() {
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
}

let pads = pads.to_vec1::<i64>()?;
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);

match mode {
"reflect" => {
let mut out = data.clone();
for (i, &dim) in data.dims().iter().enumerate().rev() {
if pads_pre[i] == 0 && pads_post[i] == 0 {
continue;
}
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
let cycle_len = dim * 2 - 1;
let skip = (pads_pre[i] as usize) % cycle_len;
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
Tensor::from_iter(idx, out.device())?
} else {
Tensor::full(0i64, (dim,), out.device())?
};

out = out.index_select(&idx, i)?;
}

values.insert(node.output[0].clone(), out);
}
_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
),
}
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
Expand Down
Loading

0 comments on commit b9fac7e

Please sign in to comment.