|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +from jax import jit, value_and_grad, vmap, grad, vjp, tree_util, custom_jvp |
| 4 | +from functools import partial |
| 5 | +from .common.nblist import NeighborListFreud |
| 6 | + |
| 7 | + |
| 8 | +class Loss_Generator: |
| 9 | + |
| 10 | + def __init__(self, f_nout, box, pos0, mass, dt, nsteps, nout, cov_map, rc, efunc): |
| 11 | + |
| 12 | + """ Constructor |
| 13 | +
|
| 14 | + Parameters |
| 15 | + ---------- |
| 16 | + f_nout: function |
| 17 | + Function of state user defined. |
| 18 | + box: jnp.ndarray |
| 19 | + Box of system, 3*3. |
| 20 | + pos0: jnp.ndarray |
| 21 | + Initial position used to allcate nblist. |
| 22 | + mass: jnp.ndarray |
| 23 | + Mass of each atom. |
| 24 | + dt: float |
| 25 | + Time step in simulation. |
| 26 | + nsteps: int |
| 27 | + Total steps in simulation. |
| 28 | + nout: int |
| 29 | + Get state in every nout steps. |
| 30 | + cov_map: jnp.ndarray |
| 31 | + Cov_map matrix. |
| 32 | + rc: float |
| 33 | + Cutoff distance in nblist. |
| 34 | + efunc: function |
| 35 | + Potential energy function. |
| 36 | +
|
| 37 | + Examples |
| 38 | + ---------- |
| 39 | +
|
| 40 | + """ |
| 41 | + |
| 42 | + self.f_nout = f_nout |
| 43 | + self.box = box |
| 44 | + mass = jnp.tile(mass.reshape([len(mass), 1]), (1, 3)) |
| 45 | + self.dt = dt |
| 46 | + self.nsteps = nsteps |
| 47 | + self.nout = nout |
| 48 | + nbl = NeighborListFreud(box, rc, cov_map) |
| 49 | + nbl.allocate(pos0) |
| 50 | + |
| 51 | + def return_pairs(pos): |
| 52 | + pos = jax.lax.stop_gradient(pos) |
| 53 | + nbl.update(pos) |
| 54 | + return nbl.pairs |
| 55 | + bonds = [] |
| 56 | + for i in range(len(cov_map)): |
| 57 | + bonds.append(jnp.concatenate([jnp.array([i]), jnp.where(cov_map[i] > 0)[0]])) |
| 58 | + |
| 59 | + @jit |
| 60 | + def regularize_pos(pos): |
| 61 | + cpos = jnp.stack([jnp.sum(pos[bond], axis=0)/len(bond) for bond in bonds]) |
| 62 | + box_inv = jnp.linalg.inv(box) |
| 63 | + spos = cpos.dot(box_inv) |
| 64 | + spos -= jnp.floor(spos) |
| 65 | + shift = spos.dot(box) - cpos |
| 66 | + return pos + shift |
| 67 | + |
| 68 | + self.states_axis = {'pos': 0, 'vel': 0} |
| 69 | + # use leap-frog Verlet integration method (v_0.5, x1) -> (v_1.5, x2) |
| 70 | + @jit |
| 71 | + def vv_step(state, params, pairs): |
| 72 | + x0 = state['pos'] |
| 73 | + v0 = state['vel'] |
| 74 | + f0 = -grad(efunc, argnums=(0))(x0, box, pairs, params) |
| 75 | + a0 = f0 / mass |
| 76 | + v1 = v0 + a0 * dt |
| 77 | + x1 = x0 + v1 * dt |
| 78 | + v1 = v1 - jnp.sum(v1*mass, axis=0)/jnp.sum(mass, axis=0) |
| 79 | + x1 = regularize_pos(x1) |
| 80 | + return {'pos': x1, 'vel':v1} |
| 81 | + |
| 82 | + self.return_pairs = return_pairs |
| 83 | + self.regularize_pos = regularize_pos |
| 84 | + self.vv_step = vv_step |
| 85 | + |
| 86 | + return |
| 87 | + |
| 88 | + |
| 89 | + def ode_fwd(self, state, params): |
| 90 | + """ |
| 91 | + Run forward to get 'trajectory' |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + state: dict |
| 96 | + Initial state, {'pos': jnp.ndarray, 'vel': jnp.ndarray} |
| 97 | + params: dict |
| 98 | + Forcefield parameters |
| 99 | +
|
| 100 | + Returns |
| 101 | + ---------- |
| 102 | + state: dict |
| 103 | + Final state, {'pos': jnp.ndarray, 'vel': jnp.ndarray} |
| 104 | + traj: dict |
| 105 | + Save each 'state' in 'trajectory', {'time': jnp.ndarray, 'state': jnp.ndarray} |
| 106 | + """ |
| 107 | + |
| 108 | + def fwd(state): |
| 109 | + for i in range(self.nout): |
| 110 | + pairs = jnp.stack([self.return_pairs(x) for x in state['pos']]) |
| 111 | + state = vmap(self.vv_step, in_axes=(self.states_axis, None, 0), out_axes=(0))(state, params, pairs) |
| 112 | + return state |
| 113 | + traj = {} |
| 114 | + traj['time'] = jnp.zeros(self.nsteps//self.nout+1) |
| 115 | + traj0 = self.f_nout(state) |
| 116 | + traj['state'] = jnp.repeat(traj0[jnp.newaxis, ...], (self.nsteps//self.nout+1), axis=0) |
| 117 | + for i in range(self.nsteps//self.nout): |
| 118 | + state = fwd(state) |
| 119 | + traj['time'] = traj['time'].at[i+1].set(self.nout*self.dt*(i+1)) |
| 120 | + traj['state'] = traj['state'].at[i+1].set(self.f_nout(state)) |
| 121 | + return state, traj |
| 122 | + |
| 123 | + |
| 124 | + def _ode_bwd(self, state, params, gradient_traj): |
| 125 | + """ |
| 126 | + Run backward to get final adjoint_state and gradient |
| 127 | +
|
| 128 | + Parameters |
| 129 | + ---------- |
| 130 | + state: dict |
| 131 | + Final state, {'pos': jnp.ndarray, 'vel': jnp.ndarray} |
| 132 | + params: dict |
| 133 | + Forcefield parameters |
| 134 | + gradient_traj: jnp.ndarray |
| 135 | + Derivatives of Loss with respect to 'state' in traj |
| 136 | +
|
| 137 | + Returns |
| 138 | + ---------- |
| 139 | + adjoint_state: dict |
| 140 | + Final adjoint state, {'pos': jnp.ndarray, 'vel': jnp.ndarray} |
| 141 | + gradient: dict |
| 142 | + Gradient of Loss with respect to params |
| 143 | + """ |
| 144 | + def batch_vjp(state, params, pairs, adjoint_state): |
| 145 | + primals, vv_vjp = vjp(partial(self.vv_step, pairs=pairs), state, params) |
| 146 | + (grad_state, grad_params) = vv_vjp(adjoint_state) |
| 147 | + return grad_state, grad_params |
| 148 | + |
| 149 | + def bwd(state, adjoint_state, gradient): |
| 150 | + for i in range(self.nout): |
| 151 | + pairs = jnp.stack([self.return_pairs(x) for x in state['pos']]) |
| 152 | + state = vmap(self.vv_step, in_axes=(self.states_axis, None, 0), out_axes=(0))(state, params, pairs) |
| 153 | + state['vel'] = - state['vel'] |
| 154 | + state['pos'] = state['pos'] + state['vel']* self.dt |
| 155 | + state['pos'] = vmap(self.regularize_pos)(state['pos']) |
| 156 | + pairs = jnp.stack([self.return_pairs(x) for x in state['pos']]) |
| 157 | + (grad_state, grad_params) = vmap(batch_vjp, in_axes=(self.states_axis, None, 0, self.states_axis))(state, params, pairs, adjoint_state) |
| 158 | + gradient = tree_util.tree_map(lambda p, u: p + jnp.sum(u, axis=0), gradient, grad_params) |
| 159 | + adjoint_state = grad_state |
| 160 | + state['pos'] = state['pos'] - state['vel']*self.dt |
| 161 | + state['pos'] = vmap(self.regularize_pos)(state['pos']) |
| 162 | + state['vel'] = - state['vel'] |
| 163 | + return state, adjoint_state, gradient |
| 164 | + primals, f_vjp = vjp(self.f_nout, state) |
| 165 | + adjoint_state = f_vjp(gradient_traj[-1])[0] |
| 166 | + gradient = tree_util.tree_map(jnp.zeros_like, params) |
| 167 | + # (v_1.5, x2) -> (-v_1.5, x1) |
| 168 | + state['pos'] = state['pos'] - state['vel']*self.dt |
| 169 | + state['pos'] = vmap(self.regularize_pos)(state['pos']) |
| 170 | + state['vel'] = - state['vel'] |
| 171 | + for i in range(self.nsteps//self.nout): |
| 172 | + state, adjoint_state, gradient = bwd(state, adjoint_state, gradient) |
| 173 | + primals, f_vjp = vjp(self.f_nout, state) |
| 174 | + adjoint_state = {key: adjoint_state[key] + f_vjp(gradient_traj[-(i+2)])[0][key] for key in state} |
| 175 | + return adjoint_state, gradient |
| 176 | + |
| 177 | + def generate_Loss(self, L, has_aux=False, metadata=[]): |
| 178 | + """ |
| 179 | + Generate Loss function |
| 180 | +
|
| 181 | + Parameters |
| 182 | + ---------- |
| 183 | + L: function |
| 184 | + The 'Loss' function user defined, input: traj['state'], output: loss |
| 185 | + has_aux: bool |
| 186 | + If the L function returns auxiliary data |
| 187 | + metadata: [] |
| 188 | + Record the traj and auxiliary data, {'traj':traj, 'aux_data':aux_data} |
| 189 | + |
| 190 | + Returns: |
| 191 | + ---------- |
| 192 | + Loss: function |
| 193 | + Loss function |
| 194 | +
|
| 195 | + Examples: |
| 196 | + ---------- |
| 197 | + """ |
| 198 | + @custom_jvp |
| 199 | + def Loss(initial_state, params): |
| 200 | + """ |
| 201 | + This function returns the loss. |
| 202 | +
|
| 203 | + Parameters |
| 204 | + ---------- |
| 205 | + initial_state: dict |
| 206 | + Initial state, {'pos': jnp.ndarray, 'vel': jnp.ndarray} |
| 207 | + params: dict |
| 208 | + The parameter dictionary. |
| 209 | + |
| 210 | + Returns: |
| 211 | + ---------- |
| 212 | + loss: float |
| 213 | + Loss |
| 214 | +
|
| 215 | + Examples: |
| 216 | + ---------- |
| 217 | + """ |
| 218 | + final_state, traj = self.ode_fwd(initial_state, params) |
| 219 | + if has_aux == True: |
| 220 | + loss, aux_data = L(traj['state']) |
| 221 | + metadata.append({'traj':traj, 'aux_data':aux_data}) |
| 222 | + else: |
| 223 | + loss = L(traj['state']) |
| 224 | + metadata.append({'traj':traj}) |
| 225 | + return loss |
| 226 | + |
| 227 | + @Loss.defjvp |
| 228 | + def _f_jvp(primals, tangents): |
| 229 | + x, y = primals |
| 230 | + x_dot, y_dot = tangents |
| 231 | + final_state, traj = self.ode_fwd(x, y) |
| 232 | + if has_aux == True: |
| 233 | + (primal_out, aux_data), gradient_traj = value_and_grad(L, has_aux=True)(traj['state']) |
| 234 | + metadata.append({'traj':traj, 'aux_data':aux_data}) |
| 235 | + else: |
| 236 | + primal_out, gradient_traj = value_and_grad(L)(traj['state']) |
| 237 | + metadata.append({'traj':traj}) |
| 238 | + adjoint_state, gradient = self._ode_bwd(final_state, y, gradient_traj) |
| 239 | + tangent_out = sum(tree_util.tree_leaves(tree_util.tree_map(lambda p, u: jnp.sum(p * u), adjoint_state, x_dot))) + sum(tree_util.tree_leaves(tree_util.tree_map(lambda p, u: jnp.sum(p * u), gradient, y_dot))) |
| 240 | + return primal_out, tangent_out |
| 241 | + |
| 242 | + return Loss |
| 243 | + |
0 commit comments