Skip to content

Commit 99003d1

Browse files
committed
Added tool to optimize rnn with pybrain evolution.
1 parent 71b2c0c commit 99003d1

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

evolearn.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python2.6
2+
# -*- coding: utf-8 -*-
3+
4+
5+
from __future__ import division
6+
7+
8+
__author__ = 'Justin S Bayer, [email protected]'
9+
10+
11+
import glob
12+
import itertools
13+
import optparse
14+
import os
15+
import sys
16+
17+
import scipy
18+
from pybrain.optimization import PGPE
19+
20+
import rnn
21+
22+
23+
24+
def make_optparse():
25+
parser = optparse.OptionParser()
26+
parser.add_option('--datapath', dest='datapath', type='str',
27+
help='specify the directory where the data lies.')
28+
parser.add_option('--hidden', dest='n_hidden', type='int',
29+
help='specify number of hiddens to use')
30+
parser.add_option('--maxevals', dest='maxevals', type='int', default=100,
31+
help='specify number of maximum passes through data')
32+
return parser
33+
34+
35+
def load_dataset(path):
36+
sortedfilesbyglob = lambda x: sorted(glob.glob(os.path.join(path, '%s*' % x)))
37+
inptfiles = sortedfilesbyglob('input')
38+
targetfiles = sortedfilesbyglob('target')
39+
40+
data = []
41+
for infn, targetfn in itertools.izip(inptfiles, targetfiles):
42+
inpt = scipy.loadtxt(infn)
43+
target = scipy.loadtxt(targetfn)
44+
target.shape = scipy.size(target), 1
45+
data.append((inpt, target))
46+
return data
47+
48+
49+
def make_objective_func(net, data, errorfunc):
50+
def obj(x):
51+
net.parameters[:] = x
52+
error = 0
53+
for inpt, target in data:
54+
_, output = net(inpt)
55+
error += float(errorfunc(output, target))
56+
return error
57+
return obj
58+
59+
60+
def stats(net, data):
61+
true_positives = 0
62+
total = 0
63+
for inpt, target in data:
64+
_, output = net(inpt)
65+
output = (output > 0.5).astype('float64')
66+
total += target.shape[0]
67+
true_positives += (output * target).sum()
68+
return true_positives, total
69+
70+
71+
def main():
72+
options, args = make_optparse().parse_args()
73+
print "Loading data"
74+
data = load_dataset(options.datapath)
75+
76+
print "Building network"
77+
net = rnn.RecurrentNetwork(74, options.n_hidden, 1, outfunc='sig')
78+
print "Number of parameters:", len(net.parameters)
79+
80+
objfunc = make_objective_func(net, data, rnn.SumOfSquares())
81+
x0 = scipy.random.standard_normal(len(net.parameters)) * 0.1
82+
optimizer = PGPE(objfunc, x0, minimize=True)
83+
optimizer.maxEvaluations = options.maxevals
84+
85+
print "First fitness:", objfunc(x0)
86+
print "Optimizing..."
87+
params, fitness = optimizer.learn()
88+
print "Last fitness:", fitness
89+
90+
true_positives, total = stats(net, data)
91+
print "Total positives found: %i (%.2f)" % (true_positives, true_positives /
92+
total)
93+
94+
return 0
95+
96+
97+
if __name__ == '__main__':
98+
sys.exit(main())
99+

0 commit comments

Comments
 (0)