Skip to content

Commit 1d844a7

Browse files
committed
refactored layers
1 parent b258dc0 commit 1d844a7

28 files changed

+974
-1721
lines changed

examples/CTRNN.ipynb

Lines changed: 0 additions & 696 deletions
This file was deleted.

examples/MultiArea.ipynb

Lines changed: 0 additions & 393 deletions
This file was deleted.

nn4n/criterion/composite_loss.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,21 @@ def __init__(self, loss_cfg):
2222

2323
# Mapping of loss types to their respective classes or instances
2424
loss_types = {
25-
'fr': FiringRateLoss,
26-
'fr_dist': FiringRateDistLoss,
27-
'rnn_conn': RNNConnectivityLoss,
28-
'state_pred': StatePredictionLoss,
29-
'entropy': EntropyLoss,
30-
'mse': nn.MSELoss,
31-
'hebbian': HebbianLoss,
25+
"fr": FiringRateLoss,
26+
"fr_dist": FiringRateDistLoss,
27+
"rnn_conn": RNNConnectivityLoss,
28+
"state_pred": StatePredictionLoss,
29+
"entropy": EntropyLoss,
30+
"mse": nn.MSELoss,
31+
"hebbian": HebbianLoss,
3232
}
33-
torch_losses = ['mse']
33+
torch_losses = ["mse"]
3434

3535
# Iterate over the loss_cfg to instantiate and store losses
3636
for loss_name, loss_spec in loss_cfg.items():
37-
loss_type = loss_spec['type']
38-
loss_params = loss_spec.get('params', {})
39-
loss_lambda = loss_spec.get('lambda', 1.0)
37+
loss_type = loss_spec["type"]
38+
loss_params = loss_spec.get("params", {})
39+
loss_lambda = loss_spec.get("lambda", 1.0)
4040

4141
# Instantiate the loss function
4242
if loss_type in loss_types:
@@ -51,7 +51,9 @@ def __init__(self, loss_cfg):
5151
# Store the loss instance and its weight in a dictionary
5252
self.loss_components[loss_name] = (loss_instance, loss_lambda)
5353
else:
54-
raise ValueError(f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}")
54+
raise ValueError(
55+
f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}"
56+
)
5557

