|
6 | 6 |
|
7 | 7 | class BaseLayer(nn.Module):
|
8 | 8 | """
|
9 |
| - Linear Layer with optional sparsity, excitatory/inhibitory, and plasticity constraints. |
10 |
| - The layer is initialized by passing specs in layer_struct. |
11 |
| -
|
12 |
| - Required keywords in layer_struct: |
13 |
| - - input_dim: dimension of input |
14 |
| - - output_dim: dimension of output |
15 |
| - - weight: weight matrix init method/init weight matrix, default: 'uniform' |
16 |
| - - bias: bias vector init method/init bias vector, default: 'uniform' |
17 |
| - - sparsity_mask: mask for sparse connectivity |
18 |
| - - ei_mask: mask for Dale's law |
19 |
| - - plasticity_mask: mask for plasticity |
| 9 | + nn4n Layer class |
20 | 10 | """
|
21 | 11 |
|
22 |
| - def __init__( |
23 |
| - self, |
24 |
| - input_dim: int, |
25 |
| - output_dim: int, |
26 |
| - weight: str = "uniform", |
27 |
| - bias: str = "uniform", |
28 |
| - ei_mask: torch.Tensor = None, |
29 |
| - sparsity_mask: torch.Tensor = None, |
30 |
| - plasticity_mask: torch.Tensor = None, |
31 |
| - ): |
| 12 | + def __init__(self): |
32 | 13 | super().__init__()
|
33 |
| - self.input_dim = input_dim |
34 |
| - self.output_dim = output_dim |
35 |
| - self.weight_dist = weight |
36 |
| - self.bias_dist = bias |
37 |
| - self.weight = self._generate_weight(self.weight_dist) |
38 |
| - self.bias = self._generate_bias(self.bias_dist) |
39 |
| - self.ei_mask = ei_mask.T if ei_mask is not None else None |
40 |
| - self.sparsity_mask = sparsity_mask.T if sparsity_mask is not None else None |
41 |
| - self.plasticity_mask = ( |
42 |
| - plasticity_mask.T if plasticity_mask is not None else None |
43 |
| - ) |
44 |
| - # All unique plasticity values in the plasticity mask |
45 |
| - self.plasticity_scales = ( |
46 |
| - torch.unique(self.plasticity_mask) |
47 |
| - if self.plasticity_mask is not None |
48 |
| - else None |
49 |
| - ) |
50 |
| - |
51 |
| - self._init_trainable() |
52 |
| - self._check_layer() |
53 |
| - |
54 |
| - # INITIALIZATION |
55 |
| - # ====================================================================================== |
56 |
| - @staticmethod |
57 |
| - def _check_keys(layer_struct): |
58 |
| - required_keys = ["input_dim", "output_dim"] |
59 |
| - for key in required_keys: |
60 |
| - if key not in layer_struct: |
61 |
| - raise ValueError(f"Key '{key}' is missing in layer_struct") |
62 |
| - |
63 |
| - valid_keys = ["input_dim", "output_dim", "weight", "bias", "ei_mask", "sparsity_mask", "plasticity_mask"] |
64 |
| - for key in layer_struct.keys(): |
65 |
| - if key not in valid_keys: |
66 |
| - raise ValueError(f"Key '{key}' is not a valid key in layer_struct") |
67 |
| - |
68 |
| - @classmethod |
69 |
| - def from_dict(cls, layer_struct): |
70 |
| - """ |
71 |
| - Alternative constructor to initialize LinearLayer from a dictionary. |
72 |
| - """ |
73 |
| - # Create an instance using the dictionary values |
74 |
| - cls._check_keys(layer_struct) |
75 |
| - return cls( |
76 |
| - input_dim=layer_struct["input_dim"], |
77 |
| - output_dim=layer_struct["output_dim"], |
78 |
| - weight=layer_struct.get("weight", "uniform"), |
79 |
| - bias=layer_struct.get("bias", "uniform"), |
80 |
| - ei_mask=layer_struct.get("ei_mask"), |
81 |
| - sparsity_mask=layer_struct.get("sparsity_mask"), |
82 |
| - plasticity_mask=layer_struct.get("plasticity_mask"), |
83 |
| - ) |
84 |
| - |
85 |
| - def _check_layer(self): |
86 |
| - """ |
87 |
| - Check if the layer is initialized properly |
88 |
| - """ |
89 |
| - # TODO: Implement this |
90 |
| - pass |
91 |
| - |
92 |
| - # INIT TRAINABLE |
93 |
| - # ====================================================================================== |
94 |
| - def _init_trainable(self): |
95 |
| - # Enfore constraints |
96 |
| - self._init_constraints() |
97 |
| - # Convert weight and bias to learnable parameters |
98 |
| - self.weight = nn.Parameter( |
99 |
| - self.weight, requires_grad=self.weight_dist is not None |
100 |
| - ) |
101 |
| - self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None) |
102 |
| - |
103 |
| - def _init_constraints(self): |
104 |
| - """ |
105 |
| - Initialize constraints |
106 |
| - It will also balance excitatory and inhibitory neurons |
107 |
| - """ |
108 |
| - if self.sparsity_mask is not None: |
109 |
| - |
110 |
| - self.weight *= self.sparsity_mask |
111 |
| - if self.ei_mask is not None: |
112 |
| - # Apply Dale's law |
113 |
| - self.weight[self.ei_mask == 1] = torch.clamp( |
114 |
| - self.weight[self.ei_mask == 1], min=0 |
115 |
| - ) # For excitatory neurons, set negative weights to 0 |
116 |
| - self.weight[self.ei_mask == -1] = torch.clamp( |
117 |
| - self.weight[self.ei_mask == -1], max=0 |
118 |
| - ) # For inhibitory neurons, set positive weights to 0 |
119 |
| - |
120 |
| - # Balance excitatory and inhibitory neurons weight magnitudes |
121 |
| - self._balance_excitatory_inhibitory() |
122 |
| - |
123 |
| - def _generate_bias(self, bias_init): |
124 |
| - """Generate random bias""" |
125 |
| - if bias_init == "uniform": |
126 |
| - # If uniform, let b be uniform in [-sqrt(k), sqrt(k)] |
127 |
| - sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) |
128 |
| - b = torch.rand(self.output_dim) * sqrt_k |
129 |
| - b = b * 2 - sqrt_k |
130 |
| - elif bias_init == "normal": |
131 |
| - b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim)) |
132 |
| - elif bias_init == "zero" or bias_init == None: |
133 |
| - b = torch.zeros(self.output_dim) |
134 |
| - elif type(bias_init) == np.ndarray: |
135 |
| - b = torch.from_numpy(bias_init) |
136 |
| - else: |
137 |
| - raise NotImplementedError |
138 |
| - return b.float() |
139 |
| - |
140 |
| - def _generate_weight(self, weight_init): |
141 |
| - """Generate random weight""" |
142 |
| - if weight_init == "uniform": |
143 |
| - # If uniform, let w be uniform in [-sqrt(k), sqrt(k)] |
144 |
| - sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) |
145 |
| - w = torch.rand(self.output_dim, self.input_dim) * sqrt_k |
146 |
| - w = w * 2 - sqrt_k |
147 |
| - elif weight_init == "normal": |
148 |
| - w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt( |
149 |
| - torch.tensor(self.input_dim) |
150 |
| - ) |
151 |
| - elif weight_init == "zero": |
152 |
| - w = torch.zeros((self.output_dim, self.input_dim)) |
153 |
| - elif type(weight_init) == np.ndarray: |
154 |
| - w = torch.from_numpy(weight_init) |
155 |
| - else: |
156 |
| - raise NotImplementedError |
157 |
| - return w.float() |
158 |
| - |
159 |
| - def _balance_excitatory_inhibitory(self): |
160 |
| - """Balance excitatory and inhibitory weights""" |
161 |
| - scale_mat = torch.ones_like(self.weight) |
162 |
| - ext_sum = self.weight[self.sparsity_mask == 1].sum() |
163 |
| - inh_sum = self.weight[self.sparsity_mask == -1].sum() |
164 |
| - if ext_sum == 0 or inh_sum == 0: |
165 |
| - # Automatically stop balancing if one of the sums is 0 |
166 |
| - # devide by 10 to avoid recurrent explosion/decay |
167 |
| - self.weight /= 10 |
168 |
| - else: |
169 |
| - if ext_sum > abs(inh_sum): |
170 |
| - _scale = abs(inh_sum).item() / ext_sum.item() |
171 |
| - scale_mat[self.sparsity_mask == 1] = _scale |
172 |
| - elif ext_sum < abs(inh_sum): |
173 |
| - _scale = ext_sum.item() / abs(inh_sum).item() |
174 |
| - scale_mat[self.sparsity_mask == -1] = _scale |
175 |
| - # Apply scaling |
176 |
| - self.weight *= scale_mat |
177 |
| - |
178 |
| - # TRAINING |
179 |
| - # ====================================================================================== |
180 |
| - def to(self, device): |
181 |
| - """Move the network to the device (cpu/gpu)""" |
182 |
| - super().to(device) |
183 |
| - if self.sparsity_mask is not None: |
184 |
| - self.sparsity_mask = self.sparsity_mask.to(device) |
185 |
| - if self.ei_mask is not None: |
186 |
| - self.ei_mask = self.ei_mask.to(device) |
187 |
| - if self.bias.requires_grad: |
188 |
| - self.bias = self.bias.to(device) |
189 |
| - return self |
190 |
| - |
191 |
| - def forward(self, x): |
192 |
| - """ |
193 |
| - Forwardly update network |
194 |
| -
|
195 |
| - Inputs: |
196 |
| - - x: input, shape: (batch_size, input_dim) |
197 |
| -
|
198 |
| - Returns: |
199 |
| - - state: shape: (batch_size, hidden_size) |
200 |
| - """ |
201 |
| - return x.float() @ self.weight.T + self.bias |
202 |
| - |
203 |
| - def apply_plasticity(self): |
204 |
| - """ |
205 |
| - Apply plasticity mask to the weight gradient |
206 |
| - """ |
207 |
| - with torch.no_grad(): |
208 |
| - # assume the plasticity mask are all valid and being checked in ctrnn class |
209 |
| - for scale in self.plasticity_scales: |
210 |
| - if self.weight.grad is not None: |
211 |
| - self.weight.grad[self.plasticity_mask == scale] *= scale |
212 |
| - else: |
213 |
| - raise RuntimeError( |
214 |
| - "Weight gradient is None, possibly because the forward loop is non-differentiable" |
215 |
| - ) |
216 |
| - |
217 |
| - def freeze(self): |
218 |
| - """Freeze the layer""" |
219 |
| - self.weight.requires_grad = False |
220 |
| - self.bias.requires_grad = False |
221 |
| - |
222 |
| - def unfreeze(self): |
223 |
| - """Unfreeze the layer""" |
224 |
| - self.weight.requires_grad = True |
225 |
| - self.bias.requires_grad = True |
226 |
| - |
227 |
| - # CONSTRAINTS |
228 |
| - # ====================================================================================== |
229 |
| - def enforce_constraints(self): |
230 |
| - """ |
231 |
| - Enforce constraints |
232 |
| -
|
233 |
| - The constraints are: |
234 |
| - - sparsity_mask: mask for sparse connectivity |
235 |
| - - ei_mask: mask for Dale's law |
236 |
| - """ |
237 |
| - if self.sparsity_mask is not None: |
238 |
| - self._enforce_sparsity() |
239 |
| - if self.ei_mask is not None: |
240 |
| - self._enforce_ei() |
241 |
| - |
242 |
| - def _enforce_sparsity(self): |
243 |
| - """Enforce sparsity""" |
244 |
| - w = self.weight.detach().clone() * self.sparsity_mask |
245 |
| - self.weight.data.copy_(torch.nn.Parameter(w)) |
246 |
| - |
247 |
| - def _enforce_ei(self): |
248 |
| - """Enforce Dale's law""" |
249 |
| - w = self.weight.detach().clone() |
250 |
| - w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0) |
251 |
| - w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0) |
252 |
| - self.weight.data.copy_(torch.nn.Parameter(w)) |
253 |
| - |
254 |
| - # HELPER FUNCTIONS |
255 |
| - # ====================================================================================== |
256 |
| - def set_weight(self, weight): |
257 |
| - """Set the value of weight""" |
258 |
| - assert ( |
259 |
| - weight.shape == self.weight.shape |
260 |
| - ), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}" |
261 |
| - with torch.no_grad(): |
262 |
| - self.weight.copy_(weight) |
263 |
| - |
264 |
| - def plot_layer(self): |
265 |
| - """Plot the weights matrix and distribution of each layer""" |
266 |
| - weight = ( |
267 |
| - self.weight.cpu() |
268 |
| - if self.weight.device != torch.device("cpu") |
269 |
| - else self.weight |
270 |
| - ) |
271 |
| - utils.plot_connectivity_matrix_dist( |
272 |
| - weight.detach().numpy(), |
273 |
| - f"Weight", |
274 |
| - False, |
275 |
| - self.sparsity_mask is not None, |
276 |
| - ) |
277 | 14 |
|
278 | 15 | def get_specs(self):
|
279 |
| - """Print the specs of each layer""" |
280 |
| - return { |
281 |
| - "input_dim": self.input_dim, |
282 |
| - "output_dim": self.output_dim, |
283 |
| - "weight_learnable": self.weight.requires_grad, |
284 |
| - "weight_min": self.weight.min().item(), |
285 |
| - "weight_max": self.weight.max().item(), |
286 |
| - "bias_learnable": self.bias.requires_grad, |
287 |
| - "bias_min": self.bias.min().item(), |
288 |
| - "bias_max": self.bias.max().item(), |
289 |
| - "sparsity": ( |
290 |
| - self.sparsity_mask.sum() / self.sparsity_mask.numel() |
291 |
| - if self.sparsity_mask is not None |
292 |
| - else 1 |
293 |
| - ) |
294 |
| - } |
| 16 | + pass |
295 | 17 |
|
296 | 18 | def print_layer(self):
|
297 | 19 | """
|
298 | 20 | Print the specs of the layer
|
299 | 21 | """
|
300 |
| - utils.print_dict("Layer Specs", self.get_specs()) |
| 22 | + utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs()) |
0 commit comments