Clarification re: supported data types in jax.linearize
and jax.linear_transpose
#25517
Labels
enhancement
New feature or request
Hi All,
IIUC
jax.linearize
andjax.linear_transpose
do not support taking derivatives with respect to integer values, i.e. there is no forward-mode analogy ofjax.grad(..., allow_int=True)
. In JAX forward-mode autodiff, integers are treated as constants.When transposing the linearised function taking an integer input, the following error is raised
Could we change that error message to something more informative by special-casing integer inputs? I'm happy to do it and open a PR.
This issue came up downstream.
The text was updated successfully, but these errors were encountered: