Skip to content

Commit

Permalink
∞ bench vmap perf cpu & gpu
Browse files Browse the repository at this point in the history
marmaduke woodman committed Nov 14, 2023
1 parent 76da584 commit 87ad1a5
Showing 2 changed files with 83 additions and 8 deletions.
30 changes: 22 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -119,8 +119,8 @@ pl.figure(figsize=(8,2))
for i, sig_i in enumerate([0.0, 0.2, 0.3, 0.4]):
pl.subplot(1, 4, i + 1)
log_ks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j]
pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()]
pars = pars.reshape((vb.cores, -1, 3))
pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy()
pars = pars.reshape((3, vb.cores))
result = run_batches(pars)
pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)
pl.ylabel('k') if i==0 else (), pl.xlabel('eta')
@@ -129,13 +129,27 @@ pl.show()
pl.savefig('example3.jpg')
```
![](example3.jpg)
On a single GPU, you don't need the `jax.pmap` anymore, because it
is a single compute device, and `jax.vmap` is enough,
```python
run_batches = jax.jit(jax.vmap(run, in_axes=1))
pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy()
result = run_batches(pars)
```

#### Performance notes
If we bump the network size to 164 (e.g. Destrieux atlas from FreeSurfer)
and check efficiency of the Jax code,

- Xeon-2133 uses 88 W to do 5.7 Miter/s = 65 Kiter/W
- Quadro RTX 5000 uses 200 W to do 26 Miter/s = 130 Kiter/W
-

#### Distributed

Tips
- If you are using a system like Dask or Slurm, you can then invoke
that `run_batches` function in a distributed setting as required,
without needing to manage a per core or per node for loop.
- On a single GPU, the `jax.pmap` is not needed to map grid elements to GPU
threads, it should *just work*.. examples forthcoming :D
If you are using a system like Dask or Slurm, you can then invoke
that `run_batches` function in a distributed setting as required,
without needing to manage a per core or per node for loop.

### Simplest neural field

61 changes: 61 additions & 0 deletions vbjax/tests/bench_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
# for no gpu
# os.environ['CUDA_VISIBLE_DEVICES']=''
import time
import jax, jax.numpy as np
import vbjax as vb

def net(x, p):
r, v = x
k, _, mpr_p = p
c = k*r.sum(), k*v.sum()
return vb.mpr_dfun(x, c, mpr_p)

def noise(_, p):
_, sigma, _ = p
return sigma

_, loop = vb.make_sde(0.01, net, noise)
n_nodes = 164
rv0 = vb.randn(2, n_nodes)
zs = vb.randn(1000, *rv0.shape)

def run(pars, mpr_p=vb.mpr_default_theta):
k, sig, eta = pars # explored pars
p = k, sig, mpr_p._replace(eta=eta) # set mpr
xs = loop(rv0, zs, p) # run sim
std = xs[400:, 0].std() # eval metric
return std

run_batches = jax.jit(jax.vmap(run, in_axes=1))


def bench_cpu():
run_batches_cores = jax.pmap(jax.vmap(run, in_axes=1), in_axes=1)

for cores in [12]*15:
for n in [1, 2]: #[2,4,8,16]:
log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:36j]
pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy()
pars = pars.reshape((3, cores, -1))
tic = time.time()
for i in range(50):
result = run_batches_cores(pars)
result.block_until_ready()
toc = time.time()
iter = 50*log_ks.size*zs.shape[0]
print(f'{cores} {n} {iter/1e6/(toc-tic):0.2f} Miter/s')
print()


def bench_gpu():
for n in [32]*20: #[2,4,8,16,32,48,64]:
log_ks, etas = np.mgrid[-9.0:0.0:1j*n, -5.0:-6.0:32j]
pars = np.c_[np.exp(log_ks.ravel()),np.ones(log_ks.size)*0.2, etas.ravel()].T.copy()
tic = time.time()
for i in range(50):
result = run_batches(pars)
result.block_until_ready()
toc = time.time()
iter = 50*log_ks.size*zs.shape[0]
print(f'{n} {iter/1e6/(toc-tic):0.2f} Miter/s')

0 comments on commit 87ad1a5

Please sign in to comment.