DiffTraj provides a support to do gradient calculation on a loss function from trajectory. First a NVE (leap-frog Verlet) simulation is conducted, and then compute the gradient from the trajectory.
NVE simulation follows the leap-frog Verlet integration method, just like in openMM. The positions and velocities stored in the context are offset from each other by half a time step. In each step, they are updated as follows:
where
Naive using auto-differential jax.grad
of jax on the trajectory may cause the OOM problem, here in DiffTraj, we use the adjoint method to run a reverse calculation utilizing the time reversibility of NVE integrator to accumulate the gradient. The Loss function is,
where
When the Loss function is a function of trajectory, the calculation follows:
Class Loss_Generator
:
- Set the condition of simulation.
- Contains the leap-frog Verlet integration method.
Function ode_fwd
:
- Run the NVE simulation.
- Get the trajectory.
Function generate_Loss
:
- Generate the Loss function.
Here we would tell you how to use Loss_Generator and get gradient, we also give an example to introduce how to use this module.
The module is designed to calculate the scalar loss from trajectory and its gradient w.r.t. both initial state and parameters, we offer two user defined functions f_nout
and L
, f_nout
helps users to save any properties from trajectory and L
helps users to define any scalar loss from the properties. You need first define the initial conditions,
- Initialization: Create an instance of the
Loss_Generator
, here the user definedf_nout
is a function of state, the result would be saved.
Generator = Loss_Generator(f_nout, box, init_state['pos'][0], mass, dt, nsteps, nout, cov_map, rc, efunc)
You can use the Generator to only do a NVE simulation or do both NVE simulation and gradient calculation.
- Only do a NVE simulation, here the
traj
saves result fromf_nout
at eachnout
steps.
final_state, traj = Generator.ode_fwd(initial_state, params)
- Define Loss function and get gradient, here the user defined
L
is a function oftraj
and returns the scalar loss, for example, the input ofL
can be the positions of certain atoms which are saved intraj
and the output ofL
can be mean squared error of the positions of certain atoms w.r.t that in another trajectory.
Loss = Generator.generate_Loss(L, has_aux=True, metadata=metadata)
v, g = value_and_grad(Loss, argnums=(0, 1))(init_state, params)