Skip to content

Commit

Permalink
skip if cannot get neg
Browse files Browse the repository at this point in the history
  • Loading branch information
kuroneko2828 committed Sep 10, 2024
1 parent dc3fdfd commit d45aaee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/multi_channel_bpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def fit(self, lr, reg_params, n_epochs, neg_item_sampling_mode, verbose=False):
self.pos_level_dist,
self.train_inter_pos_dict,
mode=neg_item_sampling_mode)
if j is None:
continue
user_embed, pos_item_embed, neg_item_embed = \
perform_gradient_descent(self.user_reps[u]['embed'],
self.item_reps[i],
Expand Down
6 changes: 5 additions & 1 deletion src/multi_channel_bpr/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def get_neg_item(user_rep, N, n, u, i, pos_level_dist, train_inter_pos_dict,
else:
if mode == 'uniform':
# sample item uniformly from unobserved channel
j = np.random.choice(np.setdiff1d(np.arange(n), user_rep['items']))
unknown_items = np.setdiff1d(np.arange(n), user_rep['items'])
if unknown_items.size == 0:
j = None
else:
j = np.random.choice(unknown_items)

elif mode == 'non-uniform':
# sample item non-uniformly from unobserved channel
Expand Down

0 comments on commit d45aaee

Please sign in to comment.