Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xla memory leak on compiled jax function #1484

Open
unalmis opened this issue Dec 20, 2024 · 14 comments
Open

xla memory leak on compiled jax function #1484

unalmis opened this issue Dec 20, 2024 · 14 comments
Assignees
Labels
AD related to automatic differentation bug Something isn't working optimization Adding or improving optimization methods question Further information is requested

Comments

@unalmis
Copy link
Collaborator

unalmis commented Dec 20, 2024

Paste this into a jupyter cell block and run twice. On the second run once the function is compiled an xla runtime error will be raised.

eq0 = get("HELIOTRON")
eq1 = eq0.copy()
k = 2  # which modes to unfix
print()
print("---------------------------------------")
print(f"Optimizing boundary modes M, N <= {k}")
print("---------------------------------------")
modes_R = np.vstack(
    (
        [0, 0, 0],
        eq1.surface.R_basis.modes[np.max(np.abs(eq1.surface.R_basis.modes), 1) > k, :],
    )
)
modes_Z = eq1.surface.Z_basis.modes[np.max(np.abs(eq1.surface.Z_basis.modes), 1) > k, :]
constraints = (
    FixBoundaryR(eq=eq1, modes=modes_R),
    FixBoundaryZ(eq=eq1, modes=modes_Z),
    FixPressure(eq=eq1),
    FixIota(eq=eq1),
    FixPsi(eq=eq1),
)
grid = LinearGrid(
    rho=np.linspace(0.2, 1, 5), M=eq1.M_grid, N=eq1.N_grid, NFP=eq1.NFP, sym=False
)
objective = ObjectiveFunction(
    (
        EffectiveRipple(
            eq1,
            grid=grid,
            X=16,
            Y=32,
            Y_B=128,
            num_transit=10,
            num_well=30 * 10,
            num_quad=32,
        ),
    )
)
optimizer = Optimizer("proximal-lsq-exact")
(eq1,), _ = optimizer.optimize(
    eq1,
    objective,
    constraints,
    ftol=1e-4,
    xtol=1e-6,
    gtol=1e-6,
    maxiter=1,
    verbose=3,
    options={"initial_trust_ratio": 2e-3},
)
print("Optimization complete!")
---------------------------------------
Optimizing boundary modes M, N <= 2
---------------------------------------

[/home/kaya/Documents/project/DESC/desc/optimize/optimizer.py:466](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/optimizer.py#line=465): UserWarning: No nonlinear constraints detected, ignoring wrapper method proximal.
  warnings.warn(

Building objective: Effective ripple
Precomputing transforms
Timer: Precomputing transforms = 123 ms
Timer: Objective build = 134 ms
Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed pressure
Building objective: fixed iota
Building objective: fixed Psi
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Timer: Objective build = 189 ms
Timer: Linear constraint projection build = 900 ms
Number of parameters: 1617
Number of objectives: 5
Timer: Initializing the optimization = 1.33 sec

Starting optimization
Using method: proximal-lsq-exact
   Iteration     Total nfev        Cost      Cost reduction    Step norm     Optimality   

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[3], line 48
     29 objective = ObjectiveFunction(
     30     (
     31         EffectiveRipple(
   (...)
     45     )
     46 )
     47 optimizer = Optimizer("proximal-lsq-exact")
---> 48 (eq1,), _ = optimizer.optimize(
     49     eq1,
     50     objective,
     51     constraints,
     52     ftol=1e-4,
     53     xtol=1e-6,
     54     gtol=1e-6,
     55     maxiter=1,  # increase maxiter to 50 for a better result
     56     verbose=3,
     57     options={"initial_trust_ratio": 2e-3},
     58 )
     59 print("Optimization complete!")

File [~/Documents/project/DESC/desc/optimize/optimizer.py:308](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/optimizer.py#line=307), in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    304     print("Using method: " + str(self.method))
    306 timer.start("Solution time")
--> 308 result = optimizers[method]["fun"](
    309     objective,
    310     nonlinear_constraint,
    311     x0,
    312     method,
    313     x_scale,
    314     verbose,
    315     stoptol,
    316     options,
    317 )
    319 if isinstance(objective, LinearConstraintProjection):
    320     # remove wrapper to get at underlying objective
    321     result["allx"] = [objective.recover(x) for x in result["allx"]]

File [~/Documents/project/DESC/desc/optimize/_desc_wrappers.py:270](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/_desc_wrappers.py#line=269), in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File [~/Documents/project/DESC/desc/optimize/least_squares.py:256](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/least_squares.py#line=255), in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    254 if verbose > 1:
    255     print_header_nonlinear()
--> 256     print_iteration_nonlinear(
    257         iteration, nfev, cost, actual_reduction, step_norm, g_norm
    258     )
    260 allx = [x]
    261 alltr = [trust_radius]

File [~/Documents/project/DESC/desc/optimize/utils.py:399](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/utils.py#line=398), in print_iteration_nonlinear(iteration, nfev, cost, cost_reduction, step_norm, optimality, constr_violation, *args)
    397     optimality = " " * 15
    398 else:
--> 399     optimality = "{:^15.3e}".format(optimality)
    400 s = "{}{}{}{}{}{}".format(
    401     iteration, nfev, cost, cost_reduction, step_norm, optimality
    402 )
    403 if constr_violation is not None:

File [~/miniconda3/envs/desc-env/lib/python3.10/site-packages/jax/_src/array.py:314](http://localhost:8888/home/kaya/miniconda3/envs/desc-env/lib/python3.10/site-packages/jax/_src/array.py#line=313), in ArrayImpl.__format__(self, format_spec)
    311 def __format__(self, format_spec):
    312   # Simulates behavior of https://github.com/numpy/numpy/pull/9883
    313   if self.ndim == 0:
--> 314     return format(self._value[()], format_spec)
    315   else:
    316     return format(self._value, format_spec)

File ~/miniconda3/envs/desc-env/lib/python3.10/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    330 @wraps(func)
    331 def wrapper(*args, **kwargs):
    332   with TraceAnnotation(name, **decorator_kwargs):
--> 333     return func(*args, **kwargs)
    334   return wrapper

File [~/miniconda3/envs/desc-env/lib/python3.10/site-packages/jax/_src/array.py:613](http://localhost:8888/home/kaya/miniconda3/envs/desc-env/lib/python3.10/site-packages/jax/_src/array.py#line=612), in ArrayImpl._value(self)
    611 if self._npy_value is None:
    612   if self.is_fully_replicated:
--> 613     self._npy_value = self._single_device_array_to_np_array()
    614     self._npy_value.flags.writeable = False
    615     return cast(np.ndarray, self._npy_value)

XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError preparing computation: %sOut of memory allocating 216486693000 bytes.
@unalmis unalmis added AD related to automatic differentation bug Something isn't working labels Dec 20, 2024
@unalmis
Copy link
Collaborator Author

unalmis commented Dec 20, 2024

Issue does not occur when you ADD force balance objective ForceBalance(eq=eq1), which also avoids the warning

[/home/kaya/Documents/project/DESC/desc/optimize/optimizer.py:466](http://localhost:8888/home/kaya/Documents/project/DESC/desc/optimize/optimizer.py#line=465): UserWarning: No nonlinear constraints detected, ignoring wrapper method proximal.
  warnings.warn(

@unalmis unalmis added the optimization Adding or improving optimization methods label Dec 20, 2024
@unalmis unalmis changed the title xla runtime error on compiled jax function xla memory leak on compiled jax function Dec 20, 2024
@f0uriest
Copy link
Member

looks like it might just be an OOM issue? Does it happen if you reduce the resolution a lot? It likely doesn't happen with force balance inclued bc in that case its only taking derivatives wrt the boundary stuff, not the full internal dofs

@f0uriest
Copy link
Member

also which branch was this on

@unalmis
Copy link
Collaborator Author

unalmis commented Dec 20, 2024

This was on #1290

@YigitElma
Copy link
Collaborator

This is different than the previous issue, right? The one that Rory solved by changing how we flatten self.

@YigitElma
Copy link
Collaborator

YigitElma commented Dec 20, 2024

I think this is just OOM issue as Rory said. I get the same error on the first run (WSL uses 12GB of memory). But I don't get the error if I use optimizer = Optimizer("lsq-auglag") instead. I think we use optimizer = Optimizer("lsq-auglag") optimization without the ForceBalance constraint, so maybe just use that? And after optimization solve for equilibrium.

Still, running the same code shouldn't re-compile anything, but maybe this is on the edge and one tiny extra recompilation cause it to OOM. I am not sure.

@unalmis
Copy link
Collaborator Author

unalmis commented Dec 20, 2024

Well the first run is <=5gb of memory on my machine. After compile the error message states that it wanted to allocate > 200gb. That indicates a memory leak to me.

This memory leak does not occur when I ADD force balance.

@YigitElma
Copy link
Collaborator

This memory leak does not occur when I ADD force balance.

To the constraints? Or objective?

@unalmis
Copy link
Collaborator Author

unalmis commented Dec 20, 2024

constraints, didn't try objective

@YigitElma
Copy link
Collaborator

YigitElma commented Dec 20, 2024

Well when you add Force balance to the constraints, the Jacobian size reduces to 5x24, but without it, it is 5x1617. I get the OOM even in the first run with the same message (>200GB), so I am not sure what 200Gb is.

@YigitElma
Copy link
Collaborator

And the warning at the top basically says that you don't have non-linear constraints, so it won't use ProximalProjection. Even though later it says, Starting optimization Using method: proximal-lsq-exact, it actually uses lsq-exact. Worth mentioning for debugging. Adding Force balance is a big change in terms of actual used code.

@YigitElma
Copy link
Collaborator

YigitElma commented Dec 20, 2024

398 else:
--> 399 optimality = "{:^15.3e}".format(optimality)

Also I realized that error occurs here, which is after the jacobian calculation. Could it be a problem with jax.array to np.array issue? @f0uriest It already calculates the Jacbian once but than OOM when you try to print its value.

Same error happens if you add g_norm = np.asarray(g_norm) too.

@YigitElma
Copy link
Collaborator

YigitElma commented Dec 20, 2024

@unalmis Can you try running with these changes here,

if optimality is None or abs(optimality) == np.inf:

    # if optimality is None or abs(optimality) == np.inf:
    #     optimality = " " * 15
    # else:
    #     optimality = "{:^15.3e}".format(optimality)
    optimality = " " * 15

This made me able to run it once, but the second run fails again, same error message but for different line.

@YigitElma
Copy link
Collaborator

The error message,

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[6], line 42
     40 # optimizer = Optimizer("lsq-auglag")
     41 optimizer = Optimizer("proximal-lsq-exact")
---> 42 (eq1,), _ = optimizer.optimize(
     43     eq1,
     44     objective,
     45     constraints,
     46     ftol=1e-4,
     47     xtol=1e-6,
     48     gtol=0,
     49     maxiter=1,
     50     verbose=3,
     51     options={"initial_trust_ratio": 2e-3},
     52 )
     53 print("Optimization complete!")

File /CODES/DESC/desc/optimize/optimizer.py:308, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    304     print("Using method: " + str(self.method))
    306 timer.start("Solution time")
--> 308 result = optimizers[method]["fun"](
    309     objective,
    310     nonlinear_constraint,
    311     x0,
    312     method,
    313     x_scale,
    314     verbose,
    315     stoptol,
    316     options,
    317 )
    319 if isinstance(objective, LinearConstraintProjection):
    320     # remove wrapper to get at underlying objective
    321     result["allx"] = [objective.recover(x) for x in result["allx"]]

File /CODES/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File /CODES/DESC/desc/optimize/least_squares.py:224, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    222 trust_radius = init_tr.get(trust_radius, trust_radius)
    223 trust_radius *= tr_ratio
--> 224 trust_radius = trust_radius if (trust_radius > 0) else 1.0
    226 max_trust_radius = options.pop("max_trust_radius", jnp.inf)
    227 min_trust_radius = options.pop("min_trust_radius", jnp.finfo(x0.dtype).eps)

File ~/anaconda3/envs/desc-env-cpu/lib/python3.12/site-packages/jax/_src/array.py:294, in ArrayImpl.__bool__(self)
    292 def __bool__(self):
    293   core.check_bool_conversion(self)
--> 294   return bool(self._value)

File ~/anaconda3/envs/desc-env-cpu/lib/python3.12/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    330 @wraps(func)
    331 def wrapper(*args, **kwargs):
    332   with TraceAnnotation(name, **decorator_kwargs):
--> 333     return func(*args, **kwargs)
    334   return wrapper

File ~/anaconda3/envs/desc-env-cpu/lib/python3.12/site-packages/jax/_src/array.py:628, in ArrayImpl._value(self)
    626 if self._npy_value is None:
    627   if self.is_fully_replicated:
--> 628     self._npy_value = self._single_device_array_to_np_array()
    629     self._npy_value.flags.writeable = False
    630     return cast(np.ndarray, self._npy_value)

XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError preparing computation: %sOut of memory allocating 216486938512 bytes.

@dpanici dpanici added the question Further information is requested label Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD related to automatic differentation bug Something isn't working optimization Adding or improving optimization methods question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants