Skip to content

Commit

Permalink
add LastFM preprocessing code
Browse files Browse the repository at this point in the history
  • Loading branch information
cynricfu committed Jul 21, 2022
1 parent 66a8cc3 commit 160b343
Showing 1 changed file with 175 additions and 0 deletions.
175 changes: 175 additions & 0 deletions preprocess_lastfm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"id": "3d79b67d-629f-41b6-a7ad-abe97d5b25f1",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import random\n",
"\n",
"import numpy as np\n",
"import scipy.sparse\n",
"import dgl\n",
"import torch as th"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6ac6fa75-a37a-413a-ad38-a17c3efb5b1c",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"load_path = Path('data/lastfm/magnn')\n",
"adjM = scipy.sparse.load_npz(load_path / 'adjM.npz').toarray()\n",
"type_mask = np.load(load_path / 'node_types.npy')\n",
"user_artist = np.load(load_path / 'user_artist.npy')\n",
"train_val_test_idx = np.load(load_path / 'train_val_test_idx.npz')\n",
"\n",
"train_idx = train_val_test_idx['train_idx']\n",
"val_idx = train_val_test_idx['val_idx']\n",
"test_idx = train_val_test_idx['test_idx']\n",
"train_val_idx = np.sort(np.concatenate((train_idx, val_idx)))\n",
"\n",
"ntypes = ['user', 'artist', 'tag']\n",
"ntype_ids = {'user': 0, 'artist': 1, 'tag': 2}\n",
"num_nodes_dict = {ntype: (type_mask == ntype_ids[ntype]).sum() for ntype in ntypes}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5bbfa8fc-febb-4228-86c8-2afeca6ce041",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# g_train, g_val, g_test\n",
"data_dict_train = {\n",
" ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n",
" ('user', 'user-artist', 'artist'): (user_artist[train_idx][:, 0], user_artist[train_idx][:, 1]),\n",
" ('artist', 'artist-user', 'user'): (user_artist[train_idx][:, 1], user_artist[train_idx][:, 0]),\n",
" ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n",
" ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n",
"}\n",
"data_dict_val = {\n",
" ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n",
" ('user', 'user-artist', 'artist'): (user_artist[train_val_idx][:, 0], user_artist[train_val_idx][:, 1]),\n",
" ('artist', 'artist-user', 'user'): (user_artist[train_val_idx][:, 1], user_artist[train_val_idx][:, 0]),\n",
" ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n",
" ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n",
"}\n",
"data_dict_test = {\n",
" ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n",
" ('user', 'user-artist', 'artist'): (user_artist[:, 0], user_artist[:, 1]),\n",
" ('artist', 'artist-user', 'user'): (user_artist[:, 1], user_artist[:, 0]),\n",
" ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n",
" ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n",
"}\n",
"\n",
"g_train = dgl.heterograph(data_dict_train, num_nodes_dict, idtype=th.int64)\n",
"g_val = dgl.heterograph(data_dict_val, num_nodes_dict, idtype=th.int64)\n",
"g_test = dgl.heterograph(data_dict_test, num_nodes_dict, idtype=th.int64)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "baa17041-f14d-4b6c-be63-975d3d5cd99f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# sample hard negatives\n",
"# for a positive pair (u, a), sample (u, a') as a negative pair\n",
"# where a' associates with a via A-U-A metapath (i.e., a 2-hop neighbor) and (u, a') does not exist in the original graph\n",
"\n",
"# validation\n",
"g_val_AUA = dgl.metapath_reachable_graph(g_val, ['artist-user', 'user-artist'])\n",
"val_neg_user_artist = []\n",
"for u, a in user_artist[val_idx]:\n",
" neg_artists = list(set(g_val_AUA.out_edges(a)[1].tolist()) - set(g_val.out_edges(u, etype='user-artist')[1].tolist()))\n",
" neg_a = random.choice(neg_artists)\n",
" val_neg_user_artist.append([u, neg_a])\n",
"val_neg_user_artist = np.array(val_neg_user_artist)\n",
"\n",
"# testing\n",
"g_test_AUA = dgl.metapath_reachable_graph(g_test, ['artist-user', 'user-artist'])\n",
"test_neg_user_artist = []\n",
"for u, a in user_artist[test_idx]:\n",
" neg_artists = list(set(g_test_AUA.out_edges(a)[1].tolist()) - set(g_test.out_edges(u, etype='user-artist')[1].tolist()))\n",
" neg_a = random.choice(neg_artists)\n",
" test_neg_user_artist.append([u, neg_a])\n",
"test_neg_user_artist = np.array(test_neg_user_artist)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d076d803-d12f-419e-b720-ed6468665ca5",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# save graphs and hard negative pairs\n",
"save_path = load_path.parent\n",
"dgl.save_graphs(str(save_path / 'graph.bin'), [g_train, g_val, g_test])\n",
"np.save(save_path / 'val_neg_user_artist.npy', val_neg_user_artist)\n",
"np.save(save_path / 'test_neg_user_artist.npy', test_neg_user_artist)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dec8de09-fa56-4eda-a867-56f0313b41d1",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "dgl-0.7-pytorch-1.10-py39",
"language": "python",
"name": "dgl-0.7-pytorch-1.10-py39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 160b343

Please sign in to comment.