diff --git a/docs/guide/impl-tips.rst b/docs/guide/impl-tips.rst index 6bb7e7a73..0faa399cf 100644 --- a/docs/guide/impl-tips.rst +++ b/docs/guide/impl-tips.rst @@ -1,5 +1,5 @@ -Algorithm Implementation Tips -============================= +Model Implementation Tips +========================= Implementing algorithms is fun, but there are a few things that are good to keep in mind. @@ -9,5 +9,62 @@ In general, development follows the following: 2. Clear 3. Fast -In that order. Further, we always want LensKit to be *usable* in an easy fashion. Code -implementing algorithms, however, may be quite complex in order to achieve good performance. +In that order. Further, we always want LensKit to be *usable* in an easy +fashion. Code implementing commonly-used models, however, may be quite complex +in order to achieve good performance. + +.. _iterative-training: + +Iterative Training +~~~~~~~~~~~~~~~~~~ + +The :class:`lenskit.training.IterativeTraining` class provides a standardized +interface and training loop support for training models with iterative methods +that pass through the training data in multiple *epochs*. Models that use this +support extend :class:`~lenskit.training.IterativeTraining` in addition to +:class:`~lenskit.pipeline.Component`, and implement the +:meth:`~lenskit.training.IterativeTraining.training_loop` method instead of +:meth:`~lenskit.training.Trainable.train`. Iteratively-trainable components +should also have an ``epochs`` setting on their configuration class that +specifies the number of training epochs to run. + +The :meth:`~lenskit.training.IterativeTraining.training_loop` method does 3 things: + +1. Set up initial data structures, preparation, etc. needed for model training. +2. Train the model, yielding after each training epoch. It can optionally + yield a set of metrics, such as training loss or update magnitudes. +3. Perform any final steps and training data cleanup. + +The model should be usable after each epoch, to support things like measuring +performance on validation data. + +The training loop itself is represented as a Python iterator, so that a ``for`` +loop will loop through the training epochs. While the interface definition +specifies the ``Iterator`` type in order to minimize restrictions on component +implementers, we recommend that it actually be a ``Generator``, which allows the +caller to request early termination (through the +:meth:`~collections.abc.Generator.close` method). We also recommend that the +``training_loop()`` method only return the generator after initial data preparation +is complete, so that setup time is not included in the time taken for the first +loop iteration. The easiest way to do implement this is by delegating to an +inner loop function, written as a Python generator: + +.. code:: python + + def training_loop(self, data: Dataset, options: TrainingOptions): + # do initial data setup/prep for training + context = ... + # pass off to inner generator + return self._training_loop_impl(context) + + def _training_loop_impl(self, context): + for i in range(self.config.epochs): + # do the model training + # compute the metrics + try: + yield {'loss': loss} + except GeneratorExit: + # client code has requested early termination + break + + # any final cleanup steps diff --git a/lenskit/lenskit/als/_common.py b/lenskit/lenskit/als/_common.py index 954adf181..4123cdc0d 100644 --- a/lenskit/lenskit/als/_common.py +++ b/lenskit/lenskit/als/_common.py @@ -7,21 +7,20 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Generator from typing import Literal, TypeAlias import numpy as np import structlog import torch from pydantic import BaseModel -from typing_extensions import Iterator, NamedTuple, Self, override +from typing_extensions import NamedTuple, override -from lenskit import util from lenskit.data import Dataset, ItemList, QueryInput, RecQuery, Vocabulary from lenskit.data.types import UIPair -from lenskit.logging import item_progress from lenskit.parallel.config import ensure_parallel_init from lenskit.pipeline import Component -from lenskit.training import Trainable, TrainingOptions +from lenskit.training import IterativeTraining, TrainingOptions EntityClass: TypeAlias = Literal["user", "item"] @@ -126,7 +125,7 @@ def to(self, device): return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device)) -class ALSBase(ABC, Component[ItemList], Trainable): +class ALSBase(IterativeTraining, Component[ItemList], ABC): """ Base class for ALS models. @@ -144,7 +143,9 @@ class ALSBase(ABC, Component[ItemList], Trainable): logger: structlog.stdlib.BoundLogger @override - def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> bool: + def training_loop( + self, data: Dataset, options: TrainingOptions + ) -> Generator[dict[str, float], None, None]: """ Run ALS to train a model. @@ -154,49 +155,33 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> Returns: ``True`` if the model was trained. """ - if hasattr(self, "item_features_") and not options.retrain: - return False - ensure_parallel_init() - timer = util.Stopwatch() + + rng = options.random_generator() + + train = self.prepare_data(data) + self.users_ = train.users + self.items_ = train.items + + self.initialize_params(train, rng) + + return self._training_loop_generator(train) for algo in self.fit_iters(data, options): pass # we just need to do the iterations - if self.user_features_ is not None: - self.logger.info( - "trained model in %s (|P|=%f, |Q|=%f)", - timer, - torch.norm(self.user_features_, "fro"), - torch.norm(self.item_features_, "fro"), - features=self.config.features, - ) - else: - self.logger.info( - "trained model in %s (|Q|=%f)", - timer, - torch.norm(self.item_features_, "fro"), - features=self.config.features, - ) - return True - def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]: + def _training_loop_generator( + self, train: TrainingData + ) -> Generator[dict[str, float], None, None]: """ Run ALS to train a model, yielding after each iteration. Args: ratings: the ratings data frame. """ - log = self.logger = self.logger.bind(features=self.config.features) - rng = options.random_generator() - - train = self.prepare_data(data) - self.users_ = train.users - self.items_ = train.items - - self.initialize_params(train, rng) assert self.user_features_ is not None assert self.item_features_ is not None @@ -207,27 +192,26 @@ def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]: "item", train.iu_rates, self.item_features_, self.user_features_, self.config.item_reg ) - log.info("beginning ALS model training") - - with item_progress("Training ALS", self.config.epochs) as epb: - for epoch in range(self.config.epochs): - log = log.bind(epoch=epoch) - epoch = epoch + 1 + for epoch in range(self.config.epochs): + log = log.bind(epoch=epoch) + epoch = epoch + 1 - du = self.als_half_epoch(epoch, u_ctx) - log.debug("finished user epoch") + du = self.als_half_epoch(epoch, u_ctx) + log.debug("finished user epoch") - di = self.als_half_epoch(epoch, i_ctx) - log.debug("finished item epoch") + di = self.als_half_epoch(epoch, i_ctx) + log.debug("finished item epoch") - log.info("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di) - epb.update() - yield self + log.debug("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di) + yield {"deltaP": du, "deltaQ": di} if not self.config.save_user_features: self.user_features_ = None self.user_ = None + log.debug("finalizing model training") + self.finalize_training() + @abstractmethod def prepare_data(self, data: Dataset) -> TrainingData: # pragma: no cover """ @@ -270,6 +254,9 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: # pragma: """ ... + def finalize_training(self): + pass + @override def __call__(self, query: QueryInput, items: ItemList) -> ItemList: query = RecQuery.create(query) diff --git a/lenskit/lenskit/als/_implicit.py b/lenskit/lenskit/als/_implicit.py index c183618f4..78d02f69f 100644 --- a/lenskit/lenskit/als/_implicit.py +++ b/lenskit/lenskit/als/_implicit.py @@ -17,7 +17,6 @@ from lenskit.logging.progress import item_progress_handle, pbh_update from lenskit.math.solve import solve_cholesky from lenskit.parallel.chunking import WorkChunks -from lenskit.training import TrainingOptions from ._common import ALSBase, ALSConfig, TrainContext, TrainingData @@ -71,14 +70,6 @@ class ImplicitMFScorer(ALSBase): OtOr_: torch.Tensor - @override - def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()): - if super().train(data, options): - # compute OtOr and save it on the model - reg = self.config.user_reg - self.OtOr_ = _implicit_otor(self.item_features_, reg) - return True - @override def prepare_data(self, data: Dataset) -> TrainingData: if self.config.use_ratings: @@ -109,6 +100,12 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: with item_progress_handle(f"epoch {epoch} {context.label}s", total=context.nrows) as pbh: return _train_implicit_cholesky_fanout(context, OtOr, chunks, pbh) + def finalize_training(self): + # compute OtOr and save it on the model + reg = self.config.user_reg + self.OtOr_ = _implicit_otor(self.item_features_, reg) + return True + @override def new_user_embedding( self, user_num: int | None, user_items: ItemList diff --git a/lenskit/lenskit/logging/progress/_dispatch.py b/lenskit/lenskit/logging/progress/_dispatch.py index 12567eb92..82ca8ea4f 100644 --- a/lenskit/lenskit/logging/progress/_dispatch.py +++ b/lenskit/lenskit/logging/progress/_dispatch.py @@ -36,7 +36,9 @@ def set_progress_impl(name: str | None, *options: Any): raise ValueError(f"unknown progress backend {name}") -def item_progress(label: str, total: int, fields: dict[str, str | None] | None = None) -> Progress: +def item_progress( + label: str, total: int | None = None, fields: dict[str, str | None] | None = None +) -> Progress: """ Create a progress bar for distinct, counted items. diff --git a/lenskit/lenskit/training.py b/lenskit/lenskit/training.py index 717b0ad6b..3603a843f 100644 --- a/lenskit/lenskit/training.py +++ b/lenskit/lenskit/training.py @@ -10,17 +10,19 @@ # pyright: strict from __future__ import annotations +from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass -from typing import ( - Protocol, - runtime_checkable, -) +from typing import Protocol, runtime_checkable import numpy as np from lenskit.data.dataset import Dataset +from lenskit.logging import get_logger, item_progress from lenskit.random import RNGInput, random_generator +_log = get_logger(__name__) + @dataclass(frozen=True) class TrainingOptions: @@ -94,3 +96,78 @@ def train(self, data: Dataset, options: TrainingOptions) -> None: The training options. """ raise NotImplementedError() + + +class IterativeTraining(ABC, Trainable): + """ + Base class for components that support iterative training. This both + automates the :meth:`Trainable.train` method for iterative training in terms + of initialization, epoch, and finalization methods, and exposes those + methods to client code that may wish to directly control the iterative + training process. + + Stability: + Full + """ + + trained_epochs: int = 0 + """ + The number of epochs for which this model has been trained. + """ + + @property + def expected_training_epochs(self) -> int | None: + """ + Get the number of training epochs expected to run. The default + implementation looks for an ``epochs`` attribute on the configuration + object (``self.config``). + """ + cfg = getattr(self, "config", None) + if cfg: + return getattr(cfg, "epochs", None) + + def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> None: + """ + Implementation of :meth:`Trainable.train` that uses the training loop. + It also uses the :attr:`trained_epochs` attribute to detect if the model + has already been trained for the purposes of honoring + :attr:`TrainingOptions.retrain`, and updates that attribute as model + training progresses. + """ + if self.trained_epochs > 0 and not options.retrain: + return + + self.trained_epochs = 0 + log = _log.bind(model=f"{self.__class__.__module__}.{self.__class__.__qualname__}") + log.info("training model") + n = self.expected_training_epochs + log.debug("creating training loop") + loop = self.training_loop(data, options) + log.debug("beginning training iterations") + with item_progress("Training iterations", total=n) as pb: + for i, metrics in enumerate(loop, 1): + metrics = metrics or {} + log.info("finished epoch", epoch=i, **metrics) + self.trained_epochs += 1 + pb.update() + + log.info("model training finished", epochs=self.trained_epochs) + + @abstractmethod + def training_loop( + self, data: Dataset, options: TrainingOptions + ) -> Iterator[dict[str, float] | None]: + """ + Training loop implementation, to be supplied by the derived class. This + method should return a iterator that, when iterated, will perform each + training epoch; when training is complete, it should finalize the model + and signal iteration completion. + + Each epoch can yield metrics, such as training or validation loss, to be + logged with structured logging and can be used by calling code to do + other analysis. + + See :ref:`iterative-training` for more details on writing iterative + training loops. + """ + raise NotImplementedError()