Skip to content

Commit b9a794d

Browse files
committed
Add support for small (ABn-type) molecules for sGNN
1 parent 2126ba4 commit b9a794d

File tree

2 files changed

+152
-52
lines changed

2 files changed

+152
-52
lines changed

dmff/sgnn/gnn.py

100755100644
Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import jax.nn.initializers
99
import jax.numpy as jnp
1010
import numpy as np
11-
from .graph import MAX_VALENCE, TopGraph, from_pdb
11+
from .graph import TopGraph, from_pdb
12+
from .graph import MAX_VALENCE, ATYPE_INDEX, FSCALE_BOND, FSCALE_ANGLE
1213
from ..utils import jit_condition
1314
from jax import value_and_grad, vmap
1415

@@ -55,7 +56,12 @@ def __init__(self,
5556
nn=1,
5657
sigma=162.13039087945623,
5758
mu=117.41975505778706,
58-
seed=12345):
59+
seed=12345,
60+
max_valence=MAX_VALENCE,
61+
atype_index=ATYPE_INDEX,
62+
fscale_bond=FSCALE_BOND,
63+
fscale_angle=FSCALE_ANGLE
64+
):
5965
""" Constructor for MolGNNForce
6066
6167
Parameters
@@ -77,15 +83,25 @@ def __init__(self,
7783
mu: float, optional
7884
a constant shift
7985
the final total energy would be ${(E_{NN} + \mu) * \sigma}
80-
seed: int: optional
86+
seed: int, optional
8187
random seed used in network initialization
8288
default = 12345
83-
89+
max_valence: int, optional
90+
Maximal valence number for all atoms inside the graph, use the value in graph.py by default
91+
atype_index: dict, optional
92+
A dictionary that assign index to each relevant element: e.g., {'H': 0, 'C': 1, 'O': 2}, use the ATYPE_INDEX in graph.py by default
93+
fscale_bond: float, optional
94+
The scaling factor for bond features, use value in graph.py by default
95+
fscale_angle: float, optional
96+
The scaling factor for angle features, use value in graph.py by default
8497
"""
8598
self.nn = nn
8699
self.G = G
87100
self.G.get_all_subgraphs(nn, typify=True)
88-
self.G.prepare_subgraph_feature_calc()
101+
self.G.prepare_subgraph_feature_calc(max_valence=max_valence,
102+
atype_index=atype_index,
103+
fscale_bond=fscale_bond,
104+
fscale_angle=fscale_angle)
89105
params = OrderedDict()
90106
key = jax.random.PRNGKey(seed)
91107
params['w'] = jax.random.uniform(key)
@@ -151,14 +167,14 @@ def message_pass(f_in, nb_connect, w, nn):
151167
if nn == 0:
152168
return f_in[0]
153169
elif nn == 1:
154-
nb_connect0 = nb_connect[0:MAX_VALENCE - 1]
155-
nb_connect1 = nb_connect[MAX_VALENCE - 1:2 *
156-
(MAX_VALENCE - 1)]
170+
nb_connect0 = nb_connect[0:max_valence - 1]
171+
nb_connect1 = nb_connect[max_valence - 1:2 *
172+
(max_valence - 1)]
157173
nb0 = jnp.sum(nb_connect0)
158174
nb1 = jnp.sum(nb_connect1)
159175
f = f_in[0] * (1 - jnp.heaviside(nb0, 0)*w - jnp.heaviside(nb1, 0)*w) + \
160-
w * nb_connect0.dot(f_in[1:MAX_VALENCE, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \
161-
w * nb_connect1.dot(f_in[MAX_VALENCE:2*MAX_VALENCE-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x])
176+
w * nb_connect0.dot(f_in[1:max_valence, :]) / jnp.piecewise(nb0, [nb0<1e-5, nb0>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x]) + \
177+
w * nb_connect1.dot(f_in[max_valence:2*max_valence-1, :])/ jnp.piecewise(nb1, [nb1<1e-5, nb1>=1e-5], [lambda x: jnp.array(1e-5), lambda x: x])
162178
return f
163179

164180
features = fc0(features, params)

0 commit comments

Comments
 (0)