@@ -197,6 +197,7 @@ def stop_flops_count(self):
197
197
"""
198
198
remove_batch_counter_hook_function (self )
199
199
self .apply (remove_flops_counter_hook_function )
200
+ self .apply (remove_flops_counter_variables )
200
201
201
202
202
203
def reset_flops_count (self ):
@@ -250,6 +251,8 @@ def add_flops_counter_variable_or_reset(module):
250
251
print ('Warning: variables __flops__ or __params__ are already '
251
252
'defined for the module' + type (module ).__name__ +
252
253
' ptflops can affect your code!' )
254
+ module .__ptflops_backup_flops__ = module .__flops__
255
+ module .__ptflops_backup_params__ = module .__params__
253
256
module .__flops__ = 0
254
257
module .__params__ = get_model_parameters_number (module )
255
258
@@ -265,3 +268,15 @@ def remove_flops_counter_hook_function(module):
265
268
if hasattr (module , '__flops_handle__' ):
266
269
module .__flops_handle__ .remove ()
267
270
del module .__flops_handle__
271
+
272
+
273
+ def remove_flops_counter_variables (module ):
274
+ if is_supported_instance (module ):
275
+ if hasattr (module , '__flops__' ):
276
+ del module .__flops__
277
+ if hasattr (module , '__ptflops_backup_flops__' ):
278
+ module .__flops__ = module .__ptflops_backup_flops__
279
+ if hasattr (module , '__params__' ):
280
+ del module .__params__
281
+ if hasattr (module , '__ptflops_backup_params__' ):
282
+ module .__params__ = module .__ptflops_backup_params__
0 commit comments