From b24bd898ff341bf695e36861071e6356a3836ab0 Mon Sep 17 00:00:00 2001 From: Xinyu Fu Date: Mon, 17 Jan 2022 18:54:01 +0800 Subject: [PATCH] First commit --- .gitignore | 152 +++++++++++ README.md | 24 ++ configs/HAN.json | 195 ++++++++++++++ configs/HGT.json | 22 ++ configs/MECCH.json | 34 +++ configs/RGCN.json | 21 ++ configs/base.json | 9 + data/acm-gtn/README.md | 7 + data/dblp-gtn/README.md | 7 + data/imdb-gtn/README.md | 7 + data/lastfm/README.md | 7 + experiment/__init__.py | 0 experiment/link_prediction.py | 270 +++++++++++++++++++ experiment/node_classification.py | 263 ++++++++++++++++++ experiment/utils.py | 100 +++++++ main.py | 354 ++++++++++++++++++++++++ model/MECCH.py | 434 ++++++++++++++++++++++++++++++ model/__init__.py | 0 model/baselines/HAN.py | 154 +++++++++++ model/baselines/HGT.py | 213 +++++++++++++++ model/baselines/RGCN.py | 239 ++++++++++++++++ model/baselines/__init__.py | 0 model/modules.py | 83 ++++++ model/utils.py | 12 + utils.py | 264 ++++++++++++++++++ 25 files changed, 2871 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 configs/HAN.json create mode 100644 configs/HGT.json create mode 100644 configs/MECCH.json create mode 100644 configs/RGCN.json create mode 100644 configs/base.json create mode 100644 data/acm-gtn/README.md create mode 100644 data/dblp-gtn/README.md create mode 100644 data/imdb-gtn/README.md create mode 100644 data/lastfm/README.md create mode 100644 experiment/__init__.py create mode 100644 experiment/link_prediction.py create mode 100644 experiment/node_classification.py create mode 100644 experiment/utils.py create mode 100644 main.py create mode 100644 model/MECCH.py create mode 100644 model/__init__.py create mode 100644 model/baselines/HAN.py create mode 100644 model/baselines/HGT.py create mode 100644 model/baselines/RGCN.py create mode 100644 model/baselines/__init__.py create mode 100644 model/modules.py create mode 100644 model/utils.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3eb56cb --- /dev/null +++ b/.gitignore @@ -0,0 +1,152 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..8b31319 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +## MAGNN + +This repository provides a reference implementation of MECCH as described in the paper: + +TODO + +### Dependencies + +* PyTorch 1.10 +* DGL 0.7 +* scikit-learn +* tqdm + +### Datasets + +TODO + +### Usage + +TODO + +### Citing + +TODO \ No newline at end of file diff --git a/configs/HAN.json b/configs/HAN.json new file mode 100644 index 0000000..a7419f3 --- /dev/null +++ b/configs/HAN.json @@ -0,0 +1,195 @@ +{ + "default": { + "hidden_dim": 64, + "n_heads": 8, + "n_layers": 2 + }, + "imdb-gtn": { + "n_layers": 3, + "metapaths": [ + [ + [ + "movie", + "movie-actor", + "actor" + ], + [ + "actor", + "actor-movie", + "movie" + ] + ], + [ + [ + "movie", + "movie-director", + "director" + ], + [ + "director", + "director-movie", + "movie" + ] + ] + ] + }, + "acm-gtn": { + "n_layers": 2, + "metapaths": [ + [ + [ + "paper", + "paper-author", + "author" + ], + [ + "author", + "author-paper", + "paper" + ] + ], + [ + [ + "paper", + "paper-subject", + "subject" + ], + [ + "subject", + "subject-paper", + "paper" + ] + ] + ] + }, + "dblp-gtn": { + "n_layers": 1, + "metapaths": [ + [ + [ + "author", + "author-paper", + "paper" + ], + [ + "paper", + "paper-author", + "author" + ] + ], + [ + [ + "author", + "author-paper", + "paper" + ], + [ + "paper", + "paper-conference", + "conference" + ], + [ + "conference", + "conference-paper", + "paper" + ], + [ + "paper", + "paper-author", + "author" + ] + ] + ] + }, + "lastfm": { + "n_layers": 1, + "metapaths_u": [ + [ + [ + "user", + "user-user", + "user" + ] + ], + [ + [ + "user", + "user-artist", + "artist" + ], + [ + "artist", + "artist-user", + "user" + ] + ], + [ + [ + "user", + "user-artist", + "artist" + ], + [ + "artist", + "artist-tag", + "tag" + ], + [ + "tag", + "tag-artist", + "artist" + ], + [ + "artist", + "artist-user", + "user" + ] + ] + ], + "metapaths_v": [ + [ + [ + "artist", + "artist-user", + "user" + ], + [ + "user", + "user-artist", + "artist" + ] + ], + [ + [ + "artist", + "artist-user", + "user" + ], + [ + "user", + "user-user", + "user" + ], + [ + "user", + "user-artist", + "artist" + ] + ], + [ + [ + "artist", + "artist-tag", + "tag" + ], + [ + "tag", + "tag-artist", + "artist" + ] + ] + ], + "lr": 0.02, + "weight_decay": 0.0 + } +} diff --git a/configs/HGT.json b/configs/HGT.json new file mode 100644 index 0000000..e30200b --- /dev/null +++ b/configs/HGT.json @@ -0,0 +1,22 @@ +{ + "default": { + "hidden_dim": 64, + "n_heads": 8, + "n_layers": 4 + }, + "imdb-gtn": { + "n_layers": 2, + "lr": 0.001 + }, + "acm-gtn": { + "n_layers": 3 + }, + "dblp-gtn": { + "n_layers": 4, + "lr": 0.001 + }, + "lastfm": { + "n_layers": 1, + "lr": 0.01 + } +} diff --git a/configs/MECCH.json b/configs/MECCH.json new file mode 100644 index 0000000..e0694d1 --- /dev/null +++ b/configs/MECCH.json @@ -0,0 +1,34 @@ +{ + "default": { + "ablation": false, + "context_encoder": "mean", + "use_v": false, + "n_heads": 8, + "metapath_fusion": "conv", + "residual": false, + "layer_norm": true, + "hidden_dim": 64, + "n_neighbor_samples": 0, + "batch_size": 128 + }, + "imdb-gtn": { + "max_mp_length": 5, + "n_layers": 1 + }, + "acm-gtn": { + "max_mp_length": 1, + "n_layers": 2 + }, + "dblp-gtn": { + "max_mp_length": 2, + "n_layers": 2 + }, + "lastfm": { + "max_mp_length": 1, + "n_layers": 2, + "lr": 0.01, + "weight_decay": 0.0, + "batch_size": 102400, + "exclude": false + } +} diff --git a/configs/RGCN.json b/configs/RGCN.json new file mode 100644 index 0000000..96f37c6 --- /dev/null +++ b/configs/RGCN.json @@ -0,0 +1,21 @@ +{ + "default": { + "use_self_loop": true, + "hidden_dim": 64, + "n_layers": 2 + }, + "imdb-gtn": { + "n_layers": 2 + }, + "acm-gtn": { + "n_layers": 2 + }, + "dblp-gtn": { + "n_layers": 5 + }, + "lastfm": { + "n_layers": 2, + "lr": 0.02, + "dropout": 0.0 + } +} diff --git a/configs/base.json b/configs/base.json new file mode 100644 index 0000000..6d3c8f6 --- /dev/null +++ b/configs/base.json @@ -0,0 +1,9 @@ +{ + "n_neighbor_samples": 0, + "dropout": 0.5, + "lr": 0.005, + "weight_decay": 0.001, + "n_epochs": 500, + "early_stopping_mode": "score", + "patience": 50 +} diff --git a/data/acm-gtn/README.md b/data/acm-gtn/README.md new file mode 100644 index 0000000..f4afbc9 --- /dev/null +++ b/data/acm-gtn/README.md @@ -0,0 +1,7 @@ +## ACM Dataset + +Originally from [HAN](https://github.com/Jhy1993/HAN). + +Preprocessed by [GTN](https://github.com/seongjunyun/Graph_Transformer_Networks). + +We use a version in DGLGraph format provided by [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN). diff --git a/data/dblp-gtn/README.md b/data/dblp-gtn/README.md new file mode 100644 index 0000000..342263b --- /dev/null +++ b/data/dblp-gtn/README.md @@ -0,0 +1,7 @@ +## DBLP Dataset + +Originally from [HAN](https://github.com/Jhy1993/HAN). + +Preprocessed by [GTN](https://github.com/seongjunyun/Graph_Transformer_Networks). + +We use the version provided by GTN. diff --git a/data/imdb-gtn/README.md b/data/imdb-gtn/README.md new file mode 100644 index 0000000..e16e3ad --- /dev/null +++ b/data/imdb-gtn/README.md @@ -0,0 +1,7 @@ +## IMDB Dataset + +Originally from [HAN](https://github.com/Jhy1993/HAN). + +Preprocessed by [GTN](https://github.com/seongjunyun/Graph_Transformer_Networks). + +We use a version in DGLGraph format provided by [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN). diff --git a/data/lastfm/README.md b/data/lastfm/README.md new file mode 100644 index 0000000..6dc4856 --- /dev/null +++ b/data/lastfm/README.md @@ -0,0 +1,7 @@ +## LastFM Dataset + +Originally from [HetRec 2011](https://grouplens.org/datasets/hetrec-2011/). + +Preprocessed by [MAGNN](https://github.com/cynricfu/MAGNN). + +We discard the negative edges provided by MAGNN and samples hard negative edges for validtaion and testing by ourselves. diff --git a/experiment/__init__.py b/experiment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiment/link_prediction.py b/experiment/link_prediction.py new file mode 100644 index 0000000..4beb2aa --- /dev/null +++ b/experiment/link_prediction.py @@ -0,0 +1,270 @@ +from collections import OrderedDict + +import tqdm +import numpy as np +import torch as th +import torch.nn.functional as F +import dgl + +import experiment.utils as utils + + +def link_prediction_minibatch(model, g_train, g_val, g_test, train_eid_dict, val_eid_dict, test_eid_dict, val_neg_uv, + test_neg_uv, dir_path, args): + model.to(args.device) + + target_etype = list(train_eid_dict.keys())[0] + + # GPU-based sampling results in an error, so only use CPU-based sampling here + num_workers = 4 + if args.n_neighbor_samples <= 0: + block_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args.n_layers) + else: + block_sampler = dgl.dataloading.MultiLayerNeighborSampler([{ + etype: args.n_neighbor_samples for etype in g_train.canonical_etypes}] * args.n_layers) + if args.exclude: + exclude = "reverse_types" + reverse_etypes = args.reverse_etypes + else: + exclude = None + reverse_etypes = None + val_eid2neg_uv = {eid: (u, v) for eid, (u, v) in + zip(val_eid_dict[target_etype].cpu().tolist(), val_neg_uv.cpu().tolist())} + test_eid2neg_uv = {eid: (u, v) for eid, (u, v) in + zip(test_eid_dict[target_etype].cpu().tolist(), test_neg_uv.cpu().tolist())} + train_dataloader = dgl.dataloading.EdgeDataLoader( + g_train, + {target_etype: g_train.edges(etype=target_etype, form='eid')}, + block_sampler, + exclude=exclude, + reverse_etypes=reverse_etypes, + negative_sampler=dgl.dataloading.negative_sampler.Uniform(1), + batch_size=args.batch_size, + shuffle=True, + drop_last=False, + num_workers=num_workers + ) + val_dataloader = dgl.dataloading.EdgeDataLoader( + g_test, + val_eid_dict, + block_sampler, + g_sampling=g_train, + negative_sampler=utils.FixedNegSampler(val_eid2neg_uv), + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers + ) + test_dataloader = dgl.dataloading.EdgeDataLoader( + g_test, + test_eid_dict, + block_sampler, + g_sampling=g_val, + negative_sampler=utils.FixedNegSampler(test_eid2neg_uv), + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers + ) + + optimizer = th.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + early_stopping = utils.EarlyStopping( + patience=args.patience, mode=args.early_stopping_mode, verbose=True, save_path=str(dir_path / "checkpoint.pt") + ) + + for epoch in range(args.n_epochs): + # training + model.train() + with tqdm.tqdm(train_dataloader) as tq: + for iteration, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(tq): + blocks = [b.to(args.device) for b in blocks] + positive_graph = positive_graph.to(args.device) + negative_graph = negative_graph.to(args.device) + + input_features = blocks[0].srcdata["x"] + pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features) + train_loss = F.binary_cross_entropy_with_logits(pos_score, th.ones_like( + pos_score)) + F.binary_cross_entropy_with_logits(neg_score, th.zeros_like(neg_score)) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + # print training info + tq.set_postfix( + {"loss": "{:.03f}".format(train_loss.item())}, refresh=False + ) + + # validation + model.eval() + with tqdm.tqdm(val_dataloader) as tq, th.no_grad(): + val_loss = 0 + pos_score_list = [] + neg_score_list = [] + for iteration, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(tq): + blocks = [b.to(args.device) for b in blocks] + positive_graph = positive_graph.to(args.device) + negative_graph = negative_graph.to(args.device) + + input_features = blocks[0].srcdata["x"] + pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features) + val_loss += F.binary_cross_entropy_with_logits(pos_score, th.ones_like(pos_score), + reduction="sum") + F.binary_cross_entropy_with_logits( + neg_score, th.zeros_like(neg_score), reduction="sum") + + pos_score_list.append(pos_score.cpu().numpy()) + neg_score_list.append(neg_score.cpu().numpy()) + val_loss = val_loss / val_eid_dict[target_etype].numel() + pos_scores = np.concatenate(pos_score_list, axis=0) + neg_scores = np.concatenate(neg_score_list, axis=0) + + val_auroc, val_ap = utils.link_prediction_scores(pos_scores, neg_scores) + + # print validation info + print( + "Epoch {:05d} | AUROC {:.4f} | AP {:.4f} | Val_Loss {:.4f}".format(epoch, val_auroc, val_ap, val_loss)) + + # early stopping + if args.early_stopping_mode == "score": + quantity = val_auroc + elif args.early_stopping_mode == "loss": + quantity = val_loss + else: + raise NotImplementedError + early_stopping(quantity, model) + if early_stopping.early_stop: + print("Early stopping!") + break + + # testing + model.load_state_dict(th.load(str(dir_path / "checkpoint.pt"))) + model.eval() + with tqdm.tqdm(test_dataloader) as tq, th.no_grad(): + pos_score_list = [] + neg_score_list = [] + for iteration, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(tq): + blocks = [b.to(args.device) for b in blocks] + positive_graph = positive_graph.to(args.device) + negative_graph = negative_graph.to(args.device) + + input_features = blocks[0].srcdata["x"] + pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features) + + pos_score_list.append(pos_score.cpu().numpy()) + neg_score_list.append(neg_score.cpu().numpy()) + pos_scores = np.concatenate(pos_score_list, axis=0) + neg_scores = np.concatenate(neg_score_list, axis=0) + + test_auroc, test_ap = utils.link_prediction_scores(pos_scores, neg_scores) + + # print testing info + print("Testing Evaluation Metrics") + print("AUROC: {:.4f}".format(test_auroc)) + print("AP: {:.4f}".format(test_ap)) + # save evaluation results + with dir_path.joinpath("result.txt").open("w") as f: + f.write("AUROC: {:.4f}\n".format(test_auroc)) + f.write("AP: {:.4f}\n".format(test_ap)) + return test_auroc, test_ap + + +def link_prediction_fullbatch(model, g_train, g_val, g_test, train_eid_dict, val_eid_dict, test_eid_dict, val_neg_uv, + test_neg_uv, dir_path, args): + model.to(args.device) + g_train = g_train.to(args.device) + g_val = g_val.to(args.device) + g_test = g_test.to(args.device) + train_eid_dict = {k: v.to(args.device) for k, v in train_eid_dict.items()} + val_eid_dict = {k: v.to(args.device) for k, v in val_eid_dict.items()} + test_eid_dict = {k: v.to(args.device) for k, v in test_eid_dict.items()} + + target_etype = list(train_eid_dict.keys())[0] + target_ntype_u, _, target_ntype_v = g_train.to_canonical_etype(target_etype) + + optimizer = th.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + early_stopping = utils.EarlyStopping( + patience=args.patience, mode=args.early_stopping_mode, verbose=True, save_path=str(dir_path / "checkpoint.pt") + ) + + with tqdm.tqdm(range(args.n_epochs)) as tq: + for epoch in tq: + # training + model.train() + + pos_edges = th.stack(g_train.edges(etype=target_etype), dim=1) + neg_edges = th.clone(pos_edges) + neg_edges[:, 1] = th.randint(0, g_train.num_nodes(target_ntype_v), (neg_edges.shape[0],)) + + pos_score, neg_score = model(pos_edges, neg_edges, g_train, g_train.ndata['x']) + train_loss = F.binary_cross_entropy_with_logits(pos_score, th.ones_like( + pos_score)) + F.binary_cross_entropy_with_logits(neg_score, th.zeros_like(neg_score)) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + # evaluation metrics + train_auroc, train_ap = utils.link_prediction_scores(pos_score.detach().cpu().numpy(), + neg_score.detach().cpu().numpy()) + + # set validation print info + print_info = OrderedDict() + print_info["train_loss"] = "{:.03f}".format(train_loss.item()) + print_info["train_auroc"] = "{:.4f}".format(train_auroc) + print_info["train_ap"] = "{:.4f}".format(train_ap) + + # validation + model.eval() + with th.no_grad(): + pos_edges = th.stack(g_test.find_edges(val_eid_dict[target_etype], etype=target_etype), dim=1) + pos_score, neg_score = model(pos_edges, val_neg_uv, g_train, g_train.ndata['x']) + val_loss = F.binary_cross_entropy_with_logits(pos_score, th.ones_like( + pos_score)) + F.binary_cross_entropy_with_logits(neg_score, th.zeros_like(neg_score)) + + # evaluation metrics + val_auroc, val_ap = utils.link_prediction_scores(pos_score.cpu().numpy(), neg_score.cpu().numpy()) + + # set validation print info + print_info["val_loss"] = "{:.03f}".format(val_loss.item()) + print_info["val_auroc"] = "{:.4f}".format(val_auroc) + print_info["val_ap"] = "{:.4f}".format(val_ap) + + # print training and validation info + tq.set_postfix(print_info, refresh=False) + + # early stopping + if args.early_stopping_mode == "score": + quantity = val_auroc + elif args.early_stopping_mode == "loss": + quantity = val_loss + else: + raise NotImplementedError + early_stopping(quantity, model) + if early_stopping.early_stop: + print("Early stopping!") + break + + # testing + model.load_state_dict(th.load(str(dir_path / "checkpoint.pt"))) + model.eval() + with th.no_grad(): + # forward + pos_edges = th.stack(g_test.find_edges(test_eid_dict[target_etype], etype=target_etype), dim=1) + pos_score, neg_score = model(pos_edges, test_neg_uv, g_val, g_val.ndata['x']) + + # evaluation metrics + test_auroc, test_ap = utils.link_prediction_scores(pos_score.cpu().numpy(), neg_score.cpu().numpy()) + + # print testing info + print("Testing Evaluation Metrics") + print("AUROC: {:.4f}".format(test_auroc)) + print("AP: {:.4f}".format(test_ap)) + # save evaluation results + with dir_path.joinpath("result.txt").open("w") as f: + f.write("AUROC: {:.4f}\n".format(test_auroc)) + f.write("AP: {:.4f}\n".format(test_ap)) + return test_auroc, test_ap diff --git a/experiment/node_classification.py b/experiment/node_classification.py new file mode 100644 index 0000000..62cd4ed --- /dev/null +++ b/experiment/node_classification.py @@ -0,0 +1,263 @@ +from collections import OrderedDict + +import tqdm +import numpy as np +import torch as th +import torch.nn.functional as F +import dgl + +import experiment.utils as utils + + +def node_classification_minibatch(model, g, train_nid_dict, val_nid_dict, test_nid_dict, dir_path, args): + model.to(args.device) + g = g.to(args.device) + train_nid_dict = {k: v.to(args.device) for k, v in train_nid_dict.items()} + val_nid_dict = {k: v.to(args.device) for k, v in val_nid_dict.items()} + test_nid_dict = {k: v.to(args.device) for k, v in test_nid_dict.items()} + + assert len(g.ndata["y"].keys()) == 1 + target_ntype = list(g.ndata["y"].keys())[0] + + # Use GPU-based neighborhood sampling if possible + num_workers = 4 if args.device.type == "CPU" else 0 + if args.n_neighbor_samples <= 0: + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args.n_layers) + else: + sampler = dgl.dataloading.MultiLayerNeighborSampler([{ + etype: args.n_neighbor_samples for etype in g.canonical_etypes}] * args.n_layers) + train_dataloader = dgl.dataloading.NodeDataLoader( + g, + train_nid_dict, + sampler, + batch_size=args.batch_size, + shuffle=True, + drop_last=False, + num_workers=num_workers, + device=args.device, + ) + val_dataloader = dgl.dataloading.NodeDataLoader( + g, + val_nid_dict, + sampler, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + device=args.device, + ) + test_dataloader = dgl.dataloading.NodeDataLoader( + g, + test_nid_dict, + sampler, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + device=args.device, + ) + + optimizer = th.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + early_stopping = utils.EarlyStopping( + patience=args.patience, mode=args.early_stopping_mode, verbose=True, save_path=str(dir_path / "checkpoint.pt") + ) + + for epoch in range(args.n_epochs): + # training + model.train() + with tqdm.tqdm(train_dataloader) as tq: + for iteration, (input_nodes, output_nodes, blocks) in enumerate(tq): + input_features = blocks[0].srcdata["x"] + output_labels = blocks[-1].dstdata["y"] + + logits_dict = model(blocks, input_features) + logp = F.log_softmax(logits_dict[target_ntype], dim=-1) + train_loss = F.nll_loss(logp, output_labels[target_ntype]) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + # print training info + tq.set_postfix( + {"loss": "{:.03f}".format(train_loss.item())}, refresh=False + ) + + # validation + model.eval() + with tqdm.tqdm(val_dataloader) as tq, th.no_grad(): + val_loss = 0 + logits_list = [] + y_true_list = [] + for iteration, (input_nodes, output_nodes, blocks) in enumerate(tq): + input_features = blocks[0].srcdata["x"] + output_labels = blocks[-1].dstdata["y"] + + logits_dict = model(blocks, input_features) + logp = F.log_softmax(logits_dict[target_ntype], dim=-1) + val_loss += F.nll_loss(logp, output_labels[target_ntype], reduction="sum") + + logits_list.append(logits_dict[target_ntype].cpu().numpy()) + y_true_list.append(output_labels[target_ntype].cpu().numpy()) + + val_loss = val_loss / val_nid_dict[target_ntype].numel() + logits = np.concatenate(logits_list, axis=0) + y_true = np.concatenate(y_true_list, axis=0) + + val_acc, val_auroc, val_macro_f1, val_micro_f1 = utils.classification_scores( + y_true, logits + ) + + # print validation info + print( + "Epoch {:05d} | Macro-F1 {:.4f} | Micro-F1 {:.4f} | Val_Loss {:.4f}".format( + epoch, val_macro_f1, val_micro_f1, val_loss.item() + ) + ) + + # early stopping + if args.early_stopping_mode == "score": + quantity = (val_macro_f1 + val_micro_f1) / 2 + elif args.early_stopping_mode == "loss": + quantity = val_loss + else: + raise NotImplementedError + early_stopping(quantity, model) + if early_stopping.early_stop: + print("Early stopping!") + break + + # testing + model.load_state_dict(th.load(str(dir_path / "checkpoint.pt"))) + model.eval() + with tqdm.tqdm(test_dataloader) as tq, th.no_grad(): + logits_list = [] + y_true_list = [] + for iteration, (input_nodes, output_nodes, blocks) in enumerate(tq): + input_features = blocks[0].srcdata["x"] + output_labels = blocks[-1].dstdata["y"] + + logits_dict = model(blocks, input_features) + + logits_list.append(logits_dict[target_ntype].cpu().numpy()) + y_true_list.append(output_labels[target_ntype].cpu().numpy()) + + logits = np.concatenate(logits_list, axis=0) + y_true = np.concatenate(y_true_list, axis=0) + test_acc, test_auroc, test_macro_f1, test_micro_f1 = utils.classification_scores( + y_true, logits + ) + + # print testing info + print("Testing Evaluation Metrics") + print("Macro-F1: {:.4f}".format(test_macro_f1)) + print("Micro-F1: {:.4f}".format(test_micro_f1)) + # save evaluation results + with dir_path.joinpath("result.txt").open("w") as f: + f.write("Macro-F1: {:.4f}\n".format(test_macro_f1)) + f.write("Micro-F1: {:.4f}\n".format(test_micro_f1)) + return test_macro_f1, test_micro_f1 + + +def node_classification_fullbatch(model, g, train_nid_dict, val_nid_dict, test_nid_dict, dir_path, args): + model.to(args.device) + g = g.to(args.device) + train_nid_dict = {k: v.to(args.device) for k, v in train_nid_dict.items()} + val_nid_dict = {k: v.to(args.device) for k, v in val_nid_dict.items()} + test_nid_dict = {k: v.to(args.device) for k, v in test_nid_dict.items()} + + assert len(g.ndata["y"].keys()) == 1 + target_ntype = list(g.ndata["y"].keys())[0] + + optimizer = th.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + early_stopping = utils.EarlyStopping( + patience=args.patience, mode=args.early_stopping_mode, verbose=True, save_path=str(dir_path / "checkpoint.pt") + ) + + with tqdm.tqdm(range(args.n_epochs)) as tq: + for epoch in tq: + # training + model.train() + + logits_dict = model(g, g.ndata['x']) + logits = logits_dict[target_ntype][train_nid_dict[target_ntype]] + y_true = g.ndata['y'][target_ntype][train_nid_dict[target_ntype]] + logp = F.log_softmax(logits, dim=-1) + train_loss = F.nll_loss(logp, y_true) + + optimizer.zero_grad() + train_loss.backward() + optimizer.step() + + # evaluation metrics + train_acc, train_auroc, train_macro_f1, train_micro_f1 = utils.classification_scores( + y_true.detach().cpu().numpy(), logits.detach().cpu().numpy() + ) + + # set training print info + print_info = OrderedDict() + print_info["train_loss"] = "{:.03f}".format(train_loss.item()) + print_info["train_macro_f1"] = "{:.4f}".format(train_macro_f1) + print_info["train_micro_f1"] = "{:.4f}".format(train_micro_f1) + + # validation + model.eval() + with th.no_grad(): + logits_dict = model(g, g.ndata['x']) + logits = logits_dict[target_ntype][val_nid_dict[target_ntype]] + y_true = g.ndata['y'][target_ntype][val_nid_dict[target_ntype]] + logp = F.log_softmax(logits, dim=-1) + val_loss = F.nll_loss(logp, y_true) + + # evaluation metrics + val_acc, val_auroc, val_macro_f1, val_micro_f1 = utils.classification_scores( + y_true.cpu().numpy(), logits.cpu().numpy() + ) + + # set validation print info + print_info["val_loss"] = "{:.03f}".format(val_loss.item()) + print_info["val_macro_f1"] = "{:.4f}".format(val_macro_f1) + print_info["val_micro_f1"] = "{:.4f}".format(val_micro_f1) + + # print training and validation info + tq.set_postfix(print_info, refresh=False) + + # early stopping + if args.early_stopping_mode == "score": + quantity = (val_macro_f1 + val_micro_f1) / 2 + elif args.early_stopping_mode == "loss": + quantity = val_loss + else: + raise NotImplementedError + early_stopping(quantity, model) + if early_stopping.early_stop: + print("Early stopping!") + break + + # testing + model.load_state_dict(th.load(str(dir_path / "checkpoint.pt"))) + model.eval() + with th.no_grad(): + # forward + logits_dict = model(g, g.ndata['x']) + logits = logits_dict[target_ntype][test_nid_dict[target_ntype]] + y_true = g.ndata['y'][target_ntype][test_nid_dict[target_ntype]] + + # evaluation metrics + test_acc, test_auroc, test_macro_f1, test_micro_f1 = utils.classification_scores( + y_true.cpu().numpy(), logits.cpu().numpy() + ) + + # print testing info + print("Testing Evaluation Metrics") + print("Macro-F1: {:.4f}".format(test_macro_f1)) + print("Micro-F1: {:.4f}".format(test_micro_f1)) + # save evaluation results + with dir_path.joinpath("result.txt").open("w") as f: + f.write("Macro-F1: {:.4f}\n".format(test_macro_f1)) + f.write("Micro-F1: {:.4f}\n".format(test_micro_f1)) + return test_macro_f1, test_micro_f1 diff --git a/experiment/utils.py b/experiment/utils.py new file mode 100644 index 0000000..6860cee --- /dev/null +++ b/experiment/utils.py @@ -0,0 +1,100 @@ +import numpy as np +from scipy.special import softmax +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score +import torch as th +import dgl + + +def classification_scores(y_true, logits): + y_pred = np.argmax(logits, axis=1) + y_score = softmax(logits, axis=1) + + accuracy = accuracy_score(y_true, y_pred) + auroc = roc_auc_score(y_true, y_score, multi_class="ovr") + macro_f1 = f1_score(y_true, y_pred, average="macro") + micro_f1 = f1_score(y_true, y_pred, average="micro") + + return accuracy, auroc, macro_f1, micro_f1 + + +def link_prediction_scores(pos_scores, neg_scores): + y_true = np.concatenate((np.ones_like(pos_scores, dtype=np.int64), np.zeros_like(neg_scores, dtype=np.int64)), + axis=0) + y_score = np.concatenate((pos_scores, neg_scores), axis=0) + + auroc = roc_auc_score(y_true, y_score) + ap = average_precision_score(y_true, y_score) + + return auroc, ap + + +class EarlyStopping: + """Early stops the training if validation score/loss doesn't improve after a given patience.""" + + def __init__( + self, + patience=10, + delta=0, + mode="score", + save_path="checkpoint.pt", + verbose=False, + ): + """ + Args: + patience (int): How long to wait after last time validation score/loss improved. + Default: 10 + verbose (bool): If True, prints a message for each validation score/loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.delta = delta + self.mode = mode + self.save_path = save_path + self.verbose = verbose + self.counter = 0 + self.best_score = -np.Inf + self.early_stop = False + + def __call__(self, quantity, model): + if self.mode == "score": + score = quantity + elif self.mode == "loss": + score = -quantity + else: + raise NotImplementedError + + if score < self.best_score + self.delta: + self.counter += 1 + print(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.save_checkpoint(quantity, model) + self.best_score = score + self.counter = 0 + + def save_checkpoint(self, quantity, model): + """Saves model when validation score/loss improves.""" + if self.verbose: + if self.mode == "score": + print( + f"Validation score increased ({self.best_score:.6f} --> {quantity:.6f}). Saving model ..." + ) + elif self.mode == "loss": + print( + f"Validation loss decreased ({-self.best_score:.6f} --> {quantity:.6f}). Saving model ..." + ) + else: + raise NotImplementedError + th.save(model.state_dict(), self.save_path) + + +class FixedNegSampler(dgl.dataloading.negative_sampler._BaseNegativeSampler): + def __init__(self, eid2neg_uv): + self.eid2neg_uv = eid2neg_uv + + def _generate(self, g, eids, canonical_etype): + edges = th.tensor([self.eid2neg_uv[eid] for eid in eids.cpu().tolist()], dtype=g.idtype, device=g.device) + return edges[:, 0], edges[:, 1] diff --git a/main.py b/main.py new file mode 100644 index 0000000..a1a5e3b --- /dev/null +++ b/main.py @@ -0,0 +1,354 @@ +import argparse +import json + +import dgl +import numpy as np +import torch as th + +from experiment.node_classification import node_classification_minibatch, node_classification_fullbatch +from experiment.link_prediction import link_prediction_minibatch, link_prediction_fullbatch +from model.MECCH import MECCH, khopMECCH +from model.baselines.RGCN import RGCN +from model.baselines.HGT import HGT +from model.baselines.HAN import HAN, HAN_lp +from model.modules import LinkPrediction_minibatch, LinkPrediction_fullbatch +from utils import metapath2str, add_metapath_connection, get_all_metapaths, load_data_nc, load_data_lp, \ + metapath_dict2list, select_metapaths, get_save_path + + +def load_base_config(path='./configs/base.json'): + with open(path) as f: + config = json.load(f) + print('Base configs loaded.') + return config + + +def load_model_config(path, dataset): + with open(path) as f: + config = json.load(f) + print('Model configs loaded.') + if dataset in config: + config_out = config['default'] + config_out.update(config[dataset]) + print('{} dataset configs for this model loaded, override defaults.'.format(dataset)) + return config_out + else: + print('Model do not have hyperparameter configs for {} dataset, use defaults.'.format(dataset)) + return config['default'] + + +def get_metapath_g(g, args): + # Generate the metapath neighbor graphs of all possible metapaths + # and integrate them into one dgl.DGLGraph -- metapath_g + all_metapaths_dict = get_all_metapaths(g, max_length=args.max_mp_length) + all_metapaths_list = metapath_dict2list(all_metapaths_dict) + metapath_g = None + for mp in all_metapaths_list: + metapath_g = add_metapath_connection(g, mp, metapath_g) + # copy features and labels + metapath_g.ndata["x"] = g.ndata["x"] + metapath_g.ndata["y"] = g.ndata["y"] + # select only max-length metapath + selected_metapaths = select_metapaths(all_metapaths_list, length=args.max_mp_length) + + return metapath_g, selected_metapaths + + +def get_khop_g(g, args): + homo_g = dgl.to_homogeneous(g) + temp_homo_g = dgl.to_homogeneous(g) + homo_g.edata[dgl.ETYPE][:] = 0 + homo_g.edata[dgl.EID] = th.arange(homo_g.num_edges()) + for k in range(2, args.max_mp_length + 1): + edges = dgl.khop_graph(temp_homo_g, k).edges() + etypes = th.full((edges[0].shape[0],), k - 1) + eids = th.arange(edges[0].shape[0]) + homo_g.add_edges(edges[0], edges[1], {dgl.ETYPE: etypes, dgl.EID: eids}) + hetero_g = dgl.to_heterogeneous(homo_g, g.ntypes, ['{}-hop'.format(i + 1) for i in range(args.max_mp_length)]) + hetero_g.ndata['x'] = g.ndata['x'] + hetero_g.ndata['y'] = g.ndata['y'] + return hetero_g + + +def main_nc(args): + dir_path_list = [] + for _ in range(args.repeat): + dir_path_list.append(get_save_path(args)) + + test_macro_f1_list = [] + test_micro_f1_list = [] + for i in range(args.repeat): + # load data + g, in_dim_dict, out_dim, train_nid_dict, val_nid_dict, test_nid_dict = load_data_nc(args.dataset) + print("Loaded data from dataset: {}".format(args.dataset)) + + # check cuda + use_cuda = args.gpu >= 0 and th.cuda.is_available() + if use_cuda: + args.device = th.device('cuda', args.gpu) + else: + args.device = th.device('cpu') + + # create model + model-specific data preprocessing + if args.model == "MECCH": + if args.ablation: + g = get_khop_g(g, args) + model = khopMECCH( + g, + in_dim_dict, + args.hidden_dim, + out_dim, + args.n_layers, + dropout=args.dropout, + residual=args.residual, + layer_norm=args.layer_norm + ) + else: + g, selected_metapaths = get_metapath_g(g, args) + n_heads_list = [args.n_heads] * args.n_layers + model = MECCH( + g, + selected_metapaths, + in_dim_dict, + args.hidden_dim, + out_dim, + args.n_layers, + n_heads_list, + dropout=args.dropout, + context_encoder=args.context_encoder, + use_v=args.use_v, + metapath_fusion=args.metapath_fusion, + residual=args.residual, + layer_norm=args.layer_norm + ) + minibatch_flag = True + elif args.model == "RGCN": + assert args.n_layers >= 2 + model = RGCN( + g, + in_dim_dict, + args.hidden_dim, + out_dim, + num_bases=-1, + num_hidden_layers=args.n_layers - 2, + dropout=args.dropout, + use_self_loop=args.use_self_loop + ) + minibatch_flag = False + elif args.model == "HGT": + model = HGT( + g, + in_dim_dict, + args.hidden_dim, + out_dim, + args.n_layers, + args.n_heads + ) + minibatch_flag = False + elif args.model == "HAN": + # assume the target node type has attributes + assert args.hidden_dim % args.n_heads == 0 + target_ntype = list(g.ndata["y"].keys())[0] + n_heads_list = [args.n_heads] * args.n_layers + model = HAN( + args.metapaths, + target_ntype, + in_dim_dict[target_ntype], + args.hidden_dim // args.n_heads, + out_dim, + num_heads=n_heads_list, + dropout=args.dropout + ) + minibatch_flag = False + else: + raise NotImplementedError + + if minibatch_flag: + test_macro_f1, test_micro_f1 = node_classification_minibatch(model, g, train_nid_dict, val_nid_dict, + test_nid_dict, dir_path_list[i], args) + else: + test_macro_f1, test_micro_f1 = node_classification_fullbatch(model, g, train_nid_dict, val_nid_dict, + test_nid_dict, dir_path_list[i], args) + test_macro_f1_list.append(test_macro_f1) + test_micro_f1_list.append(test_micro_f1) + + print("--------------------------------") + if args.repeat > 1: + print("Macro-F1_MEAN\tMacro-F1_STDEV\tMicro-F1_MEAN\tMicro-F1_STDEV") + print("{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}".format(np.mean(test_macro_f1_list), np.std(test_macro_f1_list, ddof=0), + np.mean(test_micro_f1_list), np.std(test_micro_f1_list, ddof=0))) + else: + print("args.repeat <= 1, not calculating the average and the standard deviation of scores") + + +def main_lp(args): + dir_path_list = [] + for _ in range(args.repeat): + dir_path_list.append(get_save_path(args)) + + test_auroc_list = [] + test_ap_list = [] + for i in range(args.repeat): + # load data + (g_train, g_val, g_test), in_dim_dict, (train_eid_dict, val_eid_dict, test_eid_dict), ( + val_neg_uv, test_neg_uv) = load_data_lp(args.dataset) + + # check cuda + use_cuda = args.gpu >= 0 and th.cuda.is_available() + if use_cuda: + args.device = th.device('cuda', args.gpu) + else: + args.device = th.device('cpu') + + target_etype = list(train_eid_dict.keys())[0] + # create model + model-specific preprocessing + if args.model == 'MECCH': + if args.ablation: + # Note: here we assume there is only one edge type between users and items + train_eid_dict = {(g_train.to_canonical_etype(k)[0], '1-hop', g_train.to_canonical_etype(k)[2]): v for + k, v in train_eid_dict.items()} + val_eid_dict = {(g_val.to_canonical_etype(k)[0], '1-hop', g_val.to_canonical_etype(k)[2]): v for k, v + in val_eid_dict.items()} + test_eid_dict = {(g_test.to_canonical_etype(k)[0], '1-hop', g_test.to_canonical_etype(k)[2]): v for k, v + in test_eid_dict.items()} + target_etype = list(train_eid_dict.keys())[0] + + g_train = get_khop_g(g_train, args) + g_val = get_khop_g(g_val, args) + g_test = get_khop_g(g_test, args) + model = khopMECCH( + g_train, + in_dim_dict, + args.hidden_dim, + args.hidden_dim, + args.n_layers, + dropout=args.dropout, + residual=args.residual, + layer_norm=args.layer_norm + ) + else: + train_eid_dict = {metapath2str([g_train.to_canonical_etype(k)]): v for k, v in train_eid_dict.items()} + val_eid_dict = {metapath2str([g_val.to_canonical_etype(k)]): v for k, v in val_eid_dict.items()} + test_eid_dict = {metapath2str([g_test.to_canonical_etype(k)]): v for k, v in test_eid_dict.items()} + target_etype = list(train_eid_dict.keys())[0] + + g_train, _ = get_metapath_g(g_train, args) + g_val, _ = get_metapath_g(g_val, args) + g_test, selected_metapaths = get_metapath_g(g_test, args) + n_heads_list = [args.n_heads] * args.n_layers + model = MECCH( + g_train, + selected_metapaths, + in_dim_dict, + args.hidden_dim, + args.hidden_dim, + args.n_layers, + n_heads_list, + dropout=args.dropout, + context_encoder=args.context_encoder, + use_v=args.use_v, + metapath_fusion=args.metapath_fusion, + residual=args.residual, + layer_norm=args.layer_norm + ) + model_lp = LinkPrediction_minibatch(model, args.hidden_dim, target_etype) + minibatch_flag = True + elif args.model == 'RGCN': + assert args.n_layers >= 2 + model = RGCN( + g_train, + in_dim_dict, + args.hidden_dim, + args.hidden_dim, + num_bases=-1, + num_hidden_layers=args.n_layers - 2, + dropout=args.dropout, + use_self_loop=args.use_self_loop + ) + if hasattr(args, 'batch_size'): + model_lp = LinkPrediction_minibatch(model, args.hidden_dim, target_etype) + minibatch_flag = True + else: + srctype, _, dsttype = g_train.to_canonical_etype(target_etype) + model_lp = LinkPrediction_fullbatch(model, args.hidden_dim, srctype, dsttype) + minibatch_flag = False + elif args.model == 'HGT': + model = HGT( + g_train, + in_dim_dict, + args.hidden_dim, + args.hidden_dim, + args.n_layers, + args.n_heads + ) + if hasattr(args, 'batch_size'): + model_lp = LinkPrediction_minibatch(model, args.hidden_dim, target_etype) + minibatch_flag = True + else: + srctype, _, dsttype = g_train.to_canonical_etype(target_etype) + model_lp = LinkPrediction_fullbatch(model, args.hidden_dim, srctype, dsttype) + minibatch_flag = False + elif args.model == 'HAN': + # assume the target node type has attributes + assert args.hidden_dim % args.n_heads == 0 + n_heads_list = [args.n_heads] * args.n_layers + model_lp = HAN_lp( + g_train, + args.metapaths_u, + args.metapaths_u[0][0][0], + -1, + args.metapaths_v, + args.metapaths_v[0][0][0], + -1, + args.hidden_dim // args.n_heads, + args.hidden_dim, + num_heads=n_heads_list, + dropout=args.dropout + ) + minibatch_flag = False + else: + raise NotImplementedError + + if minibatch_flag: + test_auroc, test_ap = link_prediction_minibatch(model_lp, g_train, g_val, g_test, train_eid_dict, + val_eid_dict, test_eid_dict, val_neg_uv, test_neg_uv, + dir_path_list[i], args) + else: + test_auroc, test_ap = link_prediction_fullbatch(model_lp, g_train, g_val, g_test, train_eid_dict, + val_eid_dict, test_eid_dict, val_neg_uv, test_neg_uv, + dir_path_list[i], args) + test_auroc_list.append(test_auroc) + test_ap_list.append(test_ap) + + print("--------------------------------") + if args.repeat > 1: + print("ROC-AUC_MEAN\tROC-AUC_STDEV\tPR-AUC_MEAN\tPR-AUC_STDEV") + print("{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}".format(np.mean(test_auroc_list), np.std(test_auroc_list, ddof=0), + np.mean(test_ap_list), np.std(test_ap_list, ddof=0))) + else: + print("args.repeat <= 1, not calculating the average and the standard deviation of scores") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("My HGNNs") + parser.add_argument('--model', '-m', type=str, required=True, help='name of model') + parser.add_argument('--dataset', '-d', type=str, required=True, help='name of dataset') + parser.add_argument('--task', '-t', type=str, default='node_classification', help='type of task') + parser.add_argument("--gpu", '-g', type=int, default=-1, help="which gpu to use, specify -1 to use CPU") + parser.add_argument('--config', '-c', type=str, help='config file for model hyperparameters') + parser.add_argument('--repeat', '-r', type=int, default=1, help='repeat the training and testing for N times') + + args = parser.parse_args() + if args.config is None: + args.config = "./configs/{}.json".format(args.model) + + configs = load_base_config() + configs.update(load_model_config(args.config, args.dataset)) + configs.update(vars(args)) + args = argparse.Namespace(**configs) + print(args) + + if args.task == 'node_classification': + main_nc(args) + elif args.task == 'link_prediction': + main_lp(args) diff --git a/model/MECCH.py b/model/MECCH.py new file mode 100644 index 0000000..55a5271 --- /dev/null +++ b/model/MECCH.py @@ -0,0 +1,434 @@ +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import dgl +import dgl.function as fn +from dgl.nn.functional import edge_softmax + +from model.modules import HeteroEmbedLayer, HeteroLinearLayer +from model.utils import sub_metapaths, get_src_ntypes + + +class MetapathContextEncoder(nn.Module): + def __init__(self, in_dim, encoder_type="mean", use_v=False, n_heads=8): + assert in_dim % n_heads == 0 + super(MetapathContextEncoder, self).__init__() + + self.encoder_type = encoder_type + self.use_v = use_v + self.n_heads = n_heads + self.d_k = in_dim // n_heads + self.sqrt_dk = math.sqrt(self.d_k) + + if encoder_type == "mean": + pass + elif encoder_type == "attention": + self.k_linear = nn.Linear(in_dim, in_dim, False) + if use_v: + self.v_linear = nn.Linear(in_dim, in_dim, False) + self.q_linear = nn.Linear(in_dim, in_dim, False) + else: + raise NotImplementedError + + def forward(self, block, h_dict, metapath_str): + mp_list = sub_metapaths(metapath_str) + src_ntypes = get_src_ntypes(metapath_str) + _, _, dst_ntype = block.to_canonical_etype(metapath_str) + + with block.local_scope(): + if self.encoder_type == "mean": + funcs = {} + num_neigh = 0 + for mp in mp_list: + funcs[mp] = (fn.copy_u("h_src", "m"), fn.sum("m", "h_neigh")) + num_neigh += block.in_degrees(etype=mp) + block.multi_update_all(funcs, "sum") + block.dstnodes[dst_ntype].data["h_dst_out"] = (block.dstnodes[dst_ntype].data["h_neigh"] + + block.dstnodes[dst_ntype].data["h_dst"]) / th.unsqueeze( + num_neigh + 1, dim=-1) + elif self.encoder_type == "attention": + # K, V projections for source nodes + for ntype in src_ntypes: + block.srcnodes[ntype].data['k'] = self.k_linear( + block.srcdata["h_src"][ntype]).view(-1, self.n_heads, self.d_k) + if self.use_v: + block.srcnodes[ntype].data['v'] = self.v_linear( + block.srcdata["h_src"][ntype]).view(-1, self.n_heads, self.d_k) + else: + block.srcnodes[ntype].data['v'] = block.srcdata["h_src"][ntype].view(-1, self.n_heads, self.d_k) + # K, V, Q projections for destination nodes + dst_k = self.k_linear( + block.dstdata["h_dst"][dst_ntype]).view(-1, self.n_heads, self.d_k) + if self.use_v: + dst_v = self.v_linear( + block.dstdata["h_dst"][dst_ntype]).view(-1, self.n_heads, self.d_k) + else: + dst_v = block.dstdata["h_dst"][dst_ntype].view(-1, self.n_heads, self.d_k) + block.dstnodes[dst_ntype].data['q'] = self.q_linear( + block.dstdata["h_dst"][dst_ntype]).view(-1, self.n_heads, self.d_k) + + # compute dot product of k and q for destination nodes, for each head + # also divide by square root of per-head dim + dst_t = th.sum(dst_k * block.dstnodes[dst_ntype].data['q'], dim=-1, keepdim=True) / self.sqrt_dk + # compute dot product of k and q for all edges, for each head + # also divide by square root of per-head dim + for mp in mp_list: + block.apply_edges(fn.u_dot_v('k', 'q', 't'), etype=mp) + block.edges[mp].data['t'] = block.edges[mp].data['t'] / self.sqrt_dk + + # edge_softmax for all edges + # DGL do not support edge_softmax for all edges for heterograph currently + # 1) select etypes of interest + sub_hetero_g = dgl.edge_type_subgraph(block, etypes=mp_list) + # 2) convert to homogeneous graph + sub_homo_g = dgl.to_homogeneous(sub_hetero_g, edata=['t']) + # 3) add self loop + offset_node = [0] + [sub_hetero_g.num_nodes(sub_hetero_g.ntypes[i]) for i in + range(sub_hetero_g.get_ntype_id(dst_ntype))] + offset_node = np.sum(offset_node) + offset_edge = sub_homo_g.num_edges() + u = sub_homo_g.nodes()[offset_node:offset_node + block.num_dst_nodes(dst_ntype)] + sub_homo_g.add_edges(u, u, data={'t': dst_t}) + # 4) perform edge_softmax + sub_homo_g.edata['a'] = edge_softmax(sub_homo_g, sub_homo_g.edata['t'], norm_by='dst') + dst_a = sub_homo_g.edata['a'][-block.num_dst_nodes(dst_ntype):] + # 5) remove self loop and convert back to heterograph + sub_homo_g.remove_edges(offset_edge + th.arange(block.num_dst_nodes(dst_ntype)).to(block.device)) + sub_hetero_g2 = dgl.to_heterogeneous(sub_homo_g, sub_hetero_g.ntypes, sub_hetero_g.etypes) + # 6) copy data + for etype in sub_hetero_g2.canonical_etypes: + block.edges[etype].data['a'] = sub_hetero_g2.edges[etype].data['a'] + + # aggregate neighbors' v multiplied by attention scores + funcs = {mp: (fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'h_neigh')) for mp in mp_list} + block.multi_update_all(funcs, "sum") + + # consider self loop, aggregate destination nodes + if self.use_v: + block.dstnodes[dst_ntype].data['h_dst_out'] = F.relu( + (block.dstnodes[dst_ntype].data['h_neigh'] + dst_a * dst_v).view(-1, self.n_heads * self.d_k)) + else: + block.dstnodes[dst_ntype].data['h_dst_out'] = ( + block.dstnodes[dst_ntype].data['h_neigh'] + dst_a * dst_v).view(-1, self.n_heads * self.d_k) + else: + raise NotImplementedError + + return block.dstnodes[dst_ntype].data["h_dst_out"] + + +class MetapathFusion(nn.Module): + def __init__(self, n_metapaths, in_dim, out_dim, fusion_type="conv"): + super(MetapathFusion, self).__init__() + + self.n_metapaths = n_metapaths + self.fusion_type = fusion_type + + if fusion_type == "mean": + self.linear = nn.Linear(in_dim, out_dim) + elif fusion_type == "weight": + self.weight = nn.Parameter(th.full((n_metapaths,), 1 / n_metapaths, dtype=th.float32)) + self.linear = nn.Linear(in_dim, out_dim) + elif fusion_type == "conv": + self.conv = nn.Parameter(th.full((n_metapaths, in_dim), 1 / n_metapaths, dtype=th.float32)) + self.linear = nn.Linear(in_dim, out_dim) + elif fusion_type == "cat": + self.linear = nn.Linear(n_metapaths * in_dim, out_dim) + else: + raise NotImplementedError + + def forward(self, h_list): + if self.fusion_type == "mean": + return self.linear(th.mean(th.stack(h_list), dim=0)) + elif self.fusion_type == "weight": + return self.linear(th.sum(th.stack(h_list) * self.weight[:, None, None], dim=0)) + elif self.fusion_type == "conv": + return self.linear(th.sum(th.stack(h_list).transpose(0, 1) * self.conv, dim=1)) + elif self.fusion_type == "cat": + return self.linear(th.hstack(h_list)) + else: + raise NotImplementedError + + +class MECCHLayer(nn.Module): + def __init__( + self, + metapaths_dict, + in_dim, + out_dim, + n_heads=8, + dropout=0.5, + context_encoder="mean", + use_v=False, + metapath_fusion="conv", + residual=False, + layer_norm=False, + activation=None + ): + super(MECCHLayer, self).__init__() + + self.metapaths_dict = metapaths_dict + self.out_dim = out_dim + self.dropout = nn.Dropout(dropout) + + if residual: + self.alpha = nn.ParameterDict() + self.residual = nn.ModuleDict() + for ntype in metapaths_dict: + self.alpha[ntype] = nn.Parameter(th.tensor(0.)) + if in_dim == out_dim: + self.residual[ntype] = nn.Identity() + else: + self.residual[ntype] = nn.Linear(in_dim, out_dim, bias=False) + else: + self.residual = None + if layer_norm: + self.layer_norm = nn.ModuleDict() + for ntype in metapaths_dict: + self.layer_norm[ntype] = nn.LayerNorm(out_dim) + else: + self.layer_norm = None + self.activation = activation + + self.context_encoders = nn.ModuleDict() + for ntype in metapaths_dict: + for metapath_str in metapaths_dict[ntype]: + self.context_encoders[metapath_str] = MetapathContextEncoder(in_dim, context_encoder, use_v, n_heads) + + # Metapath fusion + self.metapath_fuse = nn.ModuleDict() + for ntype in metapaths_dict: + self.metapath_fuse[ntype] = MetapathFusion(len(metapaths_dict[ntype]), in_dim, out_dim, metapath_fusion) + + def forward(self, block, h_dict): + with block.local_scope(): + for ntype in block.srctypes: + if block.num_src_nodes(ntype) > 0: + block.srcnodes[ntype].data["h_src"] = h_dict[ntype] + block.dstnodes[ntype].data["h_dst"] = h_dict[ntype][:block.num_dst_nodes(ntype)] + + out_h_dict = {} + for ntype in block.dsttypes: + if block.num_dst_nodes(ntype) > 0: + metapath_outs = [] + for metapath_str in self.metapaths_dict[ntype]: + metapath_outs.append(self.context_encoders[metapath_str](block, h_dict, metapath_str)) + out_h_dict[ntype] = self.metapath_fuse[ntype](metapath_outs) + + for ntype in out_h_dict: + if self.residual is not None: + alpha = th.sigmoid(self.alpha[ntype]) + out_h_dict[ntype] = out_h_dict[ntype] * alpha + self.residual[ntype]( + h_dict[ntype][: block.num_dst_nodes(ntype)]) * (1 - alpha) + if self.layer_norm is not None: + out_h_dict[ntype] = self.layer_norm[ntype](out_h_dict[ntype]) + if self.activation is not None: + out_h_dict[ntype] = self.activation(out_h_dict[ntype]) + out_h_dict[ntype] = self.dropout(out_h_dict[ntype]) + + return out_h_dict + + +class MECCH(nn.Module): + def __init__( + self, + g, + metapaths_dict, + in_dim_dict, + hidden_dim, + out_dim, + n_layers, + n_heads_list, + dropout=0.5, + context_encoder="mean", + use_v=False, + metapath_fusion="conv", + residual=False, + layer_norm=True + ): + super(MECCH, self).__init__() + + self.in_dim_dict = in_dim_dict + self.n_layers = n_layers + + n_nodes_dict = {ntype: g.num_nodes(ntype) for ntype in g.ntypes if in_dim_dict[ntype] < 0} + self.embed_layer = HeteroEmbedLayer(n_nodes_dict, hidden_dim) + self.linear_layer = HeteroLinearLayer(in_dim_dict, hidden_dim) + + self.MECCH_layers = nn.ModuleList() + for i in range(n_layers - 1): + self.MECCH_layers.append( + MECCHLayer( + metapaths_dict, + hidden_dim, + hidden_dim, + n_heads_list[i], + dropout=dropout, + context_encoder=context_encoder, + use_v=use_v, + metapath_fusion=metapath_fusion, + residual=residual, + layer_norm=layer_norm, + activation=F.relu, + ) + ) + self.MECCH_layers.append( + MECCHLayer( + metapaths_dict, + hidden_dim, + out_dim, + n_heads_list[-1], + dropout=0.0, + context_encoder=context_encoder, + use_v=use_v, + metapath_fusion=metapath_fusion, + residual=residual, + layer_norm=False, + activation=None, + ) + ) + + def forward(self, blocks, x_dict): + nids_dict = {ntype: nids for ntype, nids in blocks[0].srcdata[dgl.NID].items() if self.in_dim_dict[ntype] < 0} + + # ntype-specific embedding/projection + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + + for block, layer in zip(blocks, self.MECCH_layers): + h_dict = layer(block, h_dict) + + return h_dict + + +class khopMECCHLayer(nn.Module): + def __init__( + self, + ntypes, + in_dim, + out_dim, + dropout=0.5, + residual=False, + layer_norm=False, + activation=None + ): + super(khopMECCHLayer, self).__init__() + + self.out_dim = out_dim + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(in_dim, out_dim) + + if residual: + self.alpha = nn.ParameterDict() + self.residual = nn.ModuleDict() + for ntype in ntypes: + self.alpha[ntype] = nn.Parameter(th.tensor(0.)) + if in_dim == out_dim: + self.residual[ntype] = nn.Identity() + else: + self.residual[ntype] = nn.Linear(in_dim, out_dim, bias=False) + else: + self.residual = None + if layer_norm: + self.layer_norm = nn.ModuleDict() + for ntype in ntypes: + self.layer_norm[ntype] = nn.LayerNorm(out_dim) + else: + self.layer_norm = None + self.activation = activation + + def forward(self, block, h_dict): + with block.local_scope(): + for ntype in block.srctypes: + if block.num_src_nodes(ntype) > 0: + block.srcnodes[ntype].data["h_src"] = h_dict[ntype] + block.dstnodes[ntype].data["h_dst"] = h_dict[ntype][:block.num_dst_nodes(ntype)] + funcs = {} + num_neigh = {ntype: 0 for ntype in block.dsttypes} + for etype in block.canonical_etypes: + if block.num_edges(etype=etype): + _, _, ntype = etype + funcs[etype] = (fn.copy_u("h_src", "m"), fn.sum("m", "h_neigh")) + num_neigh[ntype] = num_neigh[ntype] + block.in_degrees(etype=etype) + block.multi_update_all(funcs, "sum") + + out_h_dict = {} + for ntype in block.dsttypes: + if block.num_dst_nodes(ntype) > 0: + out_h_dict[ntype] = (block.dstnodes[ntype].data["h_neigh"] + block.dstnodes[ntype].data[ + "h_dst"]) / th.unsqueeze(num_neigh[ntype] + 1, dim=-1) + out_h_dict[ntype] = self.linear(out_h_dict[ntype]) + if self.residual is not None: + alpha = th.sigmoid(self.alpha[ntype]) + out_h_dict[ntype] = out_h_dict[ntype] * alpha + self.residual[ntype]( + h_dict[ntype][: block.num_dst_nodes(ntype)]) * (1 - alpha) + if self.layer_norm is not None: + out_h_dict[ntype] = self.layer_norm[ntype](out_h_dict[ntype]) + if self.activation is not None: + out_h_dict[ntype] = self.activation(out_h_dict[ntype]) + out_h_dict[ntype] = self.dropout(out_h_dict[ntype]) + + return out_h_dict + + +class khopMECCH(nn.Module): + def __init__( + self, + g, + in_dim_dict, + hidden_dim, + out_dim, + n_layers, + dropout=0.5, + residual=False, + layer_norm=True + ): + super(khopMECCH, self).__init__() + + self.in_dim_dict = in_dim_dict + self.n_layers = n_layers + + n_nodes_dict = {ntype: g.num_nodes(ntype) for ntype in g.ntypes if in_dim_dict[ntype] < 0} + self.embed_layer = HeteroEmbedLayer(n_nodes_dict, hidden_dim) + self.linear_layer = HeteroLinearLayer(in_dim_dict, hidden_dim) + + self.khopMECCH_layers = nn.ModuleList() + for i in range(n_layers - 1): + self.khopMECCH_layers.append( + khopMECCHLayer( + g.ntypes, + hidden_dim, + hidden_dim, + dropout=dropout, + residual=residual, + layer_norm=layer_norm, + activation=F.relu, + ) + ) + self.khopMECCH_layers.append( + khopMECCHLayer( + g.ntypes, + hidden_dim, + out_dim, + dropout=0.0, + residual=residual, + layer_norm=False, + activation=None, + ) + ) + + def forward(self, blocks, x_dict): + nids_dict = {ntype: nids for ntype, nids in blocks[0].srcdata[dgl.NID].items() if self.in_dim_dict[ntype] < 0} + + # ntype-specific embedding/projection + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + + for block, layer in zip(blocks, self.khopMECCH_layers): + h_dict = layer(block, h_dict) + + return h_dict diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/baselines/HAN.py b/model/baselines/HAN.py new file mode 100644 index 0000000..d80f2a6 --- /dev/null +++ b/model/baselines/HAN.py @@ -0,0 +1,154 @@ +"""This model shows an example of using dgl.metapath_reachable_graph on the original heterogeneous +graph. + +Because the original HAN implementation only gives the preprocessed homogeneous graph, this model +could not reproduce the result in HAN as they did not provide the preprocessing code, and we +constructed another dataset from ACM with a different set of papers, connections, features and +labels. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl +from dgl.nn.pytorch import GATConv + + +class SemanticAttention(nn.Module): + def __init__(self, in_size, hidden_size=128): + super(SemanticAttention, self).__init__() + + self.project = nn.Sequential( + nn.Linear(in_size, hidden_size), + nn.Tanh(), + nn.Linear(hidden_size, 1, bias=False) + ) + + def forward(self, z): + w = self.project(z).mean(0) # (M, 1) + beta = torch.softmax(w, dim=0) # (M, 1) + beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) + + return (beta * z).sum(1) # (N, D * K) + + +class HANLayer(nn.Module): + """ + HAN layer. + + Arguments + --------- + meta_paths : list of metapaths, each as a list of edge types + in_size : input feature dimension + out_size : output feature dimension + layer_num_heads : number of attention heads + dropout : Dropout probability + + Inputs + ------ + g : DGLHeteroGraph + The heterogeneous graph + h : tensor + Input features + + Outputs + ------- + tensor + The output feature + """ + + def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout): + super(HANLayer, self).__init__() + + # One GAT layer for each meta path based adjacency matrix + self.gat_layers = nn.ModuleList() + for i in range(len(meta_paths)): + self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, + dropout, dropout, activation=F.elu, + allow_zero_in_degree=True)) + self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) + self.meta_paths = list( + tuple(tuple(canonical_etype) for canonical_etype in meta_path) for meta_path in meta_paths) + + self._cached_graph = None + self._cached_coalesced_graph = {} + + def forward(self, g, h): + semantic_embeddings = [] + + if self._cached_graph is None or self._cached_graph is not g: + self._cached_graph = g + self._cached_coalesced_graph.clear() + for meta_path in self.meta_paths: + self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph( + g, meta_path) + + for i, meta_path in enumerate(self.meta_paths): + new_g = self._cached_coalesced_graph[meta_path] + semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) + semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) + + return self.semantic_attention(semantic_embeddings) # (N, D * K) + + +class HAN(nn.Module): + def __init__(self, meta_paths, target_ntype, in_size, hidden_size, out_size, num_heads, dropout): + super(HAN, self).__init__() + + self.target_ntype = target_ntype + + self.layers = nn.ModuleList() + self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)) + for l in range(1, len(num_heads)): + self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l - 1], + hidden_size, num_heads[l], dropout)) + self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) + + def forward(self, g, x_dict): + h = x_dict[self.target_ntype] + + for gnn in self.layers: + h = gnn(g, h) + + return {self.target_ntype: self.predict(h)} + + +class HAN_lp(nn.Module): + def __init__(self, g, metapaths_u, target_ntype_u, in_size_u, metapaths_v, target_ntype_v, in_size_v, hidden_size, + out_size, num_heads, dropout): + super(HAN_lp, self).__init__() + self.target_ntype_u = target_ntype_u + self.target_ntype_v = target_ntype_v + self.r = nn.Parameter(torch.Tensor(out_size)) + nn.init.ones_(self.r) + # initial node embeddings + if in_size_u < 0: + in_size_u = hidden_size * num_heads[0] + self.feats_u = nn.Parameter(torch.Tensor(g.num_nodes(target_ntype_u), in_size_u)) + nn.init.xavier_normal(self.feats_u) + if in_size_v < 0: + in_size_v = hidden_size * num_heads[0] + self.feats_v = nn.Parameter(torch.Tensor(g.num_nodes(target_ntype_v), in_size_v)) + nn.init.xavier_normal(self.feats_v) + + self.model_u = HAN(metapaths_u, target_ntype_u, in_size_u, hidden_size, out_size, num_heads, dropout) + self.model_v = HAN(metapaths_v, target_ntype_v, in_size_v, hidden_size, out_size, num_heads, dropout) + + def forward(self, pos_edges, neg_edges, g, x_dict): + # set initial node embeddings + if hasattr(self, 'feats_u'): + x_dict_u = {self.target_ntype_u: self.feats_u} + else: + x_dict_u = {self.target_ntype_u: x_dict[self.target_ntype_u]} + if hasattr(self, 'feats_v'): + x_dict_v = {self.target_ntype_v: self.feats_v} + else: + x_dict_v = {self.target_ntype_v: x_dict[self.target_ntype_v]} + + h_u = self.model_u(g, x_dict_u)[self.target_ntype_u] + h_v = self.model_v(g, x_dict_v)[self.target_ntype_v] + + pos_score = torch.sum(h_u[pos_edges[:, 0]] * self.r * h_v[pos_edges[:, 1]], dim=-1) + neg_score = torch.sum(h_u[neg_edges[:, 0]] * self.r * h_v[neg_edges[:, 1]], dim=-1) + + return pos_score, neg_score diff --git a/model/baselines/HGT.py b/model/baselines/HGT.py new file mode 100644 index 0000000..628aaaf --- /dev/null +++ b/model/baselines/HGT.py @@ -0,0 +1,213 @@ +import math + +import dgl +import dgl.function as fn +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn.functional import edge_softmax + +from model.modules import HeteroEmbedLayer, HeteroLinearLayer + + +def softmax_among_all(hetgraph, weight_name, norm_by="dst"): + """Normalize edge weights by softmax across all neighbors. + + Parameters + ------------- + hetgraph : DGLGraph + The input heterogeneous graph. + weight_name : str + The name of the unnormalized edge weights. + """ + # Convert to homogeneous graph; DGL will copy the specified data to the new graph. + g = dgl.to_homogeneous(hetgraph, edata=[weight_name]) + # Call DGL's edge softmax + g.edata[weight_name] = edge_softmax(g, g.edata[weight_name], norm_by=norm_by) + # Convert it back; DGL again copies the data back to a heterogeneous storage. + hetg2 = dgl.to_heterogeneous(g, hetgraph.ntypes, hetgraph.etypes) + # Assign the normalized weights to the original graph + for etype in hetg2.canonical_etypes: + hetgraph.edges[etype].data[weight_name] = hetg2.edges[etype].data[weight_name] + + +class HGTLayer(nn.Module): + def __init__( + self, + in_dim, + out_dim, + node_dict, + edge_dict, + n_heads, + dropout=0.2, + use_norm=False, + ): + super(HGTLayer, self).__init__() + + self.in_dim = in_dim + self.out_dim = out_dim + self.node_dict = node_dict + self.edge_dict = edge_dict + self.num_types = len(node_dict) + self.num_relations = len(edge_dict) + self.total_rel = self.num_types * self.num_relations * self.num_types + self.n_heads = n_heads + self.d_k = out_dim // n_heads + self.sqrt_dk = math.sqrt(self.d_k) + self.att = None + + self.k_linears = nn.ModuleList() + self.q_linears = nn.ModuleList() + self.v_linears = nn.ModuleList() + self.a_linears = nn.ModuleList() + self.norms = nn.ModuleList() + self.use_norm = use_norm + + for t in range(self.num_types): + self.k_linears.append(nn.Linear(in_dim, out_dim)) + self.q_linears.append(nn.Linear(in_dim, out_dim)) + self.v_linears.append(nn.Linear(in_dim, out_dim)) + self.a_linears.append(nn.Linear(out_dim, out_dim)) + if use_norm: + self.norms.append(nn.LayerNorm(out_dim)) + + self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads)) + self.relation_att = nn.Parameter( + torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k) + ) + self.relation_msg = nn.Parameter( + torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k) + ) + self.skip = nn.Parameter(torch.ones(self.num_types)) + self.drop = nn.Dropout(dropout) + + nn.init.xavier_uniform_(self.relation_att) + nn.init.xavier_uniform_(self.relation_msg) + + def forward(self, G, h): + with G.local_scope(): + node_dict, edge_dict = self.node_dict, self.edge_dict + for srctype, etype, dsttype in G.canonical_etypes: + sub_graph = G[srctype, etype, dsttype] + + k_linear = self.k_linears[node_dict[srctype]] + v_linear = self.v_linears[node_dict[srctype]] + q_linear = self.q_linears[node_dict[dsttype]] + + k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k) + v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k) + q = q_linear(h[dsttype][:sub_graph.num_dst_nodes(dsttype)]).view(-1, self.n_heads, self.d_k) + + e_id = self.edge_dict[etype] + + relation_att = self.relation_att[e_id] + relation_pri = self.relation_pri[e_id] + relation_msg = self.relation_msg[e_id] + + k = torch.einsum("bij,ijk->bik", k, relation_att) + v = torch.einsum("bij,ijk->bik", v, relation_msg) + + sub_graph.srcdata["k"] = k + sub_graph.dstdata["q"] = q + sub_graph.srcdata["v_%d" % e_id] = v + + sub_graph.apply_edges(fn.v_dot_u("q", "k", "t")) + sub_graph.edata["t"] = sub_graph.edata["t"] * relation_pri / self.sqrt_dk + + softmax_among_all(G, "t", norm_by="dst") + + G.multi_update_all( + { + etype: (fn.u_mul_e("v_%d" % e_id, "t", "m"), fn.sum("m", "t")) + for etype, e_id in edge_dict.items() + }, + cross_reducer="sum", + ) + + new_h = {} + for ntype in G.dsttypes: + """ + Step 3: Target-specific Aggregation + x = norm( W[node_type] * gelu( Agg(x) ) + x ) + """ + if G.num_dst_nodes(ntype) > 0: + n_id = node_dict[ntype] + alpha = torch.sigmoid(self.skip[n_id]) + t = G.dstnodes[ntype].data["t"].view(-1, self.out_dim) + trans_out = self.drop(self.a_linears[n_id](t)) + trans_out = trans_out * alpha + h[ntype][:G.num_dst_nodes(ntype)] * (1 - alpha) + if self.use_norm: + new_h[ntype] = self.norms[n_id](trans_out) + else: + new_h[ntype] = trans_out + return new_h + + +class HGT(nn.Module): + def __init__(self, G, in_dim_dict, n_hid, n_out, n_layers, n_heads, use_norm=True): + super(HGT, self).__init__() + self.node_dict = {} + self.edge_dict = {} + for i, ntype in enumerate(G.ntypes): + self.node_dict[ntype] = i + for i, etype in enumerate(G.etypes): + self.edge_dict[etype] = i + self.gcs = nn.ModuleList() + self.in_dim_dict = in_dim_dict + self.n_hid = n_hid + self.n_out = n_out + self.n_layers = n_layers + + # input projection + n_nodes_dict = { + ntype: G.num_nodes(ntype) for ntype in G.ntypes if in_dim_dict[ntype] < 0 + } + self.embed_layer = HeteroEmbedLayer(n_nodes_dict, n_hid) + self.linear_layer = HeteroLinearLayer(in_dim_dict, n_hid) + + for _ in range(n_layers): + self.gcs.append( + HGTLayer( + n_hid, + n_hid, + self.node_dict, + self.edge_dict, + n_heads, + use_norm=use_norm, + ) + ) + self.out = nn.Linear(n_hid, n_out) + + def forward(self, G, x_dict): + h_dict = {} + if isinstance(G, list): + # minibatch + nids_dict = { + ntype: nids + for ntype, nids in G[0].srcdata[dgl.NID].items() + if self.in_dim_dict[ntype] < 0 + } + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + for ntype in h_dict: + h_dict[ntype] = F.gelu(h_dict[ntype]) + + for layer, block in zip(self.gcs, G): + h_dict = layer(block, h_dict) + else: + # full batch + nids_dict = { + ntype: G.nodes(ntype) + for ntype in G.ntypes + if self.in_dim_dict[ntype] < 0 + } + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + for ntype in h_dict: + h_dict[ntype] = F.gelu(h_dict[ntype]) + + for i in range(self.n_layers): + h_dict = self.gcs[i](G, h_dict) + return {ntype: self.out(h) for ntype, h in h_dict.items()} diff --git a/model/baselines/RGCN.py b/model/baselines/RGCN.py new file mode 100644 index 0000000..90ccc3d --- /dev/null +++ b/model/baselines/RGCN.py @@ -0,0 +1,239 @@ +"""RGCN Example Implementation from DGL""" +import dgl +import dgl.nn as dglnn +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from model.modules import HeteroEmbedLayer, HeteroLinearLayer + + +class RelGraphConvLayer(nn.Module): + r"""Relational graph convolution layer. + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + rel_names : list[str] + Relation names. + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: None. + weight : bool, optional + True if a linear layer is applied after message passing. Default: True + bias : bool, optional + True if bias is added. Default: True + activation : callable, optional + Activation function. Default: None + self_loop : bool, optional + True to include self loop message. Default: False + dropout : float, optional + Dropout rate. Default: 0.0 + """ + + def __init__( + self, + in_feat, + out_feat, + rel_names, + num_bases, + *, + weight=True, + bias=True, + activation=None, + self_loop=False, + dropout=0.0 + ): + super(RelGraphConvLayer, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.rel_names = rel_names + self.num_bases = num_bases + self.bias = bias + self.activation = activation + self.self_loop = self_loop + + self.conv = dglnn.HeteroGraphConv( + { + rel: dglnn.GraphConv( + in_feat, out_feat, norm="right", weight=False, bias=False + ) + for rel in rel_names + } + ) + + self.use_weight = weight + self.use_basis = num_bases < len(self.rel_names) and weight + if self.use_weight: + if self.use_basis: + self.basis = dglnn.WeightBasis( + (in_feat, out_feat), num_bases, len(self.rel_names) + ) + else: + self.weight = nn.Parameter( + th.Tensor(len(self.rel_names), in_feat, out_feat) + ) + nn.init.xavier_uniform_( + self.weight, gain=nn.init.calculate_gain("relu") + ) + + # bias + if bias: + self.h_bias = nn.Parameter(th.Tensor(out_feat)) + nn.init.zeros_(self.h_bias) + + # weight for self loop + if self.self_loop: + self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) + nn.init.xavier_uniform_( + self.loop_weight, gain=nn.init.calculate_gain("relu") + ) + + self.dropout = nn.Dropout(dropout) + + def forward(self, g, inputs): + """Forward computation + Parameters + ---------- + g : DGLHeteroGraph + Input graph. + inputs : dict[str, torch.Tensor] + Node feature for each node type. + Returns + ------- + dict[str, torch.Tensor] + New node features for each node type. + """ + g = g.local_var() + if self.use_weight: + weight = self.basis() if self.use_basis else self.weight + wdict = { + self.rel_names[i]: {"weight": w.squeeze(0)} + for i, w in enumerate(th.split(weight, 1, dim=0)) + } + else: + wdict = {} + + if g.is_block: + inputs_src = inputs + inputs_dst = {k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()} + else: + inputs_src = inputs_dst = inputs + + hs = self.conv(g, inputs, mod_kwargs=wdict) + + def _apply(ntype, h): + if self.self_loop: + h = h + th.matmul(inputs_dst[ntype], self.loop_weight) + if self.bias: + h = h + self.h_bias + if self.activation: + h = self.activation(h) + return self.dropout(h) + + return {ntype: _apply(ntype, h) for ntype, h in hs.items()} + + +class RGCN(nn.Module): + def __init__( + self, + g, + in_dim_dict, + h_dim, + out_dim, + num_bases, + num_hidden_layers=1, + dropout=0, + use_self_loop=False, + ): + super(RGCN, self).__init__() + + self.g = g + self.in_dim_dict = in_dim_dict + self.h_dim = h_dim + self.out_dim = out_dim + self.rel_names = list(set(g.etypes)) + self.rel_names.sort() + if num_bases < 0 or num_bases > len(self.rel_names): + self.num_bases = len(self.rel_names) + else: + self.num_bases = num_bases + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.use_self_loop = use_self_loop + + # input projection + n_nodes_dict = { + ntype: g.num_nodes(ntype) for ntype in g.ntypes if in_dim_dict[ntype] < 0 + } + self.embed_layer = HeteroEmbedLayer(n_nodes_dict, h_dim) + self.linear_layer = HeteroLinearLayer(in_dim_dict, h_dim) + self.layers = nn.ModuleList() + # i2h + self.layers.append( + RelGraphConvLayer( + self.h_dim, + self.h_dim, + self.rel_names, + self.num_bases, + activation=F.relu, + self_loop=self.use_self_loop, + dropout=self.dropout, + weight=False, + ) + ) + # h2h + for i in range(self.num_hidden_layers): + self.layers.append( + RelGraphConvLayer( + self.h_dim, + self.h_dim, + self.rel_names, + self.num_bases, + activation=F.relu, + self_loop=self.use_self_loop, + dropout=self.dropout, + ) + ) + # h2o + self.layers.append( + RelGraphConvLayer( + self.h_dim, + self.out_dim, + self.rel_names, + self.num_bases, + activation=None, + self_loop=self.use_self_loop, + ) + ) + + def forward(self, g=None, x_dict=None): + if isinstance(g, list): + # minibatch forward + nids_dict = { + ntype: nids + for ntype, nids in g[0].srcdata[dgl.NID].items() + if self.in_dim_dict[ntype] < 0 + } + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + + for layer, block in zip(self.layers, g): + h_dict = layer(block, h_dict) + else: + # full graph forward + nids_dict = { + ntype: g.nodes(ntype) + for ntype in g.ntypes + if self.in_dim_dict[ntype] < 0 + } + h_embed_dict = self.embed_layer(nids_dict) + h_linear_dict = self.linear_layer(x_dict) + h_dict = h_embed_dict | h_linear_dict + + for layer in self.layers: + h_dict = layer(g, h_dict) + + return h_dict diff --git a/model/baselines/__init__.py b/model/baselines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/modules.py b/model/modules.py new file mode 100644 index 0000000..d6f8c7a --- /dev/null +++ b/model/modules.py @@ -0,0 +1,83 @@ +import torch as th +import torch.nn as nn +import dgl.function as fn + + +class HeteroEmbedLayer(nn.Module): + def __init__(self, n_nodes_dict, embed_size): + super(HeteroEmbedLayer, self).__init__() + + self.embed_size = embed_size + + # embed nodes for each node type + self.embeds = nn.ModuleDict() + for ntype, num in n_nodes_dict.items(): + self.embeds[ntype] = nn.Embedding(num, embed_size) + + def forward(self, nids_dict): + return {ntype: self.embeds[ntype](nids) for ntype, nids in nids_dict.items()} + + +class HeteroLinearLayer(nn.Module): + def __init__(self, in_dim_dict, out_dim, bias=True): + super(HeteroLinearLayer, self).__init__() + + self.out_dim = out_dim + + # linear projection for each node type + self.linears = nn.ModuleDict() + for ntype, in_dim in in_dim_dict.items(): + if in_dim > 0: + self.linears[ntype] = nn.Linear(in_dim, out_dim, bias) + + def forward(self, h_dict): + return {ntype: self.linears[ntype](h) for ntype, h in h_dict.items()} + + +class ScorePredictor(nn.Module): + def __init__(self, dim, target_etype): + super(ScorePredictor, self).__init__() + self.target_etype = target_etype + self.r = nn.Parameter(th.Tensor(dim)) + nn.init.ones_(self.r) + + def forward(self, edge_subgraph, h_dict): + with edge_subgraph.local_scope(): + edge_subgraph.ndata['h'] = h_dict + edge_subgraph.apply_edges(fn.u_mul_v('h', 'h', 'score'), etype=self.target_etype) + edge_subgraph.edges[self.target_etype].data['score'] = th.sum( + edge_subgraph.edges[self.target_etype].data['score'] * self.r, dim=-1) + return edge_subgraph.edges[self.target_etype].data['score'] + + +class LinkPrediction_minibatch(nn.Module): + def __init__(self, emb_model, emb_dim, target_etype): + super(LinkPrediction_minibatch, self).__init__() + self.emb_model = emb_model + self.pred = ScorePredictor(emb_dim, target_etype) + + def forward(self, positive_graph, negative_graph, blocks, x_dict): + h_dict = self.emb_model(blocks, x_dict) + pos_score = self.pred(positive_graph, h_dict) + neg_score = self.pred(negative_graph, h_dict) + return pos_score, neg_score + + +class LinkPrediction_fullbatch(nn.Module): + def __init__(self, emb_model, emb_dim, target_ntype_u, target_ntype_v): + super(LinkPrediction_fullbatch, self).__init__() + self.emb_model = emb_model + self.r = nn.Parameter(th.Tensor(emb_dim)) + nn.init.ones_(self.r) + self.target_ntype_u = target_ntype_u + self.target_ntype_v = target_ntype_v + + def forward(self, positive_edges, negative_edges, g, x_dict): + h_dict = self.emb_model(g, x_dict) + h_u = h_dict[self.target_ntype_u] + h_v = h_dict[self.target_ntype_v] + + pos_score = th.sum(h_u[positive_edges[:, 0]] * self.r * h_v[positive_edges[:, 1]], dim=-1) + neg_score = th.sum(h_u[negative_edges[:, 0]] * self.r * h_v[negative_edges[:, 1]], dim=-1) + + return pos_score, neg_score diff --git a/model/utils.py b/model/utils.py new file mode 100644 index 0000000..0e2a37b --- /dev/null +++ b/model/utils.py @@ -0,0 +1,12 @@ +join_token = "=>" + + +def sub_metapaths(metapath_str): + tokens = metapath_str[3:].split(join_token) + return ["mp:" + join_token.join(tokens[i:]) for i in range(0, len(tokens) - 2, 2)] + + +def get_src_ntypes(metapath_str): + tokens = metapath_str[3:].split(join_token) + src_ntypes = [tokens[i] for i in range(0, len(tokens) - 2, 2)] + return list(set(src_ntypes)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..80f57b2 --- /dev/null +++ b/utils.py @@ -0,0 +1,264 @@ +import pickle +from collections import defaultdict +from pathlib import Path +import shutil + +import dgl +import numpy as np +import torch as th + + +def get_all_metapaths(g, min_length=1, max_length=4): + etype_dict = {} + for src, e, dst in g.canonical_etypes: + if src in etype_dict: + etype_dict[src].append((e, dst)) + else: + etype_dict[src] = [(e, dst)] + + metapath_dict = {src: {i + 1: [] for i in range(max_length)} for src in etype_dict} + for src in etype_dict: + metapath_dict[src][1].extend([(src, e[0], e[1]) for e in etype_dict[src]]) + for i in range(1, max_length): + metapath_dict[src][i + 1].extend( + [mp + (e[0], e[1]) for mp in metapath_dict[src][i] for e in etype_dict[mp[-1]]]) + + return {src: {length: metapath_dict[src][length] for length in metapath_dict[src] if length >= min_length} for src + in metapath_dict} + + +def metapath_dict2list(metapath_dict): + return [ + [mp[i - 1: i + 2] for i in range(1, len(mp), 2)] + for src in metapath_dict + for length in metapath_dict[src] + for mp in metapath_dict[src][length] + ] + + +join_token = "=>" + + +def metapath2str(metapath): + metapath_str = "mp:" + join_token.join(metapath[0]) + for src_ntype, etype, dst_ntype in metapath[1:]: + metapath_str += join_token + etype + join_token + dst_ntype + return metapath_str + + +# Assume metapath_g already contains all the nodes +def add_metapath_connection(g, metapath, metapath_g, add_reverse=False): + if metapath_g is not None: + graph_data = {e: metapath_g.edges(etype=e) for e in metapath_g.canonical_etypes} + else: + graph_data = {} + num_nodes = {n: g.num_nodes(n) for n in g.ntypes} + + new_g = dgl.metapath_reachable_graph(g, metapath) + src_nodes = new_g.edges()[0] + dst_nodes = new_g.edges()[1] + canonical_etype = (new_g.srctypes[0], metapath2str(metapath), new_g.dsttypes[0]) + graph_data[canonical_etype] = (src_nodes, dst_nodes) + if add_reverse: + src_nodes, dst_nodes = dst_nodes, src_nodes + canonical_etype = (new_g.dsttypes[0], "mp:" + "rev:" + join_token.join(metapath), new_g.srctypes[0]) + graph_data[canonical_etype] = (src_nodes, dst_nodes) + + return dgl.heterograph(graph_data, num_nodes) + + +def select_metapaths(all_metapaths_list, length=4): + # select only max-length metapath + selected_metapaths = defaultdict(list) + for mp in all_metapaths_list: + if len(mp) == length: + selected_metapaths[mp[-1][-1]].append(metapath2str(mp)) + return dict(selected_metapaths) + + +def load_data_nc(dataset_name, prefix="./data"): + if dataset_name == "imdb-gtn": + # movie*, actor, director + glist, _ = dgl.load_graphs(str(Path(prefix, dataset_name, "graph.bin"))) + g = glist[0] + + x = g.ndata.pop('h') + y = g.ndata.pop('label') + train_mask = g.ndata.pop('train_mask') + val_mask = g.ndata.pop('valid_mask') + test_mask = g.ndata.pop('test_mask') + + g = g.long() + g.nodes['movie'].data['x'] = x['movie'].float() + g.nodes['actor'].data['x'] = x['actor'].float() + g.nodes['director'].data['x'] = x['director'].float() + g.nodes['movie'].data['y'] = y['movie'].long() + + in_dim_dict = { + "movie": x["movie"].shape[1], + "actor": x["actor"].shape[1], + "director": x["director"].shape[1], + } + out_dim = y["movie"].max().item() + 1 + + train_nid_dict = { + "movie": train_mask["movie"].nonzero().flatten().long() + } + val_nid_dict = { + "movie": val_mask["movie"].nonzero().flatten().long() + } + test_nid_dict = { + "movie": test_mask["movie"].nonzero().flatten().long() + } + elif dataset_name == 'acm-gtn': + # paper*, author, subject + glist, _ = dgl.load_graphs(str(Path(prefix, dataset_name, "graph.bin"))) + g = glist[0] + + x = g.ndata.pop('h') + y = g.ndata.pop('label') + train_mask = g.ndata.pop('train_mask') + val_mask = g.ndata.pop('valid_mask') + test_mask = g.ndata.pop('test_mask') + g.ndata.pop('pspap_m2v_emb') + g.ndata.pop('psp_m2v_emb') + g.ndata.pop('pap_m2v_emb') + + g = g.long() + g.nodes['paper'].data['x'] = x['paper'].float() + g.nodes['author'].data['x'] = x['author'].float() + g.nodes['subject'].data['x'] = x['subject'].float() + g.nodes['paper'].data['y'] = y['paper'].long() + + in_dim_dict = { + "paper": x["paper"].shape[1], + "author": x["author"].shape[1], + "subject": x["subject"].shape[1], + } + out_dim = y["paper"].max().item() + 1 + + train_nid_dict = { + "paper": train_mask["paper"].nonzero().flatten().long() + } + val_nid_dict = { + "paper": val_mask["paper"].nonzero().flatten().long() + } + test_nid_dict = { + "paper": test_mask["paper"].nonzero().flatten().long() + } + elif dataset_name == 'dblp-gtn': + # paper, author*, conference + dir_path = Path(prefix, dataset_name) + edges = pickle.load(dir_path.joinpath("edges.pkl").open("rb")) + labels = pickle.load(dir_path.joinpath("labels.pkl").open("rb")) + node_features = pickle.load( + dir_path.joinpath("node_features.pkl").open("rb")) + + num_nodes = edges[0].shape[0] + node_type = np.zeros(num_nodes, dtype=int) + node_type[:] = -1 + + assert len(edges) == 4 + assert len(edges[0].nonzero()) == 2 + node_type[edges[0].nonzero()[0]] = 0 + node_type[edges[0].nonzero()[1]] = 1 + node_type[edges[1].nonzero()[0]] = 1 + node_type[edges[1].nonzero()[1]] = 0 + node_type[edges[2].nonzero()[0]] = 0 + node_type[edges[2].nonzero()[1]] = 2 + node_type[edges[3].nonzero()[0]] = 2 + node_type[edges[3].nonzero()[1]] = 0 + assert (node_type == -1).sum() == 0 + + data_dict = { + ('paper', 'paper-author', 'author'): edges[0][node_type == 0, :][:, node_type == 1].nonzero(), + ('author', 'author-paper', 'paper'): edges[1][node_type == 1, :][:, node_type == 0].nonzero(), + ('paper', 'paper-conference', 'conference'): edges[2][node_type == 0, :][:, node_type == 2].nonzero(), + ('conference', 'conference-paper', 'paper'): edges[3][node_type == 2, :][:, node_type == 0].nonzero() + } + num_nodes_dict = { + 'paper': (node_type == 0).sum(), + 'author': (node_type == 1).sum(), + 'conference': (node_type == 2).sum() + } + g = dgl.heterograph(data_dict, num_nodes_dict, idtype=th.int64) + + train_nid_dict = { + 'author': th.from_numpy(np.array(labels[0])[:, 0]).long() + } + val_nid_dict = { + 'author': th.from_numpy(np.array(labels[1])[:, 0]).long() + } + test_nid_dict = { + 'author': th.from_numpy(np.array(labels[2])[:, 0]).long() + } + + g.nodes['paper'].data['x'] = th.from_numpy( + node_features[node_type == 0]).float() + g.nodes['author'].data['x'] = th.from_numpy( + node_features[node_type == 1]).float() + g.nodes['conference'].data['x'] = th.from_numpy( + node_features[node_type == 2]).float() + + y = np.zeros((g.num_nodes('author')), dtype=int) + y[train_nid_dict['author'].numpy()] = np.array(labels[0])[:, 1] + y[val_nid_dict['author'].numpy()] = np.array(labels[1])[:, 1] + y[test_nid_dict['author'].numpy()] = np.array(labels[2])[:, 1] + g.nodes['author'].data['y'] = th.from_numpy(y).long() + + in_dim_dict = { + 'paper': g.nodes['paper'].data['x'].shape[1], + 'author': g.nodes['author'].data['x'].shape[1], + 'conference': g.nodes['conference'].data['x'].shape[1] + } + out_dim = g.nodes['author'].data['y'].max().item() + 1 + else: + raise NotImplementedError + + return g, in_dim_dict, out_dim, train_nid_dict, val_nid_dict, test_nid_dict + + +def load_data_lp(dataset_name, prefix="./data"): + if dataset_name == 'lastfm': + load_path = Path(prefix, dataset_name) + g_list, _ = dgl.load_graphs(str(load_path / 'graph.bin')) + g_train, g_val, g_test = g_list + train_val_test_idx = np.load(str(load_path / 'train_val_test_idx.npz')) + train_eid_dict = {'user-artist': th.tensor(train_val_test_idx['train_idx'])} + val_eid_dict = {'user-artist': th.tensor(train_val_test_idx['val_idx'])} + test_eid_dict = {'user-artist': th.tensor(train_val_test_idx['test_idx'])} + val_neg_uv = th.tensor(np.load(str(load_path / 'val_neg_user_artist.npy'))) + test_neg_uv = th.tensor(np.load(str(load_path / 'test_neg_user_artist.npy'))) + in_dim_dict = {ntype: -1 for ntype in g_test.ntypes} + else: + raise NotImplementedError + + g_train = g_train.long() + g_val = g_val.long() + g_test = g_test.long() + train_eid_dict = {k: v.long() for k, v in train_eid_dict.items()} + val_eid_dict = {k: v.long() for k, v in val_eid_dict.items()} + test_eid_dict = {k: v.long() for k, v in test_eid_dict.items()} + val_neg_uv = val_neg_uv.long() + test_neg_uv = test_neg_uv.long() + + return (g_train, g_val, g_test), in_dim_dict, (train_eid_dict, val_eid_dict, test_eid_dict), ( + val_neg_uv, test_neg_uv) + + +def get_save_path(args, prefix="./saves"): + dir_path = Path(prefix, args.model, args.dataset) + dir_path.mkdir(parents=True, exist_ok=True) + old_saves = [int(str(x.name)) for x in dir_path.iterdir() if x.is_dir() and str(x.name).isdigit()] + if len(old_saves) == 0: + save_num = 1 + else: + save_num = max(old_saves) + 1 + dir_path = dir_path / str(save_num) + dir_path.mkdir() + + # copy config files to the save dir + shutil.copy("./configs/base.json", dir_path) + shutil.copy(args.config, dir_path) + + return dir_path