-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
75 lines (60 loc) · 2.6 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# 1.导入相关第三方库或者包(根据自己需求,可以增加、删除等改动)
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
# 2.导入 Notebook 使用的模型
# 建立一个稍微复杂的 LSTM 模型
class LSTM(nn.Module):
def __init__(self, num_hiddens, num_outputs):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(
input_size=1,
hidden_size=num_hiddens,
num_layers=1,
batch_first=True
)
self.fc = nn.Linear(num_hiddens, num_outputs)
def forward(self, x):
x = x.view(x.shape[0], -1, 1)
r_out, (h_n, h_c) = self.lstm(x, None)
out = self.fc(r_out[:, -1, :]) # 只需要最后一个的output
return out
# 输入的数量是前 14 个交易日的收盘价
num_inputs = 14
# 输出是下一个交易日的收盘价
num_outputs = 1
# ------------------------- 请加载您最满意的模型网络结构 -----------------------------
# 读取模型
model = LSTM(128, num_outputs)
model_path = 'results/model.pt'
model.load_state_dict(torch.load(model_path))
model.eval()
def predict(test_x):
'''
对于给定的 x 预测未来的 y 。
:param test_x: 给定的数据集合 x ,对于其中的每一个元素需要预测对应的 y 。e.g.:np.array([[6.69,6.72,6.52,6.66,6.74,6.55,6.35,6.14,6.18,6.17,5.72,5.78,5.69,5.67]]
:return: test_y 对于每一个 test_x 中的元素,给出一个对应的预测值。e.g.:np.array([[0.0063614]])
'''
# test 的数目
n_test = test_x.shape[0]
test_y = None
# --------------------------- 此处下方加入读入模型和预测相关代码 -------------------------------
# 此处为 Notebook 模型示范,你可以根据自己数据处理方式进行改动
# scaler = MinMaxScaler().fit(np.array([0, 300]).reshape(-1, 1))
# test_x = scaler.transform(test_x.reshape(-1, 1)).reshape(-1, 14)
test_x = torch.tensor(test_x, dtype=torch.float32)
test_y = model(test_x)
# 如果使用 MinMaxScaler 进行数据处理,预测后应使用下一句将预测值放缩到原范围内
# test_y = scaler.inverse_transform(test_y.detach().cpu())
test_y = test_y.detach().cpu().numpy()
# --------------------------- 此处上方加入读入模型和预测相关代码 -------------------------------
# 保证输出的是一个 numpy 数组
assert (type(test_y) == np.ndarray)
# 保证 test_y 的 shape 正确
assert (test_y.shape == (n_test, 1))
return test_y