8
8
import jax .nn .initializers
9
9
import jax .numpy as jnp
10
10
import numpy as np
11
- from .graph import MAX_VALENCE , TopGraph , from_pdb
11
+ from .graph import TopGraph , from_pdb
12
+ from .graph import MAX_VALENCE , ATYPE_INDEX , FSCALE_BOND , FSCALE_ANGLE
12
13
from ..utils import jit_condition
13
14
from jax import value_and_grad , vmap
14
15
@@ -55,7 +56,12 @@ def __init__(self,
55
56
nn = 1 ,
56
57
sigma = 162.13039087945623 ,
57
58
mu = 117.41975505778706 ,
58
- seed = 12345 ):
59
+ seed = 12345 ,
60
+ max_valence = MAX_VALENCE ,
61
+ atype_index = ATYPE_INDEX ,
62
+ fscale_bond = FSCALE_BOND ,
63
+ fscale_angle = FSCALE_ANGLE
64
+ ):
59
65
""" Constructor for MolGNNForce
60
66
61
67
Parameters
@@ -77,15 +83,25 @@ def __init__(self,
77
83
mu: float, optional
78
84
a constant shift
79
85
the final total energy would be ${(E_{NN} + \mu) * \sigma}
80
- seed: int: optional
86
+ seed: int, optional
81
87
random seed used in network initialization
82
88
default = 12345
83
-
89
+ max_valence: int, optional
90
+ Maximal valence number for all atoms inside the graph, use the value in graph.py by default
91
+ atype_index: dict, optional
92
+ A dictionary that assign index to each relevant element: e.g., {'H': 0, 'C': 1, 'O': 2}, use the ATYPE_INDEX in graph.py by default
93
+ fscale_bond: float, optional
94
+ The scaling factor for bond features, use value in graph.py by default
95
+ fscale_angle: float, optional
96
+ The scaling factor for angle features, use value in graph.py by default
84
97
"""
85
98
self .nn = nn
86
99
self .G = G
87
100
self .G .get_all_subgraphs (nn , typify = True )
88
- self .G .prepare_subgraph_feature_calc ()
101
+ self .G .prepare_subgraph_feature_calc (max_valence = max_valence ,
102
+ atype_index = atype_index ,
103
+ fscale_bond = fscale_bond ,
104
+ fscale_angle = fscale_angle )
89
105
params = OrderedDict ()
90
106
key = jax .random .PRNGKey (seed )
91
107
params ['w' ] = jax .random .uniform (key )
@@ -151,14 +167,14 @@ def message_pass(f_in, nb_connect, w, nn):
151
167
if nn == 0 :
152
168
return f_in [0 ]
153
169
elif nn == 1 :
154
- nb_connect0 = nb_connect [0 :MAX_VALENCE - 1 ]
155
- nb_connect1 = nb_connect [MAX_VALENCE - 1 :2 *
156
- (MAX_VALENCE - 1 )]
170
+ nb_connect0 = nb_connect [0 :max_valence - 1 ]
171
+ nb_connect1 = nb_connect [max_valence - 1 :2 *
172
+ (max_valence - 1 )]
157
173
nb0 = jnp .sum (nb_connect0 )
158
174
nb1 = jnp .sum (nb_connect1 )
159
175
f = f_in [0 ] * (1 - jnp .heaviside (nb0 , 0 )* w - jnp .heaviside (nb1 , 0 )* w ) + \
160
- w * nb_connect0 .dot (f_in [1 :MAX_VALENCE , :]) / jnp .piecewise (nb0 , [nb0 < 1e-5 , nb0 >= 1e-5 ], [lambda x : jnp .array (1e-5 ), lambda x : x ]) + \
161
- w * nb_connect1 .dot (f_in [MAX_VALENCE :2 * MAX_VALENCE - 1 , :])/ jnp .piecewise (nb1 , [nb1 < 1e-5 , nb1 >= 1e-5 ], [lambda x : jnp .array (1e-5 ), lambda x : x ])
176
+ w * nb_connect0 .dot (f_in [1 :max_valence , :]) / jnp .piecewise (nb0 , [nb0 < 1e-5 , nb0 >= 1e-5 ], [lambda x : jnp .array (1e-5 ), lambda x : x ]) + \
177
+ w * nb_connect1 .dot (f_in [max_valence :2 * max_valence - 1 , :])/ jnp .piecewise (nb1 , [nb1 < 1e-5 , nb1 >= 1e-5 ], [lambda x : jnp .array (1e-5 ), lambda x : x ])
162
178
return f
163
179
164
180
features = fc0 (features , params )
0 commit comments