Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Multiple GPUs #1495

Draft
wants to merge 33 commits into
base: master
Choose a base branch
from
Draft

Support for Multiple GPUs #1495

wants to merge 33 commits into from

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Dec 25, 2024

I will mostly work on it after my generals. In the mean time, if anyone interested can play with it.

The idea is to,

  • shard the grid points across multiple GPUs (grid.num_nodes has to be divisible by number of GPUs)

  • the things that have to be used by all GPUs have to be replicated to all GPUs (I tried for state vector in the recover method)

  • ideally, we have to shard grid points such that all grid points on a single flux surface are on the same GPU

  • maybe use a different type of grid to make it easily shared across GPUs

  • maybe use pmap too for jvp s ??? So, pmap over vmaps

We won't see any speed improvement for the trust_region_subproblem solvers, because JAX doesn't support distributed linear algebra yet.

Resolves #1071 (but without MPI4JAX)

@YigitElma
Copy link
Collaborator Author

YigitElma commented Dec 25, 2024

Some test scripts,

import nvgpu 
devices = nvgpu.gpu_info()
print([dev["type"] for dev in devices])

from desc import set_device
set_device("gpu", num_device=3)

eq = get("W7-X")
obj = ObjectiveFunction(ForceBalance(eq))
cons = get_fixed_boundary_constraints(eq)
cons = maybe_add_self_consistency(eq, cons)
cons = ObjectiveFunction(cons)
objective = LinearConstraintProjection(objective=obj, constraint=cons)
objective.build()

print(objective.compute_scaled_error(objective.x(eq)).shape)
print(objective.jac_scaled_error(objective.x(eq)).shape)
jac = objective.jac_scaled_error(objective.x(eq))
_, _ = jax.scipy.linalg.qr(jac, mode="economic") 
%timeit _ = objective.compute_scaled_error(objective.x(eq)).block_until_ready()
%timeit _ = objective.jac_scaled_error(objective.x(eq)).block_until_ready()
%timeit _, _ = jax.scipy.linalg.qr(jac, mode="economic") 
jax.debug.visualize_array_sharding(jac)

Copy link

codecov bot commented Dec 26, 2024

Codecov Report

Attention: Patch coverage is 23.25581% with 33 lines in your changes missing coverage. Please review.

Project coverage is 95.69%. Comparing base (5322158) to head (57ab00c).
Report is 148 commits behind head on master.

Files with missing lines Patch % Lines
desc/objectives/getters.py 4.76% 20 Missing ⚠️
desc/backend.py 33.33% 6 Missing ⚠️
desc/objectives/objective_funs.py 28.57% 5 Missing ⚠️
desc/objectives/utils.py 66.66% 1 Missing ⚠️
desc/optimize/_constraint_wrappers.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1495   +/-   ##
=======================================
  Coverage   95.68%   95.69%           
=======================================
  Files         101      100    -1     
  Lines       25604    25612    +8     
=======================================
+ Hits        24500    24509    +9     
+ Misses       1104     1103    -1     
Files with missing lines Coverage Δ
desc/objectives/utils.py 99.32% <66.66%> (-0.68%) ⬇️
desc/optimize/_constraint_wrappers.py 96.96% <66.66%> (-0.27%) ⬇️
desc/objectives/objective_funs.py 94.10% <28.57%> (-0.64%) ⬇️
desc/backend.py 87.63% <33.33%> (-2.82%) ⬇️
desc/objectives/getters.py 76.69% <4.76%> (-18.43%) ⬇️

