Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Short Range Failure #193

Open
wants to merge 2 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions dmff/admp/pairwise.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def generate_pairwise_interaction(pair_int_kernel, static_args):
with the order in kernel
'''

def pair_int(positions, box, pairs, mScales, *atomic_params):
def pair_int(positions, box, pairs, mScales, *atomic_params):
# pairs = regularize_pairs(pairs)
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))

Expand All @@ -77,7 +77,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
buffer_scales = pair_buffer_scales(pairs)
mscales = mscales * buffer_scales
# mscales = mScales[nbonds-1]
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
box_inv = jnp.linalg.inv(box)
dr = ri - rj
dr = v_pbc_shift(dr, box, box_inv)
dr = jnp.linalg.norm(dr, axis=1)
Expand All @@ -89,7 +89,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
# pair_params.append(param[pairs[:, 0]])
# pair_params.append(param[pairs[:, 1]])

energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales)
energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales)
return energy

return pair_int
Expand Down Expand Up @@ -155,7 +155,9 @@ def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j):

@vmap
@jit_condition(static_argnums=())
def slater_sr_kernel(dr, m, ai, aj, bi, bj):
# with hardcore potential
def slater_sr_hc_kernel(dr, m, ai, aj, bi, bj):

'''
Slater-ISA type short range terms
see jctc 12 3851
Expand All @@ -165,5 +167,30 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj):
br = b * dr
br2 = br * br
P = 1/3 * br2 + br + 1
return a * P * jnp.exp(-br) * m

alpha = 0.24
beta = 14
x = alpha * br
x2 = x * x
x4 = x2 * x2
x8 = x4 * x4
x12 = x4 * x8
x14 = x12 * x2
HardCorePotential = a / x14 * m
return a * P * jnp.exp(-br) * m + HardCorePotential

@vmap
@jit_condition(static_argnums=())
def slater_sr_kernel(dr, m, ai, aj, bi, bj):

'''
Slater-ISA type short range terms
see jctc 12 3851
'''
b = jnp.sqrt(bi * bj)
a = ai * aj
br = b * dr
br2 = br * br
P = 1/3 * br2 + br + 1

return a * P * jnp.exp(-br) * m
11 changes: 11 additions & 0 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@

DEFAULT_THOLE_WIDTH = 5.0

# variables used in soft dipole truncation
MAX_DIP = 1.0
TRUNCATION_HARDNESS = 25 # the smaller, the softer
def SOFT_TRUNCATION(x):
x2 = x * x
x_abs = jnp.sqrt(x2 + 1e-6)
val = -1/TRUNCATION_HARDNESS * jnp.log(1 + jnp.exp(-TRUNCATION_HARDNESS*(x_abs-MAX_DIP))) + MAX_DIP
return val * x/x_abs


