Skip to content

Commit 2cb5bcb

Browse files
authored
Merge pull request #326 from IkemOkoh/patch-1
Update leakyparallel.py
2 parents e0968b2 + 2085cbc commit 2cb5bcb

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

snntorch/_neurons/leakyparallel.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,24 @@ class LeakyParallel(nn.Module):
2424
2525
Several differences between `LeakyParallel` and `Leaky` include:
2626
27-
* Negative hidden states are clipped due to the forced ReLU operation in RNN
28-
* Linear weights are included in addition to recurrent weights
29-
* `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise
30-
* There is no explicit reset mechanism
31-
* Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel`
32-
* Only the output spike is returned. Membrane potential is not accessible by default
33-
* RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`.
27+
* Negative hidden states are clipped due to the
28+
forced ReLU operation in RNN.
29+
* Linear weights are included in addition to
30+
recurrent weights.
31+
* `beta` is clipped between [0,1] and cloned to
32+
`weight_hh_l` only upon layer initialization.
33+
It is unused otherwise.
34+
* There is no explicit reset mechanism.
35+
* Several functions such as `init_hidden`, `output`,
36+
`inhibition`, and `state_quant` are unavailable
37+
in `LeakyParallel`.
38+
* Only the output spike is returned. Membrane potential
39+
is not accessible by default.
40+
* RNN uses a hidden matrix of size (num_hidden, num_hidden)
41+
to transform the hidden state vector. This would 'leak'
42+
the membrane potential between LIF neurons, and so the
43+
hidden matrix is forced to a diagonal matrix by default.
44+
This can be disabled by setting `weight_hh_enable=True`.
3445
3546
Example::
3647
@@ -117,22 +128,28 @@ def forward(self, x):
117128
118129
where:
119130
120-
`L = sequence length`
131+
* **`L** = sequence length`
121132
122-
`N = batch size`
133+
* **`N** = batch size`
123134
124-
`H_{in} = input_size`
135+
* **`H_{in}** = input_size`
125136
126-
`H_{out} = hidden_size`
137+
* **`H_{out}** = hidden_size`
127138
128139
Learnable Parameters:
129-
- **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden weights of shape (hidden_size, input_size)
130-
- **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden weights of the k-th layer which are sampled from `beta` of shape (hidden_size, hidden_size)
131-
- **bias_ih_l** - the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
132-
- **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
133-
- **threshold** (torch.Tensor) - optional learnable thresholds
134-
must be manually passed in, of shape `1` or`` (input_size).
135-
- **graded_spikes_factor** (torch.Tensor) - optional learnable graded spike factor
140+
- **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden
141+
weights of shape (hidden_size, input_size).
142+
- **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden
143+
weights of the k-th layer which are sampled from `beta` of shape
144+
(hidden_size, hidden_size).
145+
- **bias_ih_l** - the learnable input-hidden bias of the k-th layer,
146+
of shape (hidden_size).
147+
- **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer,
148+
of shape (hidden_size).
149+
- **threshold** (torch.Tensor) - optional learnable thresholds must be
150+
manually passed in, of shape `1` or`` (input_size).
151+
- **graded_spikes_factor** (torch.Tensor) - optional learnable graded
152+
spike factor.
136153
137154
"""
138155

@@ -303,4 +320,4 @@ def _threshold_buffer(self, threshold, learn_threshold):
303320
if learn_threshold:
304321
self.threshold = nn.Parameter(threshold)
305322
else:
306-
self.register_buffer("threshold", threshold)
323+
self.register_buffer("threshold", threshold)

0 commit comments

Comments
 (0)