Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten backend #133

Merged
merged 13 commits into from
May 3, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ptflops versions log

## v 0.7.3
- Add aten backend to collect the amount of flops on aten level.

## v 0.7.2.2
- Switch from setup.py to pyproject

Expand Down
89 changes: 57 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@
[![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/)
[![Build Status](https://travis-ci.com/sovrasov/flops-counter.pytorch.svg?branch=master)](https://travis-ci.com/sovrasov/flops-counter.pytorch)

This script is designed to compute the theoretical amount of multiply-add operations
in convolutional neural networks. It can also compute the number of parameters and
This tool is designed to compute the theoretical amount of multiply-add operations
in neural networks. It can also compute the number of parameters and
print per-layer computational cost of a given network.

Supported layers:
`ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However,
it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use
`aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers).

## `aten` backend
### Operations considered:
- aten.mm, aten.matmul, aten.addmm, aten.bmm
- aten.convolution

### Usage tips
- Use `verbose=True` to see the operations which were not considered during complexity computation.
- This backend prints per-module statistics only for modules directly nested into the root `nn.Module`.
Deeper modules at the second level of nesting are not shown in the per-layer statistics.

## `pytorch` backend
### Supported layers:
- Conv1d/2d/3d (including grouping)
- ConvTranspose1d/2d/3d (including grouping)
- BatchNorm1d/2d/3d, GroupNorm, InstanceNorm1d/2d/3d, LayerNorm
Expand All @@ -22,20 +37,20 @@ Experimental support:
- torchvision.ops.DeformConv2d
- visual transformers from [timm](https://github.com/huggingface/pytorch-image-models)

Requirements: Pytorch >= 1.1, torchvision >= 0.3

Thanks to @warmspringwinds for the initial version of script.

## Usage tips
### Usage tips

- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
- This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
for research purposes. For instance, one can drop all convolutions from the counting process
specifying `ignore_modules=[torch.nn.Conv2d]`.

Requirements: Pytorch >= 1.1, torchvision >= 0.3

Thanks to @warmspringwinds and Horace He for the initial version of the script.

## Install the latest version
From PyPI:
```bash
Expand All @@ -55,7 +70,12 @@ from ptflops import get_model_complexity_info

with torch.cuda.device(0):
net = models.densenet161()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch'
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))

macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten'
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
Expand All @@ -67,7 +87,7 @@ If ptflops was useful for your paper or tech report, please cite me:
@online{ptflops,
author = {Vladislav Sovrasov},
title = {ptflops: a flops counting tool for neural networks in pytorch framework},
year = 2018-2023,
year = 2018-2024,
url = {https://github.com/sovrasov/flops-counter.pytorch},
}
```
Expand All @@ -76,25 +96,30 @@ If ptflops was useful for your paper or tech report, please cite me:

### [torchvision](https://pytorch.org/vision/0.16/models.html)

Model | Input Resolution | Params(M) | MACs(G) (`pytorch`) | MACs(G) (`aten`)
--- |--- |--- |--- |---
alexnet | 224x224 | 61.10 | 0.72 | 0.71
convnext_base | 224x224 | 88.59 | 15.43 | 15.38
densenet121 | 224x224 | 7.98 | 2.90 |
efficientnet_b0 | 224x224 | 5.29 | 0.41 |
efficientnet_v2_m | 224x224 | 54.14 | 5.43 |
googlenet | 224x224 | 13.00 | 1.51 |
inception_v3 | 224x224 | 27.16 | 5.75 | 5.71
maxvit_t | 224x224 | 30.92 | 5.48 |
mnasnet1_0 | 224x224 | 4.38 | 0.33 |
mobilenet_v2 | 224x224 | 3.50 | 0.32 |
mobilenet_v3_large | 224x224 | 5.48 | 0.23 |
regnet_y_1_6gf | 224x224 | 11.20 | 1.65 |
resnet18 | 224x224 | 11.69 | 1.83 | 1.81
resnet50 | 224x224 | 25.56 | 4.13 | 4.09
resnext50_32x4d | 224x224 | 25.03 | 4.29 |
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 |
squeezenet1_0 | 224x224 | 1.25 | 0.84 | 0.82
vgg16 | 224x224 | 138.36 | 15.52 | 15.48
vit_b_16 | 224x224 | 86.57 | 17.61 (wrong) | 16.86
wide_resnet50_2 | 224x224 | 68.88 | 11.45 |


### [timm](https://github.com/huggingface/pytorch-image-models)

Model | Input Resolution | Params(M) | MACs(G)
--- |--- |--- |---
alexnet | 224x224 | 61.10 | 0.72
convnext_base | 224x224 | 88.59 | 15.43
densenet121 | 224x224 | 7.98 | 2.90
efficientnet_b0 | 224x224 | 5.29 | 0.41
efficientnet_v2_m | 224x224 | 54.14 | 5.43
googlenet | 224x224 | 13.00 | 1.51
inception_v3 | 224x224 | 27.16 | 2.86
maxvit_t | 224x224 | 30.92 | 5.48
mnasnet1_0 | 224x224 | 4.38 | 0.33
mobilenet_v2 | 224x224 | 3.50 | 0.32
mobilenet_v3_large | 224x224 | 5.48 | 0.23
regnet_y_1_6gf | 224x224 | 11.20 | 1.65
resnet18 | 224x224 | 11.69 | 1.83
resnet50 | 224x224 | 25.56 | 4.13
resnext50_32x4d | 224x224 | 25.03 | 4.29
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15
squeezenet1_0 | 224x224 | 1.25 | 0.84
vgg16 | 224x224 | 138.36 | 15.52
vit_b_16 | 224x224 | 86.57 | 17.60
wide_resnet50_2 | 224x224 | 68.88 | 11.45
5 changes: 3 additions & 2 deletions ptflops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
'''
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''


from .flops_counter import get_model_complexity_info
from .flops_counter import FLOPS_BACKEND, get_model_complexity_info
from .utils import flops_to_string, params_to_string

__all__ = [
"get_model_complexity_info",
"flops_to_string",
"params_to_string",
"FLOPS_BACKEND",
]
137 changes: 137 additions & 0 deletions ptflops/aten_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
'''
Copyright (C) 2024 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''


import sys
import traceback
from collections import defaultdict
from functools import partial
from typing import Optional, Tuple, Union

import torch
from torch.utils._python_dispatch import TorchDispatchMode

from ptflops.pytorch_engine import get_model_parameters_number
from ptflops.utils import flops_to_string
from .aten_ops import ATEN_OPS_MAPPING


class FlopCounterMode(TorchDispatchMode):
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
output_params=None):
self.verbose = verbose
if output_params is None:
output_params = defaultdict(dict)
self.output_params = output_params
self.print_fn = partial(print, **self.output_params['print_params'])

self.print_per_layer_stat = print_per_layer_stat
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ['Global']
self._total_complexity = None
if module is not None:
for name, mod in dict(module.named_children()).items():
mod.register_forward_pre_hook(self.enter_module(name))
mod.register_forward_hook(self.exit_module(name))

@property
def complexity(self):
return self._total_complexity

def enter_module(self, name):
def f(*args):
self.parents.append(name)
return f

def exit_module(self, name):
def f(*args):
assert(self.parents[-1] == name)
self.parents.pop()
return f

def __enter__(self):
self.flop_counts.clear()
super().__enter__()

def __exit__(self, *args):
self._total_complexity = sum(self.flop_counts['Global'].values())
if self.print_per_layer_stat:
self.print_fn('Total:' +
flops_to_string(self._total_complexity,
**self.output_params['serialize_params']))
for mod in self.flop_counts.keys():
self.print_fn("Module: ", mod)
for k, v in self.flop_counts[mod].items():
self.print_fn(
f'{k}: ' +
flops_to_string(v, **self.output_params['serialize_params']))
self.print_fn()
super().__exit__(*args)

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
kwargs = kwargs if kwargs else {}

out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in ATEN_OPS_MAPPING:
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
elif self.verbose:
self.print_fn(f'Warning: {func_packet} operation is treated as a zero-op')

return out


def get_flops_aten(model, input_res,
print_per_layer_stat=True,
input_constructor=None, ost=sys.stdout,
verbose=False, ignore_modules=[],
custom_modules_hooks={},
output_precision=2,
flops_units: Optional[str] = 'GMac',
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
Union[int, None]]:

params_sum = get_model_parameters_number(model)
model.eval()
output_params = {'serialize_params':
{'units': flops_units, 'precision': output_precision},
'print_params': {'file': ost}}

if input_constructor:
batch = input_constructor(input_res)
else:
try:
batch = torch.ones(()).new_empty((1, *input_res),
dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device)
except StopIteration:
batch = torch.ones(()).new_empty((1, *input_res))

try:
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
with counter:
if isinstance(batch, dict):
_ = model(**batch)
else:
_ = model(batch)
macs_count = counter.complexity

except Exception as e:
print("Flops estimation was not finished successfully because of"
f" the following exception:\n{type(e)} : {e}")
traceback.print_exc()

return None, None

return macs_count, params_sum
Loading
Loading