Skip to content

Commit

Permalink
PubMed preprocessing code
Browse files Browse the repository at this point in the history
  • Loading branch information
cynricfu committed Jul 21, 2022
1 parent 183ef07 commit 6cf4aab
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
7 changes: 7 additions & 0 deletions data/pubmed/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## LastFM Dataset

From [HNE](https://github.com/yangji9181/HNE).

We further split the original training set into training and validation sets.

We samples hard negative edges for validtaion and testing by ourselves.
291 changes: 291 additions & 0 deletions preprocess_pubmed.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 101,
"id": "caacdd92-ba53-4aa9-bddd-94f034a710bd",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import csv\n",
"import random\n",
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import dgl\n",
"import torch as th\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "bd7349d8-c2cd-4c74-bc76-e69b9817f2ba",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"load_path = Path('data/pubmed/raw')\n",
"df_node = pd.read_csv(load_path / 'node.dat', sep='\\t', names=['node_id', 'node_name', 'node_type', 'node_attributes'], quoting=csv.QUOTE_NONE) # must add this quoting argument\n",
"df_link = pd.read_csv(load_path / 'link.dat', sep='\\t', names=['src_id', 'dst_id', 'link_type', 'link_weight'])\n",
"df_link_test = pd.read_csv(load_path / 'link.dat.test', sep='\\t', names=['src_id', 'dst_id', 'link_status'])\n",
"\n",
"ntypes = ['GENE', 'DISEASE', 'CHEMICAL', 'SPECIES']\n",
"ntype_ids = {'GENE': 0, 'DISEASE': 1, 'CHEMICAL': 2, 'SPECIES': 3}\n",
"etypes = ['GENE-and-GENE',\n",
" 'GENE-causing-DISEASE',\n",
" 'DISEASE-and-DISEASE',\n",
" 'CHEMICAL-in-GENE',\n",
" 'CHEMICAL-in-DISEASE',\n",
" 'CHEMICAL-and-CHEMICAL',\n",
" 'CHEMICAL-in-SPECIES',\n",
" 'SPECIES-with-GENE',\n",
" 'SPECIES-with-DISEASE',\n",
" 'SPECIES-and-SPECIES']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "56815267-cc18-454f-8c43-003360629e39",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"type_mask = np.zeros(63109, dtype=int)\n",
"type_mask[:] = -1\n",
"type_mask[df_node['node_id']] = df_node['node_type']\n",
"num_nodes_dict = {ntype: (type_mask == ntype_ids[ntype]).sum() for ntype in ntypes}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6d3037ed-c5e2-4352-8e50-95be5c790f52",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"adjM = np.zeros((63109, 63109), dtype=int)\n",
"\n",
"# edges from link.dat\n",
"links = df_link[['src_id', 'dst_id']].to_numpy()\n",
"adjM[links[:, 0], links[:, 1]] = 1\n",
"\n",
"# positive edges from link.dat.test\n",
"links_test = df_link_test[df_link_test['link_status'] == 1].to_numpy()\n",
"adjM[links_test[:, 0], links_test[:, 1]] = 2\n",
"\n",
"# DISEASE-DISEASE matrix\n",
"DD_adjM = adjM[type_mask == ntype_ids['DISEASE']][:, type_mask == ntype_ids['DISEASE']]\n",
"DD_edges = DD_adjM.nonzero()\n",
"train_val_idx = (DD_adjM[DD_edges] == 1).nonzero()[0]\n",
"test_idx = (DD_adjM[DD_edges] == 2).nonzero()[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7c45c341-f994-4f8f-8f7c-957df04daf76",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# sample edges for validation\n",
"train_idx, val_idx = train_test_split(train_val_idx, test_size=0.125, random_state=1024)\n",
"train_idx.sort()\n",
"val_idx.sort()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72984f30-a2dd-490c-851b-77f42a7bf81e",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# g_train\n",
"data_dict_train = {}\n",
"for etype in etypes:\n",
" srctype, _, dsttype = etype.split('-')\n",
" if srctype == dsttype:\n",
" if etype == 'DISEASE-and-DISEASE':\n",
" data_dict_train[('DISEASE', 'DISEASE-and-DISEASE', 'DISEASE')] = (DD_edges[0][train_idx], DD_edges[1][train_idx])\n",
" else:\n",
" data_dict_train[(srctype, etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[srctype]].nonzero()\n",
" else:\n",
" data_dict_train[(srctype, etype, dsttype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].nonzero()\n",
" data_dict_train[(dsttype, '^' + etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].transpose().nonzero()\n",
"\n",
"# g_val\n",
"data_dict_val = {}\n",
"for etype in etypes:\n",
" srctype, _, dsttype = etype.split('-')\n",
" if srctype == dsttype:\n",
" if etype == 'DISEASE-and-DISEASE':\n",
" data_dict_val[('DISEASE', 'DISEASE-and-DISEASE', 'DISEASE')] = (DD_edges[0][train_val_idx], DD_edges[1][train_val_idx])\n",
" else:\n",
" data_dict_val[(srctype, etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[srctype]].nonzero()\n",
" else:\n",
" data_dict_val[(srctype, etype, dsttype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].nonzero()\n",
" data_dict_val[(dsttype, '^' + etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].transpose().nonzero()\n",
"\n",
"# g_test\n",
"data_dict_test = {}\n",
"for etype in etypes:\n",
" srctype, _, dsttype = etype.split('-')\n",
" if srctype == dsttype:\n",
" if etype == 'DISEASE-and-DISEASE':\n",
" data_dict_test[('DISEASE', 'DISEASE-and-DISEASE', 'DISEASE')] = DD_edges\n",
" else:\n",
" data_dict_test[(srctype, etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[srctype]].nonzero()\n",
" else:\n",
" data_dict_test[(srctype, etype, dsttype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].nonzero()\n",
" data_dict_test[(dsttype, '^' + etype, srctype)] = adjM[type_mask == ntype_ids[srctype]][:, type_mask == ntype_ids[dsttype]].transpose().nonzero()\n",
"\n",
"g_train = dgl.heterograph(data_dict_train, num_nodes_dict, idtype=th.int64)\n",
"g_val = dgl.heterograph(data_dict_val, num_nodes_dict, idtype=th.int64)\n",
"g_test = dgl.heterograph(data_dict_test, num_nodes_dict, idtype=th.int64)\n",
"\n",
"x_dict = {ntype: np.genfromtxt(df_node[df_node['node_type'] == ntype_ids[ntype]]['node_attributes'].tolist(), delimiter=',') for ntype in ntypes}\n",
"for ntype in ntypes:\n",
" temp_tensor = th.from_numpy(x_dict[ntype]).float()\n",
" g_train.nodes[ntype].data['x'] = temp_tensor\n",
" g_val.nodes[ntype].data['x'] = temp_tensor\n",
" g_test.nodes[ntype].data['x'] = temp_tensor\n",
"\n",
"in_dim_dict = {ntype: x_dict[ntype].shape[1] for ntype in ntypes}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f559f20d-c11c-4348-826d-0b959aef737d",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# sample hard negatives for validation\n",
"# for a positive pair (d1, d2), sample (d1, d') as a negative pair\n",
"# where d' associates with d1 via D-D-D metapath (i.e., a 2-hop neighbor) and (d1, d') does not exist in the original graph\n",
"# if such an edge not exist, opt to random edges not in the original graph\n",
"unique, counts = np.unique(DD_edges[0][val_idx], return_counts=True)\n",
"g_val_DDD = dgl.metapath_reachable_graph(g_val, ['DISEASE-and-DISEASE', 'DISEASE-and-DISEASE'])\n",
"val_neg_edges = []\n",
"for d, count in zip(unique, counts):\n",
" neg_diseases = list(set(g_val_DDD.out_edges(d)[1].tolist()) - set(g_val.out_edges(d, etype='DISEASE-and-DISEASE')[1].tolist()))\n",
" if count <= len(neg_diseases):\n",
" neg_ds = random.sample(neg_diseases, k=count)\n",
" else:\n",
" neg_ds = neg_diseases\n",
" to_sample = np.ones(num_nodes_dict['DISEASE'], dtype=int)\n",
" to_sample[g_val.out_edges(d, etype='DISEASE-and-DISEASE')[1].tolist()] = 0\n",
" neg_ds.extend(random.sample(to_sample.nonzero()[0].tolist(), k=count - len(neg_diseases)))\n",
" val_neg_edges.extend([[d, neg_d] for neg_d in neg_ds])\n",
"\n",
"val_neg_edges.sort()\n",
"val_neg_edges = np.array(val_neg_edges)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77493e3a-d4ff-4894-b319-5b3b4443b59e",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# fix testing negative pairs\n",
"# randomly sample edges not in the original graph\n",
"DISEASE_id_map = {global_id: local_id for local_id, global_id in enumerate(df_node[df_node['node_type'] == ntype_ids['DISEASE']]['node_id'])}\n",
"df_link_test_mapped = df_link_test[['src_id', 'dst_id']].replace(DISEASE_id_map)\n",
"df_link_test_mapped_pos = df_link_test_mapped[df_link_test['link_status'] == 1]\n",
"df_link_test_mapped_neg = df_link_test_mapped[df_link_test['link_status'] == 0]\n",
"\n",
"test_pos_nums = df_link_test_mapped_pos.groupby('src_id').size().reindex(list(range(num_nodes_dict['DISEASE'])), fill_value=0).to_numpy()\n",
"test_neg_nums = df_link_test_mapped_neg.groupby('src_id').size().reindex(list(range(num_nodes_dict['DISEASE'])), fill_value=0).to_numpy()\n",
"neg_to_add = test_pos_nums - test_neg_nums\n",
"\n",
"test_neg_edges = df_link_test_mapped_neg.values.tolist()\n",
"for d in neg_to_add.nonzero()[0]:\n",
" to_sample = np.ones(num_nodes_dict['DISEASE'], dtype=int)\n",
" to_sample[g_test.out_edges(d, etype='DISEASE-and-DISEASE')[1].tolist()] = 0\n",
" neg_ds = random.sample(to_sample.nonzero()[0].tolist(), k=neg_to_add[d])\n",
" test_neg_edges.extend([[d, neg_d] for neg_d in neg_ds])\n",
"\n",
"test_neg_edges.sort()\n",
"test_neg_edges = np.array(test_neg_edges)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45377d14-57cb-4731-afc2-568e57dda4f0",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# save everything needed\n",
"save_path = load_path.parent\n",
"dgl.save_graphs(str(save_path / 'graph.bin'), [g_train, g_val, g_test])\n",
"np.savez(save_path / 'train_val_test_idx.npz',\n",
" train_idx=train_idx,\n",
" val_idx=val_idx,\n",
" test_idx=test_idx)\n",
"np.save(save_path / 'val_neg_edges.npy', val_neg_edges)\n",
"np.save(save_path / 'test_neg_edges.npy', test_neg_edges)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dgl-0.7-pytorch-1.10-py39",
"language": "python",
"name": "dgl-0.7-pytorch-1.10-py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 6cf4aab

Please sign in to comment.