Skip to content

Commit

Permalink
Merge pull request #601 from mdekstrand/feature/incremental-training
Browse files Browse the repository at this point in the history
Add iterative training API
  • Loading branch information
mdekstrand authored Jan 12, 2025
2 parents 1954efb + 2eba78d commit f24d461
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 66 deletions.
65 changes: 61 additions & 4 deletions docs/guide/impl-tips.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Algorithm Implementation Tips
=============================
Model Implementation Tips
=========================

Implementing algorithms is fun, but there are a few things that are good to keep in mind.

Expand All @@ -9,5 +9,62 @@ In general, development follows the following:
2. Clear
3. Fast

In that order. Further, we always want LensKit to be *usable* in an easy fashion. Code
implementing algorithms, however, may be quite complex in order to achieve good performance.
In that order. Further, we always want LensKit to be *usable* in an easy
fashion. Code implementing commonly-used models, however, may be quite complex
in order to achieve good performance.

.. _iterative-training:

Iterative Training
~~~~~~~~~~~~~~~~~~

The :class:`lenskit.training.IterativeTraining` class provides a standardized
interface and training loop support for training models with iterative methods
that pass through the training data in multiple *epochs*. Models that use this
support extend :class:`~lenskit.training.IterativeTraining` in addition to
:class:`~lenskit.pipeline.Component`, and implement the
:meth:`~lenskit.training.IterativeTraining.training_loop` method instead of
:meth:`~lenskit.training.Trainable.train`. Iteratively-trainable components
should also have an ``epochs`` setting on their configuration class that
specifies the number of training epochs to run.

The :meth:`~lenskit.training.IterativeTraining.training_loop` method does 3 things:

1. Set up initial data structures, preparation, etc. needed for model training.
2. Train the model, yielding after each training epoch. It can optionally
yield a set of metrics, such as training loss or update magnitudes.
3. Perform any final steps and training data cleanup.

The model should be usable after each epoch, to support things like measuring
performance on validation data.

The training loop itself is represented as a Python iterator, so that a ``for``
loop will loop through the training epochs. While the interface definition
specifies the ``Iterator`` type in order to minimize restrictions on component
implementers, we recommend that it actually be a ``Generator``, which allows the
caller to request early termination (through the
:meth:`~collections.abc.Generator.close` method). We also recommend that the
``training_loop()`` method only return the generator after initial data preparation
is complete, so that setup time is not included in the time taken for the first
loop iteration. The easiest way to do implement this is by delegating to an
inner loop function, written as a Python generator:

.. code:: python
def training_loop(self, data: Dataset, options: TrainingOptions):
# do initial data setup/prep for training
context = ...
# pass off to inner generator
return self._training_loop_impl(context)
def _training_loop_impl(self, context):
for i in range(self.config.epochs):
# do the model training
# compute the metrics
try:
yield {'loss': loss}
except GeneratorExit:
# client code has requested early termination
break
# any final cleanup steps
83 changes: 35 additions & 48 deletions lenskit/lenskit/als/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Literal, TypeAlias

import numpy as np
import structlog
import torch
from pydantic import BaseModel
from typing_extensions import Iterator, NamedTuple, Self, override
from typing_extensions import NamedTuple, override

from lenskit import util
from lenskit.data import Dataset, ItemList, QueryInput, RecQuery, Vocabulary
from lenskit.data.types import UIPair
from lenskit.logging import item_progress
from lenskit.parallel.config import ensure_parallel_init
from lenskit.pipeline import Component
from lenskit.training import Trainable, TrainingOptions
from lenskit.training import IterativeTraining, TrainingOptions

EntityClass: TypeAlias = Literal["user", "item"]

Expand Down Expand Up @@ -126,7 +125,7 @@ def to(self, device):
return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device))


