@@ -13,26 +13,26 @@ def forward(self, **kwargs):
13
13
14
14
15
15
class FiringRateLoss (CustomLoss ):
16
- def __init__ (self , metric = 'l2' , ** kwargs ):
16
+ def __init__ (self , metric = "l2" , ** kwargs ):
17
17
super ().__init__ (** kwargs )
18
- assert metric in ['l1' , 'l2' ], "metric must be either l1 or l2"
18
+ assert metric in ["l1" , "l2" ], "metric must be either l1 or l2"
19
19
self .metric = metric
20
20
21
21
def forward (self , states , ** kwargs ):
22
22
# Calculate the mean firing rate across specified dimensions
23
23
mean_fr = torch .mean (states , dim = (0 , 1 ))
24
24
25
25
# Replace custom norm calculation with PyTorch's built-in norm
26
- if self .metric == 'l1' :
27
- return F .l1_loss (mean_fr , torch .zeros_like (mean_fr ), reduction = ' mean' )
26
+ if self .metric == "l1" :
27
+ return F .l1_loss (mean_fr , torch .zeros_like (mean_fr ), reduction = " mean" )
28
28
else :
29
- return F .mse_loss (mean_fr , torch .zeros_like (mean_fr ), reduction = ' mean' )
29
+ return F .mse_loss (mean_fr , torch .zeros_like (mean_fr ), reduction = " mean" )
30
30
31
31
32
32
class FiringRateDistLoss (CustomLoss ):
33
- def __init__ (self , metric = 'sd' , ** kwargs ):
33
+ def __init__ (self , metric = "sd" , ** kwargs ):
34
34
super ().__init__ (** kwargs )
35
- valid_metrics = ['sd' , 'cv' , ' mean_ad' , ' max_ad' ]
35
+ valid_metrics = ["sd" , "cv" , " mean_ad" , " max_ad" ]
36
36
assert metric in valid_metrics , (
37
37
"metric must be chosen from 'sd' (standard deviation), "
38
38
"'cv' (coefficient of variation), 'mean_ad' (mean abs deviation), "
@@ -44,21 +44,21 @@ def forward(self, states, **kwargs):
44
44
mean_fr = torch .mean (states , dim = (0 , 1 ))
45
45
46
46
# Standard deviation
47
- if self .metric == 'sd' :
47
+ if self .metric == "sd" :
48
48
return torch .std (mean_fr )
49
49
50
50
# Coefficient of variation
51
- elif self .metric == 'cv' :
51
+ elif self .metric == "cv" :
52
52
return torch .std (mean_fr ) / torch .mean (mean_fr )
53
53
54
54
# Mean absolute deviation
55
- elif self .metric == ' mean_ad' :
55
+ elif self .metric == " mean_ad" :
56
56
avg_mean_fr = torch .mean (mean_fr )
57
57
# Use F.l1_loss for mean absolute deviation
58
- return F .l1_loss (mean_fr , avg_mean_fr .expand_as (mean_fr ), reduction = ' mean' )
58
+ return F .l1_loss (mean_fr , avg_mean_fr .expand_as (mean_fr ), reduction = " mean" )
59
59
60
60
# Maximum absolute deviation
61
- elif self .metric == ' max_ad' :
61
+ elif self .metric == " max_ad" :
62
62
avg_mean_fr = torch .mean (mean_fr )
63
63
return torch .max (torch .abs (mean_fr - avg_mean_fr ))
64
64
@@ -73,10 +73,12 @@ def forward(self, states, **kwargs):
73
73
states = states .transpose (0 , 1 )
74
74
75
75
# Ensure the sequence is long enough for the prediction window
76
- assert states .shape [1 ] > self .tau , "The sequence length is shorter than the prediction window."
76
+ assert (
77
+ states .shape [1 ] > self .tau
78
+ ), "The sequence length is shorter than the prediction window."
77
79
78
80
# Use MSE loss instead of manual difference calculation
79
- return F .mse_loss (states [:- self .tau ], states [self .tau :], reduction = ' mean' )
81
+ return F .mse_loss (states [: - self .tau ], states [self .tau :], reduction = " mean" )
80
82
81
83
82
84
class HebbianLoss (nn .Module ):
@@ -88,7 +90,7 @@ def forward(self, states, weights):
88
90
# weights shape: (num_neurons, num_neurons)
89
91
90
92
# Compute correlations by averaging over time steps
91
- correlations = torch .einsum (' bti,btj->btij' , states , states )
93
+ correlations = torch .einsum (" bti,btj->btij" , states , states )
92
94
93
95
# Apply weights to correlations and sum to get Hebbian loss for each batch
94
96
hebbian_loss = torch .sum (weights * correlations , dim = (- 1 , - 2 ))
@@ -114,8 +116,7 @@ def forward(self, states):
114
116
prob_states = states / (states .sum (dim = - 1 , keepdim = True ) + eps )
115
117
116
118
# Compute the entropy of the neuron activations
117
- entropy_loss = - torch .sum (prob_states *
118
- torch .log (prob_states + eps ), dim = - 1 )
119
+ entropy_loss = - torch .sum (prob_states * torch .log (prob_states + eps ), dim = - 1 )
119
120
120
121
# Take the mean entropy over batches and time steps
121
122
mean_entropy = torch .mean (entropy_loss )
@@ -128,7 +129,7 @@ def forward(self, states):
128
129
129
130
130
131
class PopulationKL (nn .Module ):
131
- def __init__ (self , symmetric = True , reg = 1e-3 , reduction = ' mean' ):
132
+ def __init__ (self , symmetric = True , reg = 1e-3 , reduction = " mean" ):
132
133
super ().__init__ ()
133
134
self .symmetric = symmetric
134
135
self .reg = reg
@@ -140,45 +141,49 @@ def forward(self, states_0, states_1):
140
141
mean_0 = torch .mean (states_0 , dim = (0 , 1 ), keepdim = True )
141
142
# Shape: (1, 1, n_neurons)
142
143
mean_1 = torch .mean (states_1 , dim = (0 , 1 ), keepdim = True )
143
- var_0 = torch .var (states_0 , dim = (0 , 1 ), unbiased = False ,
144
- keepdim = True ) # Shape: (1, 1, n_neurons)
145
- var_1 = torch .var (states_1 , dim = (0 , 1 ), unbiased = False ,
146
- keepdim = True ) # Shape: (1, 1, n_neurons)
144
+ var_0 = torch .var (
145
+ states_0 , dim = (0 , 1 ), unbiased = False , keepdim = True
146
+ ) # Shape: (1, 1, n_neurons)
147
+ var_1 = torch .var (
148
+ states_1 , dim = (0 , 1 ), unbiased = False , keepdim = True
149
+ ) # Shape: (1, 1, n_neurons)
147
150
148
151
# Compute the KL divergence between the two populations (per neuron)
149
152
# Shape: (1, 1, n_neurons)
150
- kl_div = 0.5 * (torch .log (var_1 / var_0 ) +
151
- (var_0 + (mean_0 - mean_1 ) ** 2 ) / var_1 - 1 )
153
+ kl_div = 0.5 * (
154
+ torch .log (var_1 / var_0 ) + (var_0 + (mean_0 - mean_1 ) ** 2 ) / var_1 - 1
155
+ )
152
156
153
157
# Symmetric KL divergence: average the KL(P || Q) and KL(Q || P)
154
158
if self .symmetric :
155
159
# Shape: (1, 1, n_neurons)
156
- reverse_kl_div = 0.5 * \
157
- ( torch .log (var_0 / var_1 ) +
158
- ( var_1 + ( mean_1 - mean_0 ) ** 2 ) / var_0 - 1 )
160
+ reverse_kl_div = 0.5 * (
161
+ torch .log (var_0 / var_1 ) + ( var_1 + ( mean_1 - mean_0 ) ** 2 ) / var_0 - 1
162
+ )
159
163
# Shape: (1, 1, n_neurons)
160
164
kl_div = 0.5 * (kl_div + reverse_kl_div )
161
165
162
166
# Apply reduction based on the reduction method
163
- if self .reduction == ' mean' :
167
+ if self .reduction == " mean" :
164
168
kl_loss = torch .mean (kl_div ) # Scalar value
165
- elif self .reduction == ' sum' :
169
+ elif self .reduction == " sum" :
166
170
kl_loss = torch .sum (kl_div ) # Scalar value
167
- elif self .reduction == ' none' :
171
+ elif self .reduction == " none" :
168
172
kl_loss = kl_div # Shape: (1, 1, n_neurons)
169
173
else :
170
174
raise ValueError (f"Invalid reduction mode: { self .reduction } " )
171
175
172
176
# Regularization: L2 norm of the states across the neurons
173
- reg_loss = torch .mean (torch .norm (states_0 , dim = - 1 ) ** 2 ) + \
174
- torch .mean (torch .norm (states_1 , dim = - 1 ) ** 2 )
177
+ reg_loss = torch .mean (torch .norm (states_0 , dim = - 1 ) ** 2 ) + torch .mean (
178
+ torch .norm (states_1 , dim = - 1 ) ** 2
179
+ )
175
180
176
181
# Combine the KL divergence with the regularization term
177
- if self .reduction == ' none' :
182
+ if self .reduction == " none" :
178
183
# If no reduction, add regularization element-wise
179
- total_loss = kl_loss + self .reg * \
180
- ( torch .norm (states_0 , dim = - 1 ) ** 2 +
181
- torch . norm ( states_1 , dim = - 1 ) ** 2 )
184
+ total_loss = kl_loss + self .reg * (
185
+ torch .norm (states_0 , dim = - 1 ) ** 2 + torch . norm ( states_1 , dim = - 1 ) ** 2
186
+ )
182
187
else :
183
188
total_loss = kl_loss + self .reg * reg_loss
184
189
0 commit comments