diff --git a/load_data.py b/load_data.py new file mode 100644 index 0000000..7b94597 --- /dev/null +++ b/load_data.py @@ -0,0 +1,9 @@ +import numpy as np + + +def load_data(size=2000,height=256,width=256,channels=3): + + inputs = np.random.randn((size,height,width,channels)) + + return inputs + diff --git a/model_run.py b/model_run.py new file mode 100644 index 0000000..8401847 --- /dev/null +++ b/model_run.py @@ -0,0 +1,49 @@ +import torch,time +from ptflops import get_model_complexity_info + + +def run(point): + + batch_size = point['batch_size'] + height = point['height'] + width = point['width'] + in_channels = point['in_channels'] + out_channels = point['out_channels'] + kernel_size = (point['kernel_size'],point['kernel_size']) + + + inputs = torch.arange(batch_size * height * width * in_channels,dtype=torch.float).view((batch_size,in_channels,height,width)) + + layer = torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1) + flops, params = get_model_complexity_info(layer, tuple(inputs.shape[1:]),as_strings=False) + print(flops) + + outputs = layer(inputs) + + runs = 5 + tot_time = 0. + tt = time.time() + for _ in range(runs): + outputs = layer(inputs) + tot_time += time.time() - tt + tt = time.time() + + ave_time = tot_time / runs + + ave_flops = flops / ave_time + + return ave_flops + + +if __name__ == '__main__': + point = { + 'batch_size': 10, + 'height': 512, + 'width': 512, + 'in_channels': 3, + 'out_channels': 64, + 'kernel_size': 4 + } + + print('flops for this setting =',run(point)) + diff --git a/problem.py b/problem.py new file mode 100644 index 0000000..61dca68 --- /dev/null +++ b/problem.py @@ -0,0 +1,14 @@ +from deephyper.benchmark import HpProblem + +Problem = HpProblem() +Problem.add_dim('batch_size',(1,128)) +Problem.add_dim('height',(128,512)) +Problem.add_dim('width',(128,512)) +Problem.add_dim('in_channels',(2,16)) +Problem.add_dim('out_channels',(2,16)) +Problem.add_dim('kernel_size',(2,8)) + +Problem.add_starting_point(batch_size=10,height=128,width=128,in_channels=3,out_channels=16,kernel_size=3) + +if __name__ == '__main__': + print(Problem)