Description
Wrong flops count if the model is compiled with torch.compile
:
- Flops of modules in
torch.nn
, for example,nn.Linear
,nn.Conv2d
are tripled. - Flops of custom modules are not counted.
Here is a code example:
import torch.nn as nn
import ptflops
import torch
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x
class Test_model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_layer = nn.Linear(1000, 1000, bias=False)
self.custom_layer = MyModule()
def forward(self, x):
out = self.linear_layer(x)
out = self.custom_layer(out)
return out
def mymodule_flops_counter_hook(conv_module, input: torch.Tensor, output):
input = input[0]
mul_count = input.numel() * 1000
conv_module.__flops__ += int(mul_count)
MyModuleMapping = {
MyModule: mymodule_flops_counter_hook
}
net = Test_model()
net = torch.compile(net)
print(ptflops.get_model_complexity_info(net, (1000,), custom_modules_hooks=MyModuleMapping, output_precision=3))
Output:
without torch.compile
:
Test_model(
1.0 M, 100.000% Params, 2.0 MMac, 100.000% MACs,
(linear_layer): Linear(1.0 M, 100.000% Params, 1.0 MMac, 50.000% MACs, in_features=1000, out_features=1000, bias=False)
(custom_layer): MyModule(0, 0.000% Params, 1.0 MMac, 50.000% MACs, )
)
('2.0 MMac', '1.0 M')
with torch.compile
:
OptimizedModule(
1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs,
(_orig_mod): Test_model(
1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs,
(linear_layer): Linear(1.0 M, 100.000% Params, 3.0 MMac, 100.000% MACs, in_features=1000, out_features=1000, bias=False)
(custom_layer): MyModule(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
)
)
('3.0 MMac', '1.0 M')