Skip to content

Commit 1ff905f

Browse files
committed
add BlockRecurrent
1 parent b06af58 commit 1ff905f

File tree

7 files changed

+185
-125
lines changed

7 files changed

+185
-125
lines changed

nn4n/nn/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .linear_layer import LinearLayer
22
from .leaky_linear_layer import LeakyLinearLayer
33
from .recurrent_layer import RecurrentLayer
4-
from .rnn_layer import RNNLayer
5-
from .module import Module
4+
from .rnn import RNN, BlockRNN
5+
from .module import Module
6+
from .block_recurrent_layer import BlockRecurrentLayer

nn4n/nn/block_recurrent_layer.py

Lines changed: 107 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,156 @@
11
import torch
2-
from typing import List
3-
from .linear_layer import LinearLayer
2+
from typing import List, Tuple
3+
from .recurrent_layer import RecurrentLayer
44

5-
6-
class BlockMatrix:
5+
class BlockMatrix(torch.nn.Module):
76
def __init__(self, n_blocks: int):
8-
"""
9-
Initializes a BlockMatrix.
10-
11-
Args:
12-
n (int): Size of the matrix (n x n).
13-
"""
7+
super().__init__()
148
self.n_blocks = n_blocks
15-
self.matrix = [[None for _ in range(n_blocks)] for _ in range(n_blocks)]
9+
self.matrix = torch.nn.ModuleList(
10+
[torch.nn.ModuleList([None for _ in range(n_blocks)]) for _ in range(n_blocks)]
11+
)
1612

17-
def __getitem__(self, idx):
18-
"""
19-
Get item using matrix-style indexing.
20-
21-
Parameters:
22-
- idx (tuple): A tuple (i, j) representing the row and column indices.
23-
24-
Returns:
25-
The module or value at position (i, j).
26-
"""
13+
def __getitem__(self, idx: tuple):
2714
if not isinstance(idx, tuple) or len(idx) != 2:
2815
raise IndexError("Index must be a tuple (i, j)")
29-
3016
i, j = idx
3117
if not (0 <= i < self.n_blocks) or not (0 <= j < self.n_blocks):
3218
raise IndexError("Index out of bounds")
33-
3419
return self.matrix[i][j]
3520

36-
def __setitem__(self, idx, value):
37-
"""
38-
Set item using matrix-style indexing.
39-
40-
Args:
41-
idx (tuple): A tuple (i, j) representing the row and column indices.
42-
value: The value to set at position (i, j).
43-
"""
21+
def __setitem__(self, idx: tuple, value: torch.nn.Module):
4422
if not isinstance(idx, tuple) or len(idx) != 2:
4523
raise IndexError("Index must be a tuple (i, j)")
46-
4724
i, j = idx
48-
if not (0 <= i < self.n_blocks) or not (0 <= j < self.n_blocks):
49-
raise IndexError("Index out of bounds")
50-
25+
if not (0 <= i < self.n_blocks):
26+
raise IndexError(f"Index {i} out of bounds for n_blocks {self.n_blocks}")
27+
if not (0 <= j < self.n_blocks):
28+
raise IndexError(f"Index {j} out of bounds for n_blocks {self.n_blocks}")
29+
if i == j:
30+
assert isinstance(value, RecurrentLayer), "Diagonal blocks must be an instance of nn4n.nn.RecurrentLayer"
31+
else:
32+
assert isinstance(value, torch.nn.Module), "Off-diagonal blocks must be an instance of torch.nn.Module"
33+
assert not isinstance(value, RecurrentLayer), "Off-diagonal blocks cannot be an instance of nn4n.nn.RecurrentLayer"
5134
self.matrix[i][j] = value
5235

53-
5436
class BlockRecurrentLayer(torch.nn.Module):
55-
def __init__(
56-
self,
57-
n_blocks: int,
58-
):
59-
"""
60-
Hidden layer of the network. The layer is initialized by passing specs in layer_struct.
61-
62-
Parameters:
63-
- n_blocks: number of blocks in the layer
64-
"""
37+
def __init__(self, n_blocks: int, **kwargs):
6538
super().__init__()
66-
self.blocks = BlockMatrix(n_blocks)
67-
39+
self.block_recurrent = BlockMatrix(n_blocks=n_blocks)
40+
self.initialized = False
41+
6842
@property
69-
def n_blocks(self) -> int:
70-
return self.blocks.n_blocks
43+
def size(self) -> int:
44+
return sum(self.block_sizes())
7145

7246
@property
73-
def size(self) -> int:
74-
return sum(self.list_sizes())
47+
def n_blocks(self) -> int:
48+
return self.block_recurrent.n_blocks
49+
50+
def block_indices(self, block_idx: int) -> torch.Tensor:
51+
ranges = self.block_ranges[block_idx]
52+
return torch.arange(ranges[0], ranges[1])
53+
54+
def _compute_block_ranges(self):
55+
block_ranges = []
56+
start_idx = 0
57+
for block_idx in range(self.n_blocks):
58+
block_size = self.block_sizes()[block_idx]
59+
if block_size == 0:
60+
block_ranges.append(None)
61+
else:
62+
block_ranges.append((start_idx, start_idx + block_size))
63+
start_idx += block_size
64+
return block_ranges
65+
66+
def block_sizes(self) -> List[int]:
67+
block_sizes = []
68+
for block_idx in range(self.n_blocks):
69+
diagonal_block = self.block_recurrent[block_idx, block_idx]
70+
block_size = diagonal_block.size if diagonal_block is not None else 0
71+
block_sizes.append(block_size)
72+
return block_sizes
73+
74+
def set_projection(self, from_idx, to_idx, layer):
75+
# NOTE: this is "inversed" which is slightly confusing
76+
self[to_idx, from_idx] = layer
77+
78+
def set_recurrent(self, idx, layer):
79+
self[idx, idx] = layer
80+
81+
def __getitem__(self, idx: tuple):
82+
return self.block_recurrent[idx]
83+
84+
def __setitem__(self, idx: tuple, value: torch.nn.Module):
85+
# Set the value then check the network is initialized every time
86+
self.block_recurrent[idx] = value
87+
self.initialized = all(isinstance(self.block_recurrent[i, i], RecurrentLayer) for i in range(self.n_blocks))
88+
self.block_ranges = self._compute_block_ranges()
7589

76-
def list_sizes(self) -> List[int]:
90+
# FORWARD
91+
# =================================================================================
92+
def _parse_fr_v(self, fr: torch.Tensor, v: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
7793
"""
78-
Get the sizes of the blocks
94+
Parse the fr and v into a list of tensors
7995
"""
80-
return [self.blocks[i, i].input_dim for i in range(self.n_blocks)]
81-
82-
def to(self, device):
83-
"""Move the network to the device (cpu/gpu)"""
84-
super().to(device)
96+
fr_list = []
97+
v_list = []
8598
for i in range(self.n_blocks):
86-
for j in range(self.n_blocks):
87-
block = self.blocks[i, j]
88-
if block is not None and isinstance(block, torch.nn.Module):
89-
block.to(device)
90-
return self
99+
fr_list.append(fr[:, self.block_indices(i)])
100+
v_list.append(v[:, self.block_indices(i)])
101+
return fr_list, v_list
91102

92-
# FORWARD
93-
# =================================================================================
94103
def forward(
95104
self,
96-
fr_hid_t: torch.Tensor,
97-
v_hid_t: torch.Tensor,
98-
u_in_t: torch.Tensor
105+
fr: torch.Tensor,
106+
v: torch.Tensor,
107+
u_list: List[torch.Tensor]
99108
) -> torch.Tensor:
100109
"""
101110
Forwardly update network
102111
103112
Parameters:
104-
- fr_hid_t: hidden state (post-activation), shape: (batch_size, hidden_size)
105-
- v_hid_t: hidden state (pre-activation), shape: (batch_size, hidden_size)
106-
- u_in_t: input, shape: (batch_size, input_size)
113+
- fr: hidden state (post-activation), shape: (batch_size, total_hidden_size)
114+
- v: hidden state (pre-activation), shape: (batch_size, total_hidden_size)
115+
- u_list: list of input tensors, each of shape (batch_size, input_dim)
107116
108117
Returns:
109-
- fr_t_next: hidden state (post-activation), shape: (batch_size, hidden_size)
110-
- v_t_next: hidden state (pre-activation), shape: (batch_size, hidden_size)
118+
- fr_t_next: hidden state (post-activation), shape: (batch_size, total_hidden_size)
119+
- v_t_next: hidden state (pre-activation), shape: (batch_size, total_hidden_size)
111120
"""
112-
v_in_t = self.input_layer(u_in_t) if self.input_layer is not None else u_in_t
113-
v_hid_t_next = self.linear_layer(fr_hid_t)
114-
v_t_next = (1 - self.alpha) * v_hid_t + self.alpha * (v_hid_t_next + v_in_t)
115-
if self.preact_noise > 0 and self.training:
116-
_preact_noise = self._generate_noise(v_t_next.size(), self.preact_noise)
117-
v_t_next = v_t_next + _preact_noise
118-
fr_t_next = self.activation(v_t_next)
119-
if self.postact_noise > 0 and self.training:
120-
_postact_noise = self._generate_noise(fr_t_next.size(), self.postact_noise)
121-
fr_t_next = fr_t_next + _postact_noise
122-
return fr_t_next, v_t_next
121+
if not self.initialized:
122+
raise ValueError("BlockRecurrentLayer is not initialized. All diagonal blocks must be set before forward pass.")
123+
124+
fr_list, v_list = self._parse_fr_v(fr, v)
125+
u_aux_list = [torch.zeros_like(_fr, device=_fr.device) for _fr in fr_list]
126+
fr_n_list, v_n_list = [None for _ in range(self.n_blocks)], [None for _ in range(self.n_blocks)]
127+
128+
for from_idx in range(self.n_blocks):
129+
for to_idx in range(self.n_blocks):
130+
if to_idx == from_idx:
131+
continue
132+
layer = self.block_recurrent[to_idx, from_idx]
133+
if layer is not None:
134+
u_aux_list[to_idx] += layer(fr_list[from_idx])
135+
136+
for diag_idx in range(self.n_blocks):
137+
layer = self.block_recurrent[diag_idx, diag_idx]
138+
fr_n_list[diag_idx], v_n_list[diag_idx] = layer(fr_list[diag_idx], v_list[diag_idx], u_list[diag_idx], u_aux_list[diag_idx])
139+
140+
fr_next = torch.cat(fr_n_list, dim=-1)
141+
v_next = torch.cat(v_n_list, dim=-1)
142+
return fr_next, v_next
123143

124144
# HELPER FUNCTIONS
125145
# ======================================================================================
126146
def plot_layer(self, **kwargs):
127147
"""
128148
Plot the layer
129149
"""
130-
self.linear_layer.plot_layer(**kwargs)
131-
if self.input_layer is not None:
132-
self.input_layer.plot_layer(**kwargs)
150+
raise NotImplementedError("Plotting is not implemented for BlockRecurrentLayer")
133151

134152
def _get_specs(self):
135153
"""
136154
Get specs of the layer
137155
"""
138-
return {
139-
"input_dim": self.input_dim,
140-
"output_dim": self.output_dim,
141-
"hidden_size": self.hidden_size,
142-
"alpha": self.alpha,
143-
"learn_alpha": self.learn_alpha,
144-
"preact_noise": self.preact_noise,
145-
"postact_noise": self.postact_noise,
146-
}
156+
raise NotImplementedError("Getting specs is not implemented for BlockRecurrentLayer")

nn4n/nn/linear_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,4 @@ def print_layer(self):
149149
"""
150150
Print the specs of the layer
151151
"""
152-
utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs())
152+
utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs())

nn4n/nn/module.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import nn4n
22
import torch
3-
import numpy as np
4-
53

64
class Module(torch.nn.Module):
75
"""
@@ -35,8 +33,9 @@ def __init__(self,
3533
self._enforce_positivity()
3634
self._balance_excitatory_inhibitory()
3735

38-
# Register the forward hook
36+
# Register the forward and backward hooks
3937
self.register_forward_pre_hook(self.enforce_constraints)
38+
self._register_backward_hooks()
4039

4140
# INIT MASKS
4241
# ======================================================================================
@@ -74,8 +73,7 @@ def _balance_excitatory_inhibitory(self):
7473
ext_sum = self.weight[self.positivity_mask == 1].sum()
7574
inh_sum = self.weight[self.positivity_mask == -1].sum()
7675
if ext_sum == 0 or inh_sum == 0:
77-
# Automatically stop balancing if one of the sums is 0
78-
# devide by 10 to avoid recurrent explosion/decay
76+
# Avoid explosions/decay by scaling everything down
7977
self.weight /= 10
8078
else:
8179
if ext_sum > abs(inh_sum):
@@ -113,6 +111,20 @@ def _enforce_positivity(self):
113111
w[self.positivity_mask.T == -1] = torch.clamp(w[self.positivity_mask.T == -1], max=0)
114112
self.weight.data.copy_(torch.nn.Parameter(w))
115113

114+
# BACKWARD HOOK
115+
# ======================================================================================
116+
def _register_backward_hooks(self):
117+
"""
118+
Register hooks to modify gradients during backprop.
119+
For example, zero out gradients for masked-out weights
120+
to prevent updates in those positions.
121+
"""
122+
if self.sparsity_mask is not None:
123+
def hook_fn(grad):
124+
# If a weight is masked out, its gradient is zeroed.
125+
return grad * (self.sparsity_mask.T > 0).float()
126+
self.weight.register_hook(hook_fn)
127+
116128
# UTILITIES
117129
# ======================================================================================
118130
def set_weight(self, weight):
@@ -141,4 +153,4 @@ def plot_layer(self, plot_type="weight"):
141153
w=weight.detach().numpy(),
142154
title=f"Weight",
143155
ignore_zeros=self.sparsity_mask is not None,
144-
)
156+
)

nn4n/nn/recurrent_layer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,20 @@ def output_dim(self) -> int:
3636
def size(self) -> int:
3737
return self.leaky_layer.input_dim
3838

39-
def to(self, device: torch.device):
40-
"""Move the network to the device (cpu/gpu)"""
41-
super().to(device)
42-
self.device = device
43-
self.leaky_layer.to(device)
44-
if self.projection_layer is not None:
45-
self.projection_layer.to(device)
46-
return self
47-
48-
def forward(self, fr: torch.Tensor, v: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
39+
def forward(self, fr: torch.Tensor, v: torch.Tensor, u: torch.Tensor, u_aux: torch.Tensor = None):
4940
"""
5041
Forwardly update network
5142
5243
Parameters:
5344
- fr: hidden state (post-activation), shape: (batch_size, hidden_size)
5445
- v: hidden state (pre-activation), shape: (batch_size, hidden_size)
5546
- u: input, shape: (batch_size, input_size)
47+
- u_aux: auxiliary input to be added after projection, shape: (batch_size, hidden_size)
5648
5749
Returns:
5850
- fr_next: hidden state (post-activation), shape: (batch_size, hidden_size)
5951
- v_next: hidden state (pre-activation), shape: (batch_size, hidden_size)
6052
"""
6153
u = self.projection_layer(u) if self.projection_layer is not None else u
54+
u = u + u_aux if u_aux is not None else u
6255
return self.leaky_layer(fr, v, u)

0 commit comments

Comments
 (0)