Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification re: supported data types in jax.linearize and jax.linear_transpose #25517

Open
johannahaffner opened this issue Dec 16, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@johannahaffner
Copy link

Hi All,

IIUC jax.linearize and jax.linear_transpose do not support taking derivatives with respect to integer values, i.e. there is no forward-mode analogy of jax.grad(..., allow_int=True). In JAX forward-mode autodiff, integers are treated as constants.

When transposing the linearised function taking an integer input, the following error is raised

ValueError: linearized function called on tangent values inconsistent with the original primal values: got ShapedArray(int32[], weak_type=True) for primal aval ShapedArray(int32[], weak_type=True)

Could we change that error message to something more informative by special-casing integer inputs? I'm happy to do it and open a PR.

import jax

def fn(x):
    return x**2

y_float = 3.0
y_int = 3

fn_eval_float, lin_fn_float = jax.linearize(fn, 3.0)
fn_eval_int, lin_fn_int = jax.linearize(fn, 3)

grad_float = jax.linear_transpose(lin_fn_float, y_float)(1.0)  # Works fine
grad_int = jax.linear_transpose(lin_fn_int, y_int)(1.0)  # Raises value error

This issue came up downstream.

@johannahaffner johannahaffner added the enhancement New feature or request label Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant