Delegate to specialized function on CPU #24584
-
Hi, I have a function For the sake of the example, we can assume that I have two implementations Note: I do not care about gradients here. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Unfortunately we don't have a public API for this, though we've discussed in the past the possibility of providing a The mechanism we do have for this is to define a primitive, and register device-specific lowering rules for that primitive. This is a bit involved and requires some non-public APIs, but a simple example in JAX's codebase is the jax/jax/experimental/sparse/linalg.py Lines 592 to 598 in 7c4cc95 Note that it has different lowerings on CPU and on GPU. You can find more discussion of implementing custom primitives at https://jax.readthedocs.io/en/latest/jax-primitives.html |
Beta Was this translation helpful? Give feedback.
It sounds to me like
jax.lax.platform_dependent
might actually do what you need here. For example, the example from the original post might look like the following: