Skip to content

Commit bb9cad5

Browse files
authored
Merge pull request #86 from sovrasov/jit_support
Fix already defined warning behavior
2 parents 0970db1 + 8ea2c06 commit bb9cad5

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## v 0.6.8
44
- Add support of GELU activation.
55
- Fix per layer statistic output in case of zero parameters number.
6+
- Cleanup flops and params attrs after ptflops has finished counting.
67

78
## v 0.6.7
89
- Add batch_first flag support in MultiheadAttention hook

ptflops/pytorch_engine.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def stop_flops_count(self):
197197
"""
198198
remove_batch_counter_hook_function(self)
199199
self.apply(remove_flops_counter_hook_function)
200+
self.apply(remove_flops_counter_variables)
200201

201202

202203
def reset_flops_count(self):
@@ -250,6 +251,8 @@ def add_flops_counter_variable_or_reset(module):
250251
print('Warning: variables __flops__ or __params__ are already '
251252
'defined for the module' + type(module).__name__ +
252253
' ptflops can affect your code!')
254+
module.__ptflops_backup_flops__ = module.__flops__
255+
module.__ptflops_backup_params__ = module.__params__
253256
module.__flops__ = 0
254257
module.__params__ = get_model_parameters_number(module)
255258

@@ -265,3 +268,15 @@ def remove_flops_counter_hook_function(module):
265268
if hasattr(module, '__flops_handle__'):
266269
module.__flops_handle__.remove()
267270
del module.__flops_handle__
271+
272+
273+
def remove_flops_counter_variables(module):
274+
if is_supported_instance(module):
275+
if hasattr(module, '__flops__'):
276+
del module.__flops__
277+
if hasattr(module, '__ptflops_backup_flops__'):
278+
module.__flops__ = module.__ptflops_backup_flops__
279+
if hasattr(module, '__params__'):
280+
del module.__params__
281+
if hasattr(module, '__ptflops_backup_params__'):
282+
module.__params__ = module.__ptflops_backup_params__

0 commit comments

Comments
 (0)