5658
def forward(self, loss_input_dict):
5759
"""
@@ -70,8 +72,7 @@ def forward(self, loss_input_dict):
7072
loss_inputs = loss_input_dict[loss_name]
7173
if isinstance(loss_fn, nn.MSELoss):
7274
# For MSELoss, assume the inputs are 'input' and 'target'
73-
loss_value = loss_fn(
74-
loss_inputs['input'], loss_inputs['target'])
75+
loss_value = loss_fn(loss_inputs["input"], loss_inputs["target"])
7576
else:
7677
loss_value = loss_fn(**loss_inputs)
7778
loss_dict[loss_name] = loss_weight * loss_value

nn4n/criterion/connectivity_loss.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,37 @@
44

55

66
class RNNConnectivityLoss(nn.Module):
7-
def __init__(self, layer, metric='fro', **kwargs):
7+
def __init__(self, layer, metric="fro", **kwargs):
88
super().__init__()
9-
assert metric in ['l1', 'fro'], "metric must be either l1 or l2"
9+
assert metric in ["l1", "fro"], "metric must be either l1 or l2"
1010
self.metric = metric
1111
self.layer = layer
1212

1313
def forward(self, model, **kwargs):
14-
if self.layer == 'all':
14+
if self.layer == "all":
1515
weights = [
1616
model.recurrent_layer.input_layer.weight,
1717
model.recurrent_layer.hidden_layer.weight,
18-
model.readout_layer.weight
18+
model.readout_layer.weight,
1919
]
2020

21-
loss = torch.sum(torch.stack(
22-
[self._compute_norm(weight) for weight in weights]))
21+
loss = torch.sum(
22+
torch.stack([self._compute_norm(weight) for weight in weights])
23+
)
2324
return loss
24-
elif self.layer == 'input':
25+
elif self.layer == "input":
2526
return self._compute_norm(model.recurrent_layer.input_layer.weight)
26-
elif self.layer == 'hidden':
27+
elif self.layer == "hidden":
2728
return self._compute_norm(model.recurrent_layer.hidden_layer.weight)
28-
elif self.layer == 'readout':
29+
elif self.layer == "readout":
2930
return self._compute_norm(model.readout_layer.weight)
3031
else:
31-
raise ValueError(f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'")
32+
raise ValueError(
33+
f"Invalid layer '{self.layer}'. Available layers are: 'all', 'input', 'hidden', 'readout'"
34+
)
3235

3336
def _compute_norm(self, weight):
34-
if self.metric == 'l1':
37+
if self.metric == "l1":
3538
return torch.norm(weight, p=1)
3639
else:
37-
return torch.norm(weight, p='fro')
40+
return torch.norm(weight, p="fro")

nn4n/criterion/firing_rate_loss.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,26 @@ def forward(self, **kwargs):
1313

1414

1515
class FiringRateLoss(CustomLoss):
16-
def __init__(self, metric='l2', **kwargs):
16+
def __init__(self, metric="l2", **kwargs):
1717
super().__init__(**kwargs)
18-
assert metric in ['l1', 'l2'], "metric must be either l1 or l2"
18+
assert metric in ["l1", "l2"], "metric must be either l1 or l2"
1919
self.metric = metric
2020

2121
def forward(self, states, **kwargs):
2222
# Calculate the mean firing rate across specified dimensions
2323
mean_fr = torch.mean(states, dim=(0, 1))
2424

2525
# Replace custom norm calculation with PyTorch's built-in norm
26-
if self.metric == 'l1':
27-
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')
26+
if self.metric == "l1":
27+
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")
2828
else:
29-
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')
29+
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")
3030

3131

3232
class FiringRateDistLoss(CustomLoss):
33-
def __init__(self, metric='sd', **kwargs):
33+
def __init__(self, metric="sd", **kwargs):
3434
super().__init__(**kwargs)
35-
valid_metrics = ['sd', 'cv', 'mean_ad', 'max_ad']
35+
valid_metrics = ["sd", "cv", "mean_ad", "max_ad"]
3636
assert metric in valid_metrics, (
3737
"metric must be chosen from 'sd' (standard deviation), "
3838
"'cv' (coefficient of variation), 'mean_ad' (mean abs deviation), "
@@ -44,21 +44,21 @@ def forward(self, states, **kwargs):
4444
mean_fr = torch.mean(states, dim=(0, 1))
4545

4646
# Standard deviation
47-
if self.metric == 'sd':
47+
if self.metric == "sd":
4848
return torch.std(mean_fr)
4949

5050
# Coefficient of variation
51-
elif self.metric == 'cv':
51+
elif self.metric == "cv":
5252
return torch.std(mean_fr) / torch.mean(mean_fr)
5353

5454
# Mean absolute deviation
55-
elif self.metric == 'mean_ad':
55+
elif self.metric == "mean_ad":
5656
avg_mean_fr = torch.mean(mean_fr)
5757
# Use F.l1_loss for mean absolute deviation
58-
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction='mean')
58+
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction="mean")
5959

6060
# Maximum absolute deviation
61-
elif self.metric == 'max_ad':
61+
elif self.metric == "max_ad":
6262
avg_mean_fr = torch.mean(mean_fr)
6363
return torch.max(torch.abs(mean_fr - avg_mean_fr))
6464

@@ -73,10 +73,12 @@ def forward(self, states, **kwargs):
7373
states = states.transpose(0, 1)
7474

7575
# Ensure the sequence is long enough for the prediction window
76-
assert states.shape[1] > self.tau, "The sequence length is shorter than the prediction window."
76+
assert (
77+
states.shape[1] > self.tau
78+
), "The sequence length is shorter than the prediction window."
7779

7880
# Use MSE loss instead of manual difference calculation
79-
return F.mse_loss(states[:-self.tau], states[self.tau:], reduction='mean')
81+
return F.mse_loss(states[: -self.tau], states[self.tau :], reduction="mean")
8082

8183

8284
class HebbianLoss(nn.Module):
@@ -88,7 +90,7 @@ def forward(self, states, weights):
8890
# weights shape: (num_neurons, num_neurons)
8991

9092
# Compute correlations by averaging over time steps
91-
correlations = torch.einsum('bti,btj->btij', states, states)
93+
correlations = torch.einsum("bti,btj->btij", states, states)
9294

9395
# Apply weights to correlations and sum to get Hebbian loss for each batch
9496
hebbian_loss = torch.sum(weights * correlations, dim=(-1, -2))
@@ -114,8 +116,7 @@ def forward(self, states):
114116
prob_states = states / (states.sum(dim=-1, keepdim=True) + eps)
115117

116118
# Compute the entropy of the neuron activations
117-
entropy_loss = -torch.sum(prob_states *
118-
torch.log(prob_states + eps), dim=-1)
119+
entropy_loss = -torch.sum(prob_states * torch.log(prob_states + eps), dim=-1)
119120

120121
# Take the mean entropy over batches and time steps
121122
mean_entropy = torch.mean(entropy_loss)
@@ -128,7 +129,7 @@ def forward(self, states):
128129

129130

130131
class PopulationKL(nn.Module):
131-
def __init__(self, symmetric=True, reg=1e-3, reduction='mean'):
132+
def __init__(self, symmetric=True, reg=1e-3, reduction="mean"):
132133
super().__init__()
133134
self.symmetric = symmetric
134135
self.reg = reg
@@ -140,45 +141,49 @@ def forward(self, states_0, states_1):
140141
mean_0 = torch.mean(states_0, dim=(0, 1), keepdim=True)
141142
# Shape: (1, 1, n_neurons)
142143
mean_1 = torch.mean(states_1, dim=(0, 1), keepdim=True)
143-
var_0 = torch.var(states_0, dim=(0, 1), unbiased=False,
144-
keepdim=True) # Shape: (1, 1, n_neurons)
145-
var_1 = torch.var(states_1, dim=(0, 1), unbiased=False,
146-
keepdim=True) # Shape: (1, 1, n_neurons)
144+
var_0 = torch.var(
145+
states_0, dim=(0, 1), unbiased=False, keepdim=True
146+
) # Shape: (1, 1, n_neurons)
147+
var_1 = torch.var(
148+
states_1, dim=(0, 1), unbiased=False, keepdim=True
149+
) # Shape: (1, 1, n_neurons)
147150

148151
# Compute the KL divergence between the two populations (per neuron)
149152
# Shape: (1, 1, n_neurons)
150-
kl_div = 0.5 * (torch.log(var_1 / var_0) +
151-
(var_0 + (mean_0 - mean_1) ** 2) / var_1 - 1)
153+
kl_div = 0.5 * (
154+
torch.log(var_1 / var_0) + (var_0 + (mean_0 - mean_1) ** 2) / var_1 - 1
155+
)
152156

153157
# Symmetric KL divergence: average the KL(P || Q) and KL(Q || P)
154158
if self.symmetric:
155159
# Shape: (1, 1, n_neurons)
156-
reverse_kl_div = 0.5 * \
157-
(torch.log(var_0 / var_1) +
158-
(var_1 + (mean_1 - mean_0) ** 2) / var_0 - 1)
160+
reverse_kl_div = 0.5 * (
161+
torch.log(var_0 / var_1) + (var_1 + (mean_1 - mean_0) ** 2) / var_0 - 1
162+
)
159163
# Shape: (1, 1, n_neurons)
160164
kl_div = 0.5 * (kl_div + reverse_kl_div)
161165

162166
# Apply reduction based on the reduction method
163-
if self.reduction == 'mean':
167+
if self.reduction == "mean":
164168
kl_loss = torch.mean(kl_div) # Scalar value
165-
elif self.reduction == 'sum':
169+
elif self.reduction == "sum":
166170
kl_loss = torch.sum(kl_div) # Scalar value
167-
elif self.reduction == 'none':
171+
elif self.reduction == "none":
168172
kl_loss = kl_div # Shape: (1, 1, n_neurons)
169173
else:
170174
raise ValueError(f"Invalid reduction mode: {self.reduction}")
171175

172176
# Regularization: L2 norm of the states across the neurons
173-
reg_loss = torch.mean(torch.norm(states_0, dim=-1) ** 2) + \
174-
torch.mean(torch.norm(states_1, dim=-1) ** 2)
177+
reg_loss = torch.mean(torch.norm(states_0, dim=-1) ** 2) + torch.mean(
178+
torch.norm(states_1, dim=-1) ** 2
179+
)
175180

176181
# Combine the KL divergence with the regularization term
177-
if self.reduction == 'none':
182+
if self.reduction == "none":
178183
# If no reduction, add regularization element-wise
179-
total_loss = kl_loss + self.reg * \
180-
(torch.norm(states_0, dim=-1) ** 2 +
181-
torch.norm(states_1, dim=-1) ** 2)
184+
total_loss = kl_loss + self.reg * (
185+
torch.norm(states_0, dim=-1) ** 2 + torch.norm(states_1, dim=-1) ** 2
186+
)
182187
else:
183188
total_loss = kl_loss + self.reg * reg_loss
184189

nn4n/criterion/mlp_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _init_losses(self, **kwargs):
3030
self.loss_list = loss_list
3131

3232
def _loss_fr(self, states, **kwargs):
33-
""" Compute the loss for firing rate """
33+
"""Compute the loss for firing rate"""
3434
# return torch.sqrt(torch.square(states)).mean()
3535
loss = []
3636
for s in states:
@@ -39,7 +39,7 @@ def _loss_fr(self, states, **kwargs):
3939
return torch.stack(loss).mean()
4040

4141
def _loss_fr_sd(self, states, **kwargs):
42-
""" Compute the loss for firing rate for each neuron in terms of SD """
42+
"""Compute the loss for firing rate for each neuron in terms of SD"""
4343
# return torch.sqrt(torch.square(states)).mean(dim=(0)).std()
4444
return torch.pow(torch.mean(states, dim=(0, 1)), 2).std()
4545

@@ -50,17 +50,17 @@ def forward(self, pred, label, **kwargs):
5050
@param label: size=(-1, batch_size, 2), labels
5151
@param dur: duration of the trial
5252
"""
53-
loss = [self.lambda_mse * torch.square(pred-label).mean()]
53+
loss = [self.lambda_mse * torch.square(pred - label).mean()]
5454
for i in range(len(self.loss_list)):
5555
if self.lambda_list[i] == 0:
5656
continue
5757
else:
58-
loss.append(self.lambda_list[i]*self.loss_list[i](**kwargs))
58+
loss.append(self.lambda_list[i] * self.loss_list[i](**kwargs))
5959
loss = torch.stack(loss)
6060
return loss.sum(), loss
6161