... and 7 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 26, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +2.97 +/- 5.90     | +1.61e-02 +/- 3.19e-02 |  5.57e-01 +/- 2.6e-02  |  5.41e-01 +/- 1.9e-02  |
 test_equilibrium_init_medres            |     -0.42 +/- 3.61     | -1.81e-02 +/- 1.54e-01 |  4.25e+00 +/- 1.1e-01  |  4.26e+00 +/- 1.1e-01  |
 test_equilibrium_init_highres           |     +1.36 +/- 3.68     | +7.50e-02 +/- 2.03e-01 |  5.59e+00 +/- 8.1e-02  |  5.52e+00 +/- 1.9e-01  |
 test_objective_compile_dshape_current   |     +2.54 +/- 4.99     | +1.05e-01 +/- 2.06e-01 |  4.22e+00 +/- 1.7e-01  |  4.12e+00 +/- 1.1e-01  |
 test_objective_compute_dshape_current   |     +0.15 +/- 4.75     | +8.10e-06 +/- 2.49e-04 |  5.26e-03 +/- 2.2e-04  |  5.25e-03 +/- 1.2e-04  |
 test_objective_jac_dshape_current       |     -0.29 +/- 8.13     | -1.25e-04 +/- 3.51e-03 |  4.31e-02 +/- 2.4e-03  |  4.32e-02 +/- 2.5e-03  |
 test_perturb_2                          |     +2.22 +/- 2.21     | +4.39e-01 +/- 4.36e-01 |  2.02e+01 +/- 3.7e-01  |  1.98e+01 +/- 2.4e-01  |
 test_proximal_freeb_jac                 |     +0.62 +/- 1.99     | +4.59e-02 +/- 1.47e-01 |  7.46e+00 +/- 1.1e-01  |  7.41e+00 +/- 1.0e-01  |
 test_solve_fixed_iter                   |     +1.89 +/- 2.43     | +6.07e-01 +/- 7.80e-01 |  3.28e+01 +/- 5.7e-01  |  3.21e+01 +/- 5.3e-01  |
 test_LinearConstraintProjection_build   |     +1.99 +/- 2.23     | +2.07e-01 +/- 2.31e-01 |  1.06e+01 +/- 1.7e-01  |  1.04e+01 +/- 1.6e-01  |
 test_build_transform_fft_midres         |     -0.65 +/- 3.44     | -4.00e-03 +/- 2.11e-02 |  6.10e-01 +/- 8.3e-03  |  6.14e-01 +/- 1.9e-02  |
 test_build_transform_fft_highres        |     +0.67 +/- 4.50     | +6.52e-03 +/- 4.39e-02 |  9.83e-01 +/- 3.9e-02  |  9.77e-01 +/- 2.1e-02  |
 test_equilibrium_init_lowres            |     +0.40 +/- 2.28     | +1.57e-02 +/- 8.88e-02 |  3.91e+00 +/- 7.1e-02  |  3.89e+00 +/- 5.4e-02  |
 test_objective_compile_atf              |     +1.08 +/- 2.93     | +8.99e-02 +/- 2.44e-01 |  8.44e+00 +/- 2.2e-01  |  8.35e+00 +/- 1.2e-01  |
 test_objective_compute_atf              |     -0.07 +/- 2.89     | -1.07e-05 +/- 4.59e-04 |  1.59e-02 +/- 3.5e-04  |  1.59e-02 +/- 3.0e-04  |
 test_objective_jac_atf                  |     +0.31 +/- 3.25     | +6.15e-03 +/- 6.49e-02 |  2.00e+00 +/- 5.3e-02  |  2.00e+00 +/- 3.8e-02  |
 test_perturb_1                          |     -0.88 +/- 3.18     | -1.34e-01 +/- 4.80e-01 |  1.50e+01 +/- 2.7e-01  |  1.51e+01 +/- 4.0e-01  |
 test_proximal_jac_atf                   |     +1.78 +/- 1.35     | +1.47e-01 +/- 1.11e-01 |  8.43e+00 +/- 8.8e-02  |  8.28e+00 +/- 6.9e-02  |
 test_proximal_freeb_compute             |     +0.71 +/- 1.44     | +1.43e-03 +/- 2.90e-03 |  2.03e-01 +/- 2.2e-03  |  2.01e-01 +/- 1.8e-03  |
 test_solve_fixed_iter_compiled          |     -0.49 +/- 0.94     | -1.01e-01 +/- 1.95e-01 |  2.07e+01 +/- 6.1e-02  |  2.08e+01 +/- 1.8e-01  |

@YigitElma
Copy link
Collaborator Author

YigitElma commented Dec 26, 2024

  • maybe compute_ shouldn't even take Rb_lmn, Zb_lmn etc. as params since they are not used in computation and the derivative is 0. It is just increasing the memory use.
@jax.jit
def dummy_fun(x):
    return x*jnp.ones(N)

@jax.jit
def dummy_fun2(x):
    return x[0]*jnp.ones(N)

N = 55000
# x = jnp.ones(N)
# v = jnp.eye(N)[:,1]
# f, df = jax.jvp(dummy_fun2, (x,), (v,))

x = 1.
v = 1.
f, df = jax.jvp(dummy_fun, (x,), (v,))

The second function sometimes throws OOM just because it cannot form v which is N by N float64.

@dpanici
Copy link
Collaborator

dpanici commented Jan 6, 2025

#763
check for overlap with this one



def set_device(kind="cpu", gpuid=None):
def set_device(kind="cpu", gpuid=None, num_device=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[flake8] <901> reported by reviewdog 🐶
'set_device' is too complex (22)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parallelize across multiple GPUs with MPI4Jax
2 participants