class ALSBase(ABC, Component[ItemList], Trainable):
class ALSBase(IterativeTraining, Component[ItemList], ABC):
"""
Base class for ALS models.
Expand All @@ -144,7 +143,9 @@ class ALSBase(ABC, Component[ItemList], Trainable):
logger: structlog.stdlib.BoundLogger

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> bool:
def training_loop(
self, data: Dataset, options: TrainingOptions
) -> Generator[dict[str, float], None, None]:
"""
Run ALS to train a model.
Expand All @@ -154,49 +155,33 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) ->
Returns:
``True`` if the model was trained.
"""
if hasattr(self, "item_features_") and not options.retrain:
return False

ensure_parallel_init()
timer = util.Stopwatch()

rng = options.random_generator()

train = self.prepare_data(data)
self.users_ = train.users
self.items_ = train.items

self.initialize_params(train, rng)

return self._training_loop_generator(train)

for algo in self.fit_iters(data, options):
pass # we just need to do the iterations

if self.user_features_ is not None:
self.logger.info(
"trained model in %s (|P|=%f, |Q|=%f)",
timer,
torch.norm(self.user_features_, "fro"),
torch.norm(self.item_features_, "fro"),
features=self.config.features,
)
else:
self.logger.info(
"trained model in %s (|Q|=%f)",
timer,
torch.norm(self.item_features_, "fro"),
features=self.config.features,
)

return True

def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]:
def _training_loop_generator(
self, train: TrainingData
) -> Generator[dict[str, float], None, None]:
"""
Run ALS to train a model, yielding after each iteration.
Args:
ratings: the ratings data frame.
"""

log = self.logger = self.logger.bind(features=self.config.features)
rng = options.random_generator()

train = self.prepare_data(data)
self.users_ = train.users
self.items_ = train.items

self.initialize_params(train, rng)

assert self.user_features_ is not None
assert self.item_features_ is not None
Expand All @@ -207,27 +192,26 @@ def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]:
"item", train.iu_rates, self.item_features_, self.user_features_, self.config.item_reg
)

log.info("beginning ALS model training")

with item_progress("Training ALS", self.config.epochs) as epb:
for epoch in range(self.config.epochs):
log = log.bind(epoch=epoch)
epoch = epoch + 1
for epoch in range(self.config.epochs):
log = log.bind(epoch=epoch)
epoch = epoch + 1

du = self.als_half_epoch(epoch, u_ctx)
log.debug("finished user epoch")
du = self.als_half_epoch(epoch, u_ctx)
log.debug("finished user epoch")

di = self.als_half_epoch(epoch, i_ctx)
log.debug("finished item epoch")
di = self.als_half_epoch(epoch, i_ctx)
log.debug("finished item epoch")

log.info("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di)
epb.update()
yield self
log.debug("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di)
yield {"deltaP": du, "deltaQ": di}

if not self.config.save_user_features:
self.user_features_ = None
self.user_ = None

log.debug("finalizing model training")
self.finalize_training()

@abstractmethod
def prepare_data(self, data: Dataset) -> TrainingData: # pragma: no cover
"""
Expand Down Expand Up @@ -270,6 +254,9 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: # pragma:
"""
...

def finalize_training(self):
pass

@override
def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
query = RecQuery.create(query)
Expand Down
15 changes: 6 additions & 9 deletions lenskit/lenskit/als/_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lenskit.logging.progress import item_progress_handle, pbh_update
from lenskit.math.solve import solve_cholesky
from lenskit.parallel.chunking import WorkChunks
from lenskit.training import TrainingOptions

from ._common import ALSBase, ALSConfig, TrainContext, TrainingData

Expand Down Expand Up @@ -71,14 +70,6 @@ class ImplicitMFScorer(ALSBase):

OtOr_: torch.Tensor

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()):
if super().train(data, options):
# compute OtOr and save it on the model
reg = self.config.user_reg
self.OtOr_ = _implicit_otor(self.item_features_, reg)
return True

@override
def prepare_data(self, data: Dataset) -> TrainingData:
if self.config.use_ratings:
Expand Down Expand Up @@ -109,6 +100,12 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float:
with item_progress_handle(f"epoch {epoch} {context.label}s", total=context.nrows) as pbh:
return _train_implicit_cholesky_fanout(context, OtOr, chunks, pbh)

def finalize_training(self):
# compute OtOr and save it on the model
reg = self.config.user_reg
self.OtOr_ = _implicit_otor(self.item_features_, reg)
return True

@override
def new_user_embedding(
self, user_num: int | None, user_items: ItemList
Expand Down
4 changes: 3 additions & 1 deletion lenskit/lenskit/logging/progress/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def set_progress_impl(name: str | None, *options: Any):
raise ValueError(f"unknown progress backend {name}")


def item_progress(label: str, total: int, fields: dict[str, str | None] | None = None) -> Progress:
def item_progress(
label: str, total: int | None = None, fields: dict[str, str | None] | None = None
) -> Progress:
"""
Create a progress bar for distinct, counted items.
Expand Down
85 changes: 81 additions & 4 deletions lenskit/lenskit/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
# pyright: strict
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from typing import (
Protocol,
runtime_checkable,
)
from typing import Protocol, runtime_checkable

