Skip to content

Delegate to specialized function on CPU #24584

Answered by dfm
AdrienCorenflos asked this question in Q&A
Discussion options

You must be logged in to vote

It sounds to me like jax.lax.platform_dependent might actually do what you need here. For example, the example from the original post might look like the following:

def func(a):
  return jax.lax.platform_dependent(a, cpu=func_cpu, cuda=func_gpu)

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@dfm
Comment options

dfm Oct 30, 2024
Collaborator

Answer selected by AdrienCorenflos
@jakevdp
Comment options

@AdrienCorenflos
Comment options

@AdrienCorenflos
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants