diff --git a/examples/imgproc.rs b/examples/imgproc.rs index 3fb493dc..58f7d150 100644 --- a/examples/imgproc.rs +++ b/examples/imgproc.rs @@ -12,8 +12,11 @@ fn main() -> Result<(), Box> { let gray_resize = kornia_rs::resize::resize( gray.clone(), kornia_rs::image::ImageSize { - width: 224, - height: 224, + width: 1024, + height: 768, + }, + kornia_rs::resize::ResizeOptions { + interpolation: kornia_rs::resize::InterpolationMode::Bilinear, }, ); diff --git a/src/resize.rs b/src/resize.rs index 855d8f3c..da7e2e93 100644 --- a/src/resize.rs +++ b/src/resize.rs @@ -1,19 +1,19 @@ use crate::image::{Image, ImageSize}; -use ndarray::{Array, Array2, Array3, Ix2, Zip}; +use ndarray::{stack, Array, Array2, Array3, Ix2, Zip}; fn meshgrid(x: &Array, y: &Array) -> (Array2, Array2) { let nx = x.len_of(ndarray::Axis(1)); let ny = y.len_of(ndarray::Axis(1)); - println!("nx: {:?}", nx); - println!("ny: {:?}", ny); + //println!("nx: {:?}", nx); + //println!("ny: {:?}", ny); - println!("x: {:?}", x.shape()); + //println!("x: {:?}", x.shape()); let xx = x.broadcast((ny, nx)).unwrap().to_owned(); - println!("xx: {:?}", xx); + //println!("xx: {:?}", xx); - println!("y: {:?}", y.shape()); + //println!("y: {:?}", y.shape()); let yy = y.broadcast((nx, ny)).unwrap().t().to_owned(); - println!("yy: {:?}", yy); + //println!("yy: {:?}", yy); (xx, yy) } @@ -27,7 +27,7 @@ fn bilinear_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 { let iv = v.trunc() as usize; let frac_u = u.fract(); let frac_v = v.fract(); - let val00 = image.data[[iv, iu, 0]] as f32; + let val00 = image.data[[iv, iu, c]] as f32; let val01 = if iu + 1 < width { image.data[[iv, iu + 1, c]] as f32 } else { @@ -50,7 +50,41 @@ fn bilinear_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 { + val11 * frac_u * frac_v } -pub fn resize(image: Image, new_size: ImageSize) -> Image { +fn nearest_neighbor_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 { + let image_size = image.image_size(); + let height = image_size.height; + let width = image_size.width; + + let iu = u.round() as usize; + let iv = v.round() as usize; + + let iu = iu.clamp(0, width - 1); + let iv = iv.clamp(0, height - 1); + + let val = image.data[[iv, iu, c]] as f32; + + val +} + +#[derive(Debug, Clone, Copy)] +pub enum InterpolationMode { + Bilinear, + NearestNeighbor, +} + +pub struct ResizeOptions { + pub interpolation: InterpolationMode, +} + +impl Default for ResizeOptions { + fn default() -> Self { + ResizeOptions { + interpolation: InterpolationMode::Bilinear, + } + } +} + +pub fn resize(image: Image, new_size: ImageSize, optional_args: ResizeOptions) -> Image { let image_size = image.image_size(); // create the output image @@ -68,20 +102,55 @@ pub fn resize(image: Image, new_size: ImageSize) -> Image { //println!("yy: {:?}", yy); // TODO: parallelize this - for i in 0..xx.shape()[0] { - for j in 0..xx.shape()[1] { - let x = xx[[i, j]]; - let y = yy[[i, j]]; + //for i in 0..xx.shape()[0] { + // for j in 0..xx.shape()[1] { + // let x = xx[[i, j]]; + // let y = yy[[i, j]]; + // //println!("x: {:?}", x); + // //println!("y: {:?}", y); + // //println!("###########3"); + + // for k in 0..3 { + // //output[[i, j, k]] = image_data[[y as usize, x as usize, k]]; + // output[[i, j, k]] = bilinear_interpolation(image.clone(), x, y, k) as u8; + // } + // } + //} + + // TODO: benchmark this + let xy = stack![ndarray::Axis(2), xx, yy]; + + Zip::from(xy.rows()) + .and(output.rows_mut()) + .par_for_each(|xy, mut out| { + assert_eq!(xy.len(), 2); + let x = xy[0]; + let y = xy[1]; //println!("x: {:?}", x); //println!("y: {:?}", y); //println!("###########3"); - - for k in 0..3 { + //println!("out: {:?}", out.shape()); + for k in [0, 1, 2].iter() { //output[[i, j, k]] = image_data[[y as usize, x as usize, k]]; - output[[i, j, k]] = bilinear_interpolation(image.clone(), x, y, k) as u8; + //out[*k] = bilinear_interpolation(image.clone(), x, y, *k) as u8; + //out[*k] = nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8; + //out[*k] = match interpolation { + // Interpolation::Bilinear => bilinear_interpolation(image.clone(), x, y, *k) as u8, + // Interpolation::NearestNeighbor => nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8, + //}; + match optional_args.interpolation { + InterpolationMode::Bilinear => { + out[*k] = bilinear_interpolation(image.clone(), x, y, *k) as u8 + } + InterpolationMode::NearestNeighbor => { + out[*k] = nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8 + } + } } - } - } + //for k in 0..3 { + // out[k] = bilinear_interpolation(image.clone(), x, y, k) as u8; + //} + }); Image { data: output } } @@ -99,6 +168,7 @@ mod tests { width: 2, height: 3, }, + super::ResizeOptions::default(), ); assert_eq!(image_resized.num_channels(), 3); assert_eq!(image_resized.image_size().width, 2);