6262
def to(self, device):
63-
""" Move to device """
63+
"""Move to device"""
6464
super().to(device)
6565
self.lambda_list = self.lambda_list.to(device)
6666
return self

nn4n/criterion/rnn_loss.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class RNNLoss(nn.Module):
2424
- lambda_hid: coefficient for the hidden layer loss, default: 0
2525
- lambda_out: coefficient for the readout layer loss, default: 0
2626
- lambda_fr: coefficient for the overall firing rate loss, default: 0
27-
- lambda_fr_sd: coefficient for the standard deviation of firing rate
27+
- lambda_fr_sd: coefficient for the standard deviation of firing rate
2828
loss (to evenly distribute firing rate across neurons), default: 0
2929
- lambda_fr_cv: coefficient for the coefficient of variation of firing
3030
rate loss (to evenly distribute firing rate across neurons), default: 0
@@ -72,21 +72,30 @@ def _init_losses(self, **kwargs):
7272
n_in = self.model.recurrent_layer.input_layer.weight.shape[1]
7373
n_size = self.model.recurrent_layer.hidden_layer.weight.shape[0]
7474
n_out = self.model.readout_layer.weight.shape[0]
75-
self.n_in_dividend = n_in*n_size
76-
self.n_hid_dividend = n_size*n_size
77-
self.n_out_dividend = n_out*n_size
75+
self.n_in_dividend = n_in * n_size
76+
self.n_hid_dividend = n_size * n_size
77+
self.n_out_dividend = n_out * n_size
7878

