-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v0.4.11.5 Breaks Pycharm Debugger Console #827
Comments
Ah, probably we just need to add that method and forward the method call on. I'll write a quick PR for that. Thanks for the report! |
Alright, I've opened #828 that should fix this. |
This seems to fix the problem. However, I now get warnings, "UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.". In the case it is warning about, it isn't "wrong". I presume you have added this with the latest update. |
Yup! We've added a warning for this -- not yet released. If you have a use-case for static arrays I'd be curious? Also tagging @lockwo on this. |
One example is that I am using torch2jax (https://github.com/rdyro/torch2jax) to call a complex pytorch function via torch2jax_with_vjp. I wrap the function inside an equinox module, as some of inputs do not change each time the function is called (and its buried deep inside the code tree, so passing them each time would be a pain). Thus its easier to create the module and pass the jax arrays in the init method and store them as static jax arrays. |
I haven't used torch2jax before, so I can't comment on the specifics (I have used the new keras 3 interface and gotten old PyTorch code to interface with equinox, idk how similar that is). But what is unclear to me is what marking them as static here is doing?
I interpret this as meaning the inputs do not change as user parameters, but does this also mean they do not change as a result of optimization (i.e. they are not part of the gradient tree)? If they are not part of the gradient flow, then you could just not mark them as static and they shouldn't be differentiated. If that isn't a concern, and they are just "static" in the sense that they are user inputs that are unchanging, then it doesn't seem essential to mark them as actual static fields (unless I am missing some pytree manipulation or something else is going on). If you are confident this static approach is correct though, you can ignore the warning (after all, that's what warnings are for). |
It is true I could pass the parameters in as a static, but it is quite annoying to then have to pass them down through a hierarchy of classes. I find it easier and neater just to initial the class and store them as a static parameter. Maybe this is bad programming practice, it is just in some of my models i can have a significant number of such things (and because I do research these can be in flux), and that need to go to different parts of the model / losses functions and so it quickly becomes quite messy to keep passing them about. |
I think I get the motivation for having a class, but my question was like why is it necessary to make the array's static fields (as opposed to just member variables)? Is it for gradient reasons? Pytree manipulation reasons? |
There is no gradient or pytree manipulation required, but if I don't set them to static they will be deemed as a parameter that should be optimised. |
In general, if they shouldn't be optimized there's often two cases. 1) they don't actually have a gradient, in which case they won't be optimized even if you "let" them (since the gradients will always be 0/None). 2) they do have a gradient, but you don't want to propagate it, in which case |
My particular setup is tricky. I might be able to use stop gradient if I am careful. |
It appears the v0.4.11.5 introduces an issue with Pycharm Debugging console. Normally one can put a breakpoint and the debug will stop there and you are able to interact in a console e.g. ask shape information, print values, etc. Using 0.4.11.5 (I have tested downgrading to 0.4.11.4 and everything works fine), if you set a breakpoint inside a JIT'ed function inside an equinox Module, when it hits it, if you type anything in the console you now get the following error.
I believe the issue might well be related to the introduction of the _FilteredStderr class in latest version that might be clashing with Pycharm Debugger.
The text was updated successfully, but these errors were encountered: