Skip to content

Commit

Permalink
Document sampling functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Mar 6, 2021
1 parent 9b481d6 commit cc86821
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
13 changes: 13 additions & 0 deletions doc/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,16 @@ Building Ratings Matrices

.. autofunction:: sparse_ratings
.. autoclass:: RatingMatrix

Sampling Utilities
~~~~~~~~~~~~~~~~~~

.. module:: lenskit.data.sampling

The :py:mod:`lenskit.data.sampling` module provides support functions for various
data sampling procedures for use in model training.


.. autofunction:: neg_sample
.. autofunction:: sample_unweighted
.. autofunction:: sample_weighted
2 changes: 1 addition & 1 deletion lenskit/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .matrix import sparse_ratings # noqa: F401
from .matrix import RatingMatrix, sparse_ratings # noqa: F401
35 changes: 32 additions & 3 deletions lenskit/data/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,52 @@

@njit
def sample_unweighted(mat):
"""
Candidate sampling function for use with :py:func:`neg_sample`.
It samples items uniformly at random.
"""
return np.random.randint(0, mat.ncols)


@njit
def sample_weighted(mat):
"""
Candidate sampling function for use with :py:func:`neg_sample`.
It samples items proportionally to their popularity.
"""
j = np.random.randint(0, mat.nnz)
return mat.colinds[j]


@njit(nogil=True)
def neg_sample(mat, uv, sample):
"""
Sample the examples from a user-item matrix. For each user in uv, it samples an
item that they have not rated using rejection sampling.
Sample the examples from a user-item matrix. For each user in ``uv``, it samples
an item that they have not rated using rejection sampling.
While this is embarassingly parallel, we do not parallelize because it's usually
While this is embarassingly parallel, we do not parallelize because it's often
used in parallel.
This returns both the items and the sample counts for debugging::
neg_items, counts = neg_sample(matrix, users, sample_unweighted)
Args:
mat(csr.CSR):
The user-item matrix. Its values are ignored and do not need to be present.
uv(numpy.ndarray):
An array of user IDs.
sample(function):
A sampling function to sample candidate negative items. Should be one of
:py:func:`sample_weighted` or :py:func:`sample_unweighted`.
Returns:
numpy.ndarray, numpy.ndarray:
Two arrays:
1. The sampled negative item IDs.
2. An array of sample counts, the number of samples required to sample each
item. This is useful for diagnosing sample inefficiency.
"""
n = len(uv)
jv = np.empty(n, dtype=np.int32)
Expand Down

0 comments on commit cc86821

Please sign in to comment.