Skip to content

Commit 4015fa0

Browse files
committed
Add basic API
1 parent e7155ba commit 4015fa0

File tree

8 files changed

+88
-7
lines changed

8 files changed

+88
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ images/*
1111
.ipynb_checkpoints/
1212
tmp_data/
1313
.DS_Store
14+
api_log/

api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pickle
2+
import subprocess
3+
4+
from flask import Flask, request, jsonify
5+
from flask_restful import Resource, Api
6+
from json import dumps
7+
8+
app = Flask(__name__)
9+
api = Api(app)
10+
11+
class Predict(Resource):
12+
def get(self, sym):
13+
command = 'python2 main.py stock_symbol=%s --write' % sym
14+
subprocess.call(command.split())
15+
with open('api_log/'+sym+'.pkl', 'rb') as f:
16+
prediction = str(pickle.load(f))
17+
return jsonify(prediction)
18+
19+
api.add_resource(Predict, '/predict/<sym>') # Route_3
20+
21+
if __name__ == "__main__":
22+
app.run(host='0.0.0.0', port='5002')

check.png

108 KB
Loading
File renamed without changes.

main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
import pandas as pd
3+
import pickle
34
import pprint
45

6+
import matplotlib.pyplot as plt
57
import tensorflow as tf
68
import tensorflow.contrib.slim as slim
79

@@ -24,6 +26,7 @@
2426
flags.DEFINE_string("stock_symbol", None, "Target stock symbol [None]")
2527
flags.DEFINE_integer("sample_size", 4, "Number of stocks to plot during training. [4]")
2628
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
29+
flags.DEFINE_boolean("write", False, "True for writing contents to the file of the same name")
2730

2831
FLAGS = flags.FLAGS
2932

@@ -104,6 +107,14 @@ def main(_):
104107
if FLAGS.train:
105108
rnn_model.train(stock_data_list, FLAGS)
106109
else:
110+
test_prediction, test_loss = rnn_model.predict(stock_data_list, 50, FLAGS)
111+
if FLAGS.write:
112+
with open('api_log/'+FLAGS.stock_symbol+".pkl", 'wb') as f:
113+
pickle.dump(test_prediction, f)
114+
115+
#rnn_model.plot_samples(test_prediction, test_prediction, 'check.png', 'GOOG')
116+
#plt.show(block=True)
117+
print '-'*33
107118
if not rnn_model.load()[0]:
108119
raise Exception("[!] Train a model first, then run test mode")
109120

model_rnn.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from tensorflow.contrib.tensorboard.plugins import projector
1515

16+
from restore_model import prediction_by_trained_graph
17+
1618

1719
class LstmRNN(object):
1820
def __init__(self, sess, stock_count,
@@ -161,7 +163,7 @@ def train(self, dataset_list, config):
161163
merged_test_X = []
162164
merged_test_y = []
163165
merged_test_labels = []
164-
166+
165167
for label_, d_ in enumerate(dataset_list):
166168
merged_test_X += list(d_.test_X)
167169
merged_test_y += list(d_.test_y)
@@ -181,7 +183,8 @@ def train(self, dataset_list, config):
181183
self.targets: merged_test_y,
182184
self.symbols: merged_test_labels,
183185
}
184-
186+
print 'merged_test_X', merged_test_X
187+
print 'merged_test_y', merged_test_y
185188
global_step = 0
186189

187190
num_batches = sum(len(d_.train_X) for d_ in dataset_list) // config.batch_size
@@ -196,7 +199,7 @@ def train(self, dataset_list, config):
196199
i for i, sym_label in enumerate(merged_test_labels)
197200
if sym_label[0] == l])
198201
sample_indices[sym] = target_indices
199-
print sample_indices
202+
200203

201204
print "Start training for stocks:", [d.stock_sym for d in dataset_list]
202205
for epoch in xrange(config.max_epoch):
@@ -222,7 +225,7 @@ def train(self, dataset_list, config):
222225

223226
if np.mod(global_step, len(dataset_list) * 100 / config.input_size) == 1:
224227
test_loss, test_pred = self.sess.run([self.loss, self.pred], test_data_feed)
225-
228+
226229
print "Step:%d [Epoch:%d] [Learning rate: %.6f] train_loss:%.6f test_loss:%.6f" % (
227230
global_step, epoch, learning_rate, train_loss, test_loss)
228231

@@ -238,7 +241,7 @@ def train(self, dataset_list, config):
238241
self.save(global_step)
239242

240243
final_pred, final_loss = self.sess.run([self.pred, self.loss], test_data_feed)
241-
244+
242245
# Save the final model
243246
self.save(global_step)
244247
return final_pred
@@ -275,6 +278,8 @@ def save(self, step):
275278
global_step=step
276279
)
277280

281+
print os.path.join(self.model_logs_dir, model_name)
282+
278283
def load(self):
279284
print(" [*] Reading checkpoints...")
280285
ckpt = tf.train.get_checkpoint_state(self.model_logs_dir)
@@ -292,7 +297,7 @@ def load(self):
292297
def plot_samples(self, preds, targets, figname, stock_sym=None):
293298
def _flatten(seq):
294299
return [x for y in seq for x in y]
295-
300+
296301
truths = _flatten(targets)[-200:]
297302

298303
preds = _flatten(preds)[-200:]
@@ -337,4 +342,44 @@ def _flatten(seq):
337342
plt.title(stock_sym + " | Last %d days in test" % len(truths))
338343

339344
plt.savefig(figname.split('.')[0]+'_normalized.png', format='png', bbox_inches='tight', transparent=True)
340-
plt.close()
345+
plt.close()
346+
347+
348+
def predict(self, dataset_list, max_epoch, config):
349+
merged_test_X, merged_test_y, merged_test_labels = [], [], []
350+
for label_, d_ in enumerate(dataset_list):
351+
merged_test_X += list(d_.test_X)
352+
merged_test_y += list(d_.test_y)
353+
merged_test_labels += [[label_]] * len(d_.test_X)
354+
355+
test_X = np.array(merged_test_X)
356+
test_y = np.array(merged_test_y)
357+
358+
status, counter = self.load()
359+
if status:
360+
graph = tf.get_default_graph()
361+
test_data_feed = {
362+
self.learning_rate: 0.0,
363+
self.inputs: test_X,
364+
self.targets: test_y
365+
}
366+
#prediction = graph.get_tensor_by_name('output_layer/add:0')
367+
#loss = graph.get_tensor_by_name('train/loss_mse:0')
368+
369+
# Select samples for plotting.
370+
sample_labels = range(min(config.sample_size, len(dataset_list)))
371+
sample_indices = {}
372+
for l in sample_labels:
373+
sym = dataset_list[l].stock_sym
374+
target_indices = np.array([
375+
i for i, sym_label in enumerate(merged_test_labels)
376+
if sym_label[0] == l])
377+
sample_indices[sym] = target_indices
378+
379+
380+
test_prediction, test_loss = self.sess.run([self.pred, self.loss], test_data_feed)
381+
382+
for sample_sym, indices in sample_indices.iteritems():
383+
test_pred = test_prediction[indices]
384+
385+
return test_pred, test_loss

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ scikit-learn==0.16.1
55
scipy==0.19.1
66
tensorflow==1.2.1
77
urllib3==1.8
8+
flask
9+
flask_restful
File renamed without changes.

0 commit comments

Comments
 (0)