7979
def _loss_in(self, **kwargs):
80-
""" Compute the loss for InputLayer """
81-
return torch.norm(self.model.recurrent_layer.input_layer.weight, p='fro')**2/self.n_in_dividend
80+
"""Compute the loss for InputLayer"""
81+
return (
82+
torch.norm(self.model.recurrent_layer.input_layer.weight, p="fro") ** 2
83+
/ self.n_in_dividend
84+
)
8285

8386
def _loss_hid(self, **kwargs):
84-
""" Compute the loss for RecurrentLayer """
85-
return torch.norm(self.model.recurrent_layer.hidden_layer.weight, p='fro')**2/self.n_hid_dividend
87+
"""Compute the loss for RecurrentLayer"""
88+
return (
89+
torch.norm(self.model.recurrent_layer.hidden_layer.weight, p="fro") ** 2
90+
/ self.n_hid_dividend
91+
)
8692

8793
def _loss_out(self, **kwargs):
88-
""" Compute the loss for ReadoutLayer """
89-
return torch.norm(self.model.readout_layer.weight, p='fro')**2/self.n_out_dividend
94+
"""Compute the loss for ReadoutLayer"""
95+
return (
96+
torch.norm(self.model.readout_layer.weight, p="fro") ** 2
97+
/ self.n_out_dividend
98+
)
9099

