Skip to content

Commit

Permalink
add support for PubMed
Browse files Browse the repository at this point in the history
  • Loading branch information
cynricfu committed Jul 21, 2022
1 parent 160b343 commit 183ef07
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 67 deletions.
5 changes: 5 additions & 0 deletions configs/HGT.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,10 @@
"lastfm": {
"n_layers": 1,
"lr": 0.01
},
"pubmed": {
"n_layers": 3,
"lr": 0.01,
"weight_decay": 0.0
}
}
8 changes: 8 additions & 0 deletions configs/MECCH.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,13 @@
"weight_decay": 0.0,
"batch_size": 102400,
"exclude": false
},
"pubmed": {
"max_mp_length": 1,
"n_layers": 3,
"lr": 0.02,
"weight_decay": 0.0,
"batch_size": 102400,
"exclude": false
}
}
4 changes: 4 additions & 0 deletions configs/RGCN.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@
"n_layers": 2,
"lr": 0.02,
"dropout": 0.0
},
"pubmed": {
"n_layers": 2,
"dropout": 0.0
}
}
83 changes: 23 additions & 60 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import pickle
from pathlib import Path

import dgl
import numpy as np
Expand All @@ -12,62 +13,8 @@
from model.baselines.HGT import HGT
from model.baselines.HAN import HAN, HAN_lp
from model.modules import LinkPrediction_minibatch, LinkPrediction_fullbatch
from utils import metapath2str, add_metapath_connection, get_all_metapaths, load_data_nc, load_data_lp, \
metapath_dict2list, select_metapaths, get_save_path


def load_base_config(path='./configs/base.json'):
with open(path) as f:
config = json.load(f)
print('Base configs loaded.')
return config


def load_model_config(path, dataset):
with open(path) as f:
config = json.load(f)
print('Model configs loaded.')
if dataset in config:
config_out = config['default']
config_out.update(config[dataset])
print('{} dataset configs for this model loaded, override defaults.'.format(dataset))
return config_out
else:
print('Model do not have hyperparameter configs for {} dataset, use defaults.'.format(dataset))
return config['default']


def get_metapath_g(g, args):
# Generate the metapath neighbor graphs of all possible metapaths
# and integrate them into one dgl.DGLGraph -- metapath_g
all_metapaths_dict = get_all_metapaths(g, max_length=args.max_mp_length)
all_metapaths_list = metapath_dict2list(all_metapaths_dict)
metapath_g = None
for mp in all_metapaths_list:
metapath_g = add_metapath_connection(g, mp, metapath_g)
# copy features and labels
metapath_g.ndata["x"] = g.ndata["x"]
metapath_g.ndata["y"] = g.ndata["y"]
# select only max-length metapath
selected_metapaths = select_metapaths(all_metapaths_list, length=args.max_mp_length)

return metapath_g, selected_metapaths


def get_khop_g(g, args):
homo_g = dgl.to_homogeneous(g)
temp_homo_g = dgl.to_homogeneous(g)
homo_g.edata[dgl.ETYPE][:] = 0
homo_g.edata[dgl.EID] = th.arange(homo_g.num_edges())
for k in range(2, args.max_mp_length + 1):
edges = dgl.khop_graph(temp_homo_g, k).edges()
etypes = th.full((edges[0].shape[0],), k - 1)
eids = th.arange(edges[0].shape[0])
homo_g.add_edges(edges[0], edges[1], {dgl.ETYPE: etypes, dgl.EID: eids})
hetero_g = dgl.to_heterogeneous(homo_g, g.ntypes, ['{}-hop'.format(i + 1) for i in range(args.max_mp_length)])
hetero_g.ndata['x'] = g.ndata['x']
hetero_g.ndata['y'] = g.ndata['y']
return hetero_g
from utils import metapath2str, get_metapath_g, get_khop_g, load_data_nc, load_data_lp, \
get_save_path, load_base_config, load_model_config


def main_nc(args):
Expand Down Expand Up @@ -192,6 +139,7 @@ def main_lp(args):
# load data
(g_train, g_val, g_test), in_dim_dict, (train_eid_dict, val_eid_dict, test_eid_dict), (
val_neg_uv, test_neg_uv) = load_data_lp(args.dataset)
print("Loaded data from dataset: {}".format(args.dataset))

# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
Expand Down Expand Up @@ -232,9 +180,22 @@ def main_lp(args):
test_eid_dict = {metapath2str([g_test.to_canonical_etype(k)]): v for k, v in test_eid_dict.items()}
target_etype = list(train_eid_dict.keys())[0]

g_train, _ = get_metapath_g(g_train, args)
g_val, _ = get_metapath_g(g_val, args)
g_test, selected_metapaths = get_metapath_g(g_test, args)
# cache metapath_g
load_path = Path('./data') / args.dataset / 'metapath_g-max_mp={}'.format(args.max_mp_length)
if load_path.is_dir():
g_list, _ = dgl.load_graphs(str(load_path / 'graph.bin'))
g_train, g_val, g_test = g_list
with open(load_path / 'selected_metapaths.pkl', 'rb') as in_file:
selected_metapaths = pickle.load(in_file)
else:
g_train, _ = get_metapath_g(g_train, args)
g_val, _ = get_metapath_g(g_val, args)
g_test, selected_metapaths = get_metapath_g(g_test, args)
load_path.mkdir()
dgl.save_graphs(str(load_path / 'graph.bin'), [g_train, g_val, g_test])
with open(load_path / 'selected_metapaths.pkl', 'wb') as out_file:
pickle.dump(selected_metapaths, out_file)

