From ec150775860c7648ee31070af2fea90a9251b308 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 28 Dec 2023 08:59:45 +0900 Subject: [PATCH 01/13] Del useless conversion in tests --- tests/common_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/common_test.py b/tests/common_test.py index 510babf..8cb6614 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -17,7 +17,7 @@ def test_conv(self, default_input_image_size): print_per_layer_stat=False) assert params == 3 * 3 * 2 * 3 + 2 - assert int(macs) == 2759904 + assert macs == 2759904 def test_fc(self): net = nn.Sequential(nn.Linear(3, 2, bias=True)) @@ -26,7 +26,7 @@ def test_fc(self): print_per_layer_stat=False) assert params == 3 * 2 + 2 - assert int(macs) == 8 + assert macs == 8 def test_fc_multidim(self): net = nn.Sequential(nn.Linear(3, 2, bias=True)) @@ -35,7 +35,7 @@ def test_fc_multidim(self): print_per_layer_stat=False) assert params == (3 * 2 + 2) - assert int(macs) == (3 * 2 + 2) * 4 * 5 + assert macs == (3 * 2 + 2) * 4 * 5 def test_input_constructor_tensor(self): net = nn.Sequential(nn.Linear(3, 2, bias=True)) From aa7a7f97b41814429f4490c84ac6f19578688351 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 4 Jan 2024 07:49:40 +0900 Subject: [PATCH 02/13] Add experimental aten backend --- ptflops/__init__.py | 5 +- ptflops/aten_engine.py | 112 +++++++++++++++++++++++++++++++++++++ ptflops/aten_ops.py | 118 +++++++++++++++++++++++++++++++++++++++ ptflops/flops_counter.py | 24 +++++++- 4 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 ptflops/aten_engine.py create mode 100644 ptflops/aten_ops.py diff --git a/ptflops/__init__.py b/ptflops/__init__.py index 445292f..61560b3 100644 --- a/ptflops/__init__.py +++ b/ptflops/__init__.py @@ -1,5 +1,5 @@ ''' -Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved +Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved * You may use, distribute and modify this code under the * terms of the MIT license. * You should have received a copy of the MIT license with @@ -7,11 +7,12 @@ ''' -from .flops_counter import get_model_complexity_info +from .flops_counter import FLOPS_BACKEND, get_model_complexity_info from .utils import flops_to_string, params_to_string __all__ = [ "get_model_complexity_info", "flops_to_string", "params_to_string", + "FLOPS_BACKEND", ] diff --git a/ptflops/aten_engine.py b/ptflops/aten_engine.py new file mode 100644 index 0000000..68028d1 --- /dev/null +++ b/ptflops/aten_engine.py @@ -0,0 +1,112 @@ +''' +Copyright (C) 2024 Sovrasov V. - All Rights Reserved + * You may use, distribute and modify this code under the + * terms of the MIT license. + * You should have received a copy of the MIT license with + * this file. If not visit https://opensource.org/licenses/MIT +''' + + +import sys +import traceback +from collections import defaultdict +from typing import Optional, Tuple, Union + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + +from ptflops.pytorch_engine import get_model_parameters_number +from .aten_ops import ATEN_OPS_MAPPING + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +class FlopCounterMode(TorchDispatchMode): + def __init__(self, module=None): + self.flop_counts = defaultdict(lambda: defaultdict(int)) + self.parents = ['Global'] + self._total_complexity = None + if module is not None: + for name, mod in dict(module.named_children()).items(): + mod.register_forward_pre_hook(self.enter_module(name)) + mod.register_forward_hook(self.exit_module(name)) + + @property + def complexity(self): + return self._total_complexity + + def enter_module(self, name): + def f(*args): + self.parents.append(name) + return f + + def exit_module(self, name): + def f(*args): + assert(self.parents[-1] == name) + self.parents.pop() + return f + + def __enter__(self): + self.flop_counts.clear() + super().__enter__() + + def __exit__(self, *args): + self._total_complexity = sum(self.flop_counts['Global'].values()) + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + out = func(*args, **kwargs) + func_packet = func._overloadpacket + if func_packet in ATEN_OPS_MAPPING: + flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out)) + for par in self.parents: + self.flop_counts[par][func_packet] += flop_count + + return out + + +def get_flops_aten(model, input_res, + print_per_layer_stat=True, + input_constructor=None, ost=sys.stdout, + verbose=False, ignore_modules=[], + custom_modules_hooks={}, + output_precision=2, + flops_units: Optional[str] = 'GMac', + param_units: Optional[str] = 'M') -> Tuple[Union[int, None], + Union[int, None]]: + + params_sum = get_model_parameters_number(model) + + if input_constructor: + batch = input_constructor(input_res) + else: + try: + batch = torch.ones(()).new_empty((1, *input_res), + dtype=next(model.parameters()).dtype, + device=next(model.parameters()).device) + except StopIteration: + batch = torch.ones(()).new_empty((1, *input_res)) + + try: + counter = FlopCounterMode(model) + with counter: + if isinstance(batch, dict): + _ = model(**batch) + else: + _ = model(batch) + macs_count = counter.complexity + + except Exception as e: + print("Flops estimation was not finished successfully because of" + f"the following exception:\n{type(e)} : {e}") + traceback.print_exc() + + return None, None + + return macs_count, params_sum diff --git a/ptflops/aten_ops.py b/ptflops/aten_ops.py new file mode 100644 index 0000000..0e1e949 --- /dev/null +++ b/ptflops/aten_ops.py @@ -0,0 +1,118 @@ +''' +Copyright (C) 2023 Sovrasov V. - All Rights Reserved + * You may use, distribute and modify this code under the + * terms of the MIT license. + * You should have received a copy of the MIT license with + * this file. If not visit https://opensource.org/licenses/MIT +''' + +from typing import Any, List + +import torch + +aten = torch.ops.aten + + +def get_shape(i): + return i.shape + + +def prod(x): + res = 1 + for i in x: + res *= i + return res + + +def matmul_flop(inputs: List[Any], outputs: List[Any]) -> int: + """ + Count flops for matmul. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + input_shapes = [get_shape(v) for v in inputs] + assert len(input_shapes) == 2, input_shapes + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + flop = prod(input_shapes[0]) * input_shapes[-1][-1] + return flop + + +def addmm_flop(inputs: List[Any], outputs: List[Any]) -> int: + """ + Count flops for fully connected layers. + """ + # Count flop for nn.Linear + # inputs is a list of length 3. + input_shapes = [get_shape(v) for v in inputs[1:3]] + # input_shapes[0]: [batch size, input feature dimension] + # input_shapes[1]: [batch size, output feature dimension] + assert len(input_shapes[0]) == 2, input_shapes[0] + assert len(input_shapes[1]) == 2, input_shapes[1] + batch_size, input_dim = input_shapes[0] + output_dim = input_shapes[1][1] + flops = batch_size * input_dim * output_dim + return flops + + +def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int: + """ + Count flops for the bmm operation. + """ + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + assert len(inputs) == 2, len(inputs) + input_shapes = [get_shape(v) for v in inputs] + n, c, t = input_shapes[0] + d = input_shapes[-1][-1] + flop = n * c * t * d + return flop + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> int: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flop = batch_size * prod(w_shape) * prod(conv_shape) + return flop + + +def conv_flop(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +ATEN_OPS_MAPPING = { + aten.mm: matmul_flop, + aten.matmul: matmul_flop, + aten.addmm: addmm_flop, + aten.bmm: bmm_flop, + aten.convolution: conv_flop, + aten._convolution: conv_flop, +} diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index 20caf34..d43e403 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -1,5 +1,5 @@ ''' -Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved +Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved * You may use, distribute and modify this code under the * terms of the MIT license. * You should have received a copy of the MIT license with @@ -7,14 +7,21 @@ ''' import sys +from enum import Enum from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union import torch.nn as nn +from .aten_engine import get_flops_aten from .pytorch_engine import get_flops_pytorch from .utils import flops_to_string, params_to_string +class FLOPS_BACKEND(Enum): + PYTORCH = 'pytorch' + ATEN = 'aten' + + def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...], print_per_layer_stat: bool = True, @@ -24,7 +31,7 @@ def get_model_complexity_info(model: nn.Module, verbose: bool = False, ignore_modules: List[nn.Module] = [], custom_modules_hooks: Dict[nn.Module, Any] = {}, - backend: str = 'pytorch', + backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH, flops_units: Optional[str] = None, param_units: Optional[str] = None, output_precision: int = 2) -> Tuple[Union[str, int, None], @@ -58,6 +65,8 @@ def get_model_complexity_info(model: nn.Module, :type ignore_modules: nn.Module :param custom_modules_hooks: A dict that contains custom hooks on torch modules. :type custom_modules_hooks: Dict[nn.Module, Any] + :param backend: Backend that used for evaluating model complexity. + :type backend: FLOPS_BACKEND :param flops_units: Units for string representation of MACs (GMac, MMac or KMac). :type flops_units: Optional[str] :param param_units: Units for string representation of params (M, K or B). @@ -75,7 +84,7 @@ def get_model_complexity_info(model: nn.Module, assert len(input_res) >= 1 assert isinstance(model, nn.Module) - if backend == 'pytorch': + if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH: flops_count, params_count = get_flops_pytorch(model, input_res, print_per_layer_stat, input_constructor, ost, @@ -84,6 +93,15 @@ def get_model_complexity_info(model: nn.Module, output_precision=output_precision, flops_units=flops_units, param_units=param_units) + elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN: + flops_count, params_count = get_flops_aten(model, input_res, + print_per_layer_stat, + input_constructor, ost, + verbose, ignore_modules, + custom_modules_hooks, + output_precision=output_precision, + flops_units=flops_units, + param_units=param_units) else: raise ValueError('Wrong backend name') From 48432922df5e96fb395ab4df589f12a9da2ceb89 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 4 Jan 2024 08:57:45 +0900 Subject: [PATCH 03/13] Forward printing parametere to aten backend --- ptflops/aten_engine.py | 40 ++++++++++++++++++++++++++++++++-------- ptflops/aten_ops.py | 4 ---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/ptflops/aten_engine.py b/ptflops/aten_engine.py index 68028d1..ae61351 100644 --- a/ptflops/aten_engine.py +++ b/ptflops/aten_engine.py @@ -7,6 +7,7 @@ ''' +from functools import partial import sys import traceback from collections import defaultdict @@ -16,17 +17,20 @@ from torch.utils._python_dispatch import TorchDispatchMode from ptflops.pytorch_engine import get_model_parameters_number +from ptflops.utils import flops_to_string from .aten_ops import ATEN_OPS_MAPPING -def normalize_tuple(x): - if not isinstance(x, tuple): - return (x,) - return x - - class FlopCounterMode(TorchDispatchMode): - def __init__(self, module=None): + def __init__(self, module=None, verbose=False, print_per_layer_stat=False, + output_params=None): + self.verbose = verbose + if output_params is None: + output_params = defaultdict(dict) + self.output_params = output_params + self.print_fn = partial(print, **self.output_params['print_params']) + + self.print_per_layer_stat = print_per_layer_stat self.flop_counts = defaultdict(lambda: defaultdict(int)) self.parents = ['Global'] self._total_complexity = None @@ -56,9 +60,24 @@ def __enter__(self): def __exit__(self, *args): self._total_complexity = sum(self.flop_counts['Global'].values()) + if self.print_per_layer_stat: + self.print_fn('Total:' + + flops_to_string(self._total_complexity, + **self.output_params['serialize_params'])) + for mod in self.flop_counts.keys(): + self.print_fn("Module: ", mod) + for k, v in self.flop_counts[mod].items(): + self.print_fn( + f'{k}: ' + + flops_to_string(v, **self.output_params['serialize_params'])) + self.print_fn() super().__exit__(*args) def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x kwargs = kwargs if kwargs else {} out = func(*args, **kwargs) @@ -67,6 +86,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out)) for par in self.parents: self.flop_counts[par][func_packet] += flop_count + elif self.verbose: + self.print_fn(f'Warning: {func_packet} operation is treated as a zero-op') return out @@ -82,6 +103,9 @@ def get_flops_aten(model, input_res, Union[int, None]]: params_sum = get_model_parameters_number(model) + output_params = {'serialize_params': + {'units': flops_units, 'precision': output_precision}, + 'print_params': {'file': ost}} if input_constructor: batch = input_constructor(input_res) @@ -94,7 +118,7 @@ def get_flops_aten(model, input_res, batch = torch.ones(()).new_empty((1, *input_res)) try: - counter = FlopCounterMode(model) + counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params) with counter: if isinstance(batch, dict): _ = model(**batch) diff --git a/ptflops/aten_ops.py b/ptflops/aten_ops.py index 0e1e949..0282936 100644 --- a/ptflops/aten_ops.py +++ b/ptflops/aten_ops.py @@ -104,10 +104,6 @@ def conv_flop(inputs: List[Any], outputs: List[Any]): return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) -def transpose_shape(shape): - return [shape[1], shape[0]] + list(shape[2:]) - - ATEN_OPS_MAPPING = { aten.mm: matmul_flop, aten.matmul: matmul_flop, From 759e15fc1848a1b2249970b9ee77084490c0a12c Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 4 Jan 2024 09:00:19 +0900 Subject: [PATCH 04/13] Add backend parameter to the sample --- samples/classification.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/samples/classification.py b/samples/classification.py index 5b155af..87a5dfa 100644 --- a/samples/classification.py +++ b/samples/classification.py @@ -27,6 +27,8 @@ help='Device to store the model.') parser.add_argument('--model', choices=list(pt_models.keys()), type=str, default='resnet18') + parser.add_argument('--backend', choices=list(['pytorch', 'aten']), + type=str, default='pytorch') parser.add_argument('--result', type=str, default=None) args = parser.parse_args() @@ -42,6 +44,7 @@ macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, + backend=args.backend, print_per_layer_stat=True, ost=ost) print('{:<30} {:<8}'.format('Computational complexity: ', macs)) From bff19ea94f21506c37d5b558ed32629b4a4a6201 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Mon, 5 Feb 2024 14:34:03 +0900 Subject: [PATCH 05/13] Fix isort --- ptflops/aten_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptflops/aten_engine.py b/ptflops/aten_engine.py index ae61351..554a01e 100644 --- a/ptflops/aten_engine.py +++ b/ptflops/aten_engine.py @@ -7,10 +7,10 @@ ''' -from functools import partial import sys import traceback from collections import defaultdict +from functools import partial from typing import Optional, Tuple, Union import torch From fe667a2d4c638ad287437097b3826a600c9885a2 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Mon, 5 Feb 2024 15:18:10 +0900 Subject: [PATCH 06/13] Align conv between backends --- ptflops/aten_ops.py | 26 ++++++++++++++------------ tests/common_test.py | 4 +++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/ptflops/aten_ops.py b/ptflops/aten_ops.py index 0282936..6b86b55 100644 --- a/ptflops/aten_ops.py +++ b/ptflops/aten_ops.py @@ -17,7 +17,7 @@ def get_shape(i): return i.shape -def prod(x): +def prod(x) -> int: res = 1 for i in x: res *= i @@ -59,7 +59,7 @@ def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int: Count flops for the bmm operation. """ # Inputs should be a list of length 2. - # Inputs contains the shapes of two tensor. + # Inputs contains the shapes of two tensors. assert len(inputs) == 2, len(inputs) input_shapes = [get_shape(v) for v in inputs] n, c, t = input_shapes[0] @@ -73,23 +73,25 @@ def conv_flop_count( w_shape: List[int], out_shape: List[int], transposed: bool = False, + bias: bool = False, ) -> int: """ - Count flops for convolution. Note only multiplication is - counted. Computation for addition and bias is ignored. - Flops for a transposed convolution are calculated as - flops = (x_shape[2:] * prod(w_shape) * batch_size). + Count MACs for convolution. + Summation is ignored when applying conv kernel, but counted for bias. Args: x_shape (list(int)): The input shape before convolution. w_shape (list(int)): The filter shape. out_shape (list(int)): The output shape after convolution. transposed (bool): is the convolution transposed + bias (bool): is the bias counted Returns: - int: the number of flops + int: the number of MACs """ batch_size = x_shape[0] conv_shape = (x_shape if transposed else out_shape)[2:] flop = batch_size * prod(w_shape) * prod(conv_shape) + if bias: + flop += batch_size * out_shape[1] * prod(out_shape[2:]) return flop @@ -97,11 +99,11 @@ def conv_flop(inputs: List[Any], outputs: List[Any]): """ Count flops for convolution. """ - x, w = inputs[:2] - x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) - transposed = inputs[6] - - return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + (input, w, b, stride, pad, dilation, + transposed, _, groups) = inputs + output = outputs[0] + return conv_flop_count(input.shape, w.shape, output.shape, + transposed=transposed, bias=b is not None) ATEN_OPS_MAPPING = { diff --git a/tests/common_test.py b/tests/common_test.py index 8cb6614..470db97 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -3,6 +3,7 @@ import torch.nn as nn from ptflops import get_model_complexity_info +from ptflops.flops_counter import FLOPS_BACKEND class TestOperations: @@ -10,7 +11,8 @@ class TestOperations: def default_input_image_size(self): return (3, 224, 224) - def test_conv(self, default_input_image_size): + @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) + def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True)) macs, params = get_model_complexity_info(net, default_input_image_size, as_strings=False, From 786bea306a8aff05fc34e21af5bc4d906d425c60 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Mon, 5 Feb 2024 15:35:11 +0900 Subject: [PATCH 07/13] Align computations for FC --- ptflops/aten_ops.py | 30 +++++++++++++++--------------- tests/common_test.py | 13 ++++++++----- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/ptflops/aten_ops.py b/ptflops/aten_ops.py index 6b86b55..567ea88 100644 --- a/ptflops/aten_ops.py +++ b/ptflops/aten_ops.py @@ -13,11 +13,7 @@ aten = torch.ops.aten -def get_shape(i): - return i.shape - - -def prod(x) -> int: +def prod(x: torch.Size) -> int: res = 1 for i in x: res *= i @@ -30,7 +26,7 @@ def matmul_flop(inputs: List[Any], outputs: List[Any]) -> int: """ # Inputs should be a list of length 2. # Inputs contains the shapes of two matrices. - input_shapes = [get_shape(v) for v in inputs] + input_shapes = [v.shape for v in inputs] assert len(input_shapes) == 2, input_shapes assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes flop = prod(input_shapes[0]) * input_shapes[-1][-1] @@ -39,11 +35,11 @@ def matmul_flop(inputs: List[Any], outputs: List[Any]) -> int: def addmm_flop(inputs: List[Any], outputs: List[Any]) -> int: """ - Count flops for fully connected layers. + Count flops for fully connected layers (nn.Linear). + Bias is considered if exists. """ - # Count flop for nn.Linear - # inputs is a list of length 3. - input_shapes = [get_shape(v) for v in inputs[1:3]] + # inputs: bias, input, weight + input_shapes = [v.shape for v in inputs[1:3]] # input_shapes[0]: [batch size, input feature dimension] # input_shapes[1]: [batch size, output feature dimension] assert len(input_shapes[0]) == 2, input_shapes[0] @@ -51,6 +47,10 @@ def addmm_flop(inputs: List[Any], outputs: List[Any]) -> int: batch_size, input_dim = input_shapes[0] output_dim = input_shapes[1][1] flops = batch_size * input_dim * output_dim + + if inputs[0] is not None: + flops += batch_size * output_dim + return flops @@ -61,7 +61,7 @@ def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int: # Inputs should be a list of length 2. # Inputs contains the shapes of two tensors. assert len(inputs) == 2, len(inputs) - input_shapes = [get_shape(v) for v in inputs] + input_shapes = [v.shape for v in inputs] n, c, t = input_shapes[0] d = input_shapes[-1][-1] flop = n * c * t * d @@ -69,9 +69,9 @@ def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int: def conv_flop_count( - x_shape: List[int], - w_shape: List[int], - out_shape: List[int], + x_shape: torch.Size, + w_shape: torch.Size, + out_shape: torch.Size, transposed: bool = False, bias: bool = False, ) -> int: @@ -95,7 +95,7 @@ def conv_flop_count( return flop -def conv_flop(inputs: List[Any], outputs: List[Any]): +def conv_flop(inputs: List[Any], outputs: List[Any]) -> int: """ Count flops for convolution. """ diff --git a/tests/common_test.py b/tests/common_test.py index 470db97..f383b55 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -16,25 +16,28 @@ def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True)) macs, params = get_model_complexity_info(net, default_input_image_size, as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + backend=backend) assert params == 3 * 3 * 2 * 3 + 2 assert macs == 2759904 - def test_fc(self): + @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) + def test_fc(self, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Linear(3, 2, bias=True)) macs, params = get_model_complexity_info(net, (3,), as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, backend=backend) assert params == 3 * 2 + 2 assert macs == 8 - def test_fc_multidim(self): + @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) + def test_fc_multidim(self, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Linear(3, 2, bias=True)) macs, params = get_model_complexity_info(net, (4, 5, 3), as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, backend=backend) assert params == (3 * 2 + 2) assert macs == (3 * 2 + 2) * 4 * 5 From 5258e5946f296afce1b1be39e81c3a8ce55e5958 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Mon, 5 Feb 2024 15:40:08 +0900 Subject: [PATCH 08/13] Switch off bias in FC test --- tests/common_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/common_test.py b/tests/common_test.py index f383b55..e648b7d 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -27,20 +27,22 @@ def test_fc(self, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Linear(3, 2, bias=True)) macs, params = get_model_complexity_info(net, (3,), as_strings=False, - print_per_layer_stat=False, backend=backend) + print_per_layer_stat=False, + backend=backend) assert params == 3 * 2 + 2 assert macs == 8 @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) def test_fc_multidim(self, backend: FLOPS_BACKEND): - net = nn.Sequential(nn.Linear(3, 2, bias=True)) + net = nn.Sequential(nn.Linear(3, 2, bias=False)) macs, params = get_model_complexity_info(net, (4, 5, 3), as_strings=False, - print_per_layer_stat=False, backend=backend) + print_per_layer_stat=False, + backend=backend) - assert params == (3 * 2 + 2) - assert macs == (3 * 2 + 2) * 4 * 5 + assert params == 3 * 2 + assert macs == (3 * 2) * 4 * 5 def test_input_constructor_tensor(self): net = nn.Sequential(nn.Linear(3, 2, bias=True)) From c43ea25b2698bfca8685c041079e52a051812336 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 4 May 2024 10:41:48 +0900 Subject: [PATCH 09/13] Update readme --- README.md | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ce8d79d..4b1575d 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,26 @@ [![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/) [![Build Status](https://travis-ci.com/sovrasov/flops-counter.pytorch.svg?branch=master)](https://travis-ci.com/sovrasov/flops-counter.pytorch) -This script is designed to compute the theoretical amount of multiply-add operations -in convolutional neural networks. It can also compute the number of parameters and +This tool is designed to compute the theoretical amount of multiply-add operations +in neural networks. It can also compute the number of parameters and print per-layer computational cost of a given network. -Supported layers: +`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However, +it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use +`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers). + +## `aten` backend +### Operations considered: +- aten.mm, aten.matmul, aten.addmm, aten.bmm +- aten.convolution + +### Usage tips +- Use `verbose=True` to see the operations which were not considered during complexity computation. +- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`. +Deeper modules at the second level of nesting are not shown in the per-layer statistics. + +## `pytorch` backend +### Supported layers: - Conv1d/2d/3d (including grouping) - ConvTranspose1d/2d/3d (including grouping) - BatchNorm1d/2d/3d, GroupNorm, InstanceNorm1d/2d/3d, LayerNorm @@ -22,13 +37,9 @@ Experimental support: - torchvision.ops.DeformConv2d - visual transformers from [timm](https://github.com/huggingface/pytorch-image-models) -Requirements: Pytorch >= 1.1, torchvision >= 0.3 +### Usage tips -Thanks to @warmspringwinds for the initial version of script. - -## Usage tips - -- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are +- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops. - `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments. - `verbose` parameter allows to get information about modules that don't contribute to the final numbers. @@ -36,6 +47,10 @@ not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py for research purposes. For instance, one can drop all convolutions from the counting process specifying `ignore_modules=[torch.nn.Conv2d]`. +Requirements: Pytorch >= 1.1, torchvision >= 0.3 + +Thanks to @warmspringwinds and Horace He for the initial version of the script. + ## Install the latest version From PyPI: ```bash @@ -55,7 +70,12 @@ from ptflops import get_model_complexity_info with torch.cuda.device(0): net = models.densenet161() - macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, + macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch' + print_per_layer_stat=True, verbose=True) + print('{:<30} {:<8}'.format('Computational complexity: ', macs)) + print('{:<30} {:<8}'.format('Number of parameters: ', params)) + + macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten' print_per_layer_stat=True, verbose=True) print('{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params)) @@ -67,7 +87,7 @@ If ptflops was useful for your paper or tech report, please cite me: @online{ptflops, author = {Vladislav Sovrasov}, title = {ptflops: a flops counting tool for neural networks in pytorch framework}, - year = 2018-2023, + year = 2018-2024, url = {https://github.com/sovrasov/flops-counter.pytorch}, } ``` From 119851bd3e09b530f27f644de6ec3d168b7783a8 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 4 May 2024 10:47:45 +0900 Subject: [PATCH 10/13] Update changelog --- CHANGELOG.md | 3 +++ README.md | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea8d301..2a8dfbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # ptflops versions log +## v 0.7.3 +- Add aten backend to collect the amount of flops on aten level. + ## v 0.7.2.2 - Switch from setup.py to pyproject diff --git a/README.md b/README.md index 4b1575d..f62ac41 100644 --- a/README.md +++ b/README.md @@ -118,3 +118,8 @@ squeezenet1_0 | 224x224 | 1.25 | 0.84 vgg16 | 224x224 | 138.36 | 15.52 vit_b_16 | 224x224 | 86.57 | 17.60 wide_resnet50_2 | 224x224 | 68.88 | 11.45 + + +### [timm](https://github.com/huggingface/pytorch-image-models) + +Model | Input Resolution | Params(M) | MACs(G) From 97251b5e94524617887d407b81a7ed5108c4751e Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 4 May 2024 11:05:49 +0900 Subject: [PATCH 11/13] Fix inception input in cls sample --- samples/classification.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/samples/classification.py b/samples/classification.py index 87a5dfa..bca2950 100644 --- a/samples/classification.py +++ b/samples/classification.py @@ -15,7 +15,8 @@ 'squeezenet': models.squeezenet1_0, 'densenet': models.densenet161, 'inception': models.inception_v3, - 'convnext_base': models.convnext_base} + 'convnext_base': models.convnext_base, + 'vit_b_16': models.vit_b_16} if version.parse(torchvision.__version__) > version.parse('0.15'): pt_models['vit_b_16'] = models.vit_b_16 @@ -42,7 +43,12 @@ if torch.cuda.is_available(): net.cuda(device=args.device) - macs, params = get_model_complexity_info(net, (3, 224, 224), + if args.model == 'inception': + input_res = (3, 299, 299) + else: + input_res = (3, 224, 224) + + macs, params = get_model_complexity_info(net, input_res, as_strings=True, backend=args.backend, print_per_layer_stat=True, From 606d30da6252b4a87bd8f60027d116751b4e9502 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 4 May 2024 11:06:20 +0900 Subject: [PATCH 12/13] Update error msg in engines --- ptflops/aten_engine.py | 3 ++- ptflops/pytorch_engine.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ptflops/aten_engine.py b/ptflops/aten_engine.py index 554a01e..e5e6457 100644 --- a/ptflops/aten_engine.py +++ b/ptflops/aten_engine.py @@ -103,6 +103,7 @@ def get_flops_aten(model, input_res, Union[int, None]]: params_sum = get_model_parameters_number(model) + model.eval() output_params = {'serialize_params': {'units': flops_units, 'precision': output_precision}, 'print_params': {'file': ost}} @@ -128,7 +129,7 @@ def get_flops_aten(model, input_res, except Exception as e: print("Flops estimation was not finished successfully because of" - f"the following exception:\n{type(e)} : {e}") + f" the following exception:\n{type(e)} : {e}") traceback.print_exc() return None, None diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index 9f6d125..bc24572 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -68,7 +68,7 @@ def reset_environment(): except Exception as e: print("Flops estimation was not finished successfully because of" - f"the following exception:\n{type(e)} : {e}") + f" the following exception:\n{type(e)} : {e}") traceback.print_exc() reset_environment() From 81ace22f0c80d3e10eed82fa94404ae9d3a12c77 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 4 May 2024 11:10:21 +0900 Subject: [PATCH 13/13] Update benchmark in readme --- README.md | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index f62ac41..54eddb6 100644 --- a/README.md +++ b/README.md @@ -96,28 +96,28 @@ If ptflops was useful for your paper or tech report, please cite me: ### [torchvision](https://pytorch.org/vision/0.16/models.html) -Model | Input Resolution | Params(M) | MACs(G) ---- |--- |--- |--- -alexnet | 224x224 | 61.10 | 0.72 -convnext_base | 224x224 | 88.59 | 15.43 -densenet121 | 224x224 | 7.98 | 2.90 -efficientnet_b0 | 224x224 | 5.29 | 0.41 -efficientnet_v2_m | 224x224 | 54.14 | 5.43 -googlenet | 224x224 | 13.00 | 1.51 -inception_v3 | 224x224 | 27.16 | 2.86 -maxvit_t | 224x224 | 30.92 | 5.48 -mnasnet1_0 | 224x224 | 4.38 | 0.33 -mobilenet_v2 | 224x224 | 3.50 | 0.32 -mobilenet_v3_large | 224x224 | 5.48 | 0.23 -regnet_y_1_6gf | 224x224 | 11.20 | 1.65 -resnet18 | 224x224 | 11.69 | 1.83 -resnet50 | 224x224 | 25.56 | 4.13 -resnext50_32x4d | 224x224 | 25.03 | 4.29 -shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 -squeezenet1_0 | 224x224 | 1.25 | 0.84 -vgg16 | 224x224 | 138.36 | 15.52 -vit_b_16 | 224x224 | 86.57 | 17.60 -wide_resnet50_2 | 224x224 | 68.88 | 11.45 +Model | Input Resolution | Params(M) | MACs(G) (`pytorch`) | MACs(G) (`aten`) +--- |--- |--- |--- |--- +alexnet | 224x224 | 61.10 | 0.72 | 0.71 +convnext_base | 224x224 | 88.59 | 15.43 | 15.38 +densenet121 | 224x224 | 7.98 | 2.90 | +efficientnet_b0 | 224x224 | 5.29 | 0.41 | +efficientnet_v2_m | 224x224 | 54.14 | 5.43 | +googlenet | 224x224 | 13.00 | 1.51 | +inception_v3 | 224x224 | 27.16 | 5.75 | 5.71 +maxvit_t | 224x224 | 30.92 | 5.48 | +mnasnet1_0 | 224x224 | 4.38 | 0.33 | +mobilenet_v2 | 224x224 | 3.50 | 0.32 | +mobilenet_v3_large | 224x224 | 5.48 | 0.23 | +regnet_y_1_6gf | 224x224 | 11.20 | 1.65 | +resnet18 | 224x224 | 11.69 | 1.83 | 1.81 +resnet50 | 224x224 | 25.56 | 4.13 | 4.09 +resnext50_32x4d | 224x224 | 25.03 | 4.29 | +shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 | +squeezenet1_0 | 224x224 | 1.25 | 0.84 | 0.82 +vgg16 | 224x224 | 138.36 | 15.52 | 15.48 +vit_b_16 | 224x224 | 86.57 | 17.61 (wrong) | 16.86 +wide_resnet50_2 | 224x224 | 68.88 | 11.45 | ### [timm](https://github.com/huggingface/pytorch-image-models)