forked from CSUBioGroup/GraphLncLoc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
25 lines (17 loc) · 792 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from data.LncRNADataset import *
from models.classifier import *
from utils.config import *
import datetime
starttime = datetime.datetime.now()
params = config()
dataset = LncRNADataset(raw_dir='data/data.txt', save_dir=f'checkpoints/dglgraph/k{params.k}_d{params.d}')
model = GraphClassifier(in_dim=params.d, hidden_dim=params.hidden_dim, n_classes=params.n_classes, device=params.device)
model.cv_train(dataset, batchSize=params.batchSize,
num_epochs=params.num_epochs,
lr=params.lr,
kFold=params.kFold,
savePath=params.savePath,
device=params.device
)
endtime = datetime.datetime.now()
print(f'Total running time of all codes is {(endtime - starttime).seconds}s. ')