diff --git a/README.md b/README.md index dcdfe6d3..0441ce66 100644 --- a/README.md +++ b/README.md @@ -346,7 +346,7 @@ fn test_matmul_square_matrix() { |Sin|7|✅|✅| |Sinh|9|✅|✅| |Size|13, 1|✅|✅| -|Slice|13, 11, 10, 1||✅| +|Slice|13, 11, 10, 1|✅|✅| |Softplus|1|✅| |Softsign|1|✅| |SpaceToDepth|13, 1| @@ -398,6 +398,8 @@ fn test_matmul_square_matrix() { * For `MatMul` and `Gemm`, the matrix dimensions must be divisible by 2, or the output matrix must be of size (1, N). Matrix multiplication only supports floats, not integers (this is a WebGPU/WGSL limitation). +* The Slice operator can only be computed for axes of length one. (i.e., there must always be exactly one axis.) + ### Shape inference WONNX needs to know the shape of input and output tensors for each operation in order to generate shader code for executing diff --git a/data/images/dog.jpg b/data/images/dog.jpg new file mode 100644 index 00000000..77b03812 Binary files /dev/null and b/data/images/dog.jpg differ diff --git a/data/models/coco-classes.txt b/data/models/coco-classes.txt new file mode 100644 index 00000000..16315f2b --- /dev/null +++ b/data/models/coco-classes.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/data/models/yolox_nano.onnx b/data/models/yolox_nano.onnx new file mode 100644 index 00000000..03adac0e Binary files /dev/null and b/data/models/yolox_nano.onnx differ diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml index b24f5a53..ed1f92dd 100644 --- a/wonnx/Cargo.toml +++ b/wonnx/Cargo.toml @@ -40,7 +40,8 @@ futures = "^0.3.26" parking_lot = { version = "0.11.1", features = ["wasm-bindgen"] } [dev-dependencies] -image = "0.24.2" +image = "0.25.1" +imageproc = "0.24.0" ndarray = "0.15.4" approx = "0.5.1" pollster = "0.3.0" diff --git a/wonnx/examples/yolox_nano.rs b/wonnx/examples/yolox_nano.rs new file mode 100644 index 00000000..f491126f --- /dev/null +++ b/wonnx/examples/yolox_nano.rs @@ -0,0 +1,306 @@ +use image::imageops; +use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; +use imageproc::drawing::draw_hollow_rect_mut; +use imageproc::rect::Rect; +use log::info; +use std::collections::HashMap; +use std::convert::TryInto; +use std::time::Instant; +use std::vec; +use std::{ + fs, + io::{BufRead, BufReader}, + path::Path, +}; +use wonnx::WonnxError; + +/*----------------------------------------------------------------------------- + Post processing +--------------------------------------------------------------------------------*/ +fn draw_rect(image: &mut ImageBuffer, Vec>, x1: f32, y1: f32, x2: f32, y2: f32) { + let x1 = x1 as u32; + let y1 = y1 as u32; + let x2 = x2 as u32; + let y2 = y2 as u32; + let rect = Rect::at(x1 as i32, y1 as i32).of_size(x2 - x1 as u32, (y2 - y1) as u32); + draw_hollow_rect_mut(image, rect, Rgb([255, 0, 0])); +} + +fn calc_loc(positions: &Vec<(f32, f32, f32, f32)>) -> Vec<(f32, f32, f32, f32)> { + let mut locs = vec![]; + + // calc girds + let (h, w) = (416, 416); + let strides = vec![8, 16, 32]; + let mut h_grids = vec![]; + let mut w_grids = vec![]; + + for stride in strides.iter() { + let mut h_grid = vec![0.0; h / stride]; + let mut w_grid = vec![0.0; w / stride]; + + for i in 0..h / stride { + h_grid[i] = i as f32; + } + for i in 0..w / stride { + w_grid[i] = i as f32; + } + h_grids.push(h_grid); + w_grids.push(w_grid); + } + let acc = vec![0, 52 * 52, 52 * 52 + 26 * 26, 52 * 52 + 26 * 26 + 13 * 13]; + + for (i, stride) in strides.iter().enumerate() { + let h_grid = &h_grids[i]; + let w_grid = &w_grids[i]; + let idx = acc[i]; + + for (i, y) in h_grid.iter().enumerate() { + for (j, x) in w_grid.iter().enumerate() { + let p = idx + i * w / stride + j; + let (px, py, pw, ph) = positions[p]; + let (x, y) = ((x + px) * *stride as f32, (y + py) * *stride as f32); + let (ww, hh) = (pw.exp() * *stride as f32, ph.exp() * *stride as f32); + let loc = (x - ww / 2.0, y - hh / 2.0, x + ww / 2.0, y + hh / 2.0); + locs.push(loc); + } + } + } + locs +} + +fn non_max_suppression( + boxes: &Vec<(f32, f32, f32, f32)>, + scores: &Vec, + score_threshold: f32, + iou_threshold: f32, +) -> Vec<(usize, (f32, f32, f32, f32))> { + let mut new_boxes = vec![]; + let mut sorted_indices = (0..boxes.len()).collect::>(); + sorted_indices.sort_by(|a, b| scores[*a].partial_cmp(&scores[*b]).unwrap()); + + while let Some(last) = sorted_indices.pop() { + let mut remove_list = vec![]; + let score = scores[last]; + let bbox = boxes[last]; + let mut numerator = ( + bbox.0 * score, + bbox.1 * score, + bbox.2 * score, + bbox.3 * score, + ); + let mut denominator = score; + + for i in 0..sorted_indices.len() { + let idx = sorted_indices[i]; + let (x1, y1, x2, y2) = boxes[idx]; + let (x1_, y1_, x2_, y2_) = boxes[last]; + let box1_area = (x2 - x1) * (y2 - y1); + + let inter_x1 = x1.max(x1_); + let inter_y1 = y1.max(y1_); + let inter_x2 = x2.min(x2_); + let inter_y2 = y2.min(y2_); + let inter_w = (inter_x2 - inter_x1).max(0.0); + let inter_h = (inter_y2 - inter_y1).max(0.0); + let inter_area = inter_w * inter_h; + let area1 = (x2 - x1) * (y2 - y1); + let area2 = (x2_ - x1_) * (y2_ - y1_); + let union_area = area1 + area2 - inter_area; + let iou = inter_area / union_area; + + if scores[idx] < score_threshold { + remove_list.push(i); + } else if iou > iou_threshold { + remove_list.push(i); + let w = scores[idx] * iou; + numerator = ( + numerator.0 + boxes[idx].0 * w, + numerator.1 + boxes[idx].1 * w, + numerator.2 + boxes[idx].2 * w, + numerator.3 + boxes[idx].3 * w, + ); + denominator += w; + } else if inter_area / box1_area > 0.7 { + remove_list.push(i); + } + } + for i in remove_list.iter().rev() { + sorted_indices.remove(*i); + } + let new_bbox = ( + numerator.0 / denominator, + numerator.1 / denominator, + numerator.2 / denominator, + numerator.3 / denominator, + ); + new_boxes.push((last, new_bbox)); + } + new_boxes +} + +fn post_process(preds: &[f32]) -> Vec<(String, f32, f32, f32, f32, f32)> { + let labels = get_coco_labels(); + let mut positions = vec![]; + let mut classes = vec![]; + let mut objectnesses = vec![]; + for i in 0..3549 { + let offset = i * 85; + let objectness = preds[offset + 4]; + + let (class, score) = preds[offset + 5..offset + 85] + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + let class = labels[class].clone(); + let x1 = preds[offset]; + let y1 = preds[offset + 1]; + let x2 = preds[offset + 2]; + let y2 = preds[offset + 3]; + classes.push((class, score)); + positions.push((x1, y1, x2, y2)); + objectnesses.push(objectness); + } + + let locs = calc_loc(&positions); + + let mut result = vec![]; + // filter by objectness + let indices = non_max_suppression(&locs, &objectnesses, 0.5, 0.3); + for bbox in indices { + let (i, (x, y, w, h)) = bbox; + let (class, &score) = &classes[i]; + result.push((class.clone(), score, x, y, w, h)); + } + result +} + +/*----------------------------------------------------------------------------- + Pre processing +--------------------------------------------------------------------------------*/ +fn padding_image(image: ImageBuffer, Vec>) -> ImageBuffer, Vec> { + let (width, height) = image.dimensions(); + let target_size = if width > height { width } else { height }; + let mut new_image = ImageBuffer::new(target_size as u32, target_size as u32); + let x_offset = (target_size as u32 - width) / 2; + let y_offset = (target_size as u32 - height) / 2; + for j in 0..height { + for i in 0..width { + let pixel = image.get_pixel(i, j); + new_image.put_pixel(i + x_offset, j + y_offset, *pixel); + } + } + new_image +} + +fn load_image() -> (Vec, ImageBuffer, Vec>) { + let args: Vec = std::env::args().collect(); + let image_path = if args.len() == 2 { + Path::new(&args[1]).to_path_buf() + } else { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../data/images") + .join("dog.jpg") + }; + + let image_buffer: ImageBuffer, Vec> = image::open(image_path).unwrap().to_rgb8(); + let image_buffer = padding_image(image_buffer); + let image_buffer = imageops::resize(&image_buffer, 416, 416, FilterType::Nearest); + + // convert image to Vec with channel first format + let mut image = vec![0.0; 3 * 416 * 416]; + for j in 0..416 { + for i in 0..416 { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + for c in 0..3 { + image[c * 416 * 416 + j * 416 + i] = channels[c] as f32; + } + } + } + return (image, image_buffer); +} + +fn get_coco_labels() -> Vec { + // Download the ImageNet class labels, matching SqueezeNet's classes. + let labels_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../data/models") + .join("coco-classes.txt"); + let file = BufReader::new(fs::File::open(labels_path).unwrap()); + + file.lines().map(|line| line.unwrap()).collect() +} + +/*----------------------------------------------------------------------------- + Main +--------------------------------------------------------------------------------*/ +// Hardware management +async fn execute_gpu() -> Result, WonnxError> { + let mut input_data = HashMap::new(); + let (image, _) = load_image(); + let images = image.as_slice().try_into().unwrap(); + input_data.insert("images".to_string(), images); + + let model_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../data/models") + .join("yolox_nano.onnx"); + let session = wonnx::Session::from_path(model_path).await?; + let time_pre_compute = Instant::now(); + + info!("Start Compute"); + let result = session.run(&input_data).await?; + let time_post_compute = Instant::now(); + println!( + "time: first_prediction: {:#?}", + time_post_compute - time_pre_compute + ); + + info!("Start Post Processing"); + let time_pre_compute = Instant::now(); + let output = result.get("output").unwrap(); + let output = output.try_into().unwrap(); + let positions = post_process(output); + let time_post_compute = Instant::now(); + println!( + "time: post_processing: {:#?}", + time_post_compute - time_pre_compute + ); + + Ok(positions) +} + +async fn run() { + // Output shape is [1, 3549, 85] + // 85 = 4 (bounding box) + 1 (objectness) + 80 (class probabilities) + let preds = execute_gpu().await.unwrap(); + + let (_, image_buffer) = load_image(); + let mut image_buffer = image_buffer; + for (class, score, x0, y0, x1, y1) in preds.iter() { + println!( + "class: {}, score: {}, x0: {}, y0: {}, x1: {}, y1: {}", + class, *score, *x0, *y0, *x1, *y1 + ); + draw_rect(&mut image_buffer, *x0, *y0, *x1, *y1); + } + image_buffer.save("yolox_predict.jpg").unwrap(); +} + +fn main() { + #[cfg(not(target_arch = "wasm32"))] + { + env_logger::init(); + let time_pre_compute = Instant::now(); + + pollster::block_on(run()); + let time_post_compute = Instant::now(); + println!("time: main: {:#?}", time_post_compute - time_pre_compute); + } + #[cfg(target_arch = "wasm32")] + { + // std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + // console_log::init().expect("could not initialize logger"); + wasm_bindgen_futures::spawn_local(run()); + } +} diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 009f4a42..e95268bd 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -52,6 +52,11 @@ fn get_templates() -> &'static Tera { include_str!("../templates/endomorphism/cast.wgsl"), ) .unwrap(); + tera.add_raw_template( + "endomorphism/slice.wgsl", + include_str!("../templates/endomorphism/slice.wgsl"), + ) + .unwrap(); tera.add_raw_template( "matrix/concat.wgsl", include_str!("../templates/matrix/concat.wgsl"), @@ -470,6 +475,60 @@ pub fn compile( } } + "Slice" => { + // Assume that starts, endsm, axes, and steps are all defined. + if node.get_attribute().len() != 4 { + return Err(CompileError::UnimplementedOp( + "Slice (supports only 4 attributes: starts, ends, axes, and steps)".to_string(), + )); + } + + let starts = node.get_attribute_value::>("starts", Some(vec![0]))?; + if starts.len() != 1 { + return Err(CompileError::UnimplementedOp( + "Slice (supports only that the length of start is 1".to_string(), + )); + } + context.insert("starts", &starts); + + let ends = node.get_attribute_value::>("ends", Some(vec![2147483647]))?; + if ends.len() != 1 { + return Err(CompileError::UnimplementedOp( + "Slice (supports only that the length of end is 1".to_string(), + )); + } + context.insert("ends", &ends); + + let axes = node.get_attribute_value::>("axes", Some(vec![0]))?; + if axes.len() != 1 { + return Err(CompileError::UnimplementedOp( + "Slice (supports only that the length of axes is 1".to_string(), + )); + } + context.insert("axes", &axes); + + let steps = node.get_attribute_value::>("steps", Some(vec![1]))?; + if steps.len() != 1 { + return Err(CompileError::UnimplementedOp( + "Slice (supports only that the length of steps is 1".to_string(), + )); + } + context.insert("steps", &steps); + + let (x_threads, workgroup_size_x) = workgroup_size( + input_lengths[0], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + context.insert("workgroup_size_x", &workgroup_size_x); + + NodeTemplate { + scalar_type: agreed_type(&input_shapes[0..1], output_shapes)?, + template: "endomorphism/slice.wgsl", + threads: (x_threads, 1, 1), + } + } + "Softmax" => { let default_axis = match opset_version { 1..=10 => 1, // https://github.com/onnx/onnx/blob/master/docs/Changelog.md#softmax-1 @@ -804,6 +863,9 @@ pub fn compile( } context.insert("cum_len", &input_cumulative_len); + let axis = node.get_attribute_value("axis", Some(0))? as usize; + context.insert("axis", &axis); + let root = output_lengths[0].sqrt() + 1; let per_dim = ceil(root, 16) + 1; diff --git a/wonnx/src/optimizer.rs b/wonnx/src/optimizer.rs index 7f270762..1983cbe8 100644 --- a/wonnx/src/optimizer.rs +++ b/wonnx/src/optimizer.rs @@ -561,7 +561,7 @@ impl<'model> Optimizer<'model> { op @ ("Clip" | "Pad" | "Split" | "Resize" | "Reshape" | "ReduceMean" | "ReduceSum" | "ReduceMin" | "ReduceMax" | "ReduceSumSquare" | "ReduceLogSumExp" | "ReduceLogSum" | "ReduceL2" | "ReduceL1" - | "ReduceProd") => { + | "ReduceProd" | "Slice") => { if new_inputs.is_empty() { return Err(OptimizerError::NoInputs); } @@ -583,6 +583,7 @@ impl<'model> Optimizer<'model> { "ReduceMin" => REDUCE_OPS_INPUT_NAMES, "ReduceProd" => REDUCE_OPS_INPUT_NAMES, "ReduceSumSquare" => REDUCE_OPS_INPUT_NAMES, + "Slice" => SLICE_INPUT_NAMES, _ => unreachable!(), }; @@ -614,7 +615,8 @@ impl<'model> Optimizer<'model> { ) | ("Pad", "pads") | ("Resize", "scales") - | ("Clip", "min" | "max") => match data_type { + | ("Clip", "min" | "max") + | ("Slice", "starts" | "ends" | "axes" | "steps") => match data_type { ScalarType::F32 => { let value: Vec = if tensor_proto .get_float_data() @@ -640,13 +642,21 @@ impl<'model> Optimizer<'model> { let value = if tensor_proto .get_int64_data() .is_empty() - { + { pod_collect_to_vec(tensor_proto.get_raw_data()) } else { tensor_proto.get_int64_data().to_vec() }; + // If values is larger than i32::MAX, we need to convert it to i32::MAX + let value = value.iter().map(|x| { + if *x > i32::MAX as i64 { + i32::MAX as i64 + } else { + *x + } + }).collect::>(); log::info!( - "transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}", + "transferring input \"{}\" for op \"{}\" to i64 attribute (initializer data type: {:?}): {:?}", attr_name, op, data_type, @@ -702,7 +712,6 @@ impl<'model> Optimizer<'model> { Ok(Arc::new(new_node)) } - _ => Ok(Arc::new(Node { inputs: new_inputs, definition: NodeDefinition::Operator(op_def.clone()), @@ -819,6 +828,7 @@ static RESHAPE_INPUT_NAMES: &[&str] = &["data", "shape"]; static CLIP_INPUT_NAMES: &[&str] = &["input", "min", "max"]; static REDUCE_OPS_INPUT_NAMES: &[&str] = &["input", "axes"]; static PAD_INPUT_NAMES: &[&str] = &["data", "pads", "constant_value"]; +static SLICE_INPUT_NAMES: &[&str] = &["data", "starts", "ends", "axes", "steps"]; /// Generate the output for a ConstantOfShape node pub fn constant_of_shape_output( diff --git a/wonnx/templates/endomorphism/slice.wgsl b/wonnx/templates/endomorphism/slice.wgsl new file mode 100644 index 00000000..5dbc6f31 --- /dev/null +++ b/wonnx/templates/endomorphism/slice.wgsl @@ -0,0 +1,113 @@ +{%- include "structs.wgsl" -%} + +@group(0) @binding(0) +var input_0: Array; // data + +@group(0) @binding(1) +var output_0: Array; + +@compute @workgroup_size({{ workgroup_size_x }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let gidx = global_id.x; + let x0 = {{ i_shape[0][0] }}i; + let x1 = {{ i_shape[0][1] }}i; + + {%- if i_shape[0] | length == 3 -%} + let x2 = {{ i_shape[0][2] }}i; + let x3 = 1i; + {%- elif i_shape[0] | length == 4 -%} + let x2 = {{ i_shape[0][2] }}i; + let x3 = {{ i_shape[0][3] }}i; + {%- else -%} + let x2 = 1i; + let x3 = 1i; + {%- endif -%} + + let a = i32(gidx) / (x1 * x2 * x3); + let b = (i32(gidx) % (x1 * x2 * x3)) / (x2 * x3); + let c = (i32(gidx) % (x2 * x3)) / x3; + let d = i32(gidx) % x3; + + // Assume that starts, ends, axes and steps are only 1 element + let start = {{ starts[0] }}i; + var end = {{ ends[0] }}i; + let ax = {{ axes[0] }}i; + let step = {{ steps[0] }}i; + + // I'm not sure this if statement is moved to the compiler or not + {% if i_shape[0] | length == 4 %} + let end2 = {{ i_shape[0][2] }}i; + let end3 = {{ i_shape[0][3] }}i; + {% elif i_shape[0] | length == 3 %} + let end2 = {{ i_shape[0][2] }}i; + let end3 = 0i; + {% else %} + let end2 = 0i; // unreachable + let end3 = 0i; // unreachable + {% endif %} + + if end == 2147483647i { + if ax == 0 { + end = {{ i_shape[0][0] }}i; + } else if ax == 1 { + end = {{ i_shape[0][1] }}i; + } else if ax == 2 { + end = end2; + } else if ax == 3 { + end = end3; + } + } + + if ax == 0 { + if start <= a && a < end { + let reminder = (a - start) % step; + if reminder == 0 { + let j = (a - start) / step; + let idx = j * x1 * x2 * x3 + + b * x2 * x3 + + c * x3 + + d; + output_0.data[idx] = input_0.data[gidx]; + } + } + } else if ax == 1 { + if start <= b && b < end { + let reminder = (b - start) % step; + let ceil = (end - start + step - 1) / step; + if reminder == 0 { + let j = (b - start) / step; + let idx = a * ceil * x2 * x3 + + j * x2 * x3 + + c * x3 + + d; + output_0.data[idx] = input_0.data[gidx]; + } + } + } else if ax == 2 { + if start <= c && c < end { + let reminder = (c - start) % step; + let ceil = (end - start + step - 1) / step; + if reminder == 0 { + let j = (c - start) / step; + let idx = a * x1 * ceil * x3 + + b * ceil * x3 + + j * x3 + + d; + output_0.data[idx] = input_0.data[gidx]; + } + } + } else if ax == 3 { + if start <= d && d < end { + let reminder = (d - start) % step; + let ceil = (end - start + step - 1) / step; + if reminder == 0 { + let j = (d - start) / step; + let idx = a * x1 * x2 * ceil + + b * x2 * ceil + + c * ceil + + j; + output_0.data[idx] = input_0.data[gidx]; + } + } + } +} diff --git a/wonnx/templates/matrix/concat.wgsl b/wonnx/templates/matrix/concat.wgsl index dfba2016..e5f4a0a8 100644 --- a/wonnx/templates/matrix/concat.wgsl +++ b/wonnx/templates/matrix/concat.wgsl @@ -15,23 +15,77 @@ var output_0: Array; @compute @workgroup_size(16, 16, 1) fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { let gidx = global_id.x; - let gidy = global_id.y; + let gidy = global_id.y; - let x_executions = num_workgroups.x * 16u; + let x_executions = num_workgroups.x * 16u; - let actual_idx = gidx + gidy * x_executions; + let actual_idx = gidx + gidy * x_executions; + let ax = {{ axis }}u; + {% for input in i_lens %} {% if loop.first %} if (actual_idx < {{ i_lens[0] }}u) { - output_0.data[actual_idx] = input_0.data[actual_idx]; + var global_input_idx = actual_idx; + var input_indices: array = array(); + {% set in_shape = i_shape[0] %} + var input_shape: array = array({{ in_shape | join(sep=', ')}}); + let output_shape_length = {{ o_shape[0] | length }}u; + + // calculate input indices + for (var i = 0u; i < output_shape_length; i = i + 1u) { + input_indices[output_shape_length - i - 1] = global_input_idx % input_shape[output_shape_length - i - 1]; + global_input_idx = global_input_idx / input_shape[output_shape_length - i - 1]; + } + + // calculate output index + {% set out_shape = o_shape[0] %} + var output_shape = array({{ out_shape | join(sep=', ')}}); + + var output_idx = 0u; + for (var i = 0u; i < output_shape_length; i = i + 1u) { + output_idx = output_idx * output_shape[i] + input_indices[i]; + } + output_0.data[output_idx] = input_0.data[actual_idx]; } {% else %} - if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) { - output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u]; + if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) { + var global_input_idx = actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u; + var input_indices: array = array(); + {% set shape = i_shape[loop.index0] %} + var input_shape: array = array({{ shape | join(sep=', ')}}); + let output_shape_length = {{ o_shape[0] | length }}u; + + // calculate input indices + var input_idx = global_input_idx; + for (var i = 0u; i < output_shape_length; i = i + 1u) { + input_indices[output_shape_length - i - 1] = input_idx % input_shape[output_shape_length - i - 1]; + input_idx = input_idx / input_shape[output_shape_length - i - 1]; + } + + + // calculate the offset + var offset = 0u; + {% for i in range(end=loop.index0) %} + offset = offset + {{ i_shape[i][axis] }}u; + {% endfor %} + + // add offset to input indices + input_indices[ax] = input_indices[ax] + offset; + + // calculate output index + {% set out_shape = o_shape[0] %} + var output_shape = array({{ out_shape | join(sep=', ')}}); + + var output_idx = 0u; + for (var i = 0u; i < output_shape_length; i = i + 1u) { + output_idx = output_idx * output_shape[i] + input_indices[i]; + } + + output_0.data[output_idx] = input_{{ loop.index0 }}.data[global_input_idx]; } - + {% endif %} {% endfor %} } diff --git a/wonnx/tests/concat.rs b/wonnx/tests/concat.rs index 994b087b..fc7dc32c 100644 --- a/wonnx/tests/concat.rs +++ b/wonnx/tests/concat.rs @@ -1,5 +1,8 @@ -use std::{collections::HashMap, convert::TryInto}; -use wonnx::utils::{graph, model, node, tensor}; +use std::{collections::HashMap, convert::TryInto, vec}; +use wonnx::{ + onnx::AttributeProto, + utils::{attribute, graph, model, node, tensor}, +}; mod common; #[test] @@ -117,3 +120,291 @@ fn test_concat4() { common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result); } + +#[test] +fn test_concat_axis1() { + let xdata = vec![ + vec![vec![1., 2.], vec![3., 4.]].concat(), + vec![vec![5., 6.], vec![7., 8.]].concat(), + ] + .concat(); + let x_input_dims = vec![2, 2, 2]; + let ydata = vec![ + vec![vec![9., 10.], vec![11., 12.]].concat(), + vec![vec![13., 14.], vec![15., 16.]].concat(), + ] + .concat(); + let y_input_dims = vec![2, 2, 2]; + let zdata = vec![ + vec![vec![17., 18.], vec![19., 20.]].concat(), + vec![vec![21., 22.], vec![23., 24.]].concat(), + ] + .concat(); + let z_input_dims = vec![2, 2, 2]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ("Z".into(), zdata.as_slice().into()), + ]); + + let attributes: Vec = vec![attribute("axis", 1)]; + + let model = model(graph( + vec![ + tensor("X", &x_input_dims), + tensor("Y", &y_input_dims), + tensor("Z", &z_input_dims), + ], + vec![tensor("W", &vec![2, 6, 2])], + vec![], + vec![], + vec![node( + vec!["X", "Y", "Z"], + vec!["W"], + "a", + "Concat", + attributes, + )], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let expected_result = vec![ + vec![ + vec![1., 2.], + vec![3., 4.], + vec![9., 10.], + vec![11., 12.], + vec![17., 18.], + vec![19., 20.], + ] + .concat(), + vec![ + vec![5., 6.], + vec![7., 8.], + vec![13., 14.], + vec![15., 16.], + vec![21., 22.], + vec![23., 24.], + ] + .concat(), + ] + .concat(); + + common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result); +} + +#[test] +fn test_concat_axis2() { + let xdata = vec![ + vec![ + vec![1., 2., 3., 4.], + vec![5., 6., 7., 8.], + vec![9., 10., 11., 12.], + ] + .concat(), + vec![ + vec![13., 14., 15., 16.], + vec![17., 18., 19., 20.], + vec![21., 22., 23., 24.], + ] + .concat(), + ] + .concat(); + let x_input_dims = vec![1, 2, 3, 4]; + let ydata = vec![ + vec![ + vec![25., 26., 27., 28.], + vec![29., 30., 31., 32.], + vec![33., 34., 35., 36.], + ] + .concat(), + vec![ + vec![37., 38., 39., 40.], + vec![41., 42., 43., 44.], + vec![45., 46., 47., 48.], + ] + .concat(), + ] + .concat(); + let y_input_dims = vec![1, 2, 3, 4]; + let zdata = vec![ + vec![ + vec![49., 50., 51., 52.], + vec![53., 54., 55., 56.], + vec![57., 58., 59., 60.], + ] + .concat(), + vec![ + vec![61., 62., 63., 64.], + vec![65., 66., 67., 68.], + vec![69., 70., 71., 72.], + ] + .concat(), + ] + .concat(); + let z_input_dims = vec![1, 2, 3, 4]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ("Z".into(), zdata.as_slice().into()), + ]); + + let attributes: Vec = vec![attribute("axis", 2)]; + + let model = model(graph( + vec![ + tensor("X", &x_input_dims), + tensor("Y", &y_input_dims), + tensor("Z", &z_input_dims), + ], + vec![tensor("W", &vec![1, 2, 9, 4])], + vec![], + vec![], + vec![node( + vec!["X", "Y", "Z"], + vec!["W"], + "a", + "Concat", + attributes, + )], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + // 2x6x3x4 + let expected_result = vec![ + vec![ + vec![1., 2., 3., 4.], + vec![5., 6., 7., 8.], + vec![9., 10., 11., 12.], + vec![25., 26., 27., 28.], + vec![29., 30., 31., 32.], + vec![33., 34., 35., 36.], + vec![49., 50., 51., 52.], + vec![53., 54., 55., 56.], + vec![57., 58., 59., 60.], + ] + .concat(), + vec![ + vec![13., 14., 15., 16.], + vec![17., 18., 19., 20.], + vec![21., 22., 23., 24.], + vec![37., 38., 39., 40.], + vec![41., 42., 43., 44.], + vec![45., 46., 47., 48.], + vec![61., 62., 63., 64.], + vec![65., 66., 67., 68.], + vec![69., 70., 71., 72.], + ] + .concat(), + ] + .concat(); + common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result); +} + +#[test] +fn test_concat_axis3() { + let xdata = vec![ + vec![ + vec![1., 2., 3., 4.], + vec![5., 6., 7., 8.], + vec![9., 10., 11., 12.], + ] + .concat(), + vec![ + vec![13., 14., 15., 16.], + vec![17., 18., 19., 20.], + vec![21., 22., 23., 24.], + ] + .concat(), + ] + .concat(); + let x_input_dims = vec![1, 2, 3, 4]; + let ydata = vec![ + vec![ + vec![25., 26., 27., 28.], + vec![29., 30., 31., 32.], + vec![33., 34., 35., 36.], + ] + .concat(), + vec![ + vec![37., 38., 39., 40.], + vec![41., 42., 43., 44.], + vec![45., 46., 47., 48.], + ] + .concat(), + ] + .concat(); + let y_input_dims = vec![1, 2, 3, 4]; + let zdata = vec![ + vec![ + vec![49., 50., 51., 52.], + vec![53., 54., 55., 56.], + vec![57., 58., 59., 60.], + ] + .concat(), + vec![ + vec![61., 62., 63., 64.], + vec![65., 66., 67., 68.], + vec![69., 70., 71., 72.], + ] + .concat(), + ] + .concat(); + let z_input_dims = vec![1, 2, 3, 4]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ("Z".into(), zdata.as_slice().into()), + ]); + + let attributes: Vec = vec![attribute("axis", 3)]; + + let model = model(graph( + vec![ + tensor("X", &x_input_dims), + tensor("Y", &y_input_dims), + tensor("Z", &z_input_dims), + ], + vec![tensor("W", &vec![1, 2, 3, 12])], + vec![], + vec![], + vec![node( + vec!["X", "Y", "Z"], + vec!["W"], + "a", + "Concat", + attributes, + )], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + // 2x6x3x4 + let expected_result = vec![ + vec![ + vec![1., 2., 3., 4., 25., 26., 27., 28., 49., 50., 51., 52.], + vec![5., 6., 7., 8., 29., 30., 31., 32., 53., 54., 55., 56.], + vec![9., 10., 11., 12., 33., 34., 35., 36., 57., 58., 59., 60.], + ] + .concat(), + vec![ + vec![13., 14., 15., 16., 37., 38., 39., 40., 61., 62., 63., 64.], + vec![17., 18., 19., 20., 41., 42., 43., 44., 65., 66., 67., 68.], + vec![21., 22., 23., 24., 45., 46., 47., 48., 69., 70., 71., 72.], + ] + .concat(), + ] + .concat(); + common::assert_eq_vector((&result["W"]).try_into().unwrap(), &expected_result); +} diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs new file mode 100644 index 00000000..543ba5f2 --- /dev/null +++ b/wonnx/tests/slice.rs @@ -0,0 +1,581 @@ +use std::{collections::HashMap, convert::TryInto, vec}; +use wonnx::{ + onnx::AttributeProto, + utils::{ attribute, graph, model, node, tensor, InputTensor}, +}; +mod common; + +fn test_slice( + input: &[f32], + input_shape: &[i64], + starts: &Vec, + ends: &Vec, + axes: &Vec, + steps: &Vec, + output: &[f32], + output_shape: &[i64], +) { + let mut input_data = HashMap::::new(); + input_data.insert("X".to_string(), input.into()); + + let attributes: Vec = vec![ + attribute("starts", starts.clone()), + attribute("ends", ends.clone()), + attribute("axes", axes.clone()), + attribute("steps", steps.clone()), + ]; + + let model = model(graph( + vec![tensor("X", input_shape)], + vec![tensor("Y", output_shape)], + vec![], + vec![], + vec![node( + vec!["X"], + vec!["Y"], + "slice", + "Slice", + attributes, + )], + )); + + let session = match pollster::block_on(wonnx::Session::from_model(model)) { + Ok(session) => session, + Err(e) => { + panic!("Failed to create session: {:?}", e); + } + }; + let result = pollster::block_on(session.run(&input_data)).unwrap(); + log::info!("OUT: {:?}", result["Y"]); + common::assert_eq_vector((&result["Y"]).try_into().unwrap(), output); +} + +#[test] +fn slice_step1() { + let _ = env_logger::builder().is_test(true).try_init(); + + // This test is the most simple case + // Note that each interval is half-open, i.e. it includes the start index but excludes the end index. + // axes == 0 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![0], + &vec![1], + &vec![0], + &vec![1], + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + ); + + // axes == 1 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![0], + &vec![2], + &vec![1], + &vec![1], + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + ); + + // axes == 2 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![0], + &vec![4], + &vec![2], + &vec![1], + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + ); + + // axes == 1 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![0], + &vec![1], + &vec![1], + &vec![1], + &[[1., 2., 3., 4.]].concat(), + &[1, 1, 4], + ); + + // axes == 1 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![1], + &vec![2], + &vec![1], + &vec![1], + &[[5., 6., 7., 8.]].concat(), + &[1, 1, 4], + ); + + // axes == 2 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![0], + &vec![2], + &vec![2], + &vec![1], + &[[1., 2.], [5., 6.]].concat(), + &[1, 2, 2], + ); + + // axes == 2 + test_slice( + &[[1., 2., 3., 4.], [5., 6., 7., 8.]].concat(), + &[1, 2, 4], + &vec![2], + &vec![4], + &vec![2], + &vec![1], + &[[3., 4.], [7., 8.]].concat(), + &[1, 2, 2], + ); +} + +#[test] +fn slice_step2() { + // axes == 0 + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![0], + &vec![1], + &vec![0], + &vec![2], + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![0], + &vec![4], + &vec![1], + &vec![2], + &[ + [ 1., 2., 3., 4.], + [ 9., 10., 11., 12.], + ].concat(), + &[1, 2, 4, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![1], + &vec![4], + &vec![1], + &vec![2], + &[ + [ 5., 6., 7., 8.], + [13., 14., 15., 16.] + ].concat(), + &[1, 2, 4, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![0], + &vec![2], + &vec![1], + &vec![2], + &[ + [ 1., 2., 3., 4.], + ].concat(), + &[1, 1, 4, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![3], + &vec![4], + &vec![1], + &vec![2], + &[ + [13., 14., 15., 16.] + ].concat(), + &[1, 1, 4, 1], + ); + + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![0], + &vec![4], + &vec![2], + &vec![2], + &[ + [ 1., 3.], + [ 5., 7.], + [ 9., 11.], + [13., 15.] + ].concat(), + &[1, 4, 2, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![0], + &vec![2], + &vec![2], + &vec![2], + &[ + [ 1.], + [ 5.], + [ 9.], + [13.] + ].concat(), + &[1, 4, 1, 1], + ); + + + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[1, 4, 4, 1], + &vec![2], + &vec![4], + &vec![2], + &vec![2], + &[ + [ 3.], + [ 7.], + [11.], + [15.] + ].concat(), + &[1, 4, 1, 1], + ); +} + + +#[test] +fn slice_1x3x416x416_step2() { + let input0 = vec![1.0; 1 * 3 * 416 * 416]; + let output0 = vec![1.0; 1 * 3 * 208 * 416]; + #[rustfmt::skip] + test_slice( + &input0, + &[1, 3, 416, 416], + &vec![0], + &vec![i32::MAX as i64], + &vec![2], + &vec![2], + &output0, + &[1, 3, 208, 416], + ); +} + +#[test] +fn slice_1x3x416x416_step2_axes2_start0() { + let mut input2 = vec![]; + for _ in 0..3 { + let mut row = vec![]; + for j in 0..416 { + let mut col = vec![]; + for _ in 0..416 { + if j % 2 == 0 { + col.push(1.0); + } else { + col.push(0.0); + + } + } + row.push(col); + } + input2.push(row); + } + // flatten the input + let input2 = input2.iter().flatten().flatten().copied().collect::>(); + + + let output2 = vec![1.0; 1 * 3 * 208 * 416]; + #[rustfmt::skip] + test_slice( + &input2, + &[1, 3, 416, 416], + &vec![0], + &vec![i32::MAX as i64], + &vec![2], + &vec![2], + &output2, + &[1, 3, 208, 416], + ); +} + + +#[test] +fn slice_1x3x416x416_step2_axes2_start1() { + let mut input2 = vec![]; + for _ in 0..3 { + let mut row = vec![]; + for j in 0..416 { + let mut col = vec![]; + for _ in 0..416 { + if j % 2 == 0 { + col.push(0.0); + } else { + col.push(1.0); + + } + } + row.push(col); + } + input2.push(row); + } + // flatten the input + let input2 = input2.iter().flatten().flatten().copied().collect::>(); + + + let output2 = vec![1.0; 1 * 3 * 208 * 416]; + #[rustfmt::skip] + test_slice( + &input2, + &[1, 3, 416, 416], + &vec![1], + &vec![i32::MAX as i64], + &vec![2], + &vec![2], + &output2, + &[1, 3, 208, 416], + ); +} + +#[test] +fn slice_1x3x8x8_step2_axes2_start1() { + let input2 = vec![ + vec![ + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + ], + vec![ + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + ], + vec![ + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + vec![1., 1., 1., 1., 1., 1., 1., 1.], + vec![0., 0., 0., 0., 0., 0., 0., 0.], + ] + ]; + + // flatten the input + let input2 = input2.iter().flatten().flatten().copied().collect::>(); + + let output2 = vec![0.0; 1 * 3 * 4 * 8]; + #[rustfmt::skip] + test_slice( + &input2, + &[1, 3, 8, 8], + &vec![1], + &vec![i32::MAX as i64], + &vec![2], + &vec![2], + &output2, + &[1, 3, 4, 8], + ); + + +} + +#[test] +fn slice_1x3x8x8_step2_axes3_start1() { + let input2 = vec![ + vec![ + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + ], + vec![ + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + ], + vec![ + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + vec![0., 1., 0., 1., 0., 1., 0., 1.], + ] + ]; + // flatten the input + let input2 = input2.iter().flatten().flatten().copied().collect::>(); + + let output2 = vec![1.0; 1 * 3 * 8 * 4]; + #[rustfmt::skip] + test_slice( + &input2, + &[1, 3, 8, 8], + &vec![1], + &vec![i32::MAX as i64], + &vec![3], + &vec![2], + &output2, + &[1, 3, 8, 4], + ); +} + +#[test] +fn slice_none_axes_and_steps() { + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[4, 4, 1], + &vec![0], + &vec![2], + &vec![0], + &vec![1], + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + ].concat(), + &[2, 4, 1], + ); +} + +#[test] +fn slice_ends_max_i32() { + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[4, 4, 1], + &vec![0], + &vec![i32::MAX as i64], + &vec![0], + &vec![2], + &[ + [ 1., 2., 3., 4.], + [ 9., 10., 11., 12.], + ].concat(), + &[2, 4, 1], + ); + + #[rustfmt::skip] + test_slice( + &[ + [ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.] + ].concat(), + &[4, 4, 1], + &vec![0], + &vec![i32::MAX as i64], + &vec![1], + &vec![2], + &[ + [ 1., 3.], + [ 5., 7.], + [ 9., 11.], + [13., 15.] + ].concat(), + &[4, 2, 1], + ); +}