Skip to content

Commit 71024ac

Browse files
committed
complete modularized rnn
1 parent 6d02eff commit 71024ac

18 files changed

+423
-394
lines changed

nn4n/criterion/firing_rate_loss.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
6-
class CustomLoss(nn.Module):
7-
def __init__(self, batch_first=True):
8-
super().__init__()
9-
self.batch_first = batch_first
10-
11-
def forward(self, **kwargs):
12-
pass
13-
14-
15-
class FiringRateLoss(CustomLoss):
5+
class FiringRateLoss(nn.Module):
166
def __init__(self, metric="l2", **kwargs):
177
super().__init__(**kwargs)
188
assert metric in ["l1", "l2"], "metric must be either l1 or l2"
@@ -29,7 +19,7 @@ def forward(self, states, **kwargs):
2919
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")
3020

3121

32-
class FiringRateDistLoss(CustomLoss):
22+
class FiringRateDistLoss(nn.Module):
3323
def __init__(self, metric="sd", **kwargs):
3424
super().__init__(**kwargs)
3525
valid_metrics = ["sd", "cv", "mean_ad", "max_ad"]
@@ -63,15 +53,12 @@ def forward(self, states, **kwargs):
6353
return torch.max(torch.abs(mean_fr - avg_mean_fr))
6454

6555

66-
class StatePredictionLoss(CustomLoss):
56+
class StatePredictionLoss(nn.Module):
6757
def __init__(self, tau=1, **kwargs):
6858
super().__init__(**kwargs)
6959
self.tau = tau
7060

