diff --git a/examples/fill_in_reduction.rs b/examples/fill_in_reduction.rs index f8c3932d..4b8001f8 100644 --- a/examples/fill_in_reduction.rs +++ b/examples/fill_in_reduction.rs @@ -61,7 +61,7 @@ fn save_gray_image( ) -> ImageResult<()> { let height = image.shape()[0]; let width = image.shape()[1]; - let im: Option, _>> = image.to_slice().map(|slice| { + let im: Option, _>> = image.as_slice().map(|slice| { ImageBuffer::from_raw(width as u32, height as u32, slice) .expect("failed to create image from slice") }); diff --git a/src/sparse/csmat.rs b/src/sparse/csmat.rs index dc650738..3c3b13db 100644 --- a/src/sparse/csmat.rs +++ b/src/sparse/csmat.rs @@ -1951,6 +1951,58 @@ where } } +impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot> + for ArrayBase +where + N: 'a + Copy + Num + Default + std::fmt::Debug, + I: 'a + SpIndex, + IpS: 'a + Deref, + IS: 'a + Deref, + DS: 'a + Deref, + DS2: 'b + ndarray::Data, +{ + type Output = Array; + + fn dot(&self, rhs: &CsMatBase) -> Array { + let rhs_t = rhs.transpose_view(); + let lhs_t = self.t(); + + let rows = rhs_t.rows(); + #[allow(deprecated)] + let cols = lhs_t.cols(); + // when the number of colums is small, it is more efficient + // to perform the product by iterating over the columns of + // the rhs, otherwise iterating by rows can take advantage of + // vectorized axpy. + let rres = match (rhs_t.storage(), cols >= 8) { + (CSR, true) => { + let mut res = Array::zeros((rows, cols)); + prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSR, false) => { + let mut res = Array::zeros((rows, cols).f()); + prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSC, true) => { + let mut res = Array::zeros((rows, cols)); + prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + (CSC, false) => { + let mut res = Array::zeros((rows, cols).f()); + prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut()); + res.reversed_axes() + } + }; + + assert_eq!(self.shape()[0], rres.shape()[0]); + assert_eq!(rhs.cols(), rres.shape()[1]); + rres + } +} + impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot> for CsMatBase where diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 82bcc51f..57bd1c08 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -258,12 +258,15 @@ where /// CSR-dense rowmaj multiplication /// /// Performs better if rhs has a decent number of colums. -pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csr_mulacc_dense_rowmaj<'a, N1, N2, I, Iptr, NOut>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, Iptr: 'a + SpIndex, { @@ -297,12 +300,15 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>( /// CSC-dense rowmaj multiplication /// /// Performs better if rhs has a decent number of colums. -pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csc_mulacc_dense_rowmaj<'a, N1, N2, I, Iptr, NOut>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, Iptr: 'a + SpIndex, { @@ -333,12 +339,15 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>( /// CSC-dense colmaj multiplication /// /// Performs better if rhs has few columns. -pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csc_mulacc_dense_colmaj<'a, N1, N2, I, Iptr, NOut>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, Iptr: 'a + SpIndex, { @@ -370,12 +379,15 @@ pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>( /// CSR-dense colmaj multiplication /// /// Performs better if rhs has few columns. -pub fn csr_mulacc_dense_colmaj<'a, N, I, Iptr>( - lhs: CsMatViewI, - rhs: ArrayView, - mut out: ArrayViewMut<'a, N, Ix2>, +pub fn csr_mulacc_dense_colmaj<'a, N1, N2, I, Iptr, NOut>( + lhs: CsMatViewI, + rhs: ArrayView, + mut out: ArrayViewMut<'a, NOut, Ix2>, ) where - N: 'a + Num + Copy, + N1: 'a + Num + Copy, + N2: 'a + Num + Copy, + NOut: 'a + Num + Copy, + N1: std::ops::Mul, I: 'a + SpIndex, Iptr: 'a + SpIndex, { @@ -407,7 +419,8 @@ pub fn csr_mulacc_dense_colmaj<'a, N, I, Iptr>( #[cfg(test)] mod test { use super::*; - use ndarray::{arr2, Array, ShapeBuilder}; + use ndarray::linalg::Dot; + use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder}; use sparse::csmat::CompressedStorage::{CSC, CSR}; use sparse::{CsMat, CsMatView, CsVec}; use test_data::{ @@ -571,7 +584,7 @@ mod test { #[test] fn mul_csr_dense_rowmaj() { - let a = Array::eye(3); + let a: Array2 = Array::eye(3); let e: CsMat = CsMat::eye(3); let mut res = Array::zeros((3, 3)); super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut()); @@ -663,4 +676,93 @@ mod test { let c = &a * &b; assert_eq!(c, expected_output); } + + // stolen from ndarray - not currently exported. + fn assert_close(a: ArrayView, b: ArrayView) + where + D: Dimension, + { + let diff = (&a - &b).mapv_into(f64::abs); + + let rtol = 1e-7; + let atol = 1e-12; + let crtol = b.mapv(|x| x.abs() * rtol); + let tol = crtol + atol; + let tol_m_diff = &diff - &tol; + let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y)); + println!("diff offset from tolerance level= {:.2e}", maxdiff); + if maxdiff > 0. { + println!("{:.4?}", a); + println!("{:.4?}", b); + panic!("results differ"); + } + } + + #[test] + fn test_sparse_dot_dense() { + let sparse = [ + mat1(), + mat1_csc(), + mat2(), + mat2().transpose_into(), + mat4(), + mat5(), + ]; + let dense = [ + mat_dense1(), + mat_dense1_colmaj(), + mat_dense1().reversed_axes(), + mat_dense2(), + mat_dense2().reversed_axes(), + ]; + + // test sparse.dot(dense) + for s in sparse.iter() { + for d in dense.iter() { + if d.shape()[0] < s.cols() { + continue; + } + + let d = d.slice(s![0..s.cols(), ..]); + + let truth = s.to_dense().dot(&d); + let test = s.dot(&d); + assert_close(test.view(), truth.view()); + } + } + } + + #[test] + fn test_dense_dot_sparse() { + let sparse = [ + mat1(), + mat1_csc(), + mat2(), + mat2().transpose_into(), + mat4(), + mat5(), + ]; + let dense = [ + mat_dense1(), + mat_dense1_colmaj(), + mat_dense1().reversed_axes(), + mat_dense2(), + mat_dense2().reversed_axes(), + ]; + + // test sparse.ldot(dense) + for s in sparse.iter() { + for d in dense.iter() { + if d.shape()[1] < s.rows() { + continue; + } + + let d = d.slice(s![.., 0..s.rows()]); + + let truth = d.dot(&s.to_dense()); + let test = d.dot(s); + assert_close(test.view(), truth.view()); + } + } + } }