13
13
14
14
from tensorflow .contrib .tensorboard .plugins import projector
15
15
16
+ from restore_model import prediction_by_trained_graph
17
+
16
18
17
19
class LstmRNN (object ):
18
20
def __init__ (self , sess , stock_count ,
@@ -161,7 +163,7 @@ def train(self, dataset_list, config):
161
163
merged_test_X = []
162
164
merged_test_y = []
163
165
merged_test_labels = []
164
-
166
+
165
167
for label_ , d_ in enumerate (dataset_list ):
166
168
merged_test_X += list (d_ .test_X )
167
169
merged_test_y += list (d_ .test_y )
@@ -181,7 +183,8 @@ def train(self, dataset_list, config):
181
183
self .targets : merged_test_y ,
182
184
self .symbols : merged_test_labels ,
183
185
}
184
-
186
+ print 'merged_test_X' , merged_test_X
187
+ print 'merged_test_y' , merged_test_y
185
188
global_step = 0
186
189
187
190
num_batches = sum (len (d_ .train_X ) for d_ in dataset_list ) // config .batch_size
@@ -196,7 +199,7 @@ def train(self, dataset_list, config):
196
199
i for i , sym_label in enumerate (merged_test_labels )
197
200
if sym_label [0 ] == l ])
198
201
sample_indices [sym ] = target_indices
199
- print sample_indices
202
+
200
203
201
204
print "Start training for stocks:" , [d .stock_sym for d in dataset_list ]
202
205
for epoch in xrange (config .max_epoch ):
@@ -222,7 +225,7 @@ def train(self, dataset_list, config):
222
225
223
226
if np .mod (global_step , len (dataset_list ) * 100 / config .input_size ) == 1 :
224
227
test_loss , test_pred = self .sess .run ([self .loss , self .pred ], test_data_feed )
225
-
228
+
226
229
print "Step:%d [Epoch:%d] [Learning rate: %.6f] train_loss:%.6f test_loss:%.6f" % (
227
230
global_step , epoch , learning_rate , train_loss , test_loss )
228
231
@@ -238,7 +241,7 @@ def train(self, dataset_list, config):
238
241
self .save (global_step )
239
242
240
243
final_pred , final_loss = self .sess .run ([self .pred , self .loss ], test_data_feed )
241
-
244
+
242
245
# Save the final model
243
246
self .save (global_step )
244
247
return final_pred
@@ -275,6 +278,8 @@ def save(self, step):
275
278
global_step = step
276
279
)
277
280
281
+ print os .path .join (self .model_logs_dir , model_name )
282
+
278
283
def load (self ):
279
284
print (" [*] Reading checkpoints..." )
280
285
ckpt = tf .train .get_checkpoint_state (self .model_logs_dir )
@@ -292,7 +297,7 @@ def load(self):
292
297
def plot_samples (self , preds , targets , figname , stock_sym = None ):
293
298
def _flatten (seq ):
294
299
return [x for y in seq for x in y ]
295
-
300
+
296
301
truths = _flatten (targets )[- 200 :]
297
302
298
303
preds = _flatten (preds )[- 200 :]
@@ -337,4 +342,44 @@ def _flatten(seq):
337
342
plt .title (stock_sym + " | Last %d days in test" % len (truths ))
338
343
339
344
plt .savefig (figname .split ('.' )[0 ]+ '_normalized.png' , format = 'png' , bbox_inches = 'tight' , transparent = True )
340
- plt .close ()
345
+ plt .close ()
346
+
347
+
348
+ def predict (self , dataset_list , max_epoch , config ):
349
+ merged_test_X , merged_test_y , merged_test_labels = [], [], []
350
+ for label_ , d_ in enumerate (dataset_list ):
351
+ merged_test_X += list (d_ .test_X )
352
+ merged_test_y += list (d_ .test_y )
353
+ merged_test_labels += [[label_ ]] * len (d_ .test_X )
354
+
355
+ test_X = np .array (merged_test_X )
356
+ test_y = np .array (merged_test_y )
357
+
358
+ status , counter = self .load ()
359
+ if status :
360
+ graph = tf .get_default_graph ()
361
+ test_data_feed = {
362
+ self .learning_rate : 0.0 ,
363
+ self .inputs : test_X ,
364
+ self .targets : test_y
365
+ }
366
+ #prediction = graph.get_tensor_by_name('output_layer/add:0')
367
+ #loss = graph.get_tensor_by_name('train/loss_mse:0')
368
+
369
+ # Select samples for plotting.
370
+ sample_labels = range (min (config .sample_size , len (dataset_list )))
371
+ sample_indices = {}
372
+ for l in sample_labels :
373
+ sym = dataset_list [l ].stock_sym
374
+ target_indices = np .array ([
375
+ i for i , sym_label in enumerate (merged_test_labels )
376
+ if sym_label [0 ] == l ])
377
+ sample_indices [sym ] = target_indices
378
+
379
+
380
+ test_prediction , test_loss = self .sess .run ([self .pred , self .loss ], test_data_feed )
381
+
382
+ for sample_sym , indices in sample_indices .iteritems ():
383
+ test_pred = test_prediction [indices ]
384
+
385
+ return test_pred , test_loss
0 commit comments