From 7faf1da102a85f616c91dbfc4756a3d00cb63023 Mon Sep 17 00:00:00 2001 From: marmaduke woodman Date: Mon, 6 Nov 2023 13:23:09 +0100 Subject: [PATCH] add benchmarks for different spmv kernels --- vbjax/sparse.py | 4 +++- vbjax/tests/test_sparse.py | 49 +++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/vbjax/sparse.py b/vbjax/sparse.py index 33f1757..316e23f 100644 --- a/vbjax/sparse.py +++ b/vbjax/sparse.py @@ -7,7 +7,7 @@ _to_jax = lambda x: jax.dlpack.from_dlpack(x.__dlpack__()) -def make_spmv(A, is_symmetric=False): +def make_spmv(A, is_symmetric=False, use_scipy=False): """ Make a closure for a general sparse matrix-vector multiplication. @@ -17,6 +17,8 @@ def make_spmv(A, is_symmetric=False): Constant sparse matrix. is_symmetric : bool, optional, default False Whether matrix is symmetric. + use_scipy: bool, optional, default False + Use scipy. Returns ------- diff --git a/vbjax/tests/test_sparse.py b/vbjax/tests/test_sparse.py index 15b0a6b..b3123fb 100644 --- a/vbjax/tests/test_sparse.py +++ b/vbjax/tests/test_sparse.py @@ -7,6 +7,7 @@ import jax.experimental.sparse as jsp import scipy.sparse import vbjax +import pytest def _test_spmv(spmv, A, n): @@ -68,7 +69,49 @@ def bench_csr_to_bcoo(): jb2 = f(jx) t1 = time.time() print(f'{name}: {t1 - t0:.3f} s') - -if __name__ == '__main__': - bench_csr_to_bcoo() + +def create_sample_data(n=1000, density_pct=10): + A = scipy.sparse.random(n, n, density=density_pct/100).tocsr() + jx = jax.numpy.r_[:n].astype('f') + return A, jx + + +# some performance testing values +_perf_args = 'n,density_pct,grad,impl,jit' +_perf_values = [(1000, 10), (10_000, 0.02)] + +# we want to test each of the above values with grad on and off +_perf_values = [vals + (flag, impl, jit) + for flag in (True, False) + for impl in 'scipy jaxbcoo'.split(' ') + for jit in (True, False) + for vals in _perf_values] + +@pytest.mark.parametrize(_perf_args, _perf_values) +def test_perf_jbcoo(benchmark, n, density_pct, grad, impl, jit): + A, x = create_sample_data(n=n, density_pct=density_pct) + + if impl == 'scipy': # TODO enum + spmv1 = vbjax.make_spmv(A) + elif impl == 'jaxbcoo': + jA = _csr_to_jax_bcoo(A) + spmv1 = lambda x: jA @ x + else: + raise ValueError(impl) + assert callable(spmv1) + + if grad: + spmv2 = jax.grad(lambda x: jnp.sum(spmv1(x))) + else: + spmv2 = spmv1 + assert callable(spmv2) + + if jit and impl not in ('scipy', ): + spmv3 = jax.jit(spmv2) + spmv3(x) + else: + spmv3 = spmv2 + assert callable(spmv3) + + benchmark(lambda : spmv3(x))