Skip to content

Commit afb0e52

Browse files
authored
Merge pull request deepmodeling#165 from ice-hanbin/devel
Support to do gradient calculation on a loss function from trajectory
2 parents 90cfca1 + b7d0b1e commit afb0e52

File tree

12 files changed

+2253
-0
lines changed

12 files changed

+2253
-0
lines changed

dmff/difftraj.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from jax import jit, value_and_grad, vmap, grad, vjp, tree_util, custom_jvp
4+
from functools import partial
5+
from .common.nblist import NeighborListFreud
6+
7+
8+
class Loss_Generator:
9+
10+
def __init__(self, f_nout, box, pos0, mass, dt, nsteps, nout, cov_map, rc, efunc):
11+
12+
""" Constructor
13+
14+
Parameters
15+
----------
16+
f_nout: function
17+
Function of state user defined.
18+
box: jnp.ndarray
19+
Box of system, 3*3.
20+
pos0: jnp.ndarray
21+
Initial position used to allcate nblist.
22+
mass: jnp.ndarray
23+
Mass of each atom.
24+
dt: float
25+
Time step in simulation.
26+
nsteps: int
27+
Total steps in simulation.
28+
nout: int
29+
Get state in every nout steps.
30+
cov_map: jnp.ndarray
31+
Cov_map matrix.
32+
rc: float
33+
Cutoff distance in nblist.
34+
efunc: function
35+
Potential energy function.
36+
37+
Examples
38+
----------
39+
40+
"""
41+
42+
self.f_nout = f_nout
43+
self.box = box
44+
mass = jnp.tile(mass.reshape([len(mass), 1]), (1, 3))
45+
self.dt = dt
46+
self.nsteps = nsteps
47+
self.nout = nout
48+
nbl = NeighborListFreud(box, rc, cov_map)
49+
nbl.allocate(pos0)
50+
51+
def return_pairs(pos):
52+
pos = jax.lax.stop_gradient(pos)
53+
nbl.update(pos)
54+
return nbl.pairs
55+
bonds = []
56+
for i in range(len(cov_map)):
57+
bonds.append(jnp.concatenate([jnp.array([i]), jnp.where(cov_map[i] > 0)[0]]))
58+
59+
@jit
60+
def regularize_pos(pos):
61+
cpos = jnp.stack([jnp.sum(pos[bond], axis=0)/len(bond) for bond in bonds])
62+
box_inv = jnp.linalg.inv(box)
63+
spos = cpos.dot(box_inv)
64+
spos -= jnp.floor(spos)
65+
shift = spos.dot(box) - cpos
66+
return pos + shift
67+
68+
self.states_axis = {'pos': 0, 'vel': 0}
69+
# use leap-frog Verlet integration method (v_0.5, x1) -> (v_1.5, x2)
70+
@jit
71+
def vv_step(state, params, pairs):
72+
x0 = state['pos']
73+
v0 = state['vel']
74+
f0 = -grad(efunc, argnums=(0))(x0, box, pairs, params)
75+
a0 = f0 / mass
76+
v1 = v0 + a0 * dt
77+
x1 = x0 + v1 * dt
78+
v1 = v1 - jnp.sum(v1*mass, axis=0)/jnp.sum(mass, axis=0)
79+
x1 = regularize_pos(x1)
80+
return {'pos': x1, 'vel':v1}
81+
82+
self.return_pairs = return_pairs
83+
self.regularize_pos = regularize_pos
84+
self.vv_step = vv_step
85+
86+
return
87+
88+
89+
def ode_fwd(self, state, params):
90+
"""
91+
Run forward to get 'trajectory'
92+
93+
Parameters
94+
----------
95+
state: dict
96+
Initial state, {'pos': jnp.ndarray, 'vel': jnp.ndarray}
97+
params: dict
98+
Forcefield parameters
99+
100+
Returns
101+
----------
102+
state: dict
103+
Final state, {'pos': jnp.ndarray, 'vel': jnp.ndarray}
104+
traj: dict
105+
Save each 'state' in 'trajectory', {'time': jnp.ndarray, 'state': jnp.ndarray}
106+
"""
107+
108+
def fwd(state):
109+
for i in range(self.nout):
110+
pairs = jnp.stack([self.return_pairs(x) for x in state['pos']])
111+
state = vmap(self.vv_step, in_axes=(self.states_axis, None, 0), out_axes=(0))(state, params, pairs)
112+
return state
113+
traj = {}
114+
traj['time'] = jnp.zeros(self.nsteps//self.nout+1)
115+
traj0 = self.f_nout(state)
116+
traj['state'] = jnp.repeat(traj0[jnp.newaxis, ...], (self.nsteps//self.nout+1), axis=0)
117+
for i in range(self.nsteps//self.nout):
118+
state = fwd(state)
119+
traj['time'] = traj['time'].at[i+1].set(self.nout*self.dt*(i+1))
120+
traj['state'] = traj['state'].at[i+1].set(self.f_nout(state))
121+
return state, traj
122+
123+
124+
def _ode_bwd(self, state, params, gradient_traj):
125+
"""
126+
Run backward to get final adjoint_state and gradient
127+
128+
Parameters
129+
----------
130+
state: dict
131+
Final state, {'pos': jnp.ndarray, 'vel': jnp.ndarray}
132+
params: dict
133+
Forcefield parameters
134+
gradient_traj: jnp.ndarray
135+
Derivatives of Loss with respect to 'state' in traj
136+
137+
Returns
138+
----------
139+
adjoint_state: dict
140+
Final adjoint state, {'pos': jnp.ndarray, 'vel': jnp.ndarray}
141+
gradient: dict
142+
Gradient of Loss with respect to params
143+
"""
144+
def batch_vjp(state, params, pairs, adjoint_state):
145+
primals, vv_vjp = vjp(partial(self.vv_step, pairs=pairs), state, params)
146+
(grad_state, grad_params) = vv_vjp(adjoint_state)
147+
return grad_state, grad_params
148+
149+
def bwd(state, adjoint_state, gradient):
150+
for i in range(self.nout):
151+
pairs = jnp.stack([self.return_pairs(x) for x in state['pos']])
152+
state = vmap(self.vv_step, in_axes=(self.states_axis, None, 0), out_axes=(0))(state, params, pairs)
153+
state['vel'] = - state['vel']
154+
state['pos'] = state['pos'] + state['vel']* self.dt
155+
state['pos'] = vmap(self.regularize_pos)(state['pos'])
156+
pairs = jnp.stack([self.return_pairs(x) for x in state['pos']])
157+
(grad_state, grad_params) = vmap(batch_vjp, in_axes=(self.states_axis, None, 0, self.states_axis))(state, params, pairs, adjoint_state)
158+
gradient = tree_util.tree_map(lambda p, u: p + jnp.sum(u, axis=0), gradient, grad_params)
159+
adjoint_state = grad_state
160+
state['pos'] = state['pos'] - state['vel']*self.dt
161+
state['pos'] = vmap(self.regularize_pos)(state['pos'])
162+
state['vel'] = - state['vel']
163+
return state, adjoint_state, gradient
164+
primals, f_vjp = vjp(self.f_nout, state)
165+
adjoint_state = f_vjp(gradient_traj[-1])[0]
166+
gradient = tree_util.tree_map(jnp.zeros_like, params)
167+
# (v_1.5, x2) -> (-v_1.5, x1)
168+
state['pos'] = state['pos'] - state['vel']*self.dt
169+
state['pos'] = vmap(self.regularize_pos)(state['pos'])
170+
state['vel'] = - state['vel']
171+
for i in range(self.nsteps//self.nout):
172+
state, adjoint_state, gradient = bwd(state, adjoint_state, gradient)
173+
primals, f_vjp = vjp(self.f_nout, state)
174+
adjoint_state = {key: adjoint_state[key] + f_vjp(gradient_traj[-(i+2)])[0][key] for key in state}
175+
return adjoint_state, gradient
176+
177+
def generate_Loss(self, L, has_aux=False, metadata=[]):
178+
"""
179+
Generate Loss function
180+
181+
Parameters
182+
----------
183+
L: function
184+
The 'Loss' function user defined, input: traj['state'], output: loss
185+
has_aux: bool
186+
If the L function returns auxiliary data
187+
metadata: []
188+
Record the traj and auxiliary data, {'traj':traj, 'aux_data':aux_data}
189+
190+
Returns:
191+
----------
192+
Loss: function
193+
Loss function
194+
195+
Examples:
196+
----------
197+
"""
198+
@custom_jvp
199+
def Loss(initial_state, params):
200+
"""
201+
This function returns the loss.
202+
203+
Parameters
204+
----------
205+
initial_state: dict
206+
Initial state, {'pos': jnp.ndarray, 'vel': jnp.ndarray}
207+
params: dict
208+
The parameter dictionary.
209+
210+
Returns:
211+
----------
212+
loss: float
213+
Loss
214+
215+
Examples:
216+
----------
217+
"""
218+
final_state, traj = self.ode_fwd(initial_state, params)
219+
if has_aux == True:
220+
loss, aux_data = L(traj['state'])
221+
metadata.append({'traj':traj, 'aux_data':aux_data})
222+
else:
223+
loss = L(traj['state'])
224+
metadata.append({'traj':traj})
225+
return loss
226+
227+
@Loss.defjvp
228+
def _f_jvp(primals, tangents):
229+
x, y = primals
230+
x_dot, y_dot = tangents
231+
final_state, traj = self.ode_fwd(x, y)
232+
if has_aux == True:
233+
(primal_out, aux_data), gradient_traj = value_and_grad(L, has_aux=True)(traj['state'])
234+
metadata.append({'traj':traj, 'aux_data':aux_data})
235+
else:
236+
primal_out, gradient_traj = value_and_grad(L)(traj['state'])
237+
metadata.append({'traj':traj})
238+
adjoint_state, gradient = self._ode_bwd(final_state, y, gradient_traj)
239+
tangent_out = sum(tree_util.tree_leaves(tree_util.tree_map(lambda p, u: jnp.sum(p * u), adjoint_state, x_dot))) + sum(tree_util.tree_leaves(tree_util.tree_map(lambda p, u: jnp.sum(p * u), gradient, y_dot)))
240+
return primal_out, tangent_out
241+
242+
return Loss
243+

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ DOI: 10.1021/acs.jctc.2c01297`
2727
+ [Optimization](./user_guide/4.5Optimization.md)
2828
+ [Mbar Estimator](./user_guide/4.6MBAR.md)
2929
+ [OpenMM Plugin](./user_guide/4.7OpenMMplugin.md)
30+
+ [DiffTraj](./user_guide/4.8DiffTraj.md)
3031
+ [5. Advanced examples](./user_guide/DMFF_example.ipynb)
3132

3233
## Developer Guide

docs/user_guide/4.8DiffTraj.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# DiffTraj
2+
## 1. Theory
3+
DiffTraj provides a support to do gradient calculation on a loss function from trajectory. First a NVE (leap-frog Verlet) simulation is conducted, and then compute the gradient from the trajectory.
4+
### 1.1 NVE simulation
5+
NVE simulation follows the leap-frog Verlet integration method, just like in openMM. The positions and velocities stored in the context are offset from each other by half a time step. In each step, they are updated as follows:
6+
7+
```math
8+
\begin{aligned}
9+
\mathbf{v}_i(t+\Delta t / 2) & =\mathbf{v}_i(t-\Delta t / 2)+\mathbf{f}_i(t) \Delta t / m_i \\
10+
\mathbf{r}_i(t+\Delta t) & =\mathbf{r}_i(t)+\mathbf{v}_i(t+\Delta t / 2) \Delta t
11+
\end{aligned}
12+
```
13+
where $\mathbf{v}_i$ is the velocity of particle i, $\mathbf{r}_i$ is its position, $\mathbf{f}_i$ is the force acting on it which is got from auto-differential calculation on energy, $m_i$ is its mass, and $\Delta t$ is the time step.
14+
### 1.2 Gradient calculation
15+
Naive using auto-differential `jax.grad` of jax on the trajectory may cause the OOM problem, here in DiffTraj, we use the adjoint method to run a reverse calculation utilizing the time reversibility of NVE integrator to accumulate the gradient. The Loss function is,
16+
17+
```math
18+
L\left( \mathbf{z}\left( t_1 \right) \right) =L\left( \mathbf{z}\left( t_0 \right) +\int_{t_0}^{t_1}{f\left( \mathbf{z}\left( t \right) ,\ t,\ \theta \right) dt} \right) =L\left( \text{ODESolve}\left( \mathbf{z}\left( t_0 \right) ,\ f,\ t_0,\ t_1,\ \theta \right) \right)
19+
```
20+
21+
where $L$ is the loss function, $\mathbf{z}\left( t_1 \right)$ is the state(velocity and position) at time $t_1$, $f$ is the integrator, and $\theta$ is the parameter. The gradient calculation starts at final state, the adjoint state is defined as $\mathbf{a}\left( t \right) =\frac{\partial L}{\partial \mathbf{z}\left( t \right)}$, the gradient is calculated using chain rule,
22+
23+
```math
24+
\frac{\partial L}{\partial \mathbf{z}\left( t_0 \right)}=\frac{\partial L}{\partial \mathbf{z}\left( t_1 \right)}\frac{\partial \mathbf{z}\left( t_1 \right)}{\partial \mathbf{z}\left( t_0 \right)}
25+
```
26+
27+
```math
28+
\frac{\partial L}{\partial \theta}=\mathbf{a}\left( t_1 \right) ^T\frac{\partial f\left( \mathbf{z}\left( t_0 \right) ,\ \theta \right)}{\partial \theta}
29+
```
30+
31+
When the Loss function is a function of trajectory, the calculation follows:
32+
33+
[![grad.png](https://i.postimg.cc/pdm6fWX1/grad.png)](https://postimg.cc/nst2Ztmv)
34+
35+
### References
36+
1. [Neural Ordinary Differential Equations](https://doi.org/10.48550/arXiv.1806.07366)
37+
38+
## 2. Function module
39+
40+
Class `Loss_Generator`:
41+
- Set the condition of simulation.
42+
- Contains the leap-frog Verlet integration method.
43+
44+
Function `ode_fwd`:
45+
- Run the NVE simulation.
46+
- Get the trajectory.
47+
48+
Function `generate_Loss`:
49+
- Generate the Loss function.
50+
51+
## 3. How to use it
52+
Here we would tell you how to use Loss_Generator and get gradient, we also give an [example](../../examples/msd_opt/fit-msd.ipynb) to introduce how to use this module.
53+
54+
The module is designed to calculate the scalar loss from trajectory and its gradient w.r.t. both initial state and parameters, we offer two user defined functions `f_nout` and `L`, `f_nout` helps users to save any properties from trajectory and `L` helps users to define any scalar loss from the properties. You need first define the initial conditions,
55+
56+
- Initialization: Create an instance of the `Loss_Generator`, here the user defined `f_nout` is a function of state, the result would be saved.
57+
58+
```python
59+
Generator = Loss_Generator(f_nout, box, init_state['pos'][0], mass, dt, nsteps, nout, cov_map, rc, efunc)
60+
```
61+
You can use the Generator to only do a NVE simulation or do both NVE simulation and gradient calculation.
62+
63+
- Only do a NVE simulation, here the `traj` saves result from `f_nout` at each `nout` steps.
64+
```python
65+
final_state, traj = Generator.ode_fwd(initial_state, params)
66+
```
67+
- Define Loss function and get gradient, here the user defined `L` is a function of `traj` and returns the scalar loss, for example, the input of `L` can be the positions of certain atoms which are saved in `traj` and the output of `L` can be mean squared error of the positions of certain atoms w.r.t that in another trajectory.
68+
```python
69+
Loss = Generator.generate_Loss(L, has_aux=True, metadata=metadata)
70+
v, g = value_and_grad(Loss, argnums=(0, 1))(init_state, params)
71+
```

docs/user_guide/4.modules.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ In this part, you will see 7 modules of DMFF, some of which are newly released i
99
+ [Optimization](./4.5Optimization.md)
1010
+ [Mbar Estimator](./4.6MBAR.md)
1111
+ [OpenMM Plugin](./4.7OpenMMplugin.md)
12+
+ [DiffTraj](./4.8DiffTraj.md)

0 commit comments

Comments
 (0)