Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit e46271c

Browse files
committed
feat: implement shape inference for Flatten operator
1 parent dade7eb commit e46271c

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ fn test_matmul_square_matrix() {
253253
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp">Exp</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-1">1</a>|||
254254
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand">Expand</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Expand-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Expand-8">8</a>|
255255
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike">EyeLike</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#EyeLike-9">9</a>|
256-
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten">Flatten</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9">9</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1">1</a>||
256+
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten">Flatten</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9">9</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1">1</a>|||
257257
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Floor">Floor</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-1">1</a>|||
258258
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU">GRU</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GRU-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GRU-7">7</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GRU-3">3</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GRU-1">1</a>|
259259
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather">Gather</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1">1</a>|✅ (axis=0)||

wonnx-preprocessing/src/shape_inference.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,33 @@ pub(crate) fn infer_output_shapes(
506506
Ok(vec![output_shape])
507507
}
508508

509+
("Flatten", 1, 1) => {
510+
let axis: usize = {
511+
let a = node.get_attribute_value("axis", Some(1)).unwrap();
512+
if a < 0 {
513+
(a + input_shapes[0].rank() as i64) as usize
514+
} else {
515+
a as usize
516+
}
517+
};
518+
if axis > input_shapes[0].rank() {
519+
return Err(ShapeInferenceError::InvalidNode(
520+
node.get_name().to_string(),
521+
format!("Flatten axis attribute ({axis}) should be less than or equal to rank of input ({})",input_shapes[0].rank()),
522+
));
523+
}
524+
let input_dims = &input_shapes[0].dims;
525+
let outer_dim = if axis == 0 {
526+
1
527+
} else {
528+
input_dims[0..=(axis - 1)].iter().product::<u64>() as i64
529+
};
530+
let inner_dim = input_dims[axis..].iter().product::<u64>() as i64;
531+
532+
let new_dims = vec![outer_dim, inner_dim];
533+
Ok(vec![Shape::from(input_shapes[0].data_type, &new_dims)])
534+
}
535+
509536
("GlobalAveragePool", 1, 1) => {
510537
let mut output_shape = input_shapes[0].clone();
511538
if output_shape.rank() < 2 {

0 commit comments

Comments
 (0)