Skip to content

Latest commit



71 lines (54 loc) · 4.25 KB

File metadata and controls

71 lines (54 loc) · 4.25 KB


1. Theory

DiffTraj provides a support to do gradient calculation on a loss function from trajectory. First a NVE (leap-frog Verlet) simulation is conducted, and then compute the gradient from the trajectory.

1.1 NVE simulation

NVE simulation follows the leap-frog Verlet integration method, just like in openMM. The positions and velocities stored in the context are offset from each other by half a time step. In each step, they are updated as follows:

$$\begin{aligned} \mathbf{v}_i(t+\Delta t / 2) & =\mathbf{v}_i(t-\Delta t / 2)+\mathbf{f}_i(t) \Delta t / m_i \\\ \mathbf{r}_i(t+\Delta t) & =\mathbf{r}_i(t)+\mathbf{v}_i(t+\Delta t / 2) \Delta t \end{aligned}$$

where $\mathbf{v}_i$ is the velocity of particle i, $\mathbf{r}_i$ is its position, $\mathbf{f}_i$ is the force acting on it which is got from auto-differential calculation on energy, $m_i$ is its mass, and $\Delta t$ is the time step.

1.2 Gradient calculation

Naive using auto-differential jax.grad of jax on the trajectory may cause the OOM problem, here in DiffTraj, we use the adjoint method to run a reverse calculation utilizing the time reversibility of NVE integrator to accumulate the gradient. The Loss function is,

$$L\left( \mathbf{z}\left( t_1 \right) \right) =L\left( \mathbf{z}\left( t_0 \right) +\int_{t_0}^{t_1}{f\left( \mathbf{z}\left( t \right) ,\ t,\ \theta \right) dt} \right) =L\left( \text{ODESolve}\left( \mathbf{z}\left( t_0 \right) ,\ f,\ t_0,\ t_1,\ \theta \right) \right)$$

where $L$ is the loss function, $\mathbf{z}\left( t_1 \right)$ is the state(velocity and position) at time $t_1$, $f$ is the integrator, and $\theta$ is the parameter. The gradient calculation starts at final state, the adjoint state is defined as $\mathbf{a}\left( t \right) =\frac{\partial L}{\partial \mathbf{z}\left( t \right)}$, the gradient is calculated using chain rule,

$$\frac{\partial L}{\partial \mathbf{z}\left( t_0 \right)}=\frac{\partial L}{\partial \mathbf{z}\left( t_1 \right)}\frac{\partial \mathbf{z}\left( t_1 \right)}{\partial \mathbf{z}\left( t_0 \right)}$$ $$\frac{\partial L}{\partial \theta}=\mathbf{a}\left( t_1 \right) ^T\frac{\partial f\left( \mathbf{z}\left( t_0 \right) ,\ \theta \right)}{\partial \theta}$$

When the Loss function is a function of trajectory, the calculation follows:



  1. Neural Ordinary Differential Equations

2. Function module

Class Loss_Generator:

  • Set the condition of simulation.
  • Contains the leap-frog Verlet integration method.

Function ode_fwd:

  • Run the NVE simulation.
  • Get the trajectory.

Function generate_Loss:

  • Generate the Loss function.

3. How to use it

Here we would tell you how to use Loss_Generator and get gradient, we also give an example to introduce how to use this module.

The module is designed to calculate the scalar loss from trajectory and its gradient w.r.t. both initial state and parameters, we offer two user defined functions f_nout and L, f_nout helps users to save any properties from trajectory and L helps users to define any scalar loss from the properties. You need first define the initial conditions,

  • Initialization: Create an instance of the Loss_Generator, here the user defined f_nout is a function of state, the result would be saved.
Generator = Loss_Generator(f_nout, box, init_state['pos'][0], mass, dt, nsteps, nout, cov_map, rc, efunc)

You can use the Generator to only do a NVE simulation or do both NVE simulation and gradient calculation.

  • Only do a NVE simulation, here the traj saves result from f_nout at each nout steps.
final_state, traj = Generator.ode_fwd(initial_state, params)
  • Define Loss function and get gradient, here the user defined L is a function of traj and returns the scalar loss, for example, the input of L can be the positions of certain atoms which are saved in traj and the output of L can be mean squared error of the positions of certain atoms w.r.t that in another trajectory.
Loss = Generator.generate_Loss(L, has_aux=True, metadata=metadata)
v, g = value_and_grad(Loss, argnums=(0, 1))(init_state, params)