diff --git a/linear/linear_run.py b/linear/linear_run.py index 85f31c8..adb7294 100644 --- a/linear/linear_run.py +++ b/linear/linear_run.py @@ -1,5 +1,5 @@ import time -from ptflops import get_model_complexity_info +#from ptflops import get_model_complexity_info def run(point): start = time.time() @@ -24,11 +24,10 @@ def run(point): print('torch version: ',torch.__version__,' torch file: ',torch.__file__) inputs = torch.arange(batch_size * in_features,dtype=torch.float).view((batch_size,in_features)) - + # using flops from here: + # https://machinethink.net/blog/how-fast-is-my-model/ + total_flop = batch_size * (2*in_features - 1) * out_features layer = torch.nn.Linear(in_features,out_features,bias=bias) - flops, params = get_model_complexity_info(layer, tuple(inputs.shape),as_strings=False) - print(flops) - outputs = layer(inputs) runs = 5 @@ -41,11 +40,11 @@ def run(point): ave_time = tot_time / runs - print('flop = ',flops,'ave_time = ',ave_time) + print('flop = ',total_flop,'ave_time = ',ave_time) - ave_flops = flops / ave_time #* batch_size + ave_flops = total_flop / ave_time - print('runtime=',time.time() - start) + print('runtime=',time.time() - start,'ave_flops=',ave_flops) return ave_flops except Exception as e: import traceback