diff --git a/README.md b/README.md index 7fe6a9e..cbe991d 100644 --- a/README.md +++ b/README.md @@ -119,8 +119,8 @@ pl.figure(figsize=(8,2)) for i, sig_i in enumerate([0.0, 0.2, 0.3, 0.4]): pl.subplot(1, 4, i + 1) log_ks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j] - pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()] - pars = pars.reshape((vb.cores, -1, 3)) + pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy() + pars = pars.reshape((3, vb.cores)) result = run_batches(pars) pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7) pl.ylabel('k') if i==0 else (), pl.xlabel('eta') @@ -129,13 +129,27 @@ pl.show() pl.savefig('example3.jpg') ``` ![](example3.jpg) +On a single GPU, you don't need the `jax.pmap` anymore, because it +is a single compute device, and `jax.vmap` is enough, +```python +run_batches = jax.jit(jax.vmap(run, in_axes=1)) +pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy() +result = run_batches(pars) +``` + +#### Performance notes +If we bump the network size to 164 (e.g. Destrieux atlas from FreeSurfer) +and check efficiency of the Jax code, + +- Xeon-2133 uses 88 W to do 5.7 Miter/s = 65 Kiter/W +- Quadro RTX 5000 uses 200 W to do 26 Miter/s = 130 Kiter/W +- + +#### Distributed -Tips -- If you are using a system like Dask or Slurm, you can then invoke - that `run_batches` function in a distributed setting as required, - without needing to manage a per core or per node for loop. -- On a single GPU, the `jax.pmap` is not needed to map grid elements to GPU - threads, it should *just work*.. examples forthcoming :D +If you are using a system like Dask or Slurm, you can then invoke +that `run_batches` function in a distributed setting as required, +without needing to manage a per core or per node for loop. ### Simplest neural field diff --git a/vbjax/tests/bench_vmap.py b/vbjax/tests/bench_vmap.py new file mode 100644 index 0000000..01172db --- /dev/null +++ b/vbjax/tests/bench_vmap.py @@ -0,0 +1,61 @@ +import os +# for no gpu +# os.environ['CUDA_VISIBLE_DEVICES']='' +import time +import jax, jax.numpy as np +import vbjax as vb + +def net(x, p): + r, v = x + k, _, mpr_p = p + c = k*r.sum(), k*v.sum() + return vb.mpr_dfun(x, c, mpr_p) + +def noise(_, p): + _, sigma, _ = p + return sigma + +_, loop = vb.make_sde(0.01, net, noise) +n_nodes = 164 +rv0 = vb.randn(2, n_nodes) +zs = vb.randn(1000, *rv0.shape) + +def run(pars, mpr_p=vb.mpr_default_theta): + k, sig, eta = pars # explored pars + p = k, sig, mpr_p._replace(eta=eta) # set mpr + xs = loop(rv0, zs, p) # run sim + std = xs[400:, 0].std() # eval metric + return std + +run_batches = jax.jit(jax.vmap(run, in_axes=1)) + + +def bench_cpu(): + run_batches_cores = jax.pmap(jax.vmap(run, in_axes=1), in_axes=1) + + for cores in [12]*15: + for n in [1, 2]: #[2,4,8,16]: + log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:36j] + pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy() + pars = pars.reshape((3, cores, -1)) + tic = time.time() + for i in range(50): + result = run_batches_cores(pars) + result.block_until_ready() + toc = time.time() + iter = 50*log_ks.size*zs.shape[0] + print(f'{cores} {n} {iter/1e6/(toc-tic):0.2f} Miter/s') + print() + + +def bench_gpu(): + for n in [32]*20: #[2,4,8,16,32,48,64]: + log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:32j] + pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy() + tic = time.time() + for i in range(50): + result = run_batches(pars) + result.block_until_ready() + toc = time.time() + iter = 50*log_ks.size*zs.shape[0] + print(f'{n} {iter/1e6/(toc-tic):0.2f} Miter/s')