diff --git a/preprocess_lastfm.ipynb b/preprocess_lastfm.ipynb new file mode 100644 index 0000000..a322a12 --- /dev/null +++ b/preprocess_lastfm.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "3d79b67d-629f-41b6-a7ad-abe97d5b25f1", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import random\n", + "\n", + "import numpy as np\n", + "import scipy.sparse\n", + "import dgl\n", + "import torch as th" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6ac6fa75-a37a-413a-ad38-a17c3efb5b1c", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "load_path = Path('data/lastfm/magnn')\n", + "adjM = scipy.sparse.load_npz(load_path / 'adjM.npz').toarray()\n", + "type_mask = np.load(load_path / 'node_types.npy')\n", + "user_artist = np.load(load_path / 'user_artist.npy')\n", + "train_val_test_idx = np.load(load_path / 'train_val_test_idx.npz')\n", + "\n", + "train_idx = train_val_test_idx['train_idx']\n", + "val_idx = train_val_test_idx['val_idx']\n", + "test_idx = train_val_test_idx['test_idx']\n", + "train_val_idx = np.sort(np.concatenate((train_idx, val_idx)))\n", + "\n", + "ntypes = ['user', 'artist', 'tag']\n", + "ntype_ids = {'user': 0, 'artist': 1, 'tag': 2}\n", + "num_nodes_dict = {ntype: (type_mask == ntype_ids[ntype]).sum() for ntype in ntypes}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5bbfa8fc-febb-4228-86c8-2afeca6ce041", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# g_train, g_val, g_test\n", + "data_dict_train = {\n", + " ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n", + " ('user', 'user-artist', 'artist'): (user_artist[train_idx][:, 0], user_artist[train_idx][:, 1]),\n", + " ('artist', 'artist-user', 'user'): (user_artist[train_idx][:, 1], user_artist[train_idx][:, 0]),\n", + " ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n", + " ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n", + "}\n", + "data_dict_val = {\n", + " ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n", + " ('user', 'user-artist', 'artist'): (user_artist[train_val_idx][:, 0], user_artist[train_val_idx][:, 1]),\n", + " ('artist', 'artist-user', 'user'): (user_artist[train_val_idx][:, 1], user_artist[train_val_idx][:, 0]),\n", + " ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n", + " ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n", + "}\n", + "data_dict_test = {\n", + " ('user', 'user-user', 'user'): adjM[type_mask == ntype_ids['user']][:,type_mask == ntype_ids['user']].nonzero(),\n", + " ('user', 'user-artist', 'artist'): (user_artist[:, 0], user_artist[:, 1]),\n", + " ('artist', 'artist-user', 'user'): (user_artist[:, 1], user_artist[:, 0]),\n", + " ('artist', 'artist-tag', 'tag'): adjM[type_mask == ntype_ids['artist']][:,type_mask == ntype_ids['tag']].nonzero(),\n", + " ('tag', 'tag-artist', 'artist'): adjM[type_mask == ntype_ids['tag']][:,type_mask == ntype_ids['artist']].nonzero(),\n", + "}\n", + "\n", + "g_train = dgl.heterograph(data_dict_train, num_nodes_dict, idtype=th.int64)\n", + "g_val = dgl.heterograph(data_dict_val, num_nodes_dict, idtype=th.int64)\n", + "g_test = dgl.heterograph(data_dict_test, num_nodes_dict, idtype=th.int64)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "baa17041-f14d-4b6c-be63-975d3d5cd99f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# sample hard negatives\n", + "# for a positive pair (u, a), sample (u, a') as a negative pair\n", + "# where a' associates with a via A-U-A metapath (i.e., a 2-hop neighbor) and (u, a') does not exist in the original graph\n", + "\n", + "# validation\n", + "g_val_AUA = dgl.metapath_reachable_graph(g_val, ['artist-user', 'user-artist'])\n", + "val_neg_user_artist = []\n", + "for u, a in user_artist[val_idx]:\n", + " neg_artists = list(set(g_val_AUA.out_edges(a)[1].tolist()) - set(g_val.out_edges(u, etype='user-artist')[1].tolist()))\n", + " neg_a = random.choice(neg_artists)\n", + " val_neg_user_artist.append([u, neg_a])\n", + "val_neg_user_artist = np.array(val_neg_user_artist)\n", + "\n", + "# testing\n", + "g_test_AUA = dgl.metapath_reachable_graph(g_test, ['artist-user', 'user-artist'])\n", + "test_neg_user_artist = []\n", + "for u, a in user_artist[test_idx]:\n", + " neg_artists = list(set(g_test_AUA.out_edges(a)[1].tolist()) - set(g_test.out_edges(u, etype='user-artist')[1].tolist()))\n", + " neg_a = random.choice(neg_artists)\n", + " test_neg_user_artist.append([u, neg_a])\n", + "test_neg_user_artist = np.array(test_neg_user_artist)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d076d803-d12f-419e-b720-ed6468665ca5", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# save graphs and hard negative pairs\n", + "save_path = load_path.parent\n", + "dgl.save_graphs(str(save_path / 'graph.bin'), [g_train, g_val, g_test])\n", + "np.save(save_path / 'val_neg_user_artist.npy', val_neg_user_artist)\n", + "np.save(save_path / 'test_neg_user_artist.npy', test_neg_user_artist)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dec8de09-fa56-4eda-a867-56f0313b41d1", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dgl-0.7-pytorch-1.10-py39", + "language": "python", + "name": "dgl-0.7-pytorch-1.10-py39" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file