|
1 | 1 | import torch
|
2 |
| -from typing import List |
3 |
| -from .linear_layer import LinearLayer |
| 2 | +from typing import List, Tuple |
| 3 | +from .recurrent_layer import RecurrentLayer |
4 | 4 |
|
5 |
| - |
6 |
| -class BlockMatrix: |
| 5 | +class BlockMatrix(torch.nn.Module): |
7 | 6 | def __init__(self, n_blocks: int):
|
8 |
| - """ |
9 |
| - Initializes a BlockMatrix. |
10 |
| - |
11 |
| - Args: |
12 |
| - n (int): Size of the matrix (n x n). |
13 |
| - """ |
| 7 | + super().__init__() |
14 | 8 | self.n_blocks = n_blocks
|
15 |
| - self.matrix = [[None for _ in range(n_blocks)] for _ in range(n_blocks)] |
| 9 | + self.matrix = torch.nn.ModuleList( |
| 10 | + [torch.nn.ModuleList([None for _ in range(n_blocks)]) for _ in range(n_blocks)] |
| 11 | + ) |
16 | 12 |
|
17 |
| - def __getitem__(self, idx): |
18 |
| - """ |
19 |
| - Get item using matrix-style indexing. |
20 |
| -
|
21 |
| - Parameters: |
22 |
| - - idx (tuple): A tuple (i, j) representing the row and column indices. |
23 |
| -
|
24 |
| - Returns: |
25 |
| - The module or value at position (i, j). |
26 |
| - """ |
| 13 | + def __getitem__(self, idx: tuple): |
27 | 14 | if not isinstance(idx, tuple) or len(idx) != 2:
|
28 | 15 | raise IndexError("Index must be a tuple (i, j)")
|
29 |
| - |
30 | 16 | i, j = idx
|
31 | 17 | if not (0 <= i < self.n_blocks) or not (0 <= j < self.n_blocks):
|
32 | 18 | raise IndexError("Index out of bounds")
|
33 |
| - |
34 | 19 | return self.matrix[i][j]
|
35 | 20 |
|
36 |
| - def __setitem__(self, idx, value): |
37 |
| - """ |
38 |
| - Set item using matrix-style indexing. |
39 |
| -
|
40 |
| - Args: |
41 |
| - idx (tuple): A tuple (i, j) representing the row and column indices. |
42 |
| - value: The value to set at position (i, j). |
43 |
| - """ |
| 21 | + def __setitem__(self, idx: tuple, value: torch.nn.Module): |
44 | 22 | if not isinstance(idx, tuple) or len(idx) != 2:
|
45 | 23 | raise IndexError("Index must be a tuple (i, j)")
|
46 |
| - |
47 | 24 | i, j = idx
|
48 |
| - if not (0 <= i < self.n_blocks) or not (0 <= j < self.n_blocks): |
49 |
| - raise IndexError("Index out of bounds") |
50 |
| - |
| 25 | + if not (0 <= i < self.n_blocks): |
| 26 | + raise IndexError(f"Index {i} out of bounds for n_blocks {self.n_blocks}") |
| 27 | + if not (0 <= j < self.n_blocks): |
| 28 | + raise IndexError(f"Index {j} out of bounds for n_blocks {self.n_blocks}") |
| 29 | + if i == j: |
| 30 | + assert isinstance(value, RecurrentLayer), "Diagonal blocks must be an instance of nn4n.nn.RecurrentLayer" |
| 31 | + else: |
| 32 | + assert isinstance(value, torch.nn.Module), "Off-diagonal blocks must be an instance of torch.nn.Module" |
| 33 | + assert not isinstance(value, RecurrentLayer), "Off-diagonal blocks cannot be an instance of nn4n.nn.RecurrentLayer" |
51 | 34 | self.matrix[i][j] = value
|
52 | 35 |
|
53 |
| - |
54 | 36 | class BlockRecurrentLayer(torch.nn.Module):
|
55 |
| - def __init__( |
56 |
| - self, |
57 |
| - n_blocks: int, |
58 |
| - ): |
59 |
| - """ |
60 |
| - Hidden layer of the network. The layer is initialized by passing specs in layer_struct. |
61 |
| -
|
62 |
| - Parameters: |
63 |
| - - n_blocks: number of blocks in the layer |
64 |
| - """ |
| 37 | + def __init__(self, n_blocks: int, **kwargs): |
65 | 38 | super().__init__()
|
66 |
| - self.blocks = BlockMatrix(n_blocks) |
67 |
| - |
| 39 | + self.block_recurrent = BlockMatrix(n_blocks=n_blocks) |
| 40 | + self.initialized = False |
| 41 | + |
68 | 42 | @property
|
69 |
| - def n_blocks(self) -> int: |
70 |
| - return self.blocks.n_blocks |
| 43 | + def size(self) -> int: |
| 44 | + return sum(self.block_sizes()) |
71 | 45 |
|
72 | 46 | @property
|
73 |
| - def size(self) -> int: |
74 |
| - return sum(self.list_sizes()) |
| 47 | + def n_blocks(self) -> int: |
| 48 | + return self.block_recurrent.n_blocks |
| 49 | + |
| 50 | + def block_indices(self, block_idx: int) -> torch.Tensor: |
| 51 | + ranges = self.block_ranges[block_idx] |
| 52 | + return torch.arange(ranges[0], ranges[1]) |
| 53 | + |
| 54 | + def _compute_block_ranges(self): |
| 55 | + block_ranges = [] |
| 56 | + start_idx = 0 |
| 57 | + for block_idx in range(self.n_blocks): |
| 58 | + block_size = self.block_sizes()[block_idx] |
| 59 | + if block_size == 0: |
| 60 | + block_ranges.append(None) |
| 61 | + else: |
| 62 | + block_ranges.append((start_idx, start_idx + block_size)) |
| 63 | + start_idx += block_size |
| 64 | + return block_ranges |
| 65 | + |
| 66 | + def block_sizes(self) -> List[int]: |
| 67 | + block_sizes = [] |
| 68 | + for block_idx in range(self.n_blocks): |
| 69 | + diagonal_block = self.block_recurrent[block_idx, block_idx] |
| 70 | + block_size = diagonal_block.size if diagonal_block is not None else 0 |
| 71 | + block_sizes.append(block_size) |
| 72 | + return block_sizes |
| 73 | + |
| 74 | + def set_projection(self, from_idx, to_idx, layer): |
| 75 | + # NOTE: this is "inversed" which is slightly confusing |
| 76 | + self[to_idx, from_idx] = layer |
| 77 | + |
| 78 | + def set_recurrent(self, idx, layer): |
| 79 | + self[idx, idx] = layer |
| 80 | + |
| 81 | + def __getitem__(self, idx: tuple): |
| 82 | + return self.block_recurrent[idx] |
| 83 | + |
| 84 | + def __setitem__(self, idx: tuple, value: torch.nn.Module): |
| 85 | + # Set the value then check the network is initialized every time |
| 86 | + self.block_recurrent[idx] = value |
| 87 | + self.initialized = all(isinstance(self.block_recurrent[i, i], RecurrentLayer) for i in range(self.n_blocks)) |
| 88 | + self.block_ranges = self._compute_block_ranges() |
75 | 89 |
|
76 |
| - def list_sizes(self) -> List[int]: |
| 90 | + # FORWARD |
| 91 | + # ================================================================================= |
| 92 | + def _parse_fr_v(self, fr: torch.Tensor, v: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
77 | 93 | """
|
78 |
| - Get the sizes of the blocks |
| 94 | + Parse the fr and v into a list of tensors |
79 | 95 | """
|
80 |
| - return [self.blocks[i, i].input_dim for i in range(self.n_blocks)] |
81 |
| - |
82 |
| - def to(self, device): |
83 |
| - """Move the network to the device (cpu/gpu)""" |
84 |
| - super().to(device) |
| 96 | + fr_list = [] |
| 97 | + v_list = [] |
85 | 98 | for i in range(self.n_blocks):
|
86 |
| - for j in range(self.n_blocks): |
87 |
| - block = self.blocks[i, j] |
88 |
| - if block is not None and isinstance(block, torch.nn.Module): |
89 |
| - block.to(device) |
90 |
| - return self |
| 99 | + fr_list.append(fr[:, self.block_indices(i)]) |
| 100 | + v_list.append(v[:, self.block_indices(i)]) |
| 101 | + return fr_list, v_list |
91 | 102 |
|
92 |
| - # FORWARD |
93 |
| - # ================================================================================= |
94 | 103 | def forward(
|
95 | 104 | self,
|
96 |
| - fr_hid_t: torch.Tensor, |
97 |
| - v_hid_t: torch.Tensor, |
98 |
| - u_in_t: torch.Tensor |
| 105 | + fr: torch.Tensor, |
| 106 | + v: torch.Tensor, |
| 107 | + u_list: List[torch.Tensor] |
99 | 108 | ) -> torch.Tensor:
|
100 | 109 | """
|
101 | 110 | Forwardly update network
|
102 | 111 |
|
103 | 112 | Parameters:
|
104 |
| - - fr_hid_t: hidden state (post-activation), shape: (batch_size, hidden_size) |
105 |
| - - v_hid_t: hidden state (pre-activation), shape: (batch_size, hidden_size) |
106 |
| - - u_in_t: input, shape: (batch_size, input_size) |
| 113 | + - fr: hidden state (post-activation), shape: (batch_size, total_hidden_size) |
| 114 | + - v: hidden state (pre-activation), shape: (batch_size, total_hidden_size) |
| 115 | + - u_list: list of input tensors, each of shape (batch_size, input_dim) |
107 | 116 |
|
108 | 117 | Returns:
|
109 |
| - - fr_t_next: hidden state (post-activation), shape: (batch_size, hidden_size) |
110 |
| - - v_t_next: hidden state (pre-activation), shape: (batch_size, hidden_size) |
| 118 | + - fr_t_next: hidden state (post-activation), shape: (batch_size, total_hidden_size) |
| 119 | + - v_t_next: hidden state (pre-activation), shape: (batch_size, total_hidden_size) |
111 | 120 | """
|
112 |
| - v_in_t = self.input_layer(u_in_t) if self.input_layer is not None else u_in_t |
113 |
| - v_hid_t_next = self.linear_layer(fr_hid_t) |
114 |
| - v_t_next = (1 - self.alpha) * v_hid_t + self.alpha * (v_hid_t_next + v_in_t) |
115 |
| - if self.preact_noise > 0 and self.training: |
116 |
| - _preact_noise = self._generate_noise(v_t_next.size(), self.preact_noise) |
117 |
| - v_t_next = v_t_next + _preact_noise |
118 |
| - fr_t_next = self.activation(v_t_next) |
119 |
| - if self.postact_noise > 0 and self.training: |
120 |
| - _postact_noise = self._generate_noise(fr_t_next.size(), self.postact_noise) |
121 |
| - fr_t_next = fr_t_next + _postact_noise |
122 |
| - return fr_t_next, v_t_next |
| 121 | + if not self.initialized: |
| 122 | + raise ValueError("BlockRecurrentLayer is not initialized. All diagonal blocks must be set before forward pass.") |
| 123 | + |
| 124 | + fr_list, v_list = self._parse_fr_v(fr, v) |
| 125 | + u_aux_list = [torch.zeros_like(_fr, device=_fr.device) for _fr in fr_list] |
| 126 | + fr_n_list, v_n_list = [None for _ in range(self.n_blocks)], [None for _ in range(self.n_blocks)] |
| 127 | + |
| 128 | + for from_idx in range(self.n_blocks): |
| 129 | + for to_idx in range(self.n_blocks): |
| 130 | + if to_idx == from_idx: |
| 131 | + continue |
| 132 | + layer = self.block_recurrent[to_idx, from_idx] |
| 133 | + if layer is not None: |
| 134 | + u_aux_list[to_idx] += layer(fr_list[from_idx]) |
| 135 | + |
| 136 | + for diag_idx in range(self.n_blocks): |
| 137 | + layer = self.block_recurrent[diag_idx, diag_idx] |
| 138 | + fr_n_list[diag_idx], v_n_list[diag_idx] = layer(fr_list[diag_idx], v_list[diag_idx], u_list[diag_idx], u_aux_list[diag_idx]) |
| 139 | + |
| 140 | + fr_next = torch.cat(fr_n_list, dim=-1) |
| 141 | + v_next = torch.cat(v_n_list, dim=-1) |
| 142 | + return fr_next, v_next |
123 | 143 |
|
124 | 144 | # HELPER FUNCTIONS
|
125 | 145 | # ======================================================================================
|
126 | 146 | def plot_layer(self, **kwargs):
|
127 | 147 | """
|
128 | 148 | Plot the layer
|
129 | 149 | """
|
130 |
| - self.linear_layer.plot_layer(**kwargs) |
131 |
| - if self.input_layer is not None: |
132 |
| - self.input_layer.plot_layer(**kwargs) |
| 150 | + raise NotImplementedError("Plotting is not implemented for BlockRecurrentLayer") |
133 | 151 |
|
134 | 152 | def _get_specs(self):
|
135 | 153 | """
|
136 | 154 | Get specs of the layer
|
137 | 155 | """
|
138 |
| - return { |
139 |
| - "input_dim": self.input_dim, |
140 |
| - "output_dim": self.output_dim, |
141 |
| - "hidden_size": self.hidden_size, |
142 |
| - "alpha": self.alpha, |
143 |
| - "learn_alpha": self.learn_alpha, |
144 |
| - "preact_noise": self.preact_noise, |
145 |
| - "postact_noise": self.postact_noise, |
146 |
| - } |
| 156 | + raise NotImplementedError("Getting specs is not implemented for BlockRecurrentLayer") |
0 commit comments