Skip to content

Commit 2b1da32

Browse files
committed
refactor LinearLayer and HiddenLayer
1 parent 1d844a7 commit 2b1da32

File tree

11 files changed

+246
-645
lines changed

11 files changed

+246
-645
lines changed

nn4n/layer/base_layer.py

Lines changed: 205 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
43
import numpy as np
54
import nn4n.utils as utils
65

@@ -73,7 +72,7 @@ def from_dict(cls, layer_struct):
7372
"""
7473
# Create an instance using the dictionary values
7574
cls._check_keys(layer_struct)
76-
instance = cls(
75+
return cls(
7776
input_dim=layer_struct["input_dim"],
7877
output_dim=layer_struct["output_dim"],
7978
weight=layer_struct.get("weight", "uniform"),
@@ -82,20 +81,216 @@ def from_dict(cls, layer_struct):
8281
sparsity_mask=layer_struct.get("sparsity_mask"),
8382
plasticity_mask=layer_struct.get("plasticity_mask"),
8483
)
85-
# Initialize the trainable parameters then check the layer
86-
instance._init_trainable()
87-
instance._check_layer()
8884

89-
return instance
85+
def _check_layer(self):
86+
pass
9087

88+
# INIT TRAINABLE
89+
# ======================================================================================
9190
def _init_trainable(self):
92-
# enfore constraints
91+
# Enfore constraints
9392
self._init_constraints()
94-
# convert weight and bias to torch tensor
93+
# Convert weight and bias to learnable parameters
9594
self.weight = nn.Parameter(
9695
self.weight, requires_grad=self.weight_dist is not None
9796
)
9897
self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None)
9998

100-
def _check_layer(self):
101-
pass
99+
def _init_constraints(self):
100+
"""
101+
Initialize constraints
102+
It will also balance excitatory and inhibitory neurons
103+
"""
104+
if self.sparsity_mask is not None:
105+
106+
self.weight *= self.sparsity_mask
107+
if self.ei_mask is not None:
108+
# Apply Dale's law
109+
self.weight[self.ei_mask == 1] = torch.clamp(
110+
self.weight[self.ei_mask == 1], min=0
111+
) # For excitatory neurons, set negative weights to 0
112+
self.weight[self.ei_mask == -1] = torch.clamp(
113+
self.weight[self.ei_mask == -1], max=0
114+
) # For inhibitory neurons, set positive weights to 0
115+
116+
# Balance excitatory and inhibitory neurons weight magnitudes
117+
self._balance_excitatory_inhibitory()
118+
119+
def _generate_bias(self, bias_init):
120+
"""Generate random bias"""
121+
if bias_init == "uniform":
122+
# If uniform, let b be uniform in [-sqrt(k), sqrt(k)]
123+
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
124+
b = torch.rand(self.output_dim) * sqrt_k
125+
b = b * 2 - sqrt_k
126+
elif bias_init == "normal":
127+
b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim))
128+
elif bias_init == "zero" or bias_init == None:
129+
b = torch.zeros(self.output_dim)
130+
elif type(bias_init) == np.ndarray:
131+
b = torch.from_numpy(bias_init)
132+
else:
133+
raise NotImplementedError
134+
return b.float()
135+
136+
def _generate_weight(self, weight_init):
137+
"""Generate random weight"""
138+
if weight_init == "uniform":
139+
# If uniform, let w be uniform in [-sqrt(k), sqrt(k)]
140+
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
141+
w = torch.rand(self.output_dim, self.input_dim) * sqrt_k
142+
w = w * 2 - sqrt_k
143+
elif weight_init == "normal":
144+
w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt(
145+
torch.tensor(self.input_dim)
146+
)
147+
elif weight_init == "zero":
148+
w = torch.zeros((self.output_dim, self.input_dim))
149+
elif type(weight_init) == np.ndarray:
150+
w = torch.from_numpy(weight_init)
151+
else:
152+
raise NotImplementedError
153+
return w.float()
154+
155+
def _balance_excitatory_inhibitory(self):
156+
"""Balance excitatory and inhibitory weights"""
157+
scale_mat = torch.ones_like(self.weight)
158+
ext_sum = self.weight[self.sparsity_mask == 1].sum()
159+
inh_sum = self.weight[self.sparsity_mask == -1].sum()
160+
if ext_sum == 0 or inh_sum == 0:
161+
# Automatically stop balancing if one of the sums is 0
162+
# devide by 10 to avoid recurrent explosion/decay
163+
self.weight /= 10
164+
else:
165+
if ext_sum > abs(inh_sum):
166+
_scale = abs(inh_sum).item() / ext_sum.item()
167+
scale_mat[self.sparsity_mask == 1] = _scale
168+
elif ext_sum < abs(inh_sum):
169+
_scale = ext_sum.item() / abs(inh_sum).item()
170+
scale_mat[self.sparsity_mask == -1] = _scale
171+
# Apply scaling
172+
self.weight *= scale_mat
173+
174+
# TRAINING
175+
# ======================================================================================
176+
def to(self, device):
177+
"""Move the network to the device (cpu/gpu)"""
178+
super().to(device)
179+
if self.sparsity_mask is not None:
180+
self.sparsity_mask = self.sparsity_mask.to(device)
181+
if self.ei_mask is not None:
182+
self.ei_mask = self.ei_mask.to(device)
183+
if self.bias.requires_grad:
184+
self.bias = self.bias.to(device)
185+
return self
186+
187+
def forward(self, x):
188+
"""
189+
Forwardly update network
190+
191+
Inputs:
192+
- x: input, shape: (batch_size, input_dim)
193+
194+
Returns:
195+
- state: shape: (batch_size, hidden_size)
196+
"""
197+
return x.float() @ self.weight.T + self.bias
198+
199+
def apply_plasticity(self):
200+
"""
201+
Apply plasticity mask to the weight gradient
202+
"""
203+
with torch.no_grad():
204+
# assume the plasticity mask are all valid and being checked in ctrnn class
205+
for scale in self.plasticity_scales:
206+
if self.weight.grad is not None:
207+
self.weight.grad[self.plasticity_mask == scale] *= scale
208+
else:
209+
raise RuntimeError(
210+
"Weight gradient is None, possibly because the forward loop is non-differentiable"
211+
)
212+
213+
def freeze(self):
214+
"""Freeze the layer"""
215+
self.weight.requires_grad = False
216+
self.bias.requires_grad = False
217+
218+
def unfreeze(self):
219+
"""Unfreeze the layer"""
220+
self.weight.requires_grad = True
221+
self.bias.requires_grad = True
222+
223+
# CONSTRAINTS
224+
# ======================================================================================
225+
def enforce_constraints(self):
226+
"""
227+
Enforce constraints
228+
229+
The constraints are:
230+
- sparsity_mask: mask for sparse connectivity
231+
- ei_mask: mask for Dale's law
232+
"""
233+
if self.sparsity_mask is not None:
234+
self._enforce_sparsity()
235+
if self.ei_mask is not None:
236+
self._enforce_ei()
237+
238+
def _enforce_sparsity(self):
239+
"""Enforce sparsity"""
240+
w = self.weight.detach().clone() * self.sparsity_mask
241+
self.weight.data.copy_(torch.nn.Parameter(w))
242+
243+
def _enforce_ei(self):
244+
"""Enforce Dale's law"""
245+
w = self.weight.detach().clone()
246+
w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0)
247+
w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0)
248+
self.weight.data.copy_(torch.nn.Parameter(w))
249+
250+
# HELPER FUNCTIONS
251+
# ======================================================================================
252+
def set_weight(self, weight):
253+
"""Set the value of weight"""
254+
assert (
255+
weight.shape == self.weight.shape
256+
), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}"
257+
with torch.no_grad():
258+
self.weight.copy_(weight)
259+
260+
def plot_layer(self):
261+
"""Plot the weights matrix and distribution of each layer"""
262+
weight = (
263+
self.weight.cpu()
264+
if self.weight.device != torch.device("cpu")
265+
else self.weight
266+
)
267+
utils.plot_connectivity_matrix_dist(
268+
weight.detach().numpy(),
269+
f"Weight",
270+
False,
271+
self.sparsity_mask is not None,
272+
)
273+
274+
def get_specs(self):
275+
"""Print the specs of each layer"""
276+
return {
277+
"input_dim": self.input_dim,
278+
"output_dim": self.output_dim,
279+
"weight_learnable": self.weight.requires_grad,
280+
"weight_min": self.weight.min().item(),
281+
"weight_max": self.weight.max().item(),
282+
"bias_learnable": self.bias.requires_grad,
283+
"bias_min": self.bias.min().item(),
284+
"bias_max": self.bias.max().item(),
285+
"sparsity": (
286+
self.sparsity_mask.sum() / self.sparsity_mask.numel()
287+
if self.sparsity_mask is not None
288+
else 1
289+
)
290+
}
291+
292+
def print_layer(self):
293+
"""
294+
Print the specs of the layer
295+
"""
296+
utils.print_dict("Layer Specs", self.get_specs())

0 commit comments

Comments
 (0)