Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-add return type variable for Component #598

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/guide/examples/blendcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LinearBlendConfig(BaseModel):
"""


class LinearBlendScorer(Component):
class LinearBlendScorer(Component[ItemList]):
r"""
Score items with a linear blend of two other scores.

Expand Down
2 changes: 1 addition & 1 deletion lenskit-funksvd/lenskit/funksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _align_add_bias(bias, index, keys, series):
return bias, series


class FunkSVDScorer(Trainable, Component):
class FunkSVDScorer(Trainable, Component[ItemList]):
"""
FunkSVD explicit-feedback matrix factoriation. FunkSVD is a regularized
biased matrix factorization technique trained with featurewise stochastic
Expand Down
2 changes: 1 addition & 1 deletion lenskit-hpf/lenskit/hpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_logger = logging.getLogger(__name__)


class HPFScorer(Component, Trainable):
class HPFScorer(Component[ItemList], Trainable):
"""
Hierarchical Poisson factorization, provided by
`hpfrec <https://hpfrec.readthedocs.io/en/latest/>`_.
Expand Down
2 changes: 1 addition & 1 deletion lenskit-implicit/lenskit/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ImplicitALSConfig(ImplicitConfig, extra="allow"):
weight: float = 40.0


class BaseRec(Component, Trainable):
class BaseRec(Component[ItemList], Trainable):
"""
Base class for Implicit-backed recommenders.

Expand Down
2 changes: 1 addition & 1 deletion lenskit-sklearn/lenskit/sklearn/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BiasedSVDConfig:
n_iter: int = 5


class BiasedSVDScorer(Component, Trainable):
class BiasedSVDScorer(Component[ItemList], Trainable):
"""
Biased matrix factorization for explicit feedback using SciKit-Learn's
:class:`~sklearn.decomposition.TruncatedSVD`. It operates by first
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/als/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def to(self, device):
return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device))


class ALSBase(ABC, Component, Trainable):
class ALSBase(ABC, Component[ItemList], Trainable):
"""
Base class for ALS models.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def entity_damping(self, entity: Literal["user", "item"]) -> float:
return entity_damping(self.damping, entity)


class BiasScorer(Component):
class BiasScorer(Component[ItemList]):
"""
A user-item bias rating prediction model. This component uses
:class:`BiasModel` to predict ratings for users and items.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_logger = logging.getLogger(__name__)


class TrainingCandidateSelectorBase(Component, Trainable):
class TrainingCandidateSelectorBase(Component[ItemList], Trainable):
"""
Base class for candidate selectors using the training data.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_logger = logging.getLogger(__name__)


class FallbackScorer(Component):
class FallbackScorer(Component[ItemList]):
"""
Scoring component that fills in missing scores using a fallback.

Expand Down
4 changes: 2 additions & 2 deletions lenskit/lenskit/basic/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_logger = logging.getLogger(__name__)


class UserTrainingHistoryLookup(Component, Trainable):
class UserTrainingHistoryLookup(Component[ItemList], Trainable):
"""
Look up a user's history from the training data.

Expand Down Expand Up @@ -57,7 +57,7 @@ def __str__(self):
return self.__class__.__name__


class KnownRatingScorer(Component, Trainable):
class KnownRatingScorer(Component[ItemList], Trainable):
"""
Score items by returning their values from the training data.

Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/popularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PopConfig(BaseModel):
"""


class PopScorer(Component, Trainable):
class PopScorer(Component[ItemList], Trainable):
"""
Score items by their popularity. Use with :py:class:`TopN` to get a
most-popular-items recommender.
Expand Down
4 changes: 2 additions & 2 deletions lenskit/lenskit/basic/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class RandomConfig(BaseModel, arbitrary_types_allowed=True):
"""


class RandomSelector(Component):
class RandomSelector(Component[ItemList]):
"""
Randomly select items from a candidate list.