class ADMPPmeForce:
"""
Expand Down Expand Up @@ -416,6 +425,8 @@ def update_U(i, U):
dScales,
)
U = U - field * pol[:, jnp.newaxis] / DIELECTRIC
# soft truncation: stop polarization catastrophe
U = SOFT_TRUNCATION(U)
return U

U = jax.lax.fori_loop(0, steps_pol, update_U, U)
Expand Down
80 changes: 11 additions & 69 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
TT_damping_qq_c6_kernel,
generate_pairwise_interaction,
slater_disp_damping_kernel,
slater_sr_kernel,
slater_sr_kernel, ## no Hard Core Potential
slater_sr_hc_kernel, ## added Hard Core Potential
TT_damping_qq_kernel,
)
from ..admp.pme import ADMPPmeForce
Expand Down Expand Up @@ -759,20 +760,21 @@ def createPotential(

topdata._meta[self.name+"_map_atomtype"] = map_atomtype

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})
pot_fn_sr = generate_pairwise_interaction(slater_sr_hc_kernel, static_args={})
#slater_ex_sr_kernel: added Hard Core Potential

has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True

def potential_fn(positions, box, pairs, params, aux=None):
def potential_fn(positions, box, pairs, params, aux=None):
positions = positions * 10
box = box * 10
params = params[self.name]
a_list = params["A"][map_atomtype]
b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1

energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list)
energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list)
if has_aux:
return energy, aux
else:
Expand All @@ -790,6 +792,7 @@ def getJaxPotential(self):
_DMFFGenerators["SlaterExForce"] = SlaterExGenerator



# Here are all the short range "charge penetration" terms
# They all have the exchange form with minus sign
class SlaterSrEsGenerator(SlaterExGenerator):
Expand All @@ -798,7 +801,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet, default_name=None):
super().__init__(ffinfo, paramset, default_name="SlaterSrEsForce")
else:
super().__init__(ffinfo, paramset, default_name=default_name)

def createPotential(
self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs
):
Expand All @@ -812,14 +814,14 @@ def createPotential(

topdata._meta[self.name+"_map_atomtype"] = map_atomtype

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
static_args={})
pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})
## slater_sr_others_kernel: no Hard Core Potential

has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True

def potential_fn(positions, box, pairs, params, aux=None):
def potential_fn(positions, box, pairs, params, aux=None):
positions = positions * 10
box = box * 10
params = params[self.name]
Expand Down Expand Up @@ -934,10 +936,7 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
for node in self.ffinfo["Forces"][self.name]["node"]
if node["name"] in ["Multipole", "Atom"]
]
c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ, oXXX, oXXY, oXYY, oYYY, oXXZ, oXYZ, oYYZ, oXZZ, oYZZ, oZZZ = (
[],
[],
[],
c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ = (
[],
[],
[],
Expand All @@ -948,13 +947,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
[],
[],
[],
[],
[],
[],
[],
[],
[],
[]
)
kxs, kys, kzs = [], [], []
multipole_masks = []
Expand Down Expand Up @@ -997,29 +989,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
qXY.append(0.0)
qXZ.append(0.0)
qYZ.append(0.0)
if self.lmax >= 3:
oXXX.append(float(attribs["oXXX"]))
oXXY.append(float(attribs["oXXY"]))
oXYY.append(float(attribs["oXYY"]))
oYYY.append(float(attribs["oYYY"]))
oXXZ.append(float(attribs["oXXZ"]))
oXYZ.append(float(attribs["oXYZ"]))
oYYZ.append(float(attribs["oYYZ"]))
oXZZ.append(float(attribs["oXZZ"]))
oYZZ.append(float(attribs["oYZZ"]))
oZZZ.append(float(attribs["oZZZ"]))
else:
oXXX.append(0.0)
oXXY.append(0.0)
oXYY.append(0.0)
oYYY.append(0.0)
oXXZ.append(0.0)
oXYZ.append(0.0)
oYYZ.append(0.0)
oXZZ.append(0.0)
oYZZ.append(0.0)
oZZZ.append(0.0)

mask = 1.0
if "mask" in attribs and attribs["mask"].upper() == "TRUE":
mask = 0.0
Expand Down Expand Up @@ -1077,8 +1046,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
n_mtps = 4
elif self.lmax == 2:
n_mtps = 10
elif self.lmax == 3:
n_mtps = 20
Q = np.zeros((n_atoms, n_mtps))

# TDDO: unit conversion
Expand All @@ -1096,19 +1063,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
Q[:, 8] = qXZ
Q[:, 9] = qYZ
Q[:, 4:10] *= 300
if self.lmax >= 3:
Q[:, 10] = oXXX
Q[:, 11] = oXXY
Q[:, 12] = oXYY
Q[:, 13] = oYYY
Q[:, 14] = oXXZ
Q[:, 15] = oXYZ
Q[:, 16] = oYYZ
Q[:, 17] = oXZZ
Q[:, 18] = oYZZ
Q[:, 19] = oZZZ
# TO DO: To be decided
Q[:, 10:20] *= 15000

# add all differentiable params to self.params
Q_local = convert_cart2harm(jnp.array(Q), self.lmax)
Expand Down Expand Up @@ -1138,18 +1092,6 @@ def overwrite(self, paramset):
node["attrib"]["qXY"] = Q_global[n_multipole, 7] / 300.0
node["attrib"]["qXZ"] = Q_global[n_multipole, 8] / 300.0
node["attrib"]["qYZ"] = Q_global[n_multipole, 9] / 300.0
if self.lmax >= 3:
node["attrib"]["oXXX"] = Q_global[n_multipole, 10] / 15000.0
node["attrib"]["oXXY"] = Q_global[n_multipole, 11] / 15000.0
node["attrib"]["oXYY"] = Q_global[n_multipole, 12] / 15000.0
node["attrib"]["oYYY"] = Q_global[n_multipole, 13] / 15000.0
node["attrib"]["oXXZ"] = Q_global[n_multipole, 14] / 15000.0
node["attrib"]["oXYZ"] = Q_global[n_multipole, 15] / 15000.0
node["attrib"]["oYYZ"] = Q_global[n_multipole, 16] / 15000.0
node["attrib"]["oXZZ"] = Q_global[n_multipole, 17] / 15000.0
node["attrib"]["oYZZ"] = Q_global[n_multipole, 18] / 15000.0
node["attrib"]["oZZZ"] = Q_global[n_multipole, 19] / 15000.0

if q_local_masks[n_multipole] < 0.999:
node["mask"] = "true"
n_multipole += 1
Expand Down
Loading