Skip to content

Commit

Permalink
adding files
Browse files Browse the repository at this point in the history
  • Loading branch information
jtchilders committed Nov 21, 2019
1 parent 2fbb1b8 commit a31f471
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
9 changes: 9 additions & 0 deletions load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np


def load_data(size=2000,height=256,width=256,channels=3):

inputs = np.random.randn((size,height,width,channels))

return inputs

49 changes: 49 additions & 0 deletions model_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch,time
from ptflops import get_model_complexity_info


def run(point):

batch_size = point['batch_size']
height = point['height']
width = point['width']
in_channels = point['in_channels']
out_channels = point['out_channels']
kernel_size = (point['kernel_size'],point['kernel_size'])


inputs = torch.arange(batch_size * height * width * in_channels,dtype=torch.float).view((batch_size,in_channels,height,width))

layer = torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1)
flops, params = get_model_complexity_info(layer, tuple(inputs.shape[1:]),as_strings=False)
print(flops)

outputs = layer(inputs)

runs = 5
tot_time = 0.
tt = time.time()
for _ in range(runs):
outputs = layer(inputs)
tot_time += time.time() - tt
tt = time.time()

ave_time = tot_time / runs

ave_flops = flops / ave_time

return ave_flops


if __name__ == '__main__':
point = {
'batch_size': 10,
'height': 512,
'width': 512,
'in_channels': 3,
'out_channels': 64,
'kernel_size': 4
}

print('flops for this setting =',run(point))

14 changes: 14 additions & 0 deletions problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from deephyper.benchmark import HpProblem

Problem = HpProblem()
Problem.add_dim('batch_size',(1,128))
Problem.add_dim('height',(128,512))
Problem.add_dim('width',(128,512))
Problem.add_dim('in_channels',(2,16))
Problem.add_dim('out_channels',(2,16))
Problem.add_dim('kernel_size',(2,8))

Problem.add_starting_point(batch_size=10,height=128,width=128,in_channels=3,out_channels=16,kernel_size=3)

if __name__ == '__main__':
print(Problem)

0 comments on commit a31f471

Please sign in to comment.