Skip to content

Commit cedd7ea

Browse files
authored
Merge pull request #553 from chenyushuo/master
Merge 0.1.x into master
2 parents c8f1ca7 + 90d5c3f commit cedd7ea

File tree

29 files changed

+18720
-467
lines changed

29 files changed

+18720
-467
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: RecBole tests
2+
3+
on:
4+
- pull_request
5+
6+
jobs:
7+
build:
8+
9+
runs-on: ubuntu-latest
10+
strategy:
11+
matrix:
12+
python-version: [3.8]
13+
14+
steps:
15+
- uses: actions/checkout@v2
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v2
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install pytest
24+
pip install dgl
25+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
26+
27+
# Use "python -m pytest" instead of "pytest" to fix imports
28+
- name: Test metrics
29+
run: |
30+
python -m pytest -v tests/metrics
31+
- name: Test evaluation_setting
32+
run: |
33+
python -m pytest -v tests/evaluation_setting
34+
- name: Test model
35+
run: |
36+
python -m pytest -v tests/model/test_model_auto.py
37+
- name: Test config
38+
run: |
39+
python -m pytest -v tests/config/test_config.py
40+
export PYTHONPATH=.
41+
python tests/config/test_command_line.py --use_gpu=False --valid_metric=Recall@10 --split_ratio=[0.7,0.2,0.1] --metrics=['Recall@10'] --epochs=200 --eval_setting='LO_RS' --learning_rate=0.3
42+
- name: Test evaluation_setting
43+
run: |
44+
python -m pytest -v tests/evaluation_setting
45+

recbole/data/dataloader/knowledge_dataloader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def _next_batch_data(self):
182182
elif self.state == KGDataLoaderState.RS:
183183
return self.general_dataloader._next_batch_data()
184184
elif self.state == KGDataLoaderState.RSKG:
185+
if self.kg_dataloader.pr >= self.kg_dataloader.pr_end:
186+
self.kg_dataloader.pr = 0
185187
kg_data = self.kg_dataloader._next_batch_data()
186188
rec_data = self.general_dataloader._next_batch_data()
187189
rec_data.update(kg_data)

