-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TPU hangs if using random inside a pure_callback call #24720
Comments
Fundamentally, I think the issue here is that you're misusing Does the issue persist if you change your callback funciton to be pure? And as a followup: if your callback is pure, is there any reason to use a callback, or could you just execute the code as part of the normal program flow? |
Another issue with your code: your approach of setting I would suggest changing to one of the JIT compilation strategies mentioned in that FAQ entry. |
Thanks for the quick feedback @jakevdp! I will try to adapt my code, perhaps making the class a pytree is the best way. Let me give you a little more background. I am using Jetstream Pytorch, that has a JIT'ed |
OK I re-wrote the example code using pytrees and it worked as you suggested. I will try to make it work with the base code I am working on, thank you for you answer. |
Precisely – and raising any sort of error from Python is a side-effect that should not be part of a pure callback, and has undefined behavior. |
Description
I can reproduce this on a v5e-litepod8 TPU. It seems that random calls get stuck if called from a pure_callback call. I have not found a workaround so far. Here's a snippet to reproduce the issue:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: