Skip to content

dense.dot(sparse) impl #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fill_in_reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn save_gray_image(
) -> ImageResult<()> {
let height = image.shape()[0];
let width = image.shape()[1];
let im: Option<ImageBuffer<Luma<u8>, _>> = image.to_slice().map(|slice| {
let im: Option<ImageBuffer<Luma<u8>, _>> = image.as_slice().map(|slice| {
ImageBuffer::from_raw(width as u32, height as u32, slice)
.expect("failed to create image from slice")
});
Expand Down
52 changes: 52 additions & 0 deletions src/sparse/csmat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,58 @@ where
}
}

impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
for ArrayBase<DS2, Ix2>
where
N: 'a + Copy + Num + Default + std::fmt::Debug,
I: 'a + SpIndex,
IpS: 'a + Deref<Target = [I]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;

fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
let rhs_t = rhs.transpose_view();
let lhs_t = self.t();

let rows = rhs_t.rows();
#[allow(deprecated)]
let cols = lhs_t.cols();
// when the number of colums is small, it is more efficient
// to perform the product by iterating over the columns of
// the rhs, otherwise iterating by rows can take advantage of
// vectorized axpy.
let rres = match (rhs_t.storage(), cols >= 8) {
(CSR, true) => {
let mut res = Array::zeros((rows, cols));
prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSR, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, true) => {
let mut res = Array::zeros((rows, cols));
prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
};

assert_eq!(self.shape()[0], rres.shape()[0]);
assert_eq!(rhs.cols(), rres.shape()[1]);
rres
}
}

impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
Expand Down
146 changes: 124 additions & 22 deletions src/sparse/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,15 @@ where
/// CSR-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_rowmaj<'a, N1, N2, I, Iptr, NOut>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -297,12 +300,15 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>(
/// CSC-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_rowmaj<'a, N1, N2, I, Iptr, NOut>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -333,12 +339,15 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>(
/// CSC-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_colmaj<'a, N1, N2, I, Iptr, NOut>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -370,12 +379,15 @@ pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>(
/// CSR-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csr_mulacc_dense_colmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_colmaj<'a, N1, N2, I, Iptr, NOut>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -407,7 +419,8 @@ pub fn csr_mulacc_dense_colmaj<'a, N, I, Iptr>(
#[cfg(test)]
mod test {
use super::*;
use ndarray::{arr2, Array, ShapeBuilder};
use ndarray::linalg::Dot;
use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder};
use sparse::csmat::CompressedStorage::{CSC, CSR};
use sparse::{CsMat, CsMatView, CsVec};
use test_data::{
Expand Down Expand Up @@ -571,7 +584,7 @@ mod test {

#[test]
fn mul_csr_dense_rowmaj() {
let a = Array::eye(3);
let a: Array2<f64> = Array::eye(3);
let e: CsMat<f64> = CsMat::eye(3);
let mut res = Array::zeros((3, 3));
super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut());
Expand Down Expand Up @@ -663,4 +676,93 @@ mod test {
let c = &a * &b;
assert_eq!(c, expected_output);
}

// stolen from ndarray - not currently exported.
fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
where
D: Dimension,
{
let diff = (&a - &b).mapv_into(f64::abs);

let rtol = 1e-7;
let atol = 1e-12;
let crtol = b.mapv(|x| x.abs() * rtol);
let tol = crtol + atol;
let tol_m_diff = &diff - &tol;
let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y));
println!("diff offset from tolerance level= {:.2e}", maxdiff);
if maxdiff > 0. {
println!("{:.4?}", a);
println!("{:.4?}", b);
panic!("results differ");
}
}

#[test]
fn test_sparse_dot_dense() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.dot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[0] < s.cols() {
continue;
}

let d = d.slice(s![0..s.cols(), ..]);

let truth = s.to_dense().dot(&d);
let test = s.dot(&d);
assert_close(test.view(), truth.view());
}
}
}

#[test]
fn test_dense_dot_sparse() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.ldot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[1] < s.rows() {
continue;
}

let d = d.slice(s![.., 0..s.rows()]);

let truth = d.dot(&s.to_dense());
let test = d.dot(s);
assert_close(test.view(), truth.view());
}
}
}
}