Skip to content

Commit

Permalink
update implicit to configurable component
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jan 11, 2025
1 parent 5567998 commit 20ce979
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 55 deletions.
71 changes: 24 additions & 47 deletions lenskit-implicit/lenskit/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT

import inspect
import logging

import numpy as np
from implicit.als import AlternatingLeastSquares
from implicit.bpr import BayesianPersonalizedRanking
from implicit.recommender_base import RecommenderBase
from pydantic import BaseModel, JsonValue
from scipy.sparse import csr_matrix
from typing_extensions import override

Expand All @@ -26,26 +26,28 @@
]


class ImplicitConfig(BaseModel, extra="allow"):
__pydantic_extra__: dict[str, JsonValue]


class ImplicitALSConfig(ImplicitConfig, extra="allow"):
weight: float = 40.0


class BaseRec(Component, Trainable):
"""
Base class for Implicit-backed recommenders.
Stability:
Caller
Args:
delegate:
The delegate algorithm.
"""

config: ImplicitConfig
delegate: RecommenderBase
"""
The delegate algorithm from :mod:`implicit`.
"""
weight: float
"""
The weight for positive examples (only used by some algorithms).
"""
weight: float = 1.0

matrix_: csr_matrix
"""
Expand All @@ -60,10 +62,6 @@ class BaseRec(Component, Trainable):
The item ID mapping from training.
"""

def __init__(self, delegate: RecommenderBase):
self.delegate = delegate
self.weight = 1.0

@property
def is_trained(self):
return hasattr(self, "matrix_")
Expand All @@ -72,9 +70,8 @@ def is_trained(self):
def train(self, data: Dataset):
matrix = data.interaction_matrix("scipy", layout="csr", legacy=True)
uir = matrix * self.weight
if getattr(self.delegate, "item_factors", None) is not None: # pragma: no cover
_logger.warning("implicit algorithm already trained, re-fit is usually a bug")

self.delegate = self._construct()
_logger.info("training %s on %s matrix (%d nnz)", self.delegate, uir.shape, uir.nnz)

self.delegate.fit(uir)
Expand All @@ -85,6 +82,9 @@ def train(self, data: Dataset):

return self

def _construct(self) -> RecommenderBase:
raise NotImplementedError("implicit constructor not implemented")

@override
def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
query = RecQuery.create(query)
Expand Down Expand Up @@ -113,24 +113,6 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:

return ItemList(items, scores=scores)

def __getattr__(self, name):
if "delegate" not in self.__dict__:
raise AttributeError()
dd = self.delegate.__dict__
if name in dd:
return dd[name]
else:
raise AttributeError()

def get_params(self, deep=True):
dd = self.delegate.__dict__
sig = inspect.signature(self.delegate.__class__)
names = list(sig.parameters.keys())
return dict([(k, dd.get(k)) for k in names])

def __str__(self):
return "Implicit({})".format(self.delegate)


class ALS(BaseRec):
"""
Expand All @@ -140,15 +122,14 @@ class ALS(BaseRec):
Caller
"""

def __init__(self, *args, weight=40.0, **kwargs):
"""
Construct an ALS recommender. The arguments are passed as-is to
:py:class:`implicit.als.AlternatingLeastSquares`. The `weight`
parameter controls the confidence weight for positive examples.
"""
config: ImplicitALSConfig

@property
def weight(self):
return self.config.weight

super().__init__(AlternatingLeastSquares(*args, **kwargs))
self.weight = weight
def _construct(self):
return AlternatingLeastSquares(**self.config.__pydantic_extra__) # type: ignore


class BPR(BaseRec):
Expand All @@ -159,9 +140,5 @@ class BPR(BaseRec):
Caller
"""

def __init__(self, *args, **kwargs):
"""
Construct a BPR recommender. The arguments are passed as-is to
:py:class:`implicit.als.BayesianPersonalizedRanking`.
"""
super().__init__(BayesianPersonalizedRanking(*args, **kwargs))
def _construct(self):
return BayesianPersonalizedRanking(**self.config.__pydantic_extra__) # type: ignore
16 changes: 8 additions & 8 deletions lenskit-implicit/tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class TestImplicitBPR(BasicComponentTests, ScorerTests):

@mark.slow
def test_implicit_als_train_rec(ml_ds):
algo = ALS(25)
assert algo.factors == 25
algo = ALS(factors=25)
assert algo.config.factors == 25

ret = algo.train(ml_ds)
assert ret is algo
Expand All @@ -55,7 +55,7 @@ def test_implicit_als_train_rec(ml_ds):
@mark.parametrize("n_jobs", [1, None])
def test_implicit_als_batch_accuracy(ml_100k, n_jobs):
ds = from_interactions_df(ml_100k)
results = quick_measure_model(ALS(25), ds, n_jobs=n_jobs)
results = quick_measure_model(ALS(factors=25), ds, n_jobs=n_jobs)

ndcg = results.list_summary().loc["NDCG", "mean"]
_log.info("nDCG for %d users is %.4f", len(results.list_metrics()), ndcg)
Expand All @@ -64,8 +64,8 @@ def test_implicit_als_batch_accuracy(ml_100k, n_jobs):

@mark.slow
def test_implicit_bpr_train_rec(ml_ds):
algo = BPR(25, use_gpu=False)
assert algo.factors == 25
algo = BPR(factors=25, use_gpu=False)
assert algo.config.factors == 25

algo.train(ml_ds)

Expand All @@ -89,7 +89,7 @@ def test_implicit_bpr_train_rec(ml_ds):
@mark.parametrize("n_jobs", [1, None])
def test_implicit_bpr_batch_accuracy(ml_100k, n_jobs):
ds = from_interactions_df(ml_100k)
results = quick_measure_model(BPR(25), ds, n_jobs=n_jobs)
results = quick_measure_model(BPR(factors=25), ds, n_jobs=n_jobs)

ndcg = results.list_summary().loc["NDCG", "mean"]
_log.info("nDCG for %d users is %.4f", len(results.list_metrics()), ndcg)
Expand All @@ -98,7 +98,7 @@ def test_implicit_bpr_batch_accuracy(ml_100k, n_jobs):

def test_implicit_pickle_untrained(tmp_path):
mf = tmp_path / "bpr.dat"
algo = BPR(25, use_gpu=False)
algo = BPR(factors=25, use_gpu=False)

with mf.open("wb") as f:
pickle.dump(algo, f)
Expand All @@ -107,4 +107,4 @@ def test_implicit_pickle_untrained(tmp_path):
a2 = pickle.load(f)

assert a2 is not algo
assert a2.factors == 25
assert a2.config.factors == 25

0 comments on commit 20ce979

Please sign in to comment.