n_heads_list = [args.n_heads] * args.n_layers
model = MECCH(
g_train,
Expand Down Expand Up @@ -290,6 +251,8 @@ def main_lp(args):
minibatch_flag = False
elif args.model == 'HAN':
# assume the target node type has attributes
# Note: this HAN version from DGL conducts full-batch training with online metapath_reachable_graph,
# preprocessing needed for the PubMed dataset
assert args.hidden_dim % args.n_heads == 0
n_heads_list = [args.n_heads] * args.n_layers
model_lp = HAN_lp(
Expand Down
31 changes: 24 additions & 7 deletions model/MECCH.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,16 @@ def __init__(self, n_metapaths, in_dim, out_dim, fusion_type="conv"):

def forward(self, h_list):
if self.fusion_type == "mean":
return self.linear(th.mean(th.stack(h_list), dim=0))
fused = th.mean(th.stack(h_list), dim=0)
elif self.fusion_type == "weight":
return self.linear(th.sum(th.stack(h_list) * self.weight[:, None, None], dim=0))
fused = th.sum(th.stack(h_list) * self.weight[:, None, None], dim=0)
elif self.fusion_type == "conv":
return self.linear(th.sum(th.stack(h_list).transpose(0, 1) * self.conv, dim=1))
fused = th.sum(th.stack(h_list).transpose(0, 1) * self.conv, dim=1)
elif self.fusion_type == "cat":
return self.linear(th.hstack(h_list))
fused = th.hstack(h_list)
else:
raise NotImplementedError
return self.linear(fused), fused


class MECCHLayer(nn.Module):
Expand Down Expand Up @@ -210,12 +211,13 @@ def forward(self, block, h_dict):
block.dstnodes[ntype].data["h_dst"] = h_dict[ntype][:block.num_dst_nodes(ntype)]

out_h_dict = {}
out_embs_dict = {}
for ntype in block.dsttypes:
if block.num_dst_nodes(ntype) > 0:
metapath_outs = []
for metapath_str in self.metapaths_dict[ntype]:
metapath_outs.append(self.context_encoders[metapath_str](block, h_dict, metapath_str))
out_h_dict[ntype] = self.metapath_fuse[ntype](metapath_outs)
out_h_dict[ntype], out_embs_dict[ntype] = self.metapath_fuse[ntype](metapath_outs)

for ntype in out_h_dict:
if self.residual is not None:
Expand All @@ -228,7 +230,7 @@ def forward(self, block, h_dict):
out_h_dict[ntype] = self.activation(out_h_dict[ntype])
out_h_dict[ntype] = self.dropout(out_h_dict[ntype])

return out_h_dict
return out_h_dict, out_embs_dict


class MECCH(nn.Module):
Expand Down Expand Up @@ -299,10 +301,25 @@ def forward(self, blocks, x_dict):
h_dict = h_embed_dict | h_linear_dict

for block, layer in zip(blocks, self.MECCH_layers):
h_dict = layer(block, h_dict)
h_dict, _ = layer(block, h_dict)

return h_dict

# used to get node representations for node classification tasks
# (i.e., the node vectors just before applying the final linear layer of the last MECCH layer)
def get_embs(self, blocks, x_dict):
nids_dict = {ntype: nids for ntype, nids in blocks[0].srcdata[dgl.NID].items() if self.in_dim_dict[ntype] < 0}

# ntype-specific embedding/projection
h_embed_dict = self.embed_layer(nids_dict)
h_linear_dict = self.linear_layer(x_dict)
h_dict = h_embed_dict | h_linear_dict

for block, layer in zip(blocks, self.MECCH_layers):
h_dict, embs_dict = layer(block, h_dict)

return h_dict, embs_dict


class khopMECCHLayer(nn.Module):
def __init__(
Expand Down
8 changes: 8 additions & 0 deletions model/baselines/HAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def forward(self, g, x_dict):

return {self.target_ntype: self.predict(h)}

def get_embs(self, g, x_dict):
h = x_dict[self.target_ntype]

for gnn in self.layers:
h = gnn(g, h)

return {self.target_ntype: self.predict(h)}, {self.target_ntype: h}


class HAN_lp(nn.Module):
def __init__(self, g, metapaths_u, target_ntype_u, in_size_u, metapaths_v, target_ntype_v, in_size_v, hidden_size,
Expand Down
34 changes: 34 additions & 0 deletions model/baselines/HGT.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,37 @@ def forward(self, G, x_dict):
for i in range(self.n_layers):
h_dict = self.gcs[i](G, h_dict)
return {ntype: self.out(h) for ntype, h in h_dict.items()}

def get_embs(self, G, x_dict):
h_dict = {}
if isinstance(G, list):
# minibatch
nids_dict = {
ntype: nids
for ntype, nids in G[0].srcdata[dgl.NID].items()
if self.in_dim_dict[ntype] < 0
}
h_embed_dict = self.embed_layer(nids_dict)
h_linear_dict = self.linear_layer(x_dict)
h_dict = h_embed_dict | h_linear_dict
for ntype in h_dict:
h_dict[ntype] = F.gelu(h_dict[ntype])

for layer, block in zip(self.gcs, G):
h_dict = layer(block, h_dict)
else:
# full batch
nids_dict = {
ntype: G.nodes(ntype)
for ntype in G.ntypes
if self.in_dim_dict[ntype] < 0
}
h_embed_dict = self.embed_layer(nids_dict)
h_linear_dict = self.linear_layer(x_dict)
h_dict = h_embed_dict | h_linear_dict
for ntype in h_dict:
h_dict[ntype] = F.gelu(h_dict[ntype])

for i in range(self.n_layers):
h_dict = self.gcs[i](G, h_dict)
return {ntype: self.out(h) for ntype, h in h_dict.items()}, h_dict
32 changes: 32 additions & 0 deletions model/baselines/RGCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,35 @@ def forward(self, g=None, x_dict=None):
h_dict = layer(g, h_dict)

return h_dict

def get_embs(self, g=None, x_dict=None):
if isinstance(g, list):
# minibatch forward
nids_dict = {
ntype: nids
for ntype, nids in g[0].srcdata[dgl.NID].items()
if self.in_dim_dict[ntype] < 0
}
h_embed_dict = self.embed_layer(nids_dict)
h_linear_dict = self.linear_layer(x_dict)
h_dict = h_embed_dict | h_linear_dict

for layer, block in zip(self.layers, g):
embs_dict = h_dict
h_dict = layer(block, h_dict)
else:
# full graph forward
nids_dict = {
ntype: g.nodes(ntype)
for ntype in g.ntypes
if self.in_dim_dict[ntype] < 0
}
h_embed_dict = self.embed_layer(nids_dict)
h_linear_dict = self.linear_layer(x_dict)
h_dict = h_embed_dict | h_linear_dict

for layer in self.layers:
embs_dict = h_dict
h_dict = layer(g, h_dict)

return h_dict, embs_dict
Loading

0 comments on commit 183ef07

Please sign in to comment.