Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
jtchilders committed Dec 4, 2019
1 parent 811b930 commit d1cc2b8
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions square_images/square_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import time

def run(point):
start = time.time()
try:
batch_size = point['batch_size']
image_size = point['image_size']
in_channels = point['in_channels']
out_channels = point['out_channels']
kernel_size = point['kernel_size']
omp_num_threads = point['omp_num_threads']

import os
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
os.environ['MKL_NUM_THREADS'] = str(omp_num_threads)
os.environ['KMP_HW_SUBSET'] = '1s,%sc,2t' % str(omp_num_threads)
os.environ['KMP_AFFINITY'] = 'granularity=fine,verbose,compact,1,0'
os.environ['KMP_BLOCKTIME'] = str(0)
#os.environ['MKLDNN_VERBOSE'] = str(1)
import torch

print('torch version: ',torch.__version__,' torch file: ',torch.__file__)


inputs = torch.arange(batch_size * image_size * image_size * in_channels,dtype=torch.float).view((batch_size,in_channels,image_size,image_size))

layer = torch.nn.Conv2d(in_channels,out_channels,(kernel_size,kernel_size),stride=1)
outputs = layer(inputs)

total_flop = kernel_size * kernel_size * in_channels * out_channels * outputs.shape[-1] * outputs.shape[-2] * batch_size

runs = 5
tot_time = 0.
tt = time.time()
for _ in range(runs):
outputs = layer(inputs)
tot_time += time.time() - tt
tt = time.time()

ave_time = tot_time / runs

print('total_flop = ',total_flop,'ave_time = ',ave_time)

ave_flops = total_flop / ave_time
runtime = time.time() - start
print('runtime=',runtime,'ave_flops=',ave_flops)

return ave_flops
except Exception as e:
import traceback
print('received exception: ',str(e))
print(traceback.print_exc())
print('runtime=',time.time() - start)
return 0.


if __name__ == '__main__':
point = {
'batch_size': 10,
'image_size': 512,
'in_channels': 3,
'out_channels': 64,
'kernel_size': 4,
'omp_num_threads':64,
}

print('flops for this setting =',run(point))

0 comments on commit d1cc2b8

Please sign in to comment.