Skip to content

Commit f8b121a

Browse files
authored
Merge pull request #128 from sovrasov/upd_docs
Update docs
2 parents 6e5b4d8 + f221d70 commit f8b121a

File tree

3 files changed

+61
-95
lines changed

3 files changed

+61
-95
lines changed

README.md

Lines changed: 27 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ Supported layers:
1818
Experimental support:
1919
- RNN, LSTM, GRU (NLH layout is assumed)
2020
- RNNCell, LSTMCell, GRUCell
21-
- MultiheadAttention
21+
- torch.nn.MultiheadAttention
2222
- torchvision.ops.DeformConv2d
23+
- visual transformers from [timm](https://github.com/huggingface/pytorch-image-models)
2324

2425
Requirements: Pytorch >= 1.1, torchvision >= 0.3
2526

2627
Thanks to @warmspringwinds for the initial version of script.
2728

2829
## Usage tips
2930

30-
- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
31+
- This tool doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are
32+
not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops.
3133
- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments.
3234
- `verbose` parameter allows to get information about modules that don't contribute to the final numbers.
3335
- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful
@@ -72,84 +74,26 @@ If ptflops was useful for your paper or tech report, please cite me:
7274

7375
## Benchmark
7476

75-
### [torchvision](https://pytorch.org/docs/1.0.0/torchvision/models.html)
76-
77-
Model | Input Resolution | Params(M) | MACs(G) | Top-1 error | Top-5 error
78-
--- |--- |--- |--- |--- |---
79-
alexnet |224x224 | 61.1 | 0.72 | 43.45 | 20.91
80-
vgg11 |224x224 | 132.86 | 7.63 | 30.98 | 11.37
81-
vgg13 |224x224 | 133.05 | 11.34 | 30.07 | 10.75
82-
vgg16 |224x224 | 138.36 | 15.5 | 28.41 | 9.62
83-
vgg19 |224x224 | 143.67 | 19.67 | 27.62 | 9.12
84-
vgg11_bn |224x224 | 132.87 | 7.64 | 29.62 | 10.19
85-
vgg13_bn |224x224 | 133.05 | 11.36 | 28.45 | 9.63
86-
vgg16_bn |224x224 | 138.37 | 15.53 | 26.63 | 8.50
87-
vgg19_bn |224x224 | 143.68 | 19.7 | 25.76 | 8.15
88-
resnet18 |224x224 | 11.69 | 1.82 | 30.24 | 10.92
89-
resnet34 |224x224 | 21.8 | 3.68 | 26.70 | 8.58
90-
resnet50 |224x224 | 25.56 | 4.12 | 23.85 | 7.13
91-
resnet101 |224x224 | 44.55 | 7.85 | 22.63 | 6.44
92-
resnet152 |224x224 | 60.19 | 11.58 | 21.69 | 5.94
93-
squeezenet1_0 |224x224 | 1.25 | 0.83 | 41.90 | 19.58
94-
squeezenet1_1 |224x224 | 1.24 | 0.36 | 41.81 | 19.38
95-
densenet121 |224x224 | 7.98 | 2.88 | 25.35 | 7.83
96-
densenet169 |224x224 | 14.15 | 3.42 | 24.00 | 7.00
97-
densenet201 |224x224 | 20.01 | 4.37 | 22.80 | 6.43
98-
densenet161 |224x224 | 28.68 | 7.82 | 22.35 | 6.20
99-
inception_v3 |224x224 | 27.16 | 2.85 | 22.55 | 6.44
100-
101-
* Top-1 error - ImageNet single-crop top-1 error (224x224)
102-
* Top-5 error - ImageNet single-crop top-5 error (224x224)
103-
104-
### [Cadene/pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)
105-
106-
Model | Input Resolution | Params(M) | MACs(G) | Acc@1 | Acc@5
107-
--- |--- |--- |--- |--- |---
108-
alexnet | 224x224 | 61.1 | 0.72 | 56.432 | 79.194
109-
bninception | 224x224 | 11.3 | 2.05 | 73.524 | 91.562
110-
cafferesnet101 | 224x224 | 44.55 | 7.62 | 76.2 | 92.766
111-
densenet121 | 224x224 | 7.98 | 2.88 | 74.646 | 92.136
112-
densenet161 | 224x224 | 28.68 | 7.82 | 77.56 | 93.798
113-
densenet169 | 224x224 | 14.15 | 3.42 | 76.026 | 92.992
114-
densenet201 | 224x224 | 20.01 | 4.37 | 77.152 | 93.548
115-
dpn107 | 224x224 | 86.92 | 18.42 | 79.746 | 94.684
116-
dpn131 | 224x224 | 79.25 | 16.13 | 79.432 | 94.574
117-
dpn68 | 224x224 | 12.61 | 2.36 | 75.868 | 92.774
118-
dpn68b | 224x224 | 12.61 | 2.36 | 77.034 | 93.59
119-
dpn92 | 224x224 | 37.67 | 6.56 | 79.4 | 94.62
120-
dpn98 | 224x224 | 61.57 | 11.76 | 79.224 | 94.488
121-
fbresnet152 | 224x224 | 60.27 | 11.6 | 77.386 | 93.594
122-
inceptionresnetv2 | 299x299 | 55.84 | 13.22 | 80.17 | 95.234
123-
inceptionv3 | 299x299 | 27.16 | 5.73 | 77.294 | 93.454
124-
inceptionv4 | 299x299 | 42.68 | 12.31 | 80.062 | 94.926
125-
nasnetalarge | 331x331 | 88.75 | 24.04 | 82.566 | 96.086
126-
nasnetamobile | 224x224 | 5.29 | 0.59 | 74.08 | 91.74
127-
pnasnet5large | 331x331 | 86.06 | 25.21 | 82.736 | 95.992
128-
polynet | 331x331 | 95.37 | 34.9 | 81.002 | 95.624
129-
resnet101 | 224x224 | 44.55 | 7.85 | 77.438 | 93.672
130-
resnet152 | 224x224 | 60.19 | 11.58 | 78.428 | 94.11
131-
resnet18 | 224x224 | 11.69 | 1.82 | 70.142 | 89.274
132-
resnet34 | 224x224 | 21.8 | 3.68 | 73.554 | 91.456
133-
resnet50 | 224x224 | 25.56 | 4.12 | 76.002 | 92.98
134-
resnext101_32x4d | 224x224 | 44.18 | 8.03 | 78.188 | 93.886
135-
resnext101_64x4d | 224x224 | 83.46 | 15.55 | 78.956 | 94.252
136-
se_resnet101 | 224x224 | 49.33 | 7.63 | 78.396 | 94.258
137-
se_resnet152 | 224x224 | 66.82 | 11.37 | 78.658 | 94.374
138-
se_resnet50 | 224x224 | 28.09 | 3.9 | 77.636 | 93.752
139-
se_resnext101_32x4d | 224x224 | 48.96 | 8.05 | 80.236 | 95.028
140-
se_resnext50_32x4d | 224x224 | 27.56 | 4.28 | 79.076 | 94.434
141-
senet154 | 224x224 | 115.09 | 20.82 | 81.304 | 95.498
142-
squeezenet1_0 | 224x224 | 1.25 | 0.83 | 58.108 | 80.428
143-
squeezenet1_1 | 224x224 | 1.24 | 0.36 | 58.25 | 80.8
144-
vgg11 | 224x224 | 132.86 | 7.63 | 68.97 | 88.746
145-
vgg11_bn | 224x224 | 132.87 | 7.64 | 70.452 | 89.818
146-
vgg13 | 224x224 | 133.05 | 11.34 | 69.662 | 89.264
147-
vgg13_bn | 224x224 | 133.05 | 11.36 | 71.508 | 90.494
148-
vgg16 | 224x224 | 138.36 | 15.5 | 71.636 | 90.354
149-
vgg16_bn | 224x224 | 138.37 | 15.53 | 73.518 | 91.608
150-
vgg19 | 224x224 | 143.67 | 19.67 | 72.08 | 90.822
151-
vgg19_bn | 224x224 | 143.68 | 19.7 | 74.266 | 92.066
152-
xception | 299x299 | 22.86 | 8.42 | 78.888 | 94.292
153-
154-
* Acc@1 - ImageNet single-crop top-1 accuracy on validation images of the same size used during the training process.
155-
* Acc@5 - ImageNet single-crop top-5 accuracy on validation images of the same size used during the training process.
77+
### [torchvision](https://pytorch.org/vision/0.16/models.html)
78+
79+
Model | Input Resolution | Params(M) | MACs(G)
80+
--- |--- |--- |---
81+
alexnet | 224x224 | 61.10 | 0.72
82+
convnext_base | 224x224 | 88.59 | 15.43
83+
densenet121 | 224x224 | 7.98 | 2.90
84+
efficientnet_b0 | 224x224 | 5.29 | 0.41
85+
efficientnet_v2_m | 224x224 | 54.14 | 5.43
86+
googlenet | 224x224 | 13.00 | 1.51
87+
inception_v3 | 224x224 | 27.16 | 2.86
88+
maxvit_t | 224x224 | 30.92 | 5.48
89+
mnasnet1_0 | 224x224 | 4.38 | 0.33
90+
mobilenet_v2 | 224x224 | 3.50 | 0.32
91+
mobilenet_v3_large | 224x224 | 5.48 | 0.23
92+
regnet_y_1_6gf | 224x224 | 11.20 | 1.65
93+
resnet18 | 224x224 | 11.69 | 1.83
94+
resnet50 | 224x224 | 25.56 | 4.13
95+
resnext50_32x4d | 224x224 | 25.03 | 4.29
96+
shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15
97+
squeezenet1_0 | 224x224 | 1.25 | 0.84
98+
vgg16 | 224x224 | 138.36 | 15.52
99+
wide_resnet50_2 | 224x224 | 68.88 | 11.45

ptflops/pytorch_engine.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,35 +415,41 @@ def unpatch_functional():
415415
F.interpolate = F.interpolate.op
416416

417417

418+
def wrap_tensor_op(op, collector):
419+
tensor_op_handler = torch_function_wrapper(
420+
op, TENSOR_OPS_MAPPING[op], collector)
421+
422+
def wrapper(*args, **kwargs):
423+
return tensor_op_handler(*args, **kwargs)
424+
425+
wrapper.op = tensor_op_handler.op
426+
427+
return wrapper
428+
429+
418430
def patch_tensor_ops(collector):
419431
torch.matmul = torch_function_wrapper(
420432
torch.matmul, TENSOR_OPS_MAPPING[torch.matmul], collector)
421-
torch.Tensor.matmul = torch_function_wrapper(
422-
torch.Tensor.matmul, TENSOR_OPS_MAPPING[torch.Tensor.matmul], collector)
433+
torch.Tensor.matmul = wrap_tensor_op(torch.Tensor.matmul, collector)
423434
torch.mm = torch_function_wrapper(
424435
torch.mm, TENSOR_OPS_MAPPING[torch.mm], collector)
425-
torch.Tensor.mm = torch_function_wrapper(
426-
torch.Tensor.mm, TENSOR_OPS_MAPPING[torch.Tensor.mm], collector)
436+
torch.Tensor.mm = wrap_tensor_op(torch.Tensor.mm, collector)
427437
torch.bmm = torch_function_wrapper(
428438
torch.bmm, TENSOR_OPS_MAPPING[torch.bmm], collector)
429-
torch.Tensor.bmm = torch_function_wrapper(
430-
torch.Tensor.bmm, TENSOR_OPS_MAPPING[torch.Tensor.bmm], collector)
439+
torch.Tensor.bmm = wrap_tensor_op(torch.Tensor.bmm, collector)
431440

432441
torch.addmm = torch_function_wrapper(
433442
torch.addmm, TENSOR_OPS_MAPPING[torch.addmm], collector)
434-
torch.Tensor.addmm = torch_function_wrapper(
435-
torch.Tensor.addmm, TENSOR_OPS_MAPPING[torch.Tensor.addmm], collector)
443+
torch.Tensor.addmm = wrap_tensor_op(torch.Tensor.addmm, collector)
436444
torch.baddbmm = torch_function_wrapper(
437445
torch.baddbmm, TENSOR_OPS_MAPPING[torch.baddbmm], collector)
438446

439447
torch.mul = torch_function_wrapper(
440448
torch.mul, TENSOR_OPS_MAPPING[torch.mul], collector)
441-
torch.Tensor.mul = torch_function_wrapper(
442-
torch.Tensor.mul, TENSOR_OPS_MAPPING[torch.Tensor.mul], collector)
449+
torch.Tensor.mul = wrap_tensor_op(torch.Tensor.mul, collector)
443450
torch.add = torch_function_wrapper(
444451
torch.add, TENSOR_OPS_MAPPING[torch.add], collector)
445-
torch.Tensor.add = torch_function_wrapper(
446-
torch.Tensor.add, TENSOR_OPS_MAPPING[torch.Tensor.add], collector)
452+
torch.Tensor.add = wrap_tensor_op(torch.Tensor.add, collector)
447453

448454

449455
def unpatch_tensor_ops():

tests/common_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,19 @@ def forward(self, x):
9595
print_per_layer_stat=False)
9696
assert params == 0
9797
assert macs > 0
98+
99+
def test_ten_matmul(self):
100+
class CustomModel(nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
104+
def forward(self, x):
105+
return x.matmul(x.t())
106+
107+
macs, params = \
108+
get_model_complexity_info(CustomModel(), (10, ),
109+
as_strings=False,
110+
print_per_layer_stat=False)
111+
112+
assert params == 0
113+
assert macs > 0

0 commit comments

Comments
 (0)