recbole/data/dataloader/neg_sample_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class NegSampleMixin(AbstractDataLoader):
2929
batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``.
3030
dl_format (InputType, optional): The input type of dataloader. Defaults to
3131
:obj:`~recbole.utils.InputType.POINTWISE`.
32-
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaluts to ``False``.
32+
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
3333
"""
3434
dl_type = DataLoaderType.NEGSAMPLE
3535

recbole/data/dataloader/sequential_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def augmentation(self, uid_list, item_list_index, target_index, item_list_length
143143
new_dict = {
144144
self.uid_field: uid_list,
145145
self.item_list_field: np.zeros((new_length, self.max_item_list_len), dtype=np.int64),
146-
self.time_list_field: np.zeros((new_length, self.max_item_list_len), dtype=np.int64),
146+
self.time_list_field: np.zeros((new_length, self.max_item_list_len)),
147147
self.target_iid_field: self.dataset.inter_feat[self.iid_field][target_index].values,
148148
self.target_time_field: self.dataset.inter_feat[self.time_field][target_index].values,
149149
self.item_list_length_field: item_list_length,

recbole/data/dataset/dataset.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,16 @@ class Dataset(object):
5555
Specially, if feature is loaded from Arg ``additional_feat_suffix``, its source has type str,
5656
which is the suffix of its local file (also the suffix written in Arg ``additional_feat_suffix``).
5757
58-
field2id_token (dict): Dict mapping feature name (str) to a list, which stores the original token of
59-
this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
58+
field2id_token (dict): Dict mapping feature name (str) to a :class:`np.ndarray`, which stores the original token
59+
of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
6060
is remapped to 2. Then ``field2id_token['test'] = ['[PAD]', 'token_a', 'token_b']``. (Note that 0 is
6161
always PADDING for token-like features.)
6262
63+
field2token_id (dict): Dict mapping feature name (str) to a dict, which stores the token remap table
64+
of this feature. For example, if ``test`` is token-like feature, ``token_a`` is remapped to 1, ``token_b``
65+
is remapped to 2. Then ``field2token_id['test'] = {'[PAD]': 0, 'token_a': 1, 'token_b': 2}``.
66+
(Note that 0 is always PADDING for token-like features.)
67+
6368
field2seqlen (dict): Dict mapping feature name (str) to its sequence length (int).
6469
For sequence features, their length can be either set in config,
6570
or set to the max sequence length of this feature.
@@ -116,6 +121,7 @@ def _get_preset(self):
116121
self.field2type = {}
117122
self.field2source = {}
118123
self.field2id_token = {}
124+
self.field2token_id = {}
119125
self.field2seqlen = self.config['seq_len'] or {}
120126
self._preloaded_weight = {}
121127
self.benchmark_filename_list = self.config['benchmark_filename']
@@ -897,11 +903,13 @@ def _remap(self, remap_list):
897903
tokens, split_point = self._concat_remaped_tokens(remap_list)
898904
new_ids_list, mp = pd.factorize(tokens)
899905
new_ids_list = np.split(new_ids_list + 1, split_point)
900-
mp = ['[PAD]'] + list(mp)
906+
mp = np.array(['[PAD]'] + list(mp))
907+
token_id = {t: i for i, t in enumerate(mp)}
901908

902909
for (feat, field, ftype), new_ids in zip(remap_list, new_ids_list):
903-
if (field not in self.field2id_token):
910+
if field not in self.field2id_token:
904911
self.field2id_token[field] = mp
912+
self.field2token_id[field] = token_id
905913
if ftype == FeatureType.TOKEN:
906914
feat[field] = new_ids
907915
elif ftype == FeatureType.TOKEN_SEQ:
@@ -1010,6 +1018,46 @@ def copy_field_property(self, dest_field, source_field):
10101018
self.field2source[dest_field] = self.field2source[source_field]
10111019
self.field2seqlen[dest_field] = self.field2seqlen[source_field]
10121020

1021+
@dlapi.set()
1022+
def token2id(self, field, tokens):
1023+
"""Map external tokens to internal ids.
1024+
1025+
Args:
1026+
field (str): Field of external tokens.
1027+
tokens (str, list or np.ndarray): External tokens.
1028+
1029+
Returns:
1030+
int or np.ndarray: The internal ids of external tokens.
1031+
"""
1032+
if isinstance(tokens, str):
1033+
if tokens in self.field2token_id[field]:
1034+
return self.field2token_id[field][tokens]
1035+
else:
1036+
raise ValueError('token [{}] is not existed')
1037+
elif isinstance(tokens, (list, np.ndarray)):
1038+
return np.array([self.token2id(field, token) for token in tokens])
1039+
else:
1040+
raise TypeError('The type of tokens [{}] is not supported')
1041+
1042+
@dlapi.set()
1043+
def id2token(self, field, ids):
1044+
"""Map internal ids to external tokens.
1045+
1046+
Args:
1047+
field (str): Field of internal ids.
1048+
ids (int, list, np.ndarray or torch.Tensor): Internal ids.
1049+
1050+
Returns:
1051+
str or np.ndarray: The external tokens of internal ids.
1052+
"""
1053+
try:
1054+
return self.field2id_token[field][ids]
1055+
except IndexError:
1056+
if isinstance(ids, list):
1057+
raise ValueError('[{}] is not a one-dimensional list'.format(ids))
1058+
else:
1059+
raise ValueError('[{}] is not a valid ids'.format(ids))
1060+
10131061
@property
10141062
@dlapi.set()
10151063
def user_num(self):

recbole/data/dataset/kg_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def _remap_ID_all(self):
353353
item_tokens = self._get_rec_item_token()
354354
super()._remap_ID_all()
355355
self._sort_remaped_entities(item_tokens)
356-
self.field2id_token[self.relation_field].append('[UI-Relation]')
356+
self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]')
357357

358358
@property
359359
@dlapi.set()

recbole/model/general_recommender/dmf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def get_item_embedding(self):
170170
col = interaction_matrix.col
171171
i = torch.LongTensor([row, col])
172172
data = torch.FloatTensor(interaction_matrix.data)
173-
item_matrix = torch.sparse.FloatTensor(i, data).to(self.device).transpose(0, 1)
174-
173+
item_matrix = torch.sparse.FloatTensor(i, data, torch.Size(interaction_matrix.shape)).to(self.device).transpose(0, 1)
175174
item = torch.sparse.mm(item_matrix, self.item_linear.weight.t())
176175

177176
item = self.item_fc_layers(item)

recbole/model/general_recommender/fism.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
https://github.com/AaronHeee/Neural-Attentive-Item-Similarity-Model
1515
"""
1616

