Skip to content

Commit

Permalink
adding my own flops calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
jtchilders committed Dec 4, 2019
1 parent 1c6214d commit 3085c5f
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions linear/linear_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from ptflops import get_model_complexity_info
#from ptflops import get_model_complexity_info

def run(point):
start = time.time()
Expand All @@ -24,11 +24,10 @@ def run(point):
print('torch version: ',torch.__version__,' torch file: ',torch.__file__)

inputs = torch.arange(batch_size * in_features,dtype=torch.float).view((batch_size,in_features))

# using flops from here:
# https://machinethink.net/blog/how-fast-is-my-model/
total_flop = batch_size * (2*in_features - 1) * out_features
layer = torch.nn.Linear(in_features,out_features,bias=bias)
flops, params = get_model_complexity_info(layer, tuple(inputs.shape),as_strings=False)
print(flops)

outputs = layer(inputs)

runs = 5
Expand All @@ -41,11 +40,11 @@ def run(point):

ave_time = tot_time / runs

print('flop = ',flops,'ave_time = ',ave_time)
print('flop = ',total_flop,'ave_time = ',ave_time)

ave_flops = flops / ave_time #* batch_size
ave_flops = total_flop / ave_time

print('runtime=',time.time() - start)
print('runtime=',time.time() - start,'ave_flops=',ave_flops)
return ave_flops
except Exception as e:
import traceback
Expand Down

0 comments on commit 3085c5f

Please sign in to comment.