7161
def forward(self, states, **kwargs):
72-
if not self.batch_first:
73-
states = states.transpose(0, 1)
74-
7562
# Ensure the sequence is long enough for the prediction window
7663
assert (
7764
states.shape[1] > self.tau

nn4n/criterion/rnn_loss.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class RNNLoss(nn.Module):
3535
def __init__(self, model, **kwargs):
3636
super().__init__()
3737
self.model = model
38-
self.batch_first = model.batch_first
3938
if type(self.model) != CTRNN:
4039
raise TypeError("model must be CTRNN")
4140
self._init_losses(**kwargs)
@@ -103,8 +102,6 @@ def _loss_fr(self, states, **kwargs):
103102
This compute the L2 norm (for now) of the hidden states across all timesteps and batch_size
104103
Then take the square of the mean of the norm
105104
"""
106-
if not self.batch_first:
107-
states = states.transpose(0, 1)
108105
mean_fr = torch.mean(states, dim=(0, 1))
109106
# return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct
110107
# return torch.norm(states, p='fro')**2/states.numel() # this might not be correct
@@ -119,8 +116,6 @@ def _loss_fr_sd(self, states, **kwargs):
119116
Parameters:
120117
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
121118
"""
122-
if not self.batch_first:
123-
states = states.transpose(0, 1)
124119
avg_fr = torch.mean(states, dim=(0, 1))
125120
return avg_fr.std()
126121

@@ -133,8 +128,6 @@ def _loss_fr_cv(self, states, **kwargs):
133128
Parameters:
134129
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
135130
"""
136-
if not self.batch_first:
137-
states = states.transpose(0, 1)
138131
avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1))
139132
return avg_fr.std() / avg_fr.mean()
140133

nn4n/layer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .linear_layer import LinearLayer
22
from .hidden_layer import HiddenLayer
33
from .recurrent_layer import RecurrentLayer
4-
from .rnn import RNN
4+
from .rnn import RNN

nn4n/layer/base_layer.py

Lines changed: 4 additions & 282 deletions
Original file line numberDiff line numberDiff line change
@@ -6,295 +6,17 @@
66

77
class BaseLayer(nn.Module):
88
"""
9-
Linear Layer with optional sparsity, excitatory/inhibitory, and plasticity constraints.
10-
The layer is initialized by passing specs in layer_struct.
11-
12-
Required keywords in layer_struct:
13-
- input_dim: dimension of input
14-
- output_dim: dimension of output
15-
- weight: weight matrix init method/init weight matrix, default: 'uniform'
16-
- bias: bias vector init method/init bias vector, default: 'uniform'
17-
- sparsity_mask: mask for sparse connectivity
18-
- ei_mask: mask for Dale's law
19-
- plasticity_mask: mask for plasticity
9+
nn4n Layer class
2010
"""
2111

22-
def __init__(
23-
self,
24-
input_dim: int,
25-
output_dim: int,
26-
weight: str = "uniform",
27-
bias: str = "uniform",
28-
ei_mask: torch.Tensor = None,
29-
sparsity_mask: torch.Tensor = None,
30-
plasticity_mask: torch.Tensor = None,
31-
):
12+
def __init__(self):
3213
super().__init__()
33-
self.input_dim = input_dim
34-
self.output_dim = output_dim
35-
self.weight_dist = weight
36-
self.bias_dist = bias
37-
self.weight = self._generate_weight(self.weight_dist)
38-
self.bias = self._generate_bias(self.bias_dist)
39-
self.ei_mask = ei_mask.T if ei_mask is not None else None
40-
self.sparsity_mask = sparsity_mask.T if sparsity_mask is not None else None
41-
self.plasticity_mask = (
42-
plasticity_mask.T if plasticity_mask is not None else None
43-
)
44-
# All unique plasticity values in the plasticity mask
45-
self.plasticity_scales = (
46-
torch.unique(self.plasticity_mask)
47-
if self.plasticity_mask is not None
48-
else None
49-
)
50-
51-
self._init_trainable()
52-
self._check_layer()
53-
54-
# INITIALIZATION
55-
# ======================================================================================
56-
@staticmethod
57-
def _check_keys(layer_struct):
58-
required_keys = ["input_dim", "output_dim"]
59-
for key in required_keys:
60-
if key not in layer_struct:
61-
raise ValueError(f"Key '{key}' is missing in layer_struct")
62-
63-
valid_keys = ["input_dim", "output_dim", "weight", "bias", "ei_mask", "sparsity_mask", "plasticity_mask"]
64-
for key in layer_struct.keys():
65-
if key not in valid_keys:
66-
raise ValueError(f"Key '{key}' is not a valid key in layer_struct")
67-
68-
@classmethod
69-
def from_dict(cls, layer_struct):
70-
"""
71-
Alternative constructor to initialize LinearLayer from a dictionary.
72-
"""
73-
# Create an instance using the dictionary values
74-
cls._check_keys(layer_struct)
75-
return cls(
76-
input_dim=layer_struct["input_dim"],
77-
output_dim=layer_struct["output_dim"],
78-
weight=layer_struct.get("weight", "uniform"),
79-
bias=layer_struct.get("bias", "uniform"),
80-
ei_mask=layer_struct.get("ei_mask"),
81-
sparsity_mask=layer_struct.get("sparsity_mask"),
82-
plasticity_mask=layer_struct.get("plasticity_mask"),
83-
)
84-
85-
def _check_layer(self):
86-
"""
87-
Check if the layer is initialized properly
88-
"""
89-
# TODO: Implement this
90-
pass
91-
92-
# INIT TRAINABLE
93-
# ======================================================================================
94-
def _init_trainable(self):
95-
# Enfore constraints
96-
self._init_constraints()
97-
# Convert weight and bias to learnable parameters
98-
self.weight = nn.Parameter(
99-
self.weight, requires_grad=self.weight_dist is not None
100-
)
101-
self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None)
102-
103-
def _init_constraints(self):
104-
"""
105-
Initialize constraints
106-
It will also balance excitatory and inhibitory neurons
107-
"""
108-
if self.sparsity_mask is not None:
109-
110-
self.weight *= self.sparsity_mask
111-
if self.ei_mask is not None:
112-
# Apply Dale's law
113-
self.weight[self.ei_mask == 1] = torch.clamp(
114-
self.weight[self.ei_mask == 1], min=0
115-
) # For excitatory neurons, set negative weights to 0
116-
self.weight[self.ei_mask == -1] = torch.clamp(
117-
self.weight[self.ei_mask == -1], max=0
118-
) # For inhibitory neurons, set positive weights to 0
119-
120-
# Balance excitatory and inhibitory neurons weight magnitudes
121-
self._balance_excitatory_inhibitory()
122-
123-
def _generate_bias(self, bias_init):
124-
"""Generate random bias"""
125-
if bias_init == "uniform":
126-
# If uniform, let b be uniform in [-sqrt(k), sqrt(k)]
127-
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
128-
b = torch.rand(self.output_dim) * sqrt_k
129-
b = b * 2 - sqrt_k
130-
elif bias_init == "normal":
131-
b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim))
132-
elif bias_init == "zero" or bias_init == None:
133-
b = torch.zeros(self.output_dim)
134-
elif type(bias_init) == np.ndarray:
135-
b = torch.from_numpy(bias_init)
136-
else:
137-
raise NotImplementedError
138-
return b.float()
139-
140-
def _generate_weight(self, weight_init):
141-
"""Generate random weight"""
142-
if weight_init == "uniform":
143-
# If uniform, let w be uniform in [-sqrt(k), sqrt(k)]
144-
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
145-
w = torch.rand(self.output_dim, self.input_dim) * sqrt_k
146-
w = w * 2 - sqrt_k
147-
elif weight_init == "normal":
148-
w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt(
149-
torch.tensor(self.input_dim)
150-
)
151-
elif weight_init == "zero":
152-
w = torch.zeros((self.output_dim, self.input_dim))
153-
elif type(weight_init) == np.ndarray:
154-
w = torch.from_numpy(weight_init)
155-
else:
156-
raise NotImplementedError
157-
return w.float()
158-
159-
def _balance_excitatory_inhibitory(self):
160-
"""Balance excitatory and inhibitory weights"""
161-
scale_mat = torch.ones_like(self.weight)
162-
ext_sum = self.weight[self.sparsity_mask == 1].sum()
163-
inh_sum = self.weight[self.sparsity_mask == -1].sum()
164-
if ext_sum == 0 or inh_sum == 0:
165-
# Automatically stop balancing if one of the sums is 0
166-
# devide by 10 to avoid recurrent explosion/decay
167-
self.weight /= 10
168-
else:
169-
if ext_sum > abs(inh_sum):
170-
_scale = abs(inh_sum).item() / ext_sum.item()
171-
scale_mat[self.sparsity_mask == 1] = _scale
172-
elif ext_sum < abs(inh_sum):
173-
_scale = ext_sum.item() / abs(inh_sum).item()
174-
scale_mat[self.sparsity_mask == -1] = _scale
175-
# Apply scaling
176-
self.weight *= scale_mat
177-
178-
# TRAINING
179-
# ======================================================================================
180-
def to(self, device):
181-
"""Move the network to the device (cpu/gpu)"""
182-
super().to(device)
183-
if self.sparsity_mask is not None:
184-
self.sparsity_mask = self.sparsity_mask.to(device)
185-
if self.ei_mask is not None:
186-
self.ei_mask = self.ei_mask.to(device)
187-
if self.bias.requires_grad:
188-
self.bias = self.bias.to(device)
189-
return self
190-
191-
def forward(self, x):
192-
"""
193-
Forwardly update network
194-
195-
Inputs:
196-
- x: input, shape: (batch_size, input_dim)
197-
198-
Returns:
199-
- state: shape: (batch_size, hidden_size)
200-
"""
201-
return x.float() @ self.weight.T + self.bias
202-
203-
def apply_plasticity(self):
204-
"""
205-
Apply plasticity mask to the weight gradient
206-
"""
207-
with torch.no_grad():
208-
# assume the plasticity mask are all valid and being checked in ctrnn class
209-
for scale in self.plasticity_scales:
210-
if self.weight.grad is not None:
211-
self.weight.grad[self.plasticity_mask == scale] *= scale
212-
else:
213-
raise RuntimeError(
214-
"Weight gradient is None, possibly because the forward loop is non-differentiable"
215-
)
216-
217-
def freeze(self):
218-
"""Freeze the layer"""
219-
self.weight.requires_grad = False
220-
self.bias.requires_grad = False
221-
222-
def unfreeze(self):
223-
"""Unfreeze the layer"""
224-
self.weight.requires_grad = True
225-
self.bias.requires_grad = True
226-
227-
# CONSTRAINTS
228-
# ======================================================================================
229-
def enforce_constraints(self):
230-
"""
231-
Enforce constraints
232-
233-
The constraints are:
234-
- sparsity_mask: mask for sparse connectivity
235-
- ei_mask: mask for Dale's law
236-
"""
237-
if self.sparsity_mask is not None:
238-
self._enforce_sparsity()
239-
if self.ei_mask is not None:
240-
self._enforce_ei()
241-
242-
def _enforce_sparsity(self):
243-
"""Enforce sparsity"""
244-
w = self.weight.detach().clone() * self.sparsity_mask
245-
self.weight.data.copy_(torch.nn.Parameter(w))
246-
247-
def _enforce_ei(self):
248-
"""Enforce Dale's law"""
249-
w = self.weight.detach().clone()
250-
w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0)
251-
w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0)
252-
self.weight.data.copy_(torch.nn.Parameter(w))
253-
254-
# HELPER FUNCTIONS
255-
# ======================================================================================
256-
def set_weight(self, weight):
257-
"""Set the value of weight"""
258-
assert (
259-
weight.shape == self.weight.shape
260-
), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}"
261-
with torch.no_grad():
262-
self.weight.copy_(weight)
263-
264-
def plot_layer(self):
265-
"""Plot the weights matrix and distribution of each layer"""
266-
weight = (
267-
self.weight.cpu()
268-
if self.weight.device != torch.device("cpu")
269-
else self.weight
270-
)
271-
utils.plot_connectivity_matrix_dist(
272-
weight.detach().numpy(),
273-
f"Weight",
274-
False,
275-
self.sparsity_mask is not None,
276-
)
27714

27815
def get_specs(self):
279-
"""Print the specs of each layer"""
280-
return {
281-
"input_dim": self.input_dim,
282-
"output_dim": self.output_dim,
283-
"weight_learnable": self.weight.requires_grad,
284-
"weight_min": self.weight.min().item(),
285-
"weight_max": self.weight.max().item(),
286-
"bias_learnable": self.bias.requires_grad,
287-
"bias_min": self.bias.min().item(),
288-
"bias_max": self.bias.max().item(),
289-
"sparsity": (
290-
self.sparsity_mask.sum() / self.sparsity_mask.numel()
291-
if self.sparsity_mask is not None
292-
else 1
293-
)
294-
}
16+
pass
29517

29618
def print_layer(self):
29719
"""
29820
Print the specs of the layer
29921
"""
300-
utils.print_dict("Layer Specs", self.get_specs())
22+
utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs())

0 commit comments

Comments
 (0)