-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Multiple GPUs #1495
base: master
Are you sure you want to change the base?
Conversation
Some test scripts, import nvgpu
devices = nvgpu.gpu_info()
print([dev["type"] for dev in devices])
from desc import set_device
set_device("gpu", num_device=3)
eq = get("W7-X")
obj = ObjectiveFunction(ForceBalance(eq))
cons = get_fixed_boundary_constraints(eq)
cons = maybe_add_self_consistency(eq, cons)
cons = ObjectiveFunction(cons)
objective = LinearConstraintProjection(objective=obj, constraint=cons)
objective.build()
print(objective.compute_scaled_error(objective.x(eq)).shape)
print(objective.jac_scaled_error(objective.x(eq)).shape)
jac = objective.jac_scaled_error(objective.x(eq))
_, _ = jax.scipy.linalg.qr(jac, mode="economic")
%timeit _ = objective.compute_scaled_error(objective.x(eq)).block_until_ready()
%timeit _ = objective.jac_scaled_error(objective.x(eq)).block_until_ready()
%timeit _, _ = jax.scipy.linalg.qr(jac, mode="economic")
jax.debug.visualize_array_sharding(jac) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1495 +/- ##
=======================================
Coverage 95.68% 95.69%
=======================================
Files 101 100 -1
Lines 25604 25612 +8
=======================================
+ Hits 24500 24509 +9
+ Misses 1104 1103 -1
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +2.97 +/- 5.90 | +1.61e-02 +/- 3.19e-02 | 5.57e-01 +/- 2.6e-02 | 5.41e-01 +/- 1.9e-02 |
test_equilibrium_init_medres | -0.42 +/- 3.61 | -1.81e-02 +/- 1.54e-01 | 4.25e+00 +/- 1.1e-01 | 4.26e+00 +/- 1.1e-01 |
test_equilibrium_init_highres | +1.36 +/- 3.68 | +7.50e-02 +/- 2.03e-01 | 5.59e+00 +/- 8.1e-02 | 5.52e+00 +/- 1.9e-01 |
test_objective_compile_dshape_current | +2.54 +/- 4.99 | +1.05e-01 +/- 2.06e-01 | 4.22e+00 +/- 1.7e-01 | 4.12e+00 +/- 1.1e-01 |
test_objective_compute_dshape_current | +0.15 +/- 4.75 | +8.10e-06 +/- 2.49e-04 | 5.26e-03 +/- 2.2e-04 | 5.25e-03 +/- 1.2e-04 |
test_objective_jac_dshape_current | -0.29 +/- 8.13 | -1.25e-04 +/- 3.51e-03 | 4.31e-02 +/- 2.4e-03 | 4.32e-02 +/- 2.5e-03 |
test_perturb_2 | +2.22 +/- 2.21 | +4.39e-01 +/- 4.36e-01 | 2.02e+01 +/- 3.7e-01 | 1.98e+01 +/- 2.4e-01 |
test_proximal_freeb_jac | +0.62 +/- 1.99 | +4.59e-02 +/- 1.47e-01 | 7.46e+00 +/- 1.1e-01 | 7.41e+00 +/- 1.0e-01 |
test_solve_fixed_iter | +1.89 +/- 2.43 | +6.07e-01 +/- 7.80e-01 | 3.28e+01 +/- 5.7e-01 | 3.21e+01 +/- 5.3e-01 |
test_LinearConstraintProjection_build | +1.99 +/- 2.23 | +2.07e-01 +/- 2.31e-01 | 1.06e+01 +/- 1.7e-01 | 1.04e+01 +/- 1.6e-01 |
test_build_transform_fft_midres | -0.65 +/- 3.44 | -4.00e-03 +/- 2.11e-02 | 6.10e-01 +/- 8.3e-03 | 6.14e-01 +/- 1.9e-02 |
test_build_transform_fft_highres | +0.67 +/- 4.50 | +6.52e-03 +/- 4.39e-02 | 9.83e-01 +/- 3.9e-02 | 9.77e-01 +/- 2.1e-02 |
test_equilibrium_init_lowres | +0.40 +/- 2.28 | +1.57e-02 +/- 8.88e-02 | 3.91e+00 +/- 7.1e-02 | 3.89e+00 +/- 5.4e-02 |
test_objective_compile_atf | +1.08 +/- 2.93 | +8.99e-02 +/- 2.44e-01 | 8.44e+00 +/- 2.2e-01 | 8.35e+00 +/- 1.2e-01 |
test_objective_compute_atf | -0.07 +/- 2.89 | -1.07e-05 +/- 4.59e-04 | 1.59e-02 +/- 3.5e-04 | 1.59e-02 +/- 3.0e-04 |
test_objective_jac_atf | +0.31 +/- 3.25 | +6.15e-03 +/- 6.49e-02 | 2.00e+00 +/- 5.3e-02 | 2.00e+00 +/- 3.8e-02 |
test_perturb_1 | -0.88 +/- 3.18 | -1.34e-01 +/- 4.80e-01 | 1.50e+01 +/- 2.7e-01 | 1.51e+01 +/- 4.0e-01 |
test_proximal_jac_atf | +1.78 +/- 1.35 | +1.47e-01 +/- 1.11e-01 | 8.43e+00 +/- 8.8e-02 | 8.28e+00 +/- 6.9e-02 |
test_proximal_freeb_compute | +0.71 +/- 1.44 | +1.43e-03 +/- 2.90e-03 | 2.03e-01 +/- 2.2e-03 | 2.01e-01 +/- 1.8e-03 |
test_solve_fixed_iter_compiled | -0.49 +/- 0.94 | -1.01e-01 +/- 1.95e-01 | 2.07e+01 +/- 6.1e-02 | 2.08e+01 +/- 1.8e-01 | |
@jax.jit
def dummy_fun(x):
return x*jnp.ones(N)
@jax.jit
def dummy_fun2(x):
return x[0]*jnp.ones(N)
N = 55000
# x = jnp.ones(N)
# v = jnp.eye(N)[:,1]
# f, df = jax.jvp(dummy_fun2, (x,), (v,))
x = 1.
v = 1.
f, df = jax.jvp(dummy_fun, (x,), (v,)) The second function sometimes throws OOM just because it cannot form |
#763 |
|
||
|
||
def set_device(kind="cpu", gpuid=None): | ||
def set_device(kind="cpu", gpuid=None, num_device=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[flake8] <901> reported by reviewdog 🐶
'set_device' is too complex (22)
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
I will mostly work on it after my generals. In the mean time, if anyone interested can play with it.
The idea is to,
shard the grid points across multiple GPUs (grid.num_nodes has to be divisible by number of GPUs)
the things that have to be used by all GPUs have to be replicated to all GPUs (I tried for state vector in the recover method)
ideally, we have to shard grid points such that all grid points on a single flux surface are on the same GPU
maybe use a different type of grid to make it easily shared across GPUs
maybe use pmap too for jvp s ??? So, pmap over vmaps
We won't see any speed improvement for the trust_region_subproblem solvers, because
JAX
doesn't support distributed linear algebra yet.Resolves #1071 (but without MPI4JAX)