Expand Down Expand Up @@ -74,7 +74,7 @@ def __call__(
return items[np.zeros(0, dtype=np.int32)]


class SoftmaxRanker(Component):
class SoftmaxRanker(Component[ItemList]):
"""
Stochastic top-N ranking with softmax sampling.

Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/topn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TopNConfig(BaseModel):
"""


class TopNRanker(Component):
class TopNRanker(Component[ItemList]):
"""
Rank scored items by their score and take the top *N*. The ranking length
can be passed either at runtime or at component instantiation time, with the
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/knn/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def explicit(self) -> bool:
return self.feedback == "explicit"


class ItemKNNScorer(Component, Trainable):
class ItemKNNScorer(Component[ItemList], Trainable):
"""
Item-item nearest-neighbor collaborative filtering feedback. This item-item
implementation is based on the description of item-based CF by
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/knn/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def explicit(self) -> bool:
return self.feedback == "explicit"


class UserKNNScorer(Component, Trainable):
class UserKNNScorer(Component[ItemList], Trainable):
"""
User-user nearest-neighbor collaborative filtering with ratings. This
user-user implementation is not terribly configurable; it hard-codes design
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/pipeline/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Pipeline:
_nodes: dict[str, Node[Any]]
_aliases: dict[str, Node[Any]]
_defaults: dict[str, Node[Any]]
_components: dict[str, PipelineFunction[Any]]
_components: dict[str, PipelineFunction[Any] | Component[Any]]
_hash: str | None = None
_last: Node[Any] | None = None
_anon_nodes: set[str]
Expand Down
8 changes: 4 additions & 4 deletions lenskit/lenskit/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
Any,
Callable,
Generic,
Mapping,
ParamSpec,
Protocol,
Expand All @@ -35,7 +36,6 @@

P = ParamSpec("P")
T = TypeVar("T")
Cfg = TypeVar("Cfg")
# COut is only return, so Component[U] can be assigned to Component[T] if U ≼ T.
COut = TypeVar("COut", covariant=True)
PipelineFunction: TypeAlias = Callable[..., COut]
Expand Down Expand Up @@ -130,7 +130,7 @@ def load_params(self, params: dict[str, object]) -> None:
raise NotImplementedError()


class Component:
class Component(Generic[COut]):
"""
Base class for pipeline component objects. Any component that is not just a
function should extend this class.
Expand Down Expand Up @@ -260,7 +260,7 @@ def __repr__(self) -> str:


def instantiate_component(
comp: str | type | FunctionType, config: dict[str, Any] | None
comp: str | type | FunctionType, config: Mapping[str, Any] | None
) -> Callable[..., object]:
"""
Utility function to instantiate a component given its class, function, or
Expand All @@ -281,7 +281,7 @@ def instantiate_component(
return comp
elif issubclass(comp, Component):
cfg = comp.validate_config(config)
return comp(cfg)
return comp(cfg) # type: ignore
else: # pragma: nocover
return comp() # type: ignore

Expand Down
6 changes: 3 additions & 3 deletions lenskit/tests/pipeline/test_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class PrefixConfigPYDC:
prefix: str = "UNDEFINED"


class PrefixerDC(Component):
class PrefixerDC(Component[str]):
config: PrefixConfigDC

def __call__(self, msg: str) -> str:
return self.config.prefix + msg


class PrefixerM(Component):
class PrefixerM(Component[str]):
config: PrefixConfigM

def __call__(self, msg: str) -> str:
Expand All @@ -51,7 +51,7 @@ class PrefixerM2(PrefixerM):
config: PrefixConfigM


class PrefixerPYDC(Component):
class PrefixerPYDC(Component[str]):
config: PrefixConfigPYDC

def __call__(self, msg: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lenskit/tests/pipeline/test_pipeline_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PrefixConfig:
prefix: str


class Prefixer(Component):
class Prefixer(Component[str]):
config: PrefixConfig

def __call__(self, msg: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lenskit/tests/pipeline/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PrefixConfig:
prefix: str


class Prefixer(Component):
class Prefixer(Component[str]):
config: PrefixConfig

def __call__(self, msg: str) -> str:
Expand Down
Loading