Skip to content

Commit

Permalink
wip on dot interop w/ ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarks authored and vbarrielle committed Jul 11, 2020
1 parent 8031081 commit 896c7f2
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 22 deletions.
51 changes: 51 additions & 0 deletions src/sparse/csmat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,57 @@ where
}
}

impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
for ArrayBase<DS2, Ix2>
where
N: 'a + Copy + Num + Default + std::fmt::Debug,
I: 'a + SpIndex,
IpS: 'a + Deref<Target = [I]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;

fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
let rhs_t = rhs.transpose_view();
let lhs_t = self.t();

let rows = rhs_t.rows();
let cols = lhs_t.ncols();
// when the number of colums is small, it is more efficient
// to perform the product by iterating over the columns of
// the rhs, otherwise iterating by rows can take advantage of
// vectorized axpy.
let rres = match (rhs_t.storage(), cols >= 8) {
(CSR, true) => {
let mut res = Array::zeros((rows, cols));
prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSR, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, true) => {
let mut res = Array::zeros((rows, cols));
prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
};

assert_eq!(self.shape()[0], rres.shape()[0]);
assert_eq!(rhs.cols(), rres.shape()[1]);
rres
}
}

impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
Expand Down
146 changes: 124 additions & 22 deletions src/sparse/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,15 @@ where
/// CSR-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_rowmaj<'a, N1, N2, NOut, I, Iptr>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -297,12 +300,15 @@ pub fn csr_mulacc_dense_rowmaj<'a, N, I, Iptr>(
/// CSC-dense rowmaj multiplication
///
/// Performs better if rhs has a decent number of colums.
pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_rowmaj<'a, N1, N2, NOut, I, Iptr>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -333,12 +339,15 @@ pub fn csc_mulacc_dense_rowmaj<'a, N, I, Iptr>(
/// CSC-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csc_mulacc_dense_colmaj<'a, N1, N2, NOut, I, Iptr>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -370,12 +379,15 @@ pub fn csc_mulacc_dense_colmaj<'a, N, I, Iptr>(
/// CSR-dense colmaj multiplication
///
/// Performs better if rhs has few columns.
pub fn csr_mulacc_dense_colmaj<'a, N, I, Iptr>(
lhs: CsMatViewI<N, I, Iptr>,
rhs: ArrayView<N, Ix2>,
mut out: ArrayViewMut<'a, N, Ix2>,
pub fn csr_mulacc_dense_colmaj<'a, N1, N2, NOut, I, Iptr>(
lhs: CsMatViewI<N1, I, Iptr>,
rhs: ArrayView<N2, Ix2>,
mut out: ArrayViewMut<'a, NOut, Ix2>,
) where
N: 'a + Num + Copy,
N1: 'a + Num + Copy,
N2: 'a + Num + Copy,
NOut: 'a + Num + Copy,
N1: std::ops::Mul<N2, Output = NOut>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
{
Expand Down Expand Up @@ -409,12 +421,13 @@ mod test {
use super::*;
use crate::sparse::csmat::CompressedStorage::{CSC, CSR};
use crate::sparse::{CsMat, CsMatView, CsVec};
use ndarray::linalg::Dot;
use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder};
use crate::test_data::{
mat1, mat1_csc, mat1_csc_matprod_mat4, mat1_matprod_mat2,
mat1_self_matprod, mat2, mat4, mat5, mat_dense1, mat_dense1_colmaj,
mat_dense2,
};
use ndarray::{arr2, Array, ShapeBuilder};

#[test]
fn test_csvec_dot_by_binary_search() {
Expand Down Expand Up @@ -571,7 +584,7 @@ mod test {

#[test]
fn mul_csr_dense_rowmaj() {
let a = Array::eye(3);
let a: Array2<f64> = Array::eye(3);
let e: CsMat<f64> = CsMat::eye(3);
let mut res = Array::zeros((3, 3));
super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut());
Expand Down Expand Up @@ -663,4 +676,93 @@ mod test {
let c = &a * &b;
assert_eq!(c, expected_output);
}

// stolen from ndarray - not currently exported.
fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
where
D: Dimension,
{
let diff = (&a - &b).mapv_into(f64::abs);

let rtol = 1e-7;
let atol = 1e-12;
let crtol = b.mapv(|x| x.abs() * rtol);
let tol = crtol + atol;
let tol_m_diff = &diff - &tol;
let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y));
println!("diff offset from tolerance level= {:.2e}", maxdiff);
if maxdiff > 0. {
println!("{:.4?}", a);
println!("{:.4?}", b);
panic!("results differ");
}
}

#[test]
fn test_sparse_dot_dense() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.dot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[0] < s.cols() {
continue;
}

let d = d.slice(s![0..s.cols(), ..]);

let truth = s.to_dense().dot(&d);
let test = s.dot(&d);
assert_close(test.view(), truth.view());
}
}
}

#[test]
fn test_dense_dot_sparse() {
let sparse = [
mat1(),
mat1_csc(),
mat2(),
mat2().transpose_into(),
mat4(),
mat5(),
];
let dense = [
mat_dense1(),
mat_dense1_colmaj(),
mat_dense1().reversed_axes(),
mat_dense2(),
mat_dense2().reversed_axes(),
];

// test sparse.ldot(dense)
for s in sparse.iter() {
for d in dense.iter() {
if d.shape()[1] < s.rows() {
continue;
}

let d = d.slice(s![.., 0..s.rows()]);

let truth = d.dot(&s.to_dense());
let test = d.dot(s);
assert_close(test.view(), truth.view());
}
}
}
}

0 comments on commit 896c7f2

Please sign in to comment.