Skip to content

Commit

Permalink
adding linear and utilizationn plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
jtchilders committed Dec 3, 2019
1 parent 39e7c92 commit 1c6214d
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
68 changes: 68 additions & 0 deletions linear/linear_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import time
from ptflops import get_model_complexity_info

def run(point):
start = time.time()
try:
batch_size = point['batch_size']
in_features = point['in_features']
out_features = point['out_features']
bias = int(point['bias']) == 1
omp_num_threads = point['omp_num_threads']
print(point)

import os
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
os.environ['MKL_NUM_THREADS'] = str(omp_num_threads)
os.environ['KMP_HW_SUBSET'] = '1s,%sc,2t' % str(omp_num_threads)
os.environ['KMP_AFFINITY'] = 'granularity=fine,verbose,compact,1,0'
os.environ['KMP_BLOCKTIME'] = str(0)
os.environ['MKLDNN_VERBOSE'] = str(1)
os.environ['MKL_VERBOSE'] = str(1)
import torch

print('torch version: ',torch.__version__,' torch file: ',torch.__file__)

inputs = torch.arange(batch_size * in_features,dtype=torch.float).view((batch_size,in_features))

layer = torch.nn.Linear(in_features,out_features,bias=bias)
flops, params = get_model_complexity_info(layer, tuple(inputs.shape),as_strings=False)
print(flops)

outputs = layer(inputs)

runs = 5
tot_time = 0.
tt = time.time()
for _ in range(runs):
outputs = layer(inputs)
tot_time += time.time() - tt
tt = time.time()

ave_time = tot_time / runs

print('flop = ',flops,'ave_time = ',ave_time)

ave_flops = flops / ave_time #* batch_size

print('runtime=',time.time() - start)
return ave_flops
except Exception as e:
import traceback
print('received exception: ',str(e))
print(traceback.print_exc())
print('runtime=',time.time() - start)
return 0.


if __name__ == '__main__':
point = {
'batch_size': 10,
'in_features': 512,
'out_features': 512,
'bias': 1,
'omp_num_threads':64,
}

print('flops for this setting =',run(point))

13 changes: 13 additions & 0 deletions linear/problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deephyper.benchmark import HpProblem

Problem = HpProblem()
Problem.add_dim('batch_size',(1,8192))
Problem.add_dim('in_features',(128,8192))
Problem.add_dim('out_features',(128,8192))
Problem.add_dim('omp_num_threads',(8,64))
Problem.add_dim('bias',[0,1])

Problem.add_starting_point(batch_size=128,in_features=1024,out_features=512,omp_num_threads=64,bias=0)

if __name__ == '__main__':
print(Problem)
44 changes: 44 additions & 0 deletions utilization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
import argparse,logging
from balsam.core.models import utilization_report, BalsamJob, process_job_times
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)


def main():
''' simple starter program that can be copied for use when starting a new script. '''
logging_format = '%(asctime)s %(levelname)s:%(name)s:%(message)s'
logging_datefmt = '%Y-%m-%d %H:%M:%S'
logging_level = logging.INFO

parser = argparse.ArgumentParser(description='')
parser.add_argument('-w','--workflow',help='workflow name to analyze',required=True)
parser.add_argument('--debug', dest='debug', default=False, action='store_true', help="Set Logger to DEBUG")
parser.add_argument('--error', dest='error', default=False, action='store_true', help="Set Logger to ERROR")
parser.add_argument('--warning', dest='warning', default=False, action='store_true', help="Set Logger to ERROR")
parser.add_argument('--logfilename',dest='logfilename',default=None,help='if set, logging information will go to file')
args = parser.parse_args()

if args.debug and not args.error and not args.warning:
logging_level = logging.DEBUG
elif not args.debug and args.error and not args.warning:
logging_level = logging.ERROR
elif not args.debug and not args.error and args.warning:
logging_level = logging.WARNING

logging.basicConfig(level=logging_level,
format=logging_format,
datefmt=logging_datefmt,
filename=args.logfilename)

qs = BalsamJob.objects.filter(workflow=args.workflow)
dat = process_job_times(qs)
times, utils = utilization_report(dat)

plt.step(times, utils, where="post")
plt.waitforbuttonpress()



if __name__ == "__main__":
main()

0 comments on commit 1c6214d

Please sign in to comment.