17+
from logging import getLogger
18+
1719
import torch
1820
import torch.nn as nn
19-
from torch.nn.init import normal_
20-
2121
from recbole.model.abstract_recommender import GeneralRecommender
2222
from recbole.utils import InputType
23+
from torch.nn.init import normal_
2324

2425

2526
class FISM(GeneralRecommender):
@@ -36,6 +37,8 @@ def __init__(self, config, dataset):
3637

3738
# load dataset info
3839
self.LABEL = config['LABEL_FIELD']
40+
self.logger = getLogger()
41+
3942
# get all users's history interaction information.the history item
4043
# matrix is padding by the maximum number of a user's interactions
4144
self.history_item_matrix, self.history_lens, self.mask_mat = self.get_history_info(dataset)
@@ -49,6 +52,11 @@ def __init__(self, config, dataset):
4952
# split the too large dataset into the specified pieces
5053
if self.split_to > 0:
5154
self.group = torch.chunk(torch.arange(self.n_items).to(self.device), self.split_to)
55+
else:
56+
self.logger.warning('Pay Attetion!! the `split_to` is set to 0. If you catch a OMM error in this case, ' + \
57+
'you need to increase it \n\t\t\tuntil the error disappears. For example, ' + \
58+
'you can append it in the command line such as `--split_to=5`')
59+
5260

5361
# define layers and loss
5462
# construct source and destination item embedding matrix

recbole/model/general_recommender/gcmc.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,13 @@ def get_norm_adj_mat(self):
143143
# build adj matrix
144144
A = sp.dok_matrix((self.n_users + self.n_items,
145145
self.n_users + self.n_items), dtype=np.float32)
146-
A = A.tolil()
147-
A[:self.n_users, self.n_users:] = self.interaction_matrix
148-
A[self.n_users:, :self.n_users] = self.interaction_matrix.transpose()
149-
A = A.todok()
146+
inter_M = self.interaction_matrix
147+
inter_M_t = self.interaction_matrix.transpose()
148+
data_dict = dict(zip(zip(inter_M.row, inter_M.col+self.n_users),
149+
[1]*inter_M.nnz))
150+
data_dict.update(dict(zip(zip(inter_M_t.row+self.n_users, inter_M_t.col),
151+
[1]*inter_M_t.nnz)))
152+
A._update(data_dict)
150153
# norm adj matrix
151154
sumArr = (A > 0).sum(axis=1)
152155
# add epsilon to avoid Devide by zero Warning

recbole/model/general_recommender/lightgcn.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ def get_norm_adj_mat(self):
8686
# build adj matrix
8787
A = sp.dok_matrix((self.n_users + self.n_items,
8888
self.n_users + self.n_items), dtype=np.float32)
89-
A = A.tolil()
90-
A[:self.n_users, self.n_users:] = self.interaction_matrix
91-
A[self.n_users:, :self.n_users] = self.interaction_matrix.transpose()
92-
A = A.todok()
89+
inter_M = self.interaction_matrix
90+
inter_M_t = self.interaction_matrix.transpose()
91+
data_dict = dict(zip(zip(inter_M.row, inter_M.col+self.n_users),
92+
[1]*inter_M.nnz))
93+
data_dict.update(dict(zip(zip(inter_M_t.row+self.n_users, inter_M_t.col),
94+
[1]*inter_M_t.nnz)))
95+
A._update(data_dict)
9396
# norm adj matrix
9497
sumArr = (A > 0).sum(axis=1)
9598
# add epsilon to avoid Devide by zero Warning

0 commit comments

Comments
 (0)