Skip to content

Commit

Permalink
add benchmarks for different spmv kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 6, 2023
1 parent f631dff commit 7faf1da
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
4 changes: 3 additions & 1 deletion vbjax/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_to_jax = lambda x: jax.dlpack.from_dlpack(x.__dlpack__())


def make_spmv(A, is_symmetric=False):
def make_spmv(A, is_symmetric=False, use_scipy=False):
"""
Make a closure for a general sparse matrix-vector multiplication.
Expand All @@ -17,6 +17,8 @@ def make_spmv(A, is_symmetric=False):
Constant sparse matrix.
is_symmetric : bool, optional, default False
Whether matrix is symmetric.
use_scipy: bool, optional, default False
Use scipy.
Returns
-------
Expand Down
49 changes: 46 additions & 3 deletions vbjax/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.experimental.sparse as jsp
import scipy.sparse
import vbjax
import pytest


def _test_spmv(spmv, A, n):
Expand Down Expand Up @@ -68,7 +69,49 @@ def bench_csr_to_bcoo():
jb2 = f(jx)
t1 = time.time()
print(f'{name}: {t1 - t0:.3f} s')


if __name__ == '__main__':
bench_csr_to_bcoo()

def create_sample_data(n=1000, density_pct=10):
A = scipy.sparse.random(n, n, density=density_pct/100).tocsr()
jx = jax.numpy.r_[:n].astype('f')
return A, jx


# some performance testing values
_perf_args = 'n,density_pct,grad,impl,jit'
_perf_values = [(1000, 10), (10_000, 0.02)]

# we want to test each of the above values with grad on and off
_perf_values = [vals + (flag, impl, jit)
for flag in (True, False)
for impl in 'scipy jaxbcoo'.split(' ')
for jit in (True, False)
for vals in _perf_values]

@pytest.mark.parametrize(_perf_args, _perf_values)
def test_perf_jbcoo(benchmark, n, density_pct, grad, impl, jit):
A, x = create_sample_data(n=n, density_pct=density_pct)

if impl == 'scipy': # TODO enum
spmv1 = vbjax.make_spmv(A)
elif impl == 'jaxbcoo':
jA = _csr_to_jax_bcoo(A)
spmv1 = lambda x: jA @ x
else:
raise ValueError(impl)
assert callable(spmv1)

if grad:
spmv2 = jax.grad(lambda x: jnp.sum(spmv1(x)))
else:
spmv2 = spmv1
assert callable(spmv2)

if jit and impl not in ('scipy', ):
spmv3 = jax.jit(spmv2)
spmv3(x)
else:
spmv3 = spmv2
assert callable(spmv3)

benchmark(lambda : spmv3(x))

0 comments on commit 7faf1da

Please sign in to comment.