-
Notifications
You must be signed in to change notification settings - Fork 0
/
nas_transductive.py
128 lines (113 loc) · 5.1 KB
/
nas_transductive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import random
import time
import argparse
import time
import utils as utils
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T
import torch_sparse
import os
import gc
import math
from utils import *
from torch_sparse import SparseTensor
from copy import deepcopy
from torch_geometric.utils import coalesce
from models.basicgnn_large import GCN as GCN_PYG, GIN as GIN_PYG, SGC as SGC_PYG, GraphSAGE as SAGE_PYG, JKNet as JKNet_PYG
from models.mlp import MLP as MLP_PYG
from models.parametrized_adj import PGE
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=1, help='gpu id')
parser.add_argument('--parallel_gpu_ids', type=list, default=[0,1], help='gpu id')
parser.add_argument('--dataset', type=str, default='ogbn-arxiv')
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
#gnn
parser.add_argument('--nlayers', type=int, default=3)
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--activation', type=str, default='elu')
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--inference', type=bool, default=False)
parser.add_argument('--teacher_model', type=str, default='GCN')
parser.add_argument('--lr_teacher_model', type=float, default=0.01)#arxiv:0.01 cora:0.001 pubmed:0.001
parser.add_argument('--save', type=int, default=1)
#loop and validation
parser.add_argument('--teacher_model_loop', type=int, default=600)
parser.add_argument('--teacher_val_stage', type=int, default=10)
args = parser.parse_args()
print(args)
device='cuda'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
torch.cuda.set_device(args.gpu_id)
print("Let's use", torch.cuda.device_count(), "GPUs!")
# random seed setting
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
def train_teacher():
start = time.perf_counter()
optimizer_origin=torch.optim.Adam(teacher_model.parameters(), lr=args.lr_teacher_model)
best_val=0
best_test=0
for it in range(args.teacher_model_loop+1):
#whole graph
teacher_model.train()
optimizer_origin.zero_grad()
output = teacher_model.forward(feat.to(device), adj.to(device))[idx_train]
loss = F.nll_loss(output, labels_train)
loss.backward()
optimizer_origin.step()
if(it%args.teacher_val_stage==0):
if args.inference==True:
output = teacher_model.inference(feat, inference_loader, device)
else:
output = teacher_model.predict(feat.to(device), adj.to(device))
acc_train = utils.accuracy(output[idx_train], labels_train)
acc_val = utils.accuracy(output[idx_val], labels_val)
acc_test = utils.accuracy(output[idx_test], labels_test)
print(f'Epoch: {it:02d}, '
f'Loss: {loss.item():.4f}, '
f'Train: {100 * acc_train.item():.2f}%, '
f'Valid: {100 * acc_val.item():.2f}% '
f'Test: {100 * acc_test.item():.2f}%')
if(acc_val>best_val):
best_val=acc_val
best_test=acc_test
end = time.perf_counter()
print("Best Test:", best_test)
print('Traing on the Original Graph Duration:', round(end-start), 's')
return
if __name__ == '__main__':
root=os.path.abspath(os.path.dirname(__file__))
data = get_dataset(args.dataset, args.normalize_features)#get a Pyg2Dpr class, contains all index, adj, labels, features
feat=torch.FloatTensor(data.features).to('cpu')
adj=utils.to_tensor(data.adj, device='cpu')
labels=torch.LongTensor(data.labels).to(device)
idx_train, idx_val, idx_test=data.idx_train, data.idx_val, data.idx_test
labels_train, labels_val, labels_test=labels[idx_train], labels[idx_val], labels[idx_test]
d = feat.shape[1]
nclass= int(labels.max()+1)
del data
gc.collect()
if utils.is_sparse_tensor(adj):
adj = utils.normalize_adj_tensor(adj, sparse=True)
else:
adj = utils.normalize_adj_tensor(adj)
adj=SparseTensor(row=adj._indices()[0], col=adj._indices()[1],value=adj._values(), sparse_sizes=adj.size()).t()
#teacher_model
if args.teacher_model=='GCN':
teacher_model = GCN_PYG(nfeat=d, nhid=args.hidden, nclass=nclass, dropout=args.dropout, nlayers=args.nlayers, norm='BatchNorm', act=args.activation).to(device)
elif args.teacher_model=='SGC':
teacher_model = SGC_PYG(nfeat=d, nhid=args.hidden, nclass=nclass, dropout=args.dropout, nlayers=args.nlayers, norm=None, sgc=True, act=args.activation).to(device)
else:
teacher_model = SAGE_PYG(nfeat=d, nhid=args.hidden, nclass=nclass, dropout=args.dropout, nlayers=args.nlayers, norm='BatchNorm', act=args.activation).to(device)
train_teacher()