diff --git a/README.md b/README.md
index dcdfe6d3..0441ce66 100644
--- a/README.md
+++ b/README.md
@@ -346,7 +346,7 @@ fn test_matmul_square_matrix() {
|Sin|7|✅|✅|
|Sinh|9|✅|✅|
|Size|13, 1|✅|✅|
-|Slice|13, 11, 10, 1||✅|
+|Slice|13, 11, 10, 1|✅|✅|
|Softplus|1|✅|
|Softsign|1|✅|
|SpaceToDepth|13, 1|
@@ -398,6 +398,8 @@ fn test_matmul_square_matrix() {
* For `MatMul` and `Gemm`, the matrix dimensions must be divisible by 2, or the output matrix must be of size (1, N). Matrix
multiplication only supports floats, not integers (this is a WebGPU/WGSL limitation).
+* The Slice operator can only be computed for axes of length one. (i.e., there must always be exactly one axis.)
+
### Shape inference
WONNX needs to know the shape of input and output tensors for each operation in order to generate shader code for executing
diff --git a/data/images/dog.jpg b/data/images/dog.jpg
new file mode 100644
index 00000000..77b03812
Binary files /dev/null and b/data/images/dog.jpg differ
diff --git a/data/models/coco-classes.txt b/data/models/coco-classes.txt
new file mode 100644
index 00000000..16315f2b
--- /dev/null
+++ b/data/models/coco-classes.txt
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+pottedplant
+bed
+diningtable
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
\ No newline at end of file
diff --git a/data/models/yolox_nano.onnx b/data/models/yolox_nano.onnx
new file mode 100644
index 00000000..03adac0e
Binary files /dev/null and b/data/models/yolox_nano.onnx differ
diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml
index b24f5a53..ed1f92dd 100644
--- a/wonnx/Cargo.toml
+++ b/wonnx/Cargo.toml
@@ -40,7 +40,8 @@ futures = "^0.3.26"
parking_lot = { version = "0.11.1", features = ["wasm-bindgen"] }
[dev-dependencies]
-image = "0.24.2"
+image = "0.25.1"
+imageproc = "0.24.0"
ndarray = "0.15.4"
approx = "0.5.1"
pollster = "0.3.0"
diff --git a/wonnx/examples/yolox_nano.rs b/wonnx/examples/yolox_nano.rs
new file mode 100644
index 00000000..f491126f
--- /dev/null
+++ b/wonnx/examples/yolox_nano.rs
@@ -0,0 +1,306 @@
+use image::imageops;
+use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
+use imageproc::drawing::draw_hollow_rect_mut;
+use imageproc::rect::Rect;
+use log::info;
+use std::collections::HashMap;
+use std::convert::TryInto;
+use std::time::Instant;
+use std::vec;
+use std::{
+ fs,
+ io::{BufRead, BufReader},
+ path::Path,
+};
+use wonnx::WonnxError;
+
+/*-----------------------------------------------------------------------------
+ Post processing
+--------------------------------------------------------------------------------*/
+fn draw_rect(image: &mut ImageBuffer, Vec>, x1: f32, y1: f32, x2: f32, y2: f32) {
+ let x1 = x1 as u32;
+ let y1 = y1 as u32;
+ let x2 = x2 as u32;
+ let y2 = y2 as u32;
+ let rect = Rect::at(x1 as i32, y1 as i32).of_size(x2 - x1 as u32, (y2 - y1) as u32);
+ draw_hollow_rect_mut(image, rect, Rgb([255, 0, 0]));
+}
+
+fn calc_loc(positions: &Vec<(f32, f32, f32, f32)>) -> Vec<(f32, f32, f32, f32)> {
+ let mut locs = vec![];
+
+ // calc girds
+ let (h, w) = (416, 416);
+ let strides = vec![8, 16, 32];
+ let mut h_grids = vec![];
+ let mut w_grids = vec![];
+
+ for stride in strides.iter() {
+ let mut h_grid = vec![0.0; h / stride];
+ let mut w_grid = vec![0.0; w / stride];
+
+ for i in 0..h / stride {
+ h_grid[i] = i as f32;
+ }
+ for i in 0..w / stride {
+ w_grid[i] = i as f32;
+ }
+ h_grids.push(h_grid);
+ w_grids.push(w_grid);
+ }
+ let acc = vec![0, 52 * 52, 52 * 52 + 26 * 26, 52 * 52 + 26 * 26 + 13 * 13];
+
+ for (i, stride) in strides.iter().enumerate() {
+ let h_grid = &h_grids[i];
+ let w_grid = &w_grids[i];
+ let idx = acc[i];
+
+ for (i, y) in h_grid.iter().enumerate() {
+ for (j, x) in w_grid.iter().enumerate() {
+ let p = idx + i * w / stride + j;
+ let (px, py, pw, ph) = positions[p];
+ let (x, y) = ((x + px) * *stride as f32, (y + py) * *stride as f32);
+ let (ww, hh) = (pw.exp() * *stride as f32, ph.exp() * *stride as f32);
+ let loc = (x - ww / 2.0, y - hh / 2.0, x + ww / 2.0, y + hh / 2.0);
+ locs.push(loc);
+ }
+ }
+ }
+ locs
+}
+
+fn non_max_suppression(
+ boxes: &Vec<(f32, f32, f32, f32)>,
+ scores: &Vec,
+ score_threshold: f32,
+ iou_threshold: f32,
+) -> Vec<(usize, (f32, f32, f32, f32))> {
+ let mut new_boxes = vec![];
+ let mut sorted_indices = (0..boxes.len()).collect::>();
+ sorted_indices.sort_by(|a, b| scores[*a].partial_cmp(&scores[*b]).unwrap());
+
+ while let Some(last) = sorted_indices.pop() {
+ let mut remove_list = vec![];
+ let score = scores[last];
+ let bbox = boxes[last];
+ let mut numerator = (
+ bbox.0 * score,
+ bbox.1 * score,
+ bbox.2 * score,
+ bbox.3 * score,
+ );
+ let mut denominator = score;
+
+ for i in 0..sorted_indices.len() {
+ let idx = sorted_indices[i];
+ let (x1, y1, x2, y2) = boxes[idx];
+ let (x1_, y1_, x2_, y2_) = boxes[last];
+ let box1_area = (x2 - x1) * (y2 - y1);
+
+ let inter_x1 = x1.max(x1_);
+ let inter_y1 = y1.max(y1_);
+ let inter_x2 = x2.min(x2_);
+ let inter_y2 = y2.min(y2_);
+ let inter_w = (inter_x2 - inter_x1).max(0.0);
+ let inter_h = (inter_y2 - inter_y1).max(0.0);
+ let inter_area = inter_w * inter_h;
+ let area1 = (x2 - x1) * (y2 - y1);
+ let area2 = (x2_ - x1_) * (y2_ - y1_);
+ let union_area = area1 + area2 - inter_area;
+ let iou = inter_area / union_area;
+
+ if scores[idx] < score_threshold {
+ remove_list.push(i);
+ } else if iou > iou_threshold {
+ remove_list.push(i);
+ let w = scores[idx] * iou;
+ numerator = (
+ numerator.0 + boxes[idx].0 * w,
+ numerator.1 + boxes[idx].1 * w,
+ numerator.2 + boxes[idx].2 * w,
+ numerator.3 + boxes[idx].3 * w,
+ );
+ denominator += w;
+ } else if inter_area / box1_area > 0.7 {
+ remove_list.push(i);
+ }
+ }
+ for i in remove_list.iter().rev() {
+ sorted_indices.remove(*i);
+ }
+ let new_bbox = (
+ numerator.0 / denominator,
+ numerator.1 / denominator,
+ numerator.2 / denominator,
+ numerator.3 / denominator,
+ );
+ new_boxes.push((last, new_bbox));
+ }
+ new_boxes
+}
+
+fn post_process(preds: &[f32]) -> Vec<(String, f32, f32, f32, f32, f32)> {
+ let labels = get_coco_labels();
+ let mut positions = vec![];
+ let mut classes = vec![];
+ let mut objectnesses = vec![];
+ for i in 0..3549 {
+ let offset = i * 85;
+ let objectness = preds[offset + 4];
+
+ let (class, score) = preds[offset + 5..offset + 85]
+ .iter()
+ .enumerate()
+ .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
+ .unwrap();
+ let class = labels[class].clone();
+ let x1 = preds[offset];
+ let y1 = preds[offset + 1];
+ let x2 = preds[offset + 2];
+ let y2 = preds[offset + 3];
+ classes.push((class, score));
+ positions.push((x1, y1, x2, y2));
+ objectnesses.push(objectness);
+ }
+
+ let locs = calc_loc(&positions);
+
+ let mut result = vec![];
+ // filter by objectness
+ let indices = non_max_suppression(&locs, &objectnesses, 0.5, 0.3);
+ for bbox in indices {
+ let (i, (x, y, w, h)) = bbox;
+ let (class, &score) = &classes[i];
+ result.push((class.clone(), score, x, y, w, h));
+ }
+ result
+}
+
+/*-----------------------------------------------------------------------------
+ Pre processing
+--------------------------------------------------------------------------------*/
+fn padding_image(image: ImageBuffer, Vec>) -> ImageBuffer, Vec> {
+ let (width, height) = image.dimensions();
+ let target_size = if width > height { width } else { height };
+ let mut new_image = ImageBuffer::new(target_size as u32, target_size as u32);
+ let x_offset = (target_size as u32 - width) / 2;
+ let y_offset = (target_size as u32 - height) / 2;
+ for j in 0..height {
+ for i in 0..width {
+ let pixel = image.get_pixel(i, j);
+ new_image.put_pixel(i + x_offset, j + y_offset, *pixel);
+ }
+ }
+ new_image
+}
+
+fn load_image() -> (Vec, ImageBuffer, Vec>) {
+ let args: Vec = std::env::args().collect();
+ let image_path = if args.len() == 2 {
+ Path::new(&args[1]).to_path_buf()
+ } else {
+ Path::new(env!("CARGO_MANIFEST_DIR"))
+ .join("../data/images")
+ .join("dog.jpg")
+ };
+
+ let image_buffer: ImageBuffer, Vec> = image::open(image_path).unwrap().to_rgb8();
+ let image_buffer = padding_image(image_buffer);
+ let image_buffer = imageops::resize(&image_buffer, 416, 416, FilterType::Nearest);
+
+ // convert image to Vec with channel first format
+ let mut image = vec![0.0; 3 * 416 * 416];
+ for j in 0..416 {
+ for i in 0..416 {
+ let pixel = image_buffer.get_pixel(i as u32, j as u32);
+ let channels = pixel.channels();
+ for c in 0..3 {
+ image[c * 416 * 416 + j * 416 + i] = channels[c] as f32;
+ }
+ }
+ }
+ return (image, image_buffer);
+}
+
+fn get_coco_labels() -> Vec {
+ // Download the ImageNet class labels, matching SqueezeNet's classes.
+ let labels_path = Path::new(env!("CARGO_MANIFEST_DIR"))
+ .join("../data/models")
+ .join("coco-classes.txt");
+ let file = BufReader::new(fs::File::open(labels_path).unwrap());
+
+ file.lines().map(|line| line.unwrap()).collect()
+}
+
+/*-----------------------------------------------------------------------------
+ Main
+--------------------------------------------------------------------------------*/
+// Hardware management
+async fn execute_gpu() -> Result, WonnxError> {
+ let mut input_data = HashMap::new();
+ let (image, _) = load_image();
+ let images = image.as_slice().try_into().unwrap();
+ input_data.insert("images".to_string(), images);
+
+ let model_path = Path::new(env!("CARGO_MANIFEST_DIR"))
+ .join("../data/models")
+ .join("yolox_nano.onnx");
+ let session = wonnx::Session::from_path(model_path).await?;
+ let time_pre_compute = Instant::now();
+
+ info!("Start Compute");
+ let result = session.run(&input_data).await?;
+ let time_post_compute = Instant::now();
+ println!(
+ "time: first_prediction: {:#?}",
+ time_post_compute - time_pre_compute
+ );
+
+ info!("Start Post Processing");
+ let time_pre_compute = Instant::now();
+ let output = result.get("output").unwrap();
+ let output = output.try_into().unwrap();
+ let positions = post_process(output);
+ let time_post_compute = Instant::now();
+ println!(
+ "time: post_processing: {:#?}",
+ time_post_compute - time_pre_compute
+ );
+
+ Ok(positions)
+}
+
+async fn run() {
+ // Output shape is [1, 3549, 85]
+ // 85 = 4 (bounding box) + 1 (objectness) + 80 (class probabilities)
+ let preds = execute_gpu().await.unwrap();
+
+ let (_, image_buffer) = load_image();
+ let mut image_buffer = image_buffer;
+ for (class, score, x0, y0, x1, y1) in preds.iter() {
+ println!(
+ "class: {}, score: {}, x0: {}, y0: {}, x1: {}, y1: {}",
+ class, *score, *x0, *y0, *x1, *y1
+ );
+ draw_rect(&mut image_buffer, *x0, *y0, *x1, *y1);
+ }
+ image_buffer.save("yolox_predict.jpg").unwrap();
+}
+
+fn main() {
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ env_logger::init();
+ let time_pre_compute = Instant::now();
+
+ pollster::block_on(run());
+ let time_post_compute = Instant::now();
+ println!("time: main: {:#?}", time_post_compute - time_pre_compute);
+ }
+ #[cfg(target_arch = "wasm32")]
+ {
+ // std::panic::set_hook(Box::new(console_error_panic_hook::hook));
+ // console_log::init().expect("could not initialize logger");
+ wasm_bindgen_futures::spawn_local(run());
+ }
+}
diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs
index 009f4a42..e95268bd 100644
--- a/wonnx/src/compiler.rs
+++ b/wonnx/src/compiler.rs
@@ -52,6 +52,11 @@ fn get_templates() -> &'static Tera {
include_str!("../templates/endomorphism/cast.wgsl"),
)
.unwrap();
+ tera.add_raw_template(
+ "endomorphism/slice.wgsl",
+ include_str!("../templates/endomorphism/slice.wgsl"),
+ )
+ .unwrap();
tera.add_raw_template(
"matrix/concat.wgsl",
include_str!("../templates/matrix/concat.wgsl"),
@@ -470,6 +475,60 @@ pub fn compile(
}
}
+ "Slice" => {
+ // Assume that starts, endsm, axes, and steps are all defined.
+ if node.get_attribute().len() != 4 {
+ return Err(CompileError::UnimplementedOp(
+ "Slice (supports only 4 attributes: starts, ends, axes, and steps)".to_string(),
+ ));
+ }
+
+ let starts = node.get_attribute_value::>("starts", Some(vec![0]))?;
+ if starts.len() != 1 {
+ return Err(CompileError::UnimplementedOp(
+ "Slice (supports only that the length of start is 1".to_string(),
+ ));
+ }
+ context.insert("starts", &starts);
+
+ let ends = node.get_attribute_value::>("ends", Some(vec![2147483647]))?;
+ if ends.len() != 1 {
+ return Err(CompileError::UnimplementedOp(
+ "Slice (supports only that the length of end is 1".to_string(),
+ ));
+ }
+ context.insert("ends", &ends);
+
+ let axes = node.get_attribute_value::>("axes", Some(vec![0]))?;
+ if axes.len() != 1 {
+ return Err(CompileError::UnimplementedOp(
+ "Slice (supports only that the length of axes is 1".to_string(),
+ ));
+ }
+ context.insert("axes", &axes);
+
+ let steps = node.get_attribute_value::>("steps", Some(vec![1]))?;
+ if steps.len() != 1 {
+ return Err(CompileError::UnimplementedOp(
+ "Slice (supports only that the length of steps is 1".to_string(),
+ ));
+ }
+ context.insert("steps", &steps);
+
+ let (x_threads, workgroup_size_x) = workgroup_size(
+ input_lengths[0],
+ MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
+ MAX_WORKGROUP_SIZE_X,
+ )?;
+ context.insert("workgroup_size_x", &workgroup_size_x);
+
+ NodeTemplate {
+ scalar_type: agreed_type(&input_shapes[0..1], output_shapes)?,
+ template: "endomorphism/slice.wgsl",
+ threads: (x_threads, 1, 1),
+ }
+ }
+
"Softmax" => {
let default_axis = match opset_version {
1..=10 => 1, // https://github.com/onnx/onnx/blob/master/docs/Changelog.md#softmax-1
@@ -804,6 +863,9 @@ pub fn compile(
}
context.insert("cum_len", &input_cumulative_len);
+ let axis = node.get_attribute_value("axis", Some(0))? as usize;
+ context.insert("axis", &axis);
+
let root = output_lengths[0].sqrt() + 1;
let per_dim = ceil(root, 16) + 1;
diff --git a/wonnx/src/optimizer.rs b/wonnx/src/optimizer.rs
index 7f270762..1983cbe8 100644
--- a/wonnx/src/optimizer.rs
+++ b/wonnx/src/optimizer.rs
@@ -561,7 +561,7 @@ impl<'model> Optimizer<'model> {
op @ ("Clip" | "Pad" | "Split" | "Resize" | "Reshape" | "ReduceMean"
| "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare"
| "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1"
- | "ReduceProd") => {
+ | "ReduceProd" | "Slice") => {
if new_inputs.is_empty() {
return Err(OptimizerError::NoInputs);
}
@@ -583,6 +583,7 @@ impl<'model> Optimizer<'model> {
"ReduceMin" => REDUCE_OPS_INPUT_NAMES,
"ReduceProd" => REDUCE_OPS_INPUT_NAMES,
"ReduceSumSquare" => REDUCE_OPS_INPUT_NAMES,
+ "Slice" => SLICE_INPUT_NAMES,
_ => unreachable!(),
};
@@ -614,7 +615,8 @@ impl<'model> Optimizer<'model> {
)
| ("Pad", "pads")
| ("Resize", "scales")
- | ("Clip", "min" | "max") => match data_type {
+ | ("Clip", "min" | "max")
+ | ("Slice", "starts" | "ends" | "axes" | "steps") => match data_type {
ScalarType::F32 => {
let value: Vec = if tensor_proto
.get_float_data()
@@ -640,13 +642,21 @@ impl<'model> Optimizer<'model> {
let value = if tensor_proto
.get_int64_data()
.is_empty()
- {
+ {
pod_collect_to_vec(tensor_proto.get_raw_data())
} else {
tensor_proto.get_int64_data().to_vec()
};
+ // If values is larger than i32::MAX, we need to convert it to i32::MAX
+ let value = value.iter().map(|x| {
+ if *x > i32::MAX as i64 {
+ i32::MAX as i64
+ } else {
+ *x
+ }
+ }).collect::>();
log::info!(
- "transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}",
+ "transferring input \"{}\" for op \"{}\" to i64 attribute (initializer data type: {:?}): {:?}",
attr_name,
op,
data_type,
@@ -702,7 +712,6 @@ impl<'model> Optimizer<'model> {
Ok(Arc::new(new_node))
}
-
_ => Ok(Arc::new(Node {
inputs: new_inputs,
definition: NodeDefinition::Operator(op_def.clone()),
@@ -819,6 +828,7 @@ static RESHAPE_INPUT_NAMES: &[&str] = &["data", "shape"];
static CLIP_INPUT_NAMES: &[&str] = &["input", "min", "max"];
static REDUCE_OPS_INPUT_NAMES: &[&str] = &["input", "axes"];
static PAD_INPUT_NAMES: &[&str] = &["data", "pads", "constant_value"];
+static SLICE_INPUT_NAMES: &[&str] = &["data", "starts", "ends", "axes", "steps"];
/// Generate the output for a ConstantOfShape node
pub fn constant_of_shape_output(
diff --git a/wonnx/templates/endomorphism/slice.wgsl b/wonnx/templates/endomorphism/slice.wgsl
new file mode 100644
index 00000000..5dbc6f31
--- /dev/null
+++ b/wonnx/templates/endomorphism/slice.wgsl
@@ -0,0 +1,113 @@
+{%- include "structs.wgsl" -%}
+
+@group(0) @binding(0)
+var input_0: Array; // data
+
+@group(0) @binding(1)
+var output_0: Array;
+
+@compute @workgroup_size({{ workgroup_size_x }})
+fn main(@builtin(global_invocation_id) global_id: vec3) {
+ let gidx = global_id.x;
+ let x0 = {{ i_shape[0][0] }}i;
+ let x1 = {{ i_shape[0][1] }}i;
+
+ {%- if i_shape[0] | length == 3 -%}
+ let x2 = {{ i_shape[0][2] }}i;
+ let x3 = 1i;
+ {%- elif i_shape[0] | length == 4 -%}
+ let x2 = {{ i_shape[0][2] }}i;
+ let x3 = {{ i_shape[0][3] }}i;
+ {%- else -%}
+ let x2 = 1i;
+ let x3 = 1i;
+ {%- endif -%}
+
+ let a = i32(gidx) / (x1 * x2 * x3);
+ let b = (i32(gidx) % (x1 * x2 * x3)) / (x2 * x3);
+ let c = (i32(gidx) % (x2 * x3)) / x3;
+ let d = i32(gidx) % x3;
+
+ // Assume that starts, ends, axes and steps are only 1 element
+ let start = {{ starts[0] }}i;
+ var end = {{ ends[0] }}i;
+ let ax = {{ axes[0] }}i;
+ let step = {{ steps[0] }}i;
+
+ // I'm not sure this if statement is moved to the compiler or not
+ {% if i_shape[0] | length == 4 %}
+ let end2 = {{ i_shape[0][2] }}i;
+ let end3 = {{ i_shape[0][3] }}i;
+ {% elif i_shape[0] | length == 3 %}
+ let end2 = {{ i_shape[0][2] }}i;
+ let end3 = 0i;
+ {% else %}
+ let end2 = 0i; // unreachable
+ let end3 = 0i; // unreachable
+ {% endif %}
+
+ if end == 2147483647i {
+ if ax == 0 {
+ end = {{ i_shape[0][0] }}i;
+ } else if ax == 1 {
+ end = {{ i_shape[0][1] }}i;
+ } else if ax == 2 {
+ end = end2;
+ } else if ax == 3 {
+ end = end3;
+ }
+ }
+
+ if ax == 0 {
+ if start <= a && a < end {
+ let reminder = (a - start) % step;
+ if reminder == 0 {
+ let j = (a - start) / step;
+ let idx = j * x1 * x2 * x3
+ + b * x2 * x3
+ + c * x3
+ + d;
+ output_0.data[idx] = input_0.data[gidx];
+ }
+ }
+ } else if ax == 1 {
+ if start <= b && b < end {
+ let reminder = (b - start) % step;
+ let ceil = (end - start + step - 1) / step;
+ if reminder == 0 {
+ let j = (b - start) / step;
+ let idx = a * ceil * x2 * x3
+ + j * x2 * x3
+ + c * x3
+ + d;
+ output_0.data[idx] = input_0.data[gidx];
+ }
+ }
+ } else if ax == 2 {
+ if start <= c && c < end {
+ let reminder = (c - start) % step;
+ let ceil = (end - start + step - 1) / step;
+ if reminder == 0 {
+ let j = (c - start) / step;
+ let idx = a * x1 * ceil * x3
+ + b * ceil * x3
+ + j * x3
+ + d;
+ output_0.data[idx] = input_0.data[gidx];
+ }
+ }
+ } else if ax == 3 {
+ if start <= d && d < end {
+ let reminder = (d - start) % step;
+ let ceil = (end - start + step - 1) / step;
+ if reminder == 0 {
+ let j = (d - start) / step;
+ let idx = a * x1 * x2 * ceil
+ + b * x2 * ceil
+ + c * ceil
+ + j;
+ output_0.data[idx] = input_0.data[gidx];
+ }
+ }
+ }
+}
diff --git a/wonnx/templates/matrix/concat.wgsl b/wonnx/templates/matrix/concat.wgsl
index dfba2016..e5f4a0a8 100644
--- a/wonnx/templates/matrix/concat.wgsl
+++ b/wonnx/templates/matrix/concat.wgsl
@@ -15,23 +15,77 @@ var output_0: Array;
@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) {
let gidx = global_id.x;
- let gidy = global_id.y;
+ let gidy = global_id.y;
- let x_executions = num_workgroups.x * 16u;
+ let x_executions = num_workgroups.x * 16u;
- let actual_idx = gidx + gidy * x_executions;
+ let actual_idx = gidx + gidy * x_executions;
+ let ax = {{ axis }}u;
+
{% for input in i_lens %}
{% if loop.first %}
if (actual_idx < {{ i_lens[0] }}u) {
- output_0.data[actual_idx] = input_0.data[actual_idx];
+ var global_input_idx = actual_idx;
+ var input_indices: array = array();
+ {% set in_shape = i_shape[0] %}
+ var input_shape: array = array({{ in_shape | join(sep=', ')}});
+ let output_shape_length = {{ o_shape[0] | length }}u;
+
+ // calculate input indices
+ for (var i = 0u; i < output_shape_length; i = i + 1u) {
+ input_indices[output_shape_length - i - 1] = global_input_idx % input_shape[output_shape_length - i - 1];
+ global_input_idx = global_input_idx / input_shape[output_shape_length - i - 1];
+ }
+
+ // calculate output index
+ {% set out_shape = o_shape[0] %}
+ var output_shape = array({{ out_shape | join(sep=', ')}});
+
+ var output_idx = 0u;
+ for (var i = 0u; i < output_shape_length; i = i + 1u) {
+ output_idx = output_idx * output_shape[i] + input_indices[i];
+ }
+ output_0.data[output_idx] = input_0.data[actual_idx];
}
{% else %}
- if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) {
- output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u];
+ if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) {
+ var global_input_idx = actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u;
+ var input_indices: array = array();
+ {% set shape = i_shape[loop.index0] %}
+ var input_shape: array = array({{ shape | join(sep=', ')}});
+ let output_shape_length = {{ o_shape[0] | length }}u;
+
+ // calculate input indices
+ var input_idx = global_input_idx;
+ for (var i = 0u; i < output_shape_length; i = i + 1u) {
+ input_indices[output_shape_length - i - 1] = input_idx % input_shape[output_shape_length - i - 1];
+ input_idx = input_idx / input_shape[output_shape_length - i - 1];
+ }
+
+
+ // calculate the offset
+ var offset = 0u;
+ {% for i in range(end=loop.index0) %}
+ offset = offset + {{ i_shape[i][axis] }}u;
+ {% endfor %}
+
+ // add offset to input indices
+ input_indices[ax] = input_indices[ax] + offset;
+
+ // calculate output index
+ {% set out_shape = o_shape[0] %}
+ var output_shape = array({{ out_shape | join(sep=', ')}});
+
+ var output_idx = 0u;
+ for (var i = 0u; i < output_shape_length; i = i + 1u) {
+ output_idx = output_idx * output_shape[i] + input_indices[i];
+ }
+
+ output_0.data[output_idx] = input_{{ loop.index0 }}.data[global_input_idx];
}
-
+
{% endif %}
{% endfor %}
}
diff --git a/wonnx/tests/concat.rs b/wonnx/tests/concat.rs
index 994b087b..fc7dc32c 100644
--- a/wonnx/tests/concat.rs
+++ b/wonnx/tests/concat.rs
@@ -1,5 +1,8 @@
-use std::{collections::HashMap, convert::TryInto};
-use wonnx::utils::{graph, model, node, tensor};
+use std::{collections::HashMap, convert::TryInto, vec};
+use wonnx::{
+ onnx::AttributeProto,
+ utils::{attribute, graph, model, node, tensor},
+};
mod common;
#[test]
@@ -117,3 +120,291 @@ fn test_concat4() {
common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result);
}
+
+#[test]
+fn test_concat_axis1() {
+ let xdata = vec![
+ vec![vec![1., 2.], vec![3., 4.]].concat(),
+ vec![vec![5., 6.], vec![7., 8.]].concat(),
+ ]
+ .concat();
+ let x_input_dims = vec![2, 2, 2];
+ let ydata = vec![
+ vec![vec![9., 10.], vec![11., 12.]].concat(),
+ vec![vec![13., 14.], vec![15., 16.]].concat(),
+ ]
+ .concat();
+ let y_input_dims = vec![2, 2, 2];
+ let zdata = vec![
+ vec![vec![17., 18.], vec![19., 20.]].concat(),
+ vec![vec![21., 22.], vec![23., 24.]].concat(),
+ ]
+ .concat();
+ let z_input_dims = vec![2, 2, 2];
+
+ let input_data = HashMap::from([
+ ("X".into(), xdata.as_slice().into()),
+ ("Y".into(), ydata.as_slice().into()),
+ ("Z".into(), zdata.as_slice().into()),
+ ]);
+
+ let attributes: Vec = vec![attribute("axis", 1)];
+
+ let model = model(graph(
+ vec![
+ tensor("X", &x_input_dims),
+ tensor("Y", &y_input_dims),
+ tensor("Z", &z_input_dims),
+ ],
+ vec![tensor("W", &vec![2, 6, 2])],
+ vec![],
+ vec![],
+ vec![node(
+ vec!["X", "Y", "Z"],
+ vec!["W"],
+ "a",
+ "Concat",
+ attributes,
+ )],
+ ));
+
+ let session =
+ pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+
+ let expected_result = vec![
+ vec![
+ vec![1., 2.],
+ vec![3., 4.],
+ vec![9., 10.],
+ vec![11., 12.],
+ vec![17., 18.],
+ vec![19., 20.],
+ ]
+ .concat(),
+ vec![
+ vec![5., 6.],
+ vec![7., 8.],
+ vec![13., 14.],
+ vec![15., 16.],
+ vec![21., 22.],
+ vec![23., 24.],
+ ]
+ .concat(),
+ ]
+ .concat();
+
+ common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result);
+}
+
+#[test]
+fn test_concat_axis2() {
+ let xdata = vec![
+ vec![
+ vec![1., 2., 3., 4.],
+ vec![5., 6., 7., 8.],
+ vec![9., 10., 11., 12.],
+ ]
+ .concat(),
+ vec![
+ vec![13., 14., 15., 16.],
+ vec![17., 18., 19., 20.],
+ vec![21., 22., 23., 24.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let x_input_dims = vec![1, 2, 3, 4];
+ let ydata = vec![
+ vec![
+ vec![25., 26., 27., 28.],
+ vec![29., 30., 31., 32.],
+ vec![33., 34., 35., 36.],
+ ]
+ .concat(),
+ vec![
+ vec![37., 38., 39., 40.],
+ vec![41., 42., 43., 44.],
+ vec![45., 46., 47., 48.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let y_input_dims = vec![1, 2, 3, 4];
+ let zdata = vec![
+ vec![
+ vec![49., 50., 51., 52.],
+ vec![53., 54., 55., 56.],
+ vec![57., 58., 59., 60.],
+ ]
+ .concat(),
+ vec![
+ vec![61., 62., 63., 64.],
+ vec![65., 66., 67., 68.],
+ vec![69., 70., 71., 72.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let z_input_dims = vec![1, 2, 3, 4];
+
+ let input_data = HashMap::from([
+ ("X".into(), xdata.as_slice().into()),
+ ("Y".into(), ydata.as_slice().into()),
+ ("Z".into(), zdata.as_slice().into()),
+ ]);
+
+ let attributes: Vec = vec![attribute("axis", 2)];
+
+ let model = model(graph(
+ vec![
+ tensor("X", &x_input_dims),
+ tensor("Y", &y_input_dims),
+ tensor("Z", &z_input_dims),
+ ],
+ vec![tensor("W", &vec![1, 2, 9, 4])],
+ vec![],
+ vec![],
+ vec![node(
+ vec!["X", "Y", "Z"],
+ vec!["W"],
+ "a",
+ "Concat",
+ attributes,
+ )],
+ ));
+
+ let session =
+ pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+
+ // 2x6x3x4
+ let expected_result = vec![
+ vec![
+ vec![1., 2., 3., 4.],
+ vec![5., 6., 7., 8.],
+ vec![9., 10., 11., 12.],
+ vec![25., 26., 27., 28.],
+ vec![29., 30., 31., 32.],
+ vec![33., 34., 35., 36.],
+ vec![49., 50., 51., 52.],
+ vec![53., 54., 55., 56.],
+ vec![57., 58., 59., 60.],
+ ]
+ .concat(),
+ vec![
+ vec![13., 14., 15., 16.],
+ vec![17., 18., 19., 20.],
+ vec![21., 22., 23., 24.],
+ vec![37., 38., 39., 40.],
+ vec![41., 42., 43., 44.],
+ vec![45., 46., 47., 48.],
+ vec![61., 62., 63., 64.],
+ vec![65., 66., 67., 68.],
+ vec![69., 70., 71., 72.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result);
+}
+
+#[test]
+fn test_concat_axis3() {
+ let xdata = vec![
+ vec![
+ vec![1., 2., 3., 4.],
+ vec![5., 6., 7., 8.],
+ vec![9., 10., 11., 12.],
+ ]
+ .concat(),
+ vec![
+ vec![13., 14., 15., 16.],
+ vec![17., 18., 19., 20.],
+ vec![21., 22., 23., 24.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let x_input_dims = vec![1, 2, 3, 4];
+ let ydata = vec![
+ vec![
+ vec![25., 26., 27., 28.],
+ vec![29., 30., 31., 32.],
+ vec![33., 34., 35., 36.],
+ ]
+ .concat(),
+ vec![
+ vec![37., 38., 39., 40.],
+ vec![41., 42., 43., 44.],
+ vec![45., 46., 47., 48.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let y_input_dims = vec![1, 2, 3, 4];
+ let zdata = vec![
+ vec![
+ vec![49., 50., 51., 52.],
+ vec![53., 54., 55., 56.],
+ vec![57., 58., 59., 60.],
+ ]
+ .concat(),
+ vec![
+ vec![61., 62., 63., 64.],
+ vec![65., 66., 67., 68.],
+ vec![69., 70., 71., 72.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ let z_input_dims = vec![1, 2, 3, 4];
+
+ let input_data = HashMap::from([
+ ("X".into(), xdata.as_slice().into()),
+ ("Y".into(), ydata.as_slice().into()),
+ ("Z".into(), zdata.as_slice().into()),
+ ]);
+
+ let attributes: Vec = vec![attribute("axis", 3)];
+
+ let model = model(graph(
+ vec![
+ tensor("X", &x_input_dims),
+ tensor("Y", &y_input_dims),
+ tensor("Z", &z_input_dims),
+ ],
+ vec![tensor("W", &vec![1, 2, 3, 12])],
+ vec![],
+ vec![],
+ vec![node(
+ vec!["X", "Y", "Z"],
+ vec!["W"],
+ "a",
+ "Concat",
+ attributes,
+ )],
+ ));
+
+ let session =
+ pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+
+ // 2x6x3x4
+ let expected_result = vec![
+ vec![
+ vec![1., 2., 3., 4., 25., 26., 27., 28., 49., 50., 51., 52.],
+ vec![5., 6., 7., 8., 29., 30., 31., 32., 53., 54., 55., 56.],
+ vec![9., 10., 11., 12., 33., 34., 35., 36., 57., 58., 59., 60.],
+ ]
+ .concat(),
+ vec![
+ vec![13., 14., 15., 16., 37., 38., 39., 40., 61., 62., 63., 64.],
+ vec![17., 18., 19., 20., 41., 42., 43., 44., 65., 66., 67., 68.],
+ vec![21., 22., 23., 24., 45., 46., 47., 48., 69., 70., 71., 72.],
+ ]
+ .concat(),
+ ]
+ .concat();
+ common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result);
+}
diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs
new file mode 100644
index 00000000..543ba5f2
--- /dev/null
+++ b/wonnx/tests/slice.rs
@@ -0,0 +1,581 @@
+use std::{collections::HashMap, convert::TryInto, vec};
+use wonnx::{
+ onnx::AttributeProto,
+ utils::{ attribute, graph, model, node, tensor, InputTensor},
+};
+mod common;
+
+fn test_slice(
+ input: &[f32],
+ input_shape: &[i64],
+ starts: &Vec,
+ ends: &Vec,
+ axes: &Vec,
+ steps: &Vec,
+ output: &[f32],
+ output_shape: &[i64],
+) {
+ let mut input_data = HashMap::::new();
+ input_data.insert("X".to_string(), input.into());
+
+ let attributes: Vec = vec![
+ attribute("starts", starts.clone()),
+ attribute("ends", ends.clone()),
+ attribute("axes", axes.clone()),
+ attribute("steps", steps.clone()),
+ ];
+
+ let model = model(graph(
+ vec![tensor("X", input_shape)],
+ vec![tensor("Y", output_shape)],
+ vec![],
+ vec![],
+ vec![node(
+ vec!["X"],
+ vec!["Y"],
+ "slice",
+ "Slice",
+ attributes,
+ )],
+ ));
+
+ let session = match pollster::block_on(wonnx::Session::from_model(model)) {
+ Ok(session) => session,
+ Err(e) => {
+ panic!("Failed to create session: {:?}", e);
+ }
+ };
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+ log::info!("OUT: {:?}", result["Y"]);
+ common::assert_eq_vector((&result["Y"]).try_into().unwrap(), output);
+}
+
+#[test]
+fn slice_step1() {
+ let _ = env_logger::builder().is_test(true).try_init();
+
+ // This test is the most simple case
+ // Note that each interval is half-open, i.e. it includes the start index but excludes the end index.
+ // axes == 0
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![0],
+ &vec![1],
+ &vec![0],
+ &vec![1],
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ );
+
+ // axes == 1
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![0],
+ &vec![2],
+ &vec![1],
+ &vec![1],
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ );
+
+ // axes == 2
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![0],
+ &vec![4],
+ &vec![2],
+ &vec![1],
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ );
+
+ // axes == 1
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![0],
+ &vec![1],
+ &vec![1],
+ &vec![1],
+ &[[1., 2., 3., 4.]].concat(),
+ &[1, 1, 4],
+ );
+
+ // axes == 1
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![1],
+ &vec![2],
+ &vec![1],
+ &vec![1],
+ &[[5., 6., 7., 8.]].concat(),
+ &[1, 1, 4],
+ );
+
+ // axes == 2
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![0],
+ &vec![2],
+ &vec![2],
+ &vec![1],
+ &[[1., 2.], [5., 6.]].concat(),
+ &[1, 2, 2],
+ );
+
+ // axes == 2
+ test_slice(
+ &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(),
+ &[1, 2, 4],
+ &vec![2],
+ &vec![4],
+ &vec![2],
+ &vec![1],
+ &[[3., 4.], [7., 8.]].concat(),
+ &[1, 2, 2],
+ );
+}
+
+#[test]
+fn slice_step2() {
+ // axes == 0
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![0],
+ &vec![1],
+ &vec![0],
+ &vec![2],
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![0],
+ &vec![4],
+ &vec![1],
+ &vec![2],
+ &[
+ [ 1., 2., 3., 4.],
+ [ 9., 10., 11., 12.],
+ ].concat(),
+ &[1, 2, 4, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![1],
+ &vec![4],
+ &vec![1],
+ &vec![2],
+ &[
+ [ 5., 6., 7., 8.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 2, 4, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![0],
+ &vec![2],
+ &vec![1],
+ &vec![2],
+ &[
+ [ 1., 2., 3., 4.],
+ ].concat(),
+ &[1, 1, 4, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![3],
+ &vec![4],
+ &vec![1],
+ &vec![2],
+ &[
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 1, 4, 1],
+ );
+
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![0],
+ &vec![4],
+ &vec![2],
+ &vec![2],
+ &[
+ [ 1., 3.],
+ [ 5., 7.],
+ [ 9., 11.],
+ [13., 15.]
+ ].concat(),
+ &[1, 4, 2, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![0],
+ &vec![2],
+ &vec![2],
+ &vec![2],
+ &[
+ [ 1.],
+ [ 5.],
+ [ 9.],
+ [13.]
+ ].concat(),
+ &[1, 4, 1, 1],
+ );
+
+
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[1, 4, 4, 1],
+ &vec![2],
+ &vec![4],
+ &vec![2],
+ &vec![2],
+ &[
+ [ 3.],
+ [ 7.],
+ [11.],
+ [15.]
+ ].concat(),
+ &[1, 4, 1, 1],
+ );
+}
+
+
+#[test]
+fn slice_1x3x416x416_step2() {
+ let input0 = vec![1.0; 1 * 3 * 416 * 416];
+ let output0 = vec![1.0; 1 * 3 * 208 * 416];
+ #[rustfmt::skip]
+ test_slice(
+ &input0,
+ &[1, 3, 416, 416],
+ &vec![0],
+ &vec![i32::MAX as i64],
+ &vec![2],
+ &vec![2],
+ &output0,
+ &[1, 3, 208, 416],
+ );
+}
+
+#[test]
+fn slice_1x3x416x416_step2_axes2_start0() {
+ let mut input2 = vec![];
+ for _ in 0..3 {
+ let mut row = vec![];
+ for j in 0..416 {
+ let mut col = vec![];
+ for _ in 0..416 {
+ if j % 2 == 0 {
+ col.push(1.0);
+ } else {
+ col.push(0.0);
+
+ }
+ }
+ row.push(col);
+ }
+ input2.push(row);
+ }
+ // flatten the input
+ let input2 = input2.iter().flatten().flatten().copied().collect::>();
+
+
+ let output2 = vec![1.0; 1 * 3 * 208 * 416];
+ #[rustfmt::skip]
+ test_slice(
+ &input2,
+ &[1, 3, 416, 416],
+ &vec![0],
+ &vec![i32::MAX as i64],
+ &vec![2],
+ &vec![2],
+ &output2,
+ &[1, 3, 208, 416],
+ );
+}
+
+
+#[test]
+fn slice_1x3x416x416_step2_axes2_start1() {
+ let mut input2 = vec![];
+ for _ in 0..3 {
+ let mut row = vec![];
+ for j in 0..416 {
+ let mut col = vec![];
+ for _ in 0..416 {
+ if j % 2 == 0 {
+ col.push(0.0);
+ } else {
+ col.push(1.0);
+
+ }
+ }
+ row.push(col);
+ }
+ input2.push(row);
+ }
+ // flatten the input
+ let input2 = input2.iter().flatten().flatten().copied().collect::>();
+
+
+ let output2 = vec![1.0; 1 * 3 * 208 * 416];
+ #[rustfmt::skip]
+ test_slice(
+ &input2,
+ &[1, 3, 416, 416],
+ &vec![1],
+ &vec![i32::MAX as i64],
+ &vec![2],
+ &vec![2],
+ &output2,
+ &[1, 3, 208, 416],
+ );
+}
+
+#[test]
+fn slice_1x3x8x8_step2_axes2_start1() {
+ let input2 = vec![
+ vec![
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ ],
+ vec![
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ ],
+ vec![
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ vec![1., 1., 1., 1., 1., 1., 1., 1.],
+ vec![0., 0., 0., 0., 0., 0., 0., 0.],
+ ]
+ ];
+
+ // flatten the input
+ let input2 = input2.iter().flatten().flatten().copied().collect::>();
+
+ let output2 = vec![0.0; 1 * 3 * 4 * 8];
+ #[rustfmt::skip]
+ test_slice(
+ &input2,
+ &[1, 3, 8, 8],
+ &vec![1],
+ &vec![i32::MAX as i64],
+ &vec![2],
+ &vec![2],
+ &output2,
+ &[1, 3, 4, 8],
+ );
+
+
+}
+
+#[test]
+fn slice_1x3x8x8_step2_axes3_start1() {
+ let input2 = vec![
+ vec![
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ ],
+ vec![
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ ],
+ vec![
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ vec![0., 1., 0., 1., 0., 1., 0., 1.],
+ ]
+ ];
+ // flatten the input
+ let input2 = input2.iter().flatten().flatten().copied().collect::>();
+
+ let output2 = vec![1.0; 1 * 3 * 8 * 4];
+ #[rustfmt::skip]
+ test_slice(
+ &input2,
+ &[1, 3, 8, 8],
+ &vec![1],
+ &vec![i32::MAX as i64],
+ &vec![3],
+ &vec![2],
+ &output2,
+ &[1, 3, 8, 4],
+ );
+}
+
+#[test]
+fn slice_none_axes_and_steps() {
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[4, 4, 1],
+ &vec![0],
+ &vec![2],
+ &vec![0],
+ &vec![1],
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ ].concat(),
+ &[2, 4, 1],
+ );
+}
+
+#[test]
+fn slice_ends_max_i32() {
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[4, 4, 1],
+ &vec![0],
+ &vec![i32::MAX as i64],
+ &vec![0],
+ &vec![2],
+ &[
+ [ 1., 2., 3., 4.],
+ [ 9., 10., 11., 12.],
+ ].concat(),
+ &[2, 4, 1],
+ );
+
+ #[rustfmt::skip]
+ test_slice(
+ &[
+ [ 1., 2., 3., 4.],
+ [ 5., 6., 7., 8.],
+ [ 9., 10., 11., 12.],
+ [13., 14., 15., 16.]
+ ].concat(),
+ &[4, 4, 1],
+ &vec![0],
+ &vec![i32::MAX as i64],
+ &vec![1],
+ &vec![2],
+ &[
+ [ 1., 3.],
+ [ 5., 7.],
+ [ 9., 11.],
+ [13., 15.]
+ ].concat(),
+ &[4, 2, 1],
+ );
+}