Skip to content

Commit f7820a4

Browse files
author
Lilian Weng
committed
more on LstmRnn model
1 parent 6226c12 commit f7820a4

File tree

9 files changed

+62
-44
lines changed

9 files changed

+62
-44
lines changed

.gitignore

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
*.*~
22
*.pyc
33
*.ipynb
4-
_data/*.tsv
5-
_data/*.csv
6-
_logs/*/
7-
_models/*/
4+
data/*.tsv
5+
data/*.csv
6+
logs/*
7+
models/*
8+
checkpoint/*
89
.idea/

_logs/.placeholder

Whitespace-only changes.

_models/.placeholder

Whitespace-only changes.

config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
class RNNConfig():
2-
input_size=1
3-
num_steps=30
4-
lstm_size=128
5-
num_layers=1
6-
keep_prob=0.8
2+
input_size = 1
3+
num_steps = 30
4+
lstm_size = 128
5+
num_layers = 1
6+
keep_prob = 0.8
77

88
batch_size = 200
99
init_learning_rate = 0.05
@@ -21,9 +21,10 @@ def __str__(self):
2121
def __repr__(self):
2222
return str(self.to_dict())
2323

24+
2425
DEFAULT_CONFIG = RNNConfig()
2526
print "Default configuration:", DEFAULT_CONFIG.to_dict()
2627

27-
DATA_DIR = "_data"
28-
LOG_DIR = "_logs"
29-
MODEL_DIR = "_models"
28+
DATA_DIR = "data"
29+
LOG_DIR = "logs"
30+
MODEL_DIR = "models"
File renamed without changes.

data_wrapper.py renamed to data_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
class StockDataSet(object):
1313
def __init__(self,
1414
stock_sym,
15-
config=DEFAULT_CONFIG,
15+
input_size=1,
16+
num_steps=30,
1617
test_ratio=0.1,
1718
normalized=True,
1819
close_price_only=True):
1920
self.stock_sym = stock_sym
20-
self.input_size = config.input_size
21-
self.num_steps = config.num_steps
21+
self.input_size = input_size
22+
self.num_steps = num_steps
2223
self.test_ratio = test_ratio
2324
self.close_price_only = close_price_only
2425
self.normalized = normalized
@@ -34,6 +35,10 @@ def __init__(self,
3435

3536
self.train_X, self.train_y, self.test_X, self.test_y = self._prepare_data(self.raw_seq)
3637

38+
def info(self):
39+
return "StockDataSet [%s] train: %d test: %d" % (
40+
self.stock_sym, len(self.train_X), len(self.test_y))
41+
3742
def _prepare_data(self, seq):
3843
# split into items of input_size
3944
seq = [np.array(seq[i * self.input_size: (i + 1) * self.input_size])

main.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tensorflow as tf
55
import tensorflow.contrib.slim as slim
66

7+
from data_model import StockDataSet
78
from model import LstmRNN
89

910
flags = tf.app.flags
@@ -32,15 +33,8 @@ def show_all_variables():
3233
def main(_):
3334
pp.pprint(flags.FLAGS.__flags)
3435

35-
if FLAGS.input_width is None:
36-
FLAGS.input_width = FLAGS.input_height
37-
if FLAGS.output_width is None:
38-
FLAGS.output_width = FLAGS.output_height
39-
4036
if not os.path.exists(FLAGS.checkpoint_dir):
4137
os.makedirs(FLAGS.checkpoint_dir)
42-
if not os.path.exists(FLAGS.sample_dir):
43-
os.makedirs(FLAGS.sample_dir)
4438

4539
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
4640
run_config = tf.ConfigProto()
@@ -59,8 +53,17 @@ def main(_):
5953

6054
show_all_variables()
6155

56+
stock_data = StockDataSet(
57+
"GOOG",
58+
input_size=FLAGS.input_size,
59+
num_steps=FLAGS.num_steps,
60+
test_ratio=0.1,
61+
close_price_only=True
62+
)
63+
print stock_data.info()
64+
6265
if FLAGS.train:
63-
rnn_model.train(FLAGS)
66+
rnn_model.train(stock_data, FLAGS)
6467
else:
6568
if not rnn_model.load()[0]:
6669
raise Exception("[!] Train a model first, then run test mode")

model.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _create_one_cell():
7979
self.pred_summ = tf.summary.histogram("pred", self.pred)
8080

8181
# self.loss = -tf.reduce_sum(targets * tf.log(tf.clip_by_value(prediction, 1e-10, 1.0)))
82-
self.loss = tf.reduce_mean(tf.square(self.pred - self.inputs), name="loss_mse")
82+
self.loss = tf.reduce_mean(tf.square(self.pred - self.targets), name="loss_mse")
8383
self.optim = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, name="adam_optim")
8484

8585
self.loss_sum = tf.summary.scalar("loss_mse", self.loss)
@@ -100,9 +100,12 @@ def train(self, dataset, config):
100100
self.writer = tf.summary.FileWriter(os.path.join("./logs", self.model_name))
101101
self.writer.add_graph(self.sess.graph)
102102

103-
step = 1
103+
num_batches = int(len(dataset.train_X)) // config.batch_size
104+
global_step = 1
105+
106+
for epoch in xrange(config.max_epoch):
107+
epoch_step = 1
104108

105-
for epoch in xrange(config.epoch):
106109
learning_rate = config.init_learning_rate * (
107110
config.learning_rate_decay ** max(float(epoch + 1 - config.init_epoch), 0.0)
108111
)
@@ -122,17 +125,19 @@ def train(self, dataset, config):
122125
self.learning_rate: learning_rate,
123126
}
124127
train_loss, _ = self.sess.run([self.loss, self.optim], train_data_feed)
125-
step += 1
128+
global_step += 1
129+
epoch_step += 1
126130

127-
if np.mod(epoch, 10) == 0:
131+
if np.mod(epoch, 20) == 0:
128132
test_loss, _pred, _merged_sum = self.sess.run(
129133
[self.loss, self.pred, self.merged_sum], test_data_feed)
130134
assert len(_pred) == len(dataset.test_y)
131-
print "Epoch %d [%f]:" % (epoch, learning_rate), test_loss
132-
self.writer.add_summary(_merged_sum, global_step=epoch)
135+
print "Epoch %d [%d/%d][learning rate: %f]: %.6f" % (
136+
epoch, epoch_step, num_batches, learning_rate, test_loss)
137+
self.writer.add_summary(_merged_sum, global_step=global_step)
133138

134-
if np.mod(step, 100) == 2:
135-
self.save(self.checkpoint_dir, step)
139+
if np.mod(global_step, 500) == 2:
140+
self.save(global_step)
136141

137142
print "Final Results:"
138143
final_pred, final_loss = self.sess.run([self.pred, self.loss], test_data_feed)
@@ -169,3 +174,6 @@ def load(self):
169174
else:
170175
print(" [*] Failed to find a checkpoint")
171176
return False, 0
177+
178+
def plot_samples(self):
179+
pass

train_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
"""
55
import json
66
import os
7-
import random
87
import tensorflow as tf
98

109
from build_graph import build_lstm_graph_with_config
1110
from config import DEFAULT_CONFIG, MODEL_DIR
12-
from data_wrapper import StockDataSet
11+
from data_model import StockDataSet
1312

1413

15-
def load_data(stock_name, config=DEFAULT_CONFIG):
16-
stock_dataset = StockDataSet(stock_name, config, test_ratio=0.1, close_price_only=True)
14+
def load_data(stock_name, input_size, num_steps):
15+
stock_dataset = StockDataSet(stock_name, input_size=input_size, num_steps=num_steps,
16+
test_ratio=0.1, close_price_only=True)
1717
print "Train data size:", len(stock_dataset.train_X)
1818
print "Test data size:", len(stock_dataset.test_X)
1919
return stock_dataset
@@ -34,7 +34,7 @@ def train_lstm_graph(stock_name, lstm_graph, config=DEFAULT_CONFIG):
3434
stock_name (str)
3535
lstm_graph (tf.Graph)
3636
"""
37-
stock_dataset = load_data(stock_name, config=config)
37+
stock_data = load_data(stock_name, input_size=config.input_size, num_steps=config.num_steps)
3838

3939
final_prediction = []
4040
final_loss = None
@@ -61,8 +61,8 @@ def train_lstm_graph(stock_name, lstm_graph, config=DEFAULT_CONFIG):
6161
learning_rate = graph.get_tensor_by_name('learning_rate:0')
6262

6363
test_data_feed = {
64-
inputs: stock_dataset.test_X,
65-
targets: stock_dataset.test_y,
64+
inputs: stock_data.test_X,
65+
targets: stock_data.test_y,
6666
learning_rate: 0.0
6767
}
6868

@@ -73,7 +73,7 @@ def train_lstm_graph(stock_name, lstm_graph, config=DEFAULT_CONFIG):
7373
for epoch_step in range(config.max_epoch):
7474
current_lr = learning_rates_to_use[epoch_step]
7575

76-
for batch_X, batch_y in stock_dataset.generate_one_epoch(config.batch_size):
76+
for batch_X, batch_y in stock_data.generate_one_epoch(config.batch_size):
7777
train_data_feed = {
7878
inputs: batch_X,
7979
targets: batch_y,
@@ -83,13 +83,13 @@ def train_lstm_graph(stock_name, lstm_graph, config=DEFAULT_CONFIG):
8383

8484
if epoch_step % 10 == 0:
8585
test_loss, _pred, _summary = sess.run([loss, prediction, merged_summary], test_data_feed)
86-
assert len(_pred) == len(stock_dataset.test_y)
86+
assert len(_pred) == len(stock_data.test_y)
8787
print "Epoch %d [%f]:" % (epoch_step, current_lr), test_loss
8888
if epoch_step % 50 == 0:
8989
print "Predictions:", [(
90-
map(lambda x: round(x, 4), _pred[-j]),
91-
map(lambda x: round(x, 4), stock_dataset.test_y[-j])
92-
) for j in range(5)]
90+
map(lambda x: round(x, 4), _pred[-j]),
91+
map(lambda x: round(x, 4), stock_data.test_y[-j])
92+
) for j in range(5)]
9393

9494
writer.add_summary(_summary, global_step=epoch_step)
9595

0 commit comments

Comments
 (0)