Skip to content

Commit

Permalink
add nearest neighbour
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Jan 28, 2024
1 parent 8871d57 commit 15d8dbd
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 20 deletions.
7 changes: 5 additions & 2 deletions examples/imgproc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let gray_resize = kornia_rs::resize::resize(
gray.clone(),
kornia_rs::image::ImageSize {
width: 224,
height: 224,
width: 1024,
height: 768,
},
kornia_rs::resize::ResizeOptions {
interpolation: kornia_rs::resize::InterpolationMode::Bilinear,
},
);

Expand Down
106 changes: 88 additions & 18 deletions src/resize.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use crate::image::{Image, ImageSize};
use ndarray::{Array, Array2, Array3, Ix2, Zip};
use ndarray::{stack, Array, Array2, Array3, Ix2, Zip};

fn meshgrid(x: &Array<f32, Ix2>, y: &Array<f32, Ix2>) -> (Array2<f32>, Array2<f32>) {
let nx = x.len_of(ndarray::Axis(1));
let ny = y.len_of(ndarray::Axis(1));
println!("nx: {:?}", nx);
println!("ny: {:?}", ny);
//println!("nx: {:?}", nx);
//println!("ny: {:?}", ny);

println!("x: {:?}", x.shape());
//println!("x: {:?}", x.shape());
let xx = x.broadcast((ny, nx)).unwrap().to_owned();
println!("xx: {:?}", xx);
//println!("xx: {:?}", xx);

println!("y: {:?}", y.shape());
//println!("y: {:?}", y.shape());
let yy = y.broadcast((nx, ny)).unwrap().t().to_owned();
println!("yy: {:?}", yy);
//println!("yy: {:?}", yy);

(xx, yy)
}
Expand All @@ -27,7 +27,7 @@ fn bilinear_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 {
let iv = v.trunc() as usize;
let frac_u = u.fract();
let frac_v = v.fract();
let val00 = image.data[[iv, iu, 0]] as f32;
let val00 = image.data[[iv, iu, c]] as f32;
let val01 = if iu + 1 < width {
image.data[[iv, iu + 1, c]] as f32
} else {
Expand All @@ -50,7 +50,41 @@ fn bilinear_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 {
+ val11 * frac_u * frac_v
}

pub fn resize(image: Image, new_size: ImageSize) -> Image {
fn nearest_neighbor_interpolation(image: Image, u: f32, v: f32, c: usize) -> f32 {
let image_size = image.image_size();
let height = image_size.height;
let width = image_size.width;

let iu = u.round() as usize;
let iv = v.round() as usize;

let iu = iu.clamp(0, width - 1);
let iv = iv.clamp(0, height - 1);

let val = image.data[[iv, iu, c]] as f32;

val

Check warning on line 66 in src/resize.rs

View workflow job for this annotation

GitHub Actions / Clippy Output

returning the result of a `let` binding from a block

warning: returning the result of a `let` binding from a block --> src/resize.rs:66:5 | 64 | let val = image.data[[iv, iu, c]] as f32; | ----------------------------------------- unnecessary `let` binding 65 | 66 | val | ^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#let_and_return = note: `#[warn(clippy::let_and_return)]` on by default help: return the expression directly | 64 ~ 65 | 66 ~ image.data[[iv, iu, c]] as f32 |
}

#[derive(Debug, Clone, Copy)]
pub enum InterpolationMode {
Bilinear,
NearestNeighbor,
}

pub struct ResizeOptions {
pub interpolation: InterpolationMode,
}

impl Default for ResizeOptions {
fn default() -> Self {
ResizeOptions {
interpolation: InterpolationMode::Bilinear,
}
}
}

pub fn resize(image: Image, new_size: ImageSize, optional_args: ResizeOptions) -> Image {
let image_size = image.image_size();

// create the output image
Expand All @@ -68,20 +102,55 @@ pub fn resize(image: Image, new_size: ImageSize) -> Image {
//println!("yy: {:?}", yy);

// TODO: parallelize this
for i in 0..xx.shape()[0] {
for j in 0..xx.shape()[1] {
let x = xx[[i, j]];
let y = yy[[i, j]];
//for i in 0..xx.shape()[0] {
// for j in 0..xx.shape()[1] {
// let x = xx[[i, j]];
// let y = yy[[i, j]];
// //println!("x: {:?}", x);
// //println!("y: {:?}", y);
// //println!("###########3");

// for k in 0..3 {
// //output[[i, j, k]] = image_data[[y as usize, x as usize, k]];
// output[[i, j, k]] = bilinear_interpolation(image.clone(), x, y, k) as u8;
// }
// }
//}

// TODO: benchmark this
let xy = stack![ndarray::Axis(2), xx, yy];

Zip::from(xy.rows())
.and(output.rows_mut())
.par_for_each(|xy, mut out| {
assert_eq!(xy.len(), 2);
let x = xy[0];
let y = xy[1];
//println!("x: {:?}", x);
//println!("y: {:?}", y);
//println!("###########3");

for k in 0..3 {
//println!("out: {:?}", out.shape());
for k in [0, 1, 2].iter() {
//output[[i, j, k]] = image_data[[y as usize, x as usize, k]];
output[[i, j, k]] = bilinear_interpolation(image.clone(), x, y, k) as u8;
//out[*k] = bilinear_interpolation(image.clone(), x, y, *k) as u8;
//out[*k] = nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8;
//out[*k] = match interpolation {
// Interpolation::Bilinear => bilinear_interpolation(image.clone(), x, y, *k) as u8,
// Interpolation::NearestNeighbor => nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8,
//};
match optional_args.interpolation {
InterpolationMode::Bilinear => {
out[*k] = bilinear_interpolation(image.clone(), x, y, *k) as u8
}
InterpolationMode::NearestNeighbor => {
out[*k] = nearest_neighbor_interpolation(image.clone(), x, y, *k) as u8
}
}
}
}
}
//for k in 0..3 {
// out[k] = bilinear_interpolation(image.clone(), x, y, k) as u8;
//}
});

Image { data: output }
}
Expand All @@ -99,6 +168,7 @@ mod tests {
width: 2,
height: 3,
},
super::ResizeOptions::default(),
);
assert_eq!(image_resized.num_channels(), 3);
assert_eq!(image_resized.image_size().width, 2);
Expand Down

0 comments on commit 15d8dbd

Please sign in to comment.