1
1
import torch
2
2
import torch .nn as nn
3
-
4
3
import numpy as np
5
4
import nn4n .utils as utils
6
5
@@ -73,7 +72,7 @@ def from_dict(cls, layer_struct):
73
72
"""
74
73
# Create an instance using the dictionary values
75
74
cls ._check_keys (layer_struct )
76
- instance = cls (
75
+ return cls (
77
76
input_dim = layer_struct ["input_dim" ],
78
77
output_dim = layer_struct ["output_dim" ],
79
78
weight = layer_struct .get ("weight" , "uniform" ),
@@ -82,20 +81,216 @@ def from_dict(cls, layer_struct):
82
81
sparsity_mask = layer_struct .get ("sparsity_mask" ),
83
82
plasticity_mask = layer_struct .get ("plasticity_mask" ),
84
83
)
85
- # Initialize the trainable parameters then check the layer
86
- instance ._init_trainable ()
87
- instance ._check_layer ()
88
84
89
- return instance
85
+ def _check_layer (self ):
86
+ pass
90
87
88
+ # INIT TRAINABLE
89
+ # ======================================================================================
91
90
def _init_trainable (self ):
92
- # enfore constraints
91
+ # Enfore constraints
93
92
self ._init_constraints ()
94
- # convert weight and bias to torch tensor
93
+ # Convert weight and bias to learnable parameters
95
94
self .weight = nn .Parameter (
96
95
self .weight , requires_grad = self .weight_dist is not None
97
96
)
98
97
self .bias = nn .Parameter (self .bias , requires_grad = self .bias_dist is not None )
99
98
100
- def _check_layer (self ):
101
- pass
99
+ def _init_constraints (self ):
100
+ """
101
+ Initialize constraints
102
+ It will also balance excitatory and inhibitory neurons
103
+ """
104
+ if self .sparsity_mask is not None :
105
+
106
+ self .weight *= self .sparsity_mask
107
+ if self .ei_mask is not None :
108
+ # Apply Dale's law
109
+ self .weight [self .ei_mask == 1 ] = torch .clamp (
110
+ self .weight [self .ei_mask == 1 ], min = 0
111
+ ) # For excitatory neurons, set negative weights to 0
112
+ self .weight [self .ei_mask == - 1 ] = torch .clamp (
113
+ self .weight [self .ei_mask == - 1 ], max = 0
114
+ ) # For inhibitory neurons, set positive weights to 0
115
+
116
+ # Balance excitatory and inhibitory neurons weight magnitudes
117
+ self ._balance_excitatory_inhibitory ()
118
+
119
+ def _generate_bias (self , bias_init ):
120
+ """Generate random bias"""
121
+ if bias_init == "uniform" :
122
+ # If uniform, let b be uniform in [-sqrt(k), sqrt(k)]
123
+ sqrt_k = torch .sqrt (torch .tensor (1 / self .input_dim ))
124
+ b = torch .rand (self .output_dim ) * sqrt_k
125
+ b = b * 2 - sqrt_k
126
+ elif bias_init == "normal" :
127
+ b = torch .randn (self .output_dim ) / torch .sqrt (torch .tensor (self .input_dim ))
128
+ elif bias_init == "zero" or bias_init == None :
129
+ b = torch .zeros (self .output_dim )
130
+ elif type (bias_init ) == np .ndarray :
131
+ b = torch .from_numpy (bias_init )
132
+ else :
133
+ raise NotImplementedError
134
+ return b .float ()
135
+
136
+ def _generate_weight (self , weight_init ):
137
+ """Generate random weight"""
138
+ if weight_init == "uniform" :
139
+ # If uniform, let w be uniform in [-sqrt(k), sqrt(k)]
140
+ sqrt_k = torch .sqrt (torch .tensor (1 / self .input_dim ))
141
+ w = torch .rand (self .output_dim , self .input_dim ) * sqrt_k
142
+ w = w * 2 - sqrt_k
143
+ elif weight_init == "normal" :
144
+ w = torch .randn (self .output_dim , self .input_dim ) / torch .sqrt (
145
+ torch .tensor (self .input_dim )
146
+ )
147
+ elif weight_init == "zero" :
148
+ w = torch .zeros ((self .output_dim , self .input_dim ))
149
+ elif type (weight_init ) == np .ndarray :
150
+ w = torch .from_numpy (weight_init )
151
+ else :
152
+ raise NotImplementedError
153
+ return w .float ()
154
+
155
+ def _balance_excitatory_inhibitory (self ):
156
+ """Balance excitatory and inhibitory weights"""
157
+ scale_mat = torch .ones_like (self .weight )
158
+ ext_sum = self .weight [self .sparsity_mask == 1 ].sum ()
159
+ inh_sum = self .weight [self .sparsity_mask == - 1 ].sum ()
160
+ if ext_sum == 0 or inh_sum == 0 :
161
+ # Automatically stop balancing if one of the sums is 0
162
+ # devide by 10 to avoid recurrent explosion/decay
163
+ self .weight /= 10
164
+ else :
165
+ if ext_sum > abs (inh_sum ):
166
+ _scale = abs (inh_sum ).item () / ext_sum .item ()
167
+ scale_mat [self .sparsity_mask == 1 ] = _scale
168
+ elif ext_sum < abs (inh_sum ):
169
+ _scale = ext_sum .item () / abs (inh_sum ).item ()
170
+ scale_mat [self .sparsity_mask == - 1 ] = _scale
171
+ # Apply scaling
172
+ self .weight *= scale_mat
173
+
174
+ # TRAINING
175
+ # ======================================================================================
176
+ def to (self , device ):
177
+ """Move the network to the device (cpu/gpu)"""
178
+ super ().to (device )
179
+ if self .sparsity_mask is not None :
180
+ self .sparsity_mask = self .sparsity_mask .to (device )
181
+ if self .ei_mask is not None :
182
+ self .ei_mask = self .ei_mask .to (device )
183
+ if self .bias .requires_grad :
184
+ self .bias = self .bias .to (device )
185
+ return self
186
+
187
+ def forward (self , x ):
188
+ """
189
+ Forwardly update network
190
+
191
+ Inputs:
192
+ - x: input, shape: (batch_size, input_dim)
193
+
194
+ Returns:
195
+ - state: shape: (batch_size, hidden_size)
196
+ """
197
+ return x .float () @ self .weight .T + self .bias
198
+
199
+ def apply_plasticity (self ):
200
+ """
201
+ Apply plasticity mask to the weight gradient
202
+ """
203
+ with torch .no_grad ():
204
+ # assume the plasticity mask are all valid and being checked in ctrnn class
205
+ for scale in self .plasticity_scales :
206
+ if self .weight .grad is not None :
207
+ self .weight .grad [self .plasticity_mask == scale ] *= scale
208
+ else :
209
+ raise RuntimeError (
210
+ "Weight gradient is None, possibly because the forward loop is non-differentiable"
211
+ )
212
+
213
+ def freeze (self ):
214
+ """Freeze the layer"""
215
+ self .weight .requires_grad = False
216
+ self .bias .requires_grad = False
217
+
218
+ def unfreeze (self ):
219
+ """Unfreeze the layer"""
220
+ self .weight .requires_grad = True
221
+ self .bias .requires_grad = True
222
+
223
+ # CONSTRAINTS
224
+ # ======================================================================================
225
+ def enforce_constraints (self ):
226
+ """
227
+ Enforce constraints
228
+
229
+ The constraints are:
230
+ - sparsity_mask: mask for sparse connectivity
231
+ - ei_mask: mask for Dale's law
232
+ """
233
+ if self .sparsity_mask is not None :
234
+ self ._enforce_sparsity ()
235
+ if self .ei_mask is not None :
236
+ self ._enforce_ei ()
237
+
238
+ def _enforce_sparsity (self ):
239
+ """Enforce sparsity"""
240
+ w = self .weight .detach ().clone () * self .sparsity_mask
241
+ self .weight .data .copy_ (torch .nn .Parameter (w ))
242
+
243
+ def _enforce_ei (self ):
244
+ """Enforce Dale's law"""
245
+ w = self .weight .detach ().clone ()
246
+ w [self .ei_mask == 1 ] = torch .clamp (w [self .ei_mask == 1 ], min = 0 )
247
+ w [self .ei_mask == - 1 ] = torch .clamp (w [self .ei_mask == - 1 ], max = 0 )
248
+ self .weight .data .copy_ (torch .nn .Parameter (w ))
249
+
250
+ # HELPER FUNCTIONS
251
+ # ======================================================================================
252
+ def set_weight (self , weight ):
253
+ """Set the value of weight"""
254
+ assert (
255
+ weight .shape == self .weight .shape
256
+ ), f"Weight shape mismatch, expected { self .weight .shape } , got { weight .shape } "
257
+ with torch .no_grad ():
258
+ self .weight .copy_ (weight )
259
+
260
+ def plot_layer (self ):
261
+ """Plot the weights matrix and distribution of each layer"""
262
+ weight = (
263
+ self .weight .cpu ()
264
+ if self .weight .device != torch .device ("cpu" )
265
+ else self .weight
266
+ )
267
+ utils .plot_connectivity_matrix_dist (
268
+ weight .detach ().numpy (),
269
+ f"Weight" ,
270
+ False ,
271
+ self .sparsity_mask is not None ,
272
+ )
273
+
274
+ def get_specs (self ):
275
+ """Print the specs of each layer"""
276
+ return {
277
+ "input_dim" : self .input_dim ,
278
+ "output_dim" : self .output_dim ,
279
+ "weight_learnable" : self .weight .requires_grad ,
280
+ "weight_min" : self .weight .min ().item (),
281
+ "weight_max" : self .weight .max ().item (),
282
+ "bias_learnable" : self .bias .requires_grad ,
283
+ "bias_min" : self .bias .min ().item (),
284
+ "bias_max" : self .bias .max ().item (),
285
+ "sparsity" : (
286
+ self .sparsity_mask .sum () / self .sparsity_mask .numel ()
287
+ if self .sparsity_mask is not None
288
+ else 1
289
+ )
290
+ }
291
+
292
+ def print_layer (self ):
293
+ """
294
+ Print the specs of the layer
295
+ """
296
+ utils .print_dict ("Layer Specs" , self .get_specs ())
0 commit comments