Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update backends #140

Merged
merged 5 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# ptflops versions log

## v 0.7.4
- Switch to aten by default.
- Add ignore and custom modules for aten.
- Add an option to disable counting of functional-style operations in pytorch backend.

## v 0.7.3
- Add aten backend to collect the amount of flops on aten level.

Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ print per-layer computational cost of a given network.
`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However,
it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use
`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers).
The default backend is `aten`. Please, don't use `pytorch` backend for transformer architectures.

## `aten` backend
### Operations considered:
Expand All @@ -19,6 +20,9 @@ it's still useful, since it provides a better par-layer analytics for CNNs. In a
- Use `verbose=True` to see the operations which were not considered during complexity computation.
- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`.
Deeper modules at the second level of nesting are not shown in the per-layer statistics.
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
for research purposes. For instance, one can drop all convolutions from the counting process
specifying `ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution]`.

## `pytorch` backend
### Supported layers:
Expand All @@ -41,7 +45,9 @@ Experimental support:

- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
Sometimes considering functional style conflicts with hooks for `nn.Module` (for instance, custom ones). In that case, counting with these ops can be disabled by
passing `backend_specific_config={"count_functional" : False}`.
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments.
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
for research purposes. For instance, one can drop all convolutions from the counting process
Expand Down
23 changes: 16 additions & 7 deletions ptflops/aten_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import sys
import traceback
from collections import defaultdict
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
from torch.utils._python_dispatch import TorchDispatchMode
Expand All @@ -23,12 +24,15 @@

class FlopCounterMode(TorchDispatchMode):
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
output_params=None):
output_params=None, custom_hooks={}, ignored_ops=[]):
self.verbose = verbose
if output_params is None:
output_params = defaultdict(dict)
self.output_params = output_params
self.print_fn = partial(print, **self.output_params['print_params'])
self.all_ops = deepcopy(ATEN_OPS_MAPPING)
self.all_ops.update(custom_hooks)
self.ignored_ops = ignored_ops

self.print_per_layer_stat = print_per_layer_stat
self.flop_counts = defaultdict(lambda: defaultdict(int))
Expand Down Expand Up @@ -82,8 +86,11 @@ def normalize_tuple(x):

out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in ATEN_OPS_MAPPING:
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))

if func_packet in self.ignored_ops:
self.print_fn(f'Warning: {func_packet} operation is ignored')
elif func_packet in self.all_ops:
flop_count = self.all_ops[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
elif self.verbose:
Expand All @@ -99,8 +106,9 @@ def get_flops_aten(model, input_res,
custom_modules_hooks={},
output_precision=2,
flops_units: Optional[str] = 'GMac',
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
Union[int, None]]:
param_units: Optional[str] = 'M',
extra_config: Dict = {}) -> Tuple[Union[int, None],
Union[int, None]]:

params_sum = get_model_parameters_number(model)
model.eval()
Expand All @@ -119,7 +127,8 @@ def get_flops_aten(model, input_res,
batch = torch.ones(()).new_empty((1, *input_res))

try:
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params,
custom_modules_hooks, ignore_modules)
with counter:
if isinstance(batch, dict):
_ = model(**batch)
Expand Down
44 changes: 26 additions & 18 deletions ptflops/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def get_model_complexity_info(model: nn.Module,
input_constructor: Optional[Callable[[Tuple], Dict]] = None,
ost: TextIO = sys.stdout,
verbose: bool = False,
ignore_modules: List[nn.Module] = [],
custom_modules_hooks: Dict[nn.Module, Any] = {},
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH,
ignore_modules: List[Union[nn.Module, Any]] = [],
custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {},
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN,
flops_units: Optional[str] = None,
param_units: Optional[str] = None,
output_precision: int = 2) -> Tuple[Union[str, int, None],
Union[str, int, None]]:
output_precision: int = 2,
backend_specific_config: Dict = {}) -> Tuple[
Union[str, int, None],
Union[str, int, None]]:
"""
Analyzes the input model and collects the amounts of parameters and MACs
required to make a forward pass of the model.
Expand All @@ -61,10 +63,11 @@ def get_model_complexity_info(model: nn.Module,
:type ost: TextIO
:param verbose: Parameter to control printing of extra information and warnings.
:type verbose: bool
:param ignore_modules: A list of torch.nn.Module modules to ignore.
:type ignore_modules: nn.Module
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
:type custom_modules_hooks: Dict[nn.Module, Any]
:param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore.
:type ignore_modules: List[Union[nn.Module, Any]]
:param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or
torch.ops.aten modules.
:type custom_modules_hooks: Dict[Union[nn.Module, Any], Any]
:param backend: Backend that used for evaluating model complexity.
:type backend: FLOPS_BACKEND
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).
Expand All @@ -74,6 +77,8 @@ def get_model_complexity_info(model: nn.Module,
:param output_precision: Floating point precision for representing MACs/params in
given units.
:type output_precision: int
:param backend_specific_config: Extra configuration for a specific backend.
:type backend_specific_config: dict

Returns:
Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple
Expand All @@ -85,14 +90,16 @@ def get_model_complexity_info(model: nn.Module,
assert isinstance(model, nn.Module)

if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH:
flops_count, params_count = get_flops_pytorch(model, input_res,
print_per_layer_stat,
input_constructor, ost,
verbose, ignore_modules,
custom_modules_hooks,
output_precision=output_precision,
flops_units=flops_units,
param_units=param_units)
flops_count, params_count = \
get_flops_pytorch(model, input_res,
print_per_layer_stat,
input_constructor, ost,
verbose, ignore_modules,
custom_modules_hooks,
output_precision=output_precision,
flops_units=flops_units,
param_units=param_units,
extra_config=backend_specific_config)
elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN:
flops_count, params_count = get_flops_aten(model, input_res,
print_per_layer_stat,
Expand All @@ -101,7 +108,8 @@ def get_model_complexity_info(model: nn.Module,
custom_modules_hooks,
output_precision=output_precision,
flops_units=flops_units,
param_units=param_units)
param_units=param_units,
extra_config=backend_specific_config)
else:
raise ValueError('Wrong backend name')

Expand Down
18 changes: 11 additions & 7 deletions ptflops/pytorch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
import traceback
from functools import partial
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -27,8 +27,9 @@ def get_flops_pytorch(model, input_res,
custom_modules_hooks={},
output_precision=2,
flops_units: Optional[str] = 'GMac',
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
Union[int, None]]:
param_units: Optional[str] = 'M',
extra_config: Dict = {}) -> Tuple[Union[int, None],
Union[int, None]]:
global CUSTOM_MODULES_MAPPING
CUSTOM_MODULES_MAPPING = custom_modules_hooks
flops_model = add_flops_counting_methods(model)
Expand All @@ -45,15 +46,18 @@ def get_flops_pytorch(model, input_res,
except StopIteration:
batch = torch.ones(()).new_empty((1, *input_res))

enable_func_ops_patching = extra_config.get('count_functional', True)
torch_functional_flops = []
torch_tensor_ops_flops = []
patch_functional(torch_functional_flops)
patch_tensor_ops(torch_tensor_ops_flops)
if enable_func_ops_patching:
patch_functional(torch_functional_flops)
patch_tensor_ops(torch_tensor_ops_flops)

def reset_environment():
flops_model.stop_flops_count()
unpatch_functional()
unpatch_tensor_ops()
if enable_func_ops_patching:
unpatch_functional()
unpatch_tensor_ops()
global CUSTOM_MODULES_MAPPING
CUSTOM_MODULES_MAPPING = {}

Expand Down
70 changes: 57 additions & 13 deletions tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ class TestOperations:
def default_input_image_size(self):
return (3, 224, 224)

@pytest.fixture
def simple_model_mm(self):
class CustomModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.matmul(x.t())

return CustomModel()

@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True))
Expand Down Expand Up @@ -53,7 +64,8 @@ def input_constructor(input_res):
macs, params = get_model_complexity_info(net, (3,),
input_constructor=input_constructor,
as_strings=False,
print_per_layer_stat=False)
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)

assert (macs, params) == (8, 8)

Expand All @@ -73,7 +85,8 @@ def input_constructor(input_res):
get_model_complexity_info(CustomLinear(), (3,),
input_constructor=input_constructor,
as_strings=False,
print_per_layer_stat=False)
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)

assert (macs, params) == (8, 8)

Expand All @@ -89,7 +102,8 @@ def forward(self, x):
macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False)
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)
assert params == 0
assert macs > 0

Expand All @@ -99,22 +113,52 @@ def forward(self, x):
macs, params = \
get_model_complexity_info(CustomModel(), (3, 10, 10),
as_strings=False,
print_per_layer_stat=False)
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)
assert params == 0
assert macs > 0

def test_ten_matmul(self):
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
def test_ten_matmul(self, simple_model_mm):
macs, params = \
get_model_complexity_info(simple_model_mm, (10, ),
as_strings=False,
print_per_layer_stat=False,
backend=FLOPS_BACKEND.PYTORCH)

def forward(self, x):
return x.matmul(x.t())
assert params == 0
assert macs > 0

def test_aten_ignore(self, simple_model_mm):
ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm]
macs, params = \
get_model_complexity_info(CustomModel(), (10, ),
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
as_strings=False,
print_per_layer_stat=False)
print_per_layer_stat=False,
ignore_modules=ignored_list)

assert params == 0
assert macs > 0
assert macs == 0

def test_aten_custom(self, simple_model_mm):
reference = 42
custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference}

macs, params = \
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
as_strings=False,
print_per_layer_stat=False,
custom_modules_hooks=custom_hooks)

assert params == 0
assert macs == reference

def test_torch_ignore_func(self, simple_model_mm):
macs, params = \
get_model_complexity_info(simple_model_mm, (10, ),
backend=FLOPS_BACKEND.PYTORCH,
as_strings=False,
print_per_layer_stat=False,
backend_specific_config={'count_functional': False})

assert params == 0
assert macs == 0
Loading