91100
def _loss_fr(self, states, **kwargs):
92101
"""
@@ -99,10 +108,10 @@ def _loss_fr(self, states, **kwargs):
99108
mean_fr = torch.mean(states, dim=(0, 1))
100109
# return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct
101110
# return torch.norm(states, p='fro')**2/states.numel() # this might not be correct
102-
return torch.norm(mean_fr, p=2)**2/mean_fr.numel()
111+
return torch.norm(mean_fr, p=2) ** 2 / mean_fr.numel()
103112

104113
def _loss_fr_sd(self, states, **kwargs):
105-
"""
114+
"""
106115
Compute the loss for firing rate for each neuron in terms of SD
107116
This will take the average firing rate of each neuron across all timesteps and batch_size
108117
and compute the standard deviation of the firing rate across all neurons
@@ -127,7 +136,7 @@ def _loss_fr_cv(self, states, **kwargs):
127136
if not self.batch_first:
128137
states = states.transpose(0, 1)
129138
avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1))
130-
return avg_fr.std()/avg_fr.mean()
139+
return avg_fr.std() / avg_fr.mean()
131140

132141
def forward(self, pred, label, **kwargs):
133142
"""
@@ -139,11 +148,11 @@ def forward(self, pred, label, **kwargs):
139148
140149
where -1 is the sequence length
141150
"""
142-
loss = [self.lambda_mse * torch.square(pred-label).mean()]
151+
loss = [self.lambda_mse * torch.square(pred - label).mean()]
143152
for i in range(len(self.loss_list)):
144153
if self.lambda_list[i] == 0:
145154
continue
146155
else:
147-
loss.append(self.lambda_list[i]*self.loss_list[i](**kwargs))
156+
loss.append(self.lambda_list[i] * self.loss_list[i](**kwargs))
148157
loss = torch.stack(loss)
149158
return loss.sum(), loss

0 commit comments

Comments
 (0)