-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using a lax.scan to run the solver #11
base: master
Are you sure you want to change the base?
Conversation
And actually ^^ it's generally not a good idea ^^' but if we want to, we can definitely write a custom CPU op that will dump the simulation in hdf5 from within jitted code, and from within the lax.scan. In this particular instance, I think it would be pretty cool |
So yeah I don't see any drawbacks of using a scan :-) |
Thanks! This was very much how it was done here. Also here for the adjoint. So it's good to know that XLA or JAX has gotten better on this.
I guess you meant nested scan's. We want interpolation between two steps. It looks like odeint is extrapolating from the last step? But interpolation should also be okay with nested scan's. |
In odeint they use a while inside the scan function yes. Would you be ok with an API with an argument which would be the array And then, I think it would be very cool to have the ability to do IO directly from jitted code :-) And I think I know how to do it, but probably that's for a different PR. |
418337c
to
a5329ae
Compare
Let's try switching to scan following the odeint way, once the checkpoint (exactly at a time step, directly copying disp and vel) and snapshot (interpolation between 2 steps) observables are implemented. @Yucheng-Zhang is working on those observables. Yes, it'd be super cool to have a custom IO op ^^ |
|
be9f8a4
to
52ec0b4
Compare
3e1b213
to
c2f0c24
Compare
This draft PR is in response to #9 and presents a prototype implementation of the leap-frog solver that uses a
lax.scan
instead of a for loop in thenbody
function.Here are the results on the baseline default configuration.
Current master
This PR (using scan, and I actually removed all lower level jit)
And here the notebook to reproduce this test (working off my fork):
https://gist.github.com/EiffL/aa6a651141f694ca257fb5ff83e829d6
So I would advocate using
lax.scan
.In this draft implementation, I chose not to output intermediate ptcl and obsvl, exactly like what is done on master, but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)
If you look at the implementation of odeint in jax, you can also have a slightly more complicated logic that exports the state of the system only at some desired pre-defined steps, and not necessarily at all time steps:
https://github.com/google/jax/blob/518fe6656ca2aab66dcfc8cd7866c10f476a17b1/jax/experimental/ode.py#L189
And finally, if you want to save the sims to disk, then nothing prevents you from using the nbody step function directly/manually in a for loop.