Skip to content

Commit 38ab72b

Browse files
Fix calculate_macs() for Linear layers. (#318)
* Fix calculate_macs() for Linear layers. Fix MACs in lst.out and lstm_half.out. * Add test for torch.nn.Linear. * Change groud-truth Total mult-adds in flan_t5_small.out. MACs increased from 280.27M to 18.25G because of the Linear layer fix. --------- Co-authored-by: Andrew Lavin <[email protected]>
1 parent 29166cc commit 38ab72b

File tree

6 files changed

+36
-5
lines changed

6 files changed

+36
-5
lines changed

tests/test_output/flan_t5_small.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ T5ForConditionalGeneration [3, 100, 512]
3737
Total params: 128,743,488
3838
Trainable params: 128,743,488
3939
Non-trainable params: 0
40-
Total mult-adds (M): 280.27
40+
Total mult-adds (G): 18.25
4141
==============================================================================================================
4242
Input size (MB): 0.01
4343
Forward/backward pass size (MB): 326.28

tests/test_output/linear.out

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
========================================================================================================================
2+
Layer (type:depth-idx) Input Shape Output Shape Param # Mult-Adds
3+
========================================================================================================================
4+
Linear [32, 16, 8] [32, 16, 64] 576 294,912
5+
========================================================================================================================
6+
Total params: 576
7+
Trainable params: 576
8+
Non-trainable params: 0
9+
Total mult-adds (M): 0.29
10+
========================================================================================================================
11+
Input size (MB): 0.02
12+
Forward/backward pass size (MB): 0.26
13+
Params size (MB): 0.00
14+
Estimated Total Size (MB): 0.28
15+
========================================================================================================================

tests/test_output/lstm.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ LSTMNet (LSTMNet) -- [100, 20]
1313
│ └─weight_hh_l1 [2048, 512] ├─1,048,576
1414
│ └─bias_ih_l1 [2048] ├─2,048
1515
│ └─bias_hh_l1 [2048] └─2,048
16-
├─Linear (decoder) -- [1, 100, 20] 10,260 10,260
16+
├─Linear (decoder) -- [1, 100, 20] 10,260 1,026,000
1717
│ └─weight [512, 20] ├─10,240
1818
│ └─bias [20] └─20
1919
========================================================================================================================
2020
Total params: 3,784,580
2121
Trainable params: 3,784,580
2222
Non-trainable params: 0
23-
Total mult-adds (M): 376.85
23+
Total mult-adds (M): 377.86
2424
========================================================================================================================
2525
Input size (MB): 0.00
2626
Forward/backward pass size (MB): 0.67

tests/test_output/lstm_half.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ Layer (type (var_name)) Kernel Shape Output Shape
44
LSTMNet (LSTMNet) -- [100, 20] -- --
55
├─Embedding (embedding) -- [1, 100, 300] 6,000 6,000
66
├─LSTM (encoder) -- [1, 100, 512] 3,768,320 376,832,000
7-
├─Linear (decoder) -- [1, 100, 20] 10,260 10,260
7+
├─Linear (decoder) -- [1, 100, 20] 10,260 1,026,000
88
========================================================================================================================
99
Total params: 3,784,580
1010
Trainable params: 3,784,580
1111
Non-trainable params: 0
12-
Total mult-adds (M): 376.85
12+
Total mult-adds (M): 377.86
1313
========================================================================================================================
1414
Input size (MB): 0.00
1515
Forward/backward pass size (MB): 0.33

tests/torchinfo_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,20 @@ def test_groups() -> None:
149149
)
150150

151151

152+
def test_linear() -> None:
153+
input_shape = (32, 16, 8)
154+
module = nn.Linear(8, 64)
155+
col_names = ("input_size", "output_size", "num_params", "mult_adds")
156+
input_data = torch.randn(*input_shape)
157+
summary(
158+
module,
159+
input_data=input_data,
160+
depth=1,
161+
col_names=col_names,
162+
col_width=20,
163+
)
164+
165+
152166
def test_single_input_batch_dim() -> None:
153167
model = SingleInputNet()
154168
col_names = ("kernel_size", "input_size", "output_size", "num_params", "mult_adds")

torchinfo/layer_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def calculate_macs(self) -> None:
244244
self.macs += int(
245245
cur_params * prod(self.output_size[:1] + self.output_size[2:])
246246
)
247+
elif "Linear" in self.class_name:
248+
self.macs += int(cur_params * prod(self.output_size[:-1]))
247249
else:
248250
self.macs += self.output_size[0] * cur_params
249251
# RNN modules have inner weights such as weight_ih_l0

0 commit comments

Comments
 (0)