import numpy as np

from lenskit.data.dataset import Dataset
from lenskit.logging import get_logger, item_progress
from lenskit.random import RNGInput, random_generator

_log = get_logger(__name__)


@dataclass(frozen=True)
class TrainingOptions:
Expand Down Expand Up @@ -94,3 +96,78 @@ def train(self, data: Dataset, options: TrainingOptions) -> None:
The training options.
"""
raise NotImplementedError()


class IterativeTraining(ABC, Trainable):
"""
Base class for components that support iterative training. This both
automates the :meth:`Trainable.train` method for iterative training in terms
of initialization, epoch, and finalization methods, and exposes those
methods to client code that may wish to directly control the iterative
training process.
Stability:
Full
"""

trained_epochs: int = 0
"""
The number of epochs for which this model has been trained.
"""

@property
def expected_training_epochs(self) -> int | None:
"""
Get the number of training epochs expected to run. The default
implementation looks for an ``epochs`` attribute on the configuration
object (``self.config``).
"""
cfg = getattr(self, "config", None)
if cfg:
return getattr(cfg, "epochs", None)

def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> None:
"""
Implementation of :meth:`Trainable.train` that uses the training loop.
It also uses the :attr:`trained_epochs` attribute to detect if the model
has already been trained for the purposes of honoring
:attr:`TrainingOptions.retrain`, and updates that attribute as model
training progresses.
"""
if self.trained_epochs > 0 and not options.retrain:
return

self.trained_epochs = 0
log = _log.bind(model=f"{self.__class__.__module__}.{self.__class__.__qualname__}")
log.info("training model")
n = self.expected_training_epochs
log.debug("creating training loop")
loop = self.training_loop(data, options)
log.debug("beginning training iterations")
with item_progress("Training iterations", total=n) as pb:
for i, metrics in enumerate(loop, 1):
metrics = metrics or {}
log.info("finished epoch", epoch=i, **metrics)
self.trained_epochs += 1
pb.update()

log.info("model training finished", epochs=self.trained_epochs)

@abstractmethod
def training_loop(
self, data: Dataset, options: TrainingOptions
) -> Iterator[dict[str, float] | None]:
"""
Training loop implementation, to be supplied by the derived class. This
method should return a iterator that, when iterated, will perform each
training epoch; when training is complete, it should finalize the model
and signal iteration completion.
Each epoch can yield metrics, such as training or validation loss, to be
logged with structured logging and can be used by calling code to do
other analysis.
See :ref:`iterative-training` for more details on writing iterative
training loops.
"""
raise NotImplementedError()

0 comments on commit f24d461

Please sign in to comment.