Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch for differentiable code with dynamic shapes #1303

Open
unalmis opened this issue Oct 15, 2024 · 4 comments
Open

Patch for differentiable code with dynamic shapes #1303

unalmis opened this issue Oct 15, 2024 · 4 comments
Labels
AD related to automatic differentation speed New feature or request to make the code faster

Comments

@unalmis
Copy link
Collaborator

unalmis commented Oct 15, 2024

A fundamental limitation of our auto diff tool JAX is it is unable to handle dynamic size arrays or jagged tensor operations. Both of these operations are supported by PyTorch, which I reference to better document what these operations are:

  1. Dynamic shapes
  2. Jagged tensor operations

These operations would improve the performance of bounce integrals.

Patch

So, unlike c code, in JAX we can't use jnp.nonzero, which constrains the algorithms we write. Sometimes it is possible to work around this by using a shape independent algorithm. For example, performing least squares regression to a variable number of points through equation 12 instead of 13: least squares example. In other cases, it is not possible to write a shape independent algorithm without sacrificing performance (at least on paper; sometimes JIT and gpus overcome this).

We can implement a patch in desc.backend to work around JAX not being able to do 1 and 2. For example, the code below is a differentiable version of jnp.nonzero

import jax
import jax.numpy as jnp
import numpy as np

@jax.jit
def expensive_function(x):
    # pretend this is some non-trivial operation
    return jnp.dot(x, x) * x

grad_expensive_function = jnp.vectorize(jax.grad(expensive_function))

def nonzero_where(x):
    dummy = -1.0
    idx = np.nonzero(x)[0]
    y = jnp.full(x.shape, fill_value=dummy)
    y = y.at[idx].set(expensive_function(x[idx]))
    return y

def forward(x):
    return nonzero_where(x), x

def backward(x, gradient):
    idx = np.nonzero(x)[0]
    grad = jnp.zeros(x.shape)
    grad = grad.at[idx].set(grad_expensive_function(x[idx]) * gradient[idx])
    return (grad,)

fun = jax.custom_vjp(nonzero_where)
fun.defvjp(forward, backward)

x = jnp.array([1.0, 0.0, 3.0])
y = fun(x)

example_optimization_gradient = jax.grad(lambda x: fun(x).sum())
np.testing.assert_allclose(example_optimization_gradient(x), [3, 0, 27])

where the intention is to have a function in desc.backend so that we can replace this logic

dummy = -1.0
# this computes the expensive function for each element in x
y = jnp.where(x != 0, expensive_function(x), dummy) 

with

# this applies expensive function on nonzero elements only
y = nonzero_where(x, expensive_function, dummy)

np.nonzero has built-in numpy vectorization, so we could rely on that if we need to do multiple calls to expensive_function.

However, we can't use JAX vectorization on nonzero_where because it uses a numpy function. Hence to vectorize, one is limited to python for loops or list comprehension (which uses a c loop but still not ideal) instead of vmap, map, and scan. If the loop size is not large, this may not be that bad. (We already use list comprehension in other parts of code that is called when we compute things).

@unalmis unalmis changed the title Differentiable code with dynamic shapes patch for differentiable code with dynamic shapes Oct 15, 2024
@unalmis unalmis changed the title patch for differentiable code with dynamic shapes Patch for differentiable code with dynamic shapes Oct 15, 2024
@unalmis
Copy link
Collaborator Author

unalmis commented Oct 15, 2024

For bounce integrals the amount of computation is proportional to the number of nonzero elements in an array (e.g. the number of wells). This isn't solved by the num_well parameter since there are a different number of wells for each pitch angle. Right now we do root-finding, interpolation, and quadratures as if each pitch had the maximum number of wells, and the JIT compiler does not optimize this out.
test_bounce1d_checks

@unalmis unalmis added the speed New feature or request to make the code faster label Oct 16, 2024
@YigitElma
Copy link
Collaborator

YigitElma commented Oct 16, 2024

It will probably be slow on GPU for large problems but this can also be considered,

@jax.jit
def expensive_function(x):
    # pretend this is some non-trivial operation
    return jnp.dot(x, x) * x

@jax.jit
def compute_nonzero_only(i, args):
    x, out = args

    def falseFun(args):
        _, out = args
        return out

    def trueFun(args):
        xi, out = args
        out = out.at[i].set(expensive_function(xi))
        return out

    out = cond(x[i] > 0, trueFun, falseFun, (x[i], out))
    return (x, out)

@jax.jit
def fun(x):
    out = jnp.zeros(x.shape)
    _, out = fori_loop(0, x.shape[0], compute_nonzero_only, (x, out))
    return out

x = jnp.array([1.0, 0.0, 3.0])
example_optimization_gradient = jit(jax.grad(lambda x: fun(x).sum()))
np.testing.assert_allclose(example_optimization_gradient(x), [3, 0, 27])

It is jittable and doesn't require custom derivative. You can try it with your problem and make some profiling. But this won't work properly if put in vmap https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html

@dpanici
Copy link
Collaborator

dpanici commented Oct 16, 2024

^ could use similar method to above plus a chunked vmap type vectorization to improve speed

@unalmis unalmis added the AD related to automatic differentation label Oct 19, 2024
@unalmis
Copy link
Collaborator Author

unalmis commented Oct 20, 2024

Some comments now that I analyzed the memory effect of this.

Regarding bounce integral stuff, for Bounce1D the performance profiling I discuss here still holds on my machine. There I mention that both the memory and speed of the computation under JIT on cpu were independent of what I specified for num_well (and num_pitch). The transforms were a bottleneck that was masking the effect.

In Bounce2D implemented #1119 , we don't create high resolution 3D transforms. This now allows us to confirm there is a linear relationship between the upper bound on the number of wells and the speed & memory.

If we want more efficient automatically differentiable bounce integrals, then the options are

options

  1. Implement the nonzero_where patch above.
  2. Use PyTorch.
  3. Something like using a compiled language and hook things through with custom derivatives with JAX ffi.
  4. Simulate jagged arrays with pytrees (i.e. tuples of lists). this is compatible with jax.vmap.

Regarding 2 and 3, now that the infrastructure, api, and testing to do this in DESC is done, it's simple to rewrite in PyTorch; just replace calls to jnp to torch. Likewise, I believe the documentation quality makes it easy to rewrite DESC's bounce integrals in a compiled language. However, figuring out how to interface pytorch or ffi with our current AD stuff would take time and effort.

If there's a cleaner way to do option 1 with our AD stuff while not sacrificing gpu performance that may be better. I have opened a question in JAX's discussions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD related to automatic differentation speed New feature or request to make the code faster
Projects
None yet
Development

No branches or pull requests

3 participants