Skip to content

Commit

Permalink
Use narwhals for formula materialization
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Dec 18, 2024
1 parent 56672d0 commit 70496f3
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 68 deletions.
56 changes: 21 additions & 35 deletions src/tabmat/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from collections.abc import Iterable
from typing import Any, Optional, Union

import narwhals.stable.v1 as nw
import numpy
import pandas
from formulaic import ModelMatrix, ModelSpec
from formulaic.errors import FactorEncodingError
from formulaic.materializers import FormulaMaterializer
from formulaic.materializers.types import FactorValues, NAAction, ScopedTerm
from formulaic.materializers import NarwhalsMaterializer
from formulaic.materializers.types import FactorValues, ScopedTerm
from formulaic.parser.types import Term
from formulaic.transforms import stateful_transform
from formulaic.utils.null_handling import drop_rows as drop_nulls
from interface_meta import override
from scipy import sparse as sps

Expand All @@ -29,7 +31,7 @@
from formulaic.materializers.types.formula_materializer import EncodedTermStructure


class TabmatMaterializer(FormulaMaterializer):
class TabmatMaterializer(NarwhalsMaterializer):
"""Materializer for pandas input and tabmat output."""

REGISTER_NAME = "tabmat"
Expand All @@ -52,34 +54,17 @@ def _init(self):
self.cat_missing_method = self.params.get("cat_missing_method", "fail")
self.cat_missing_name = self.params.get("cat_missing_name", "(MISSING)")

# Always convert input to narwhals DataFrame
self.__narwhals_data = nw.from_native(self.data, eager_only=True)
self.__data_context = self.__narwhals_data.to_dict()

# We can override formulaic's C() function here
self.context["C"] = _C

@override
def _is_categorical(self, values):
if isinstance(values, (pandas.Series, pandas.Categorical)):
return values.dtype == object or isinstance(
values.dtype, pandas.CategoricalDtype
)
return super()._is_categorical(values)

@override
def _check_for_nulls(self, name, values, na_action, drop_rows):
if na_action is NAAction.IGNORE:
return

if na_action is NAAction.RAISE:
if isinstance(values, pandas.Series) and values.isnull().values.any():
raise ValueError(f"`{name}` contains null values after evaluation.")

elif na_action is NAAction.DROP:
if isinstance(values, pandas.Series):
drop_rows.update(numpy.flatnonzero(values.isnull().values))

else:
raise ValueError(
f"Do not know how to interpret `na_action` = {repr(na_action)}."
)
@override # type: ignore
@property
def data_context(self):
return self.__data_context

@override
def _encode_constant(self, value, metadata, encoder_state, spec, drop_rows):
Expand All @@ -89,7 +74,7 @@ def _encode_constant(self, value, metadata, encoder_state, spec, drop_rows):
@override
def _encode_numerical(self, values, metadata, encoder_state, spec, drop_rows):
if drop_rows:
values = values.drop(index=values.index[drop_rows])
values = drop_nulls(values, indices=drop_rows)
if isinstance(values, pandas.Series):
values = values.to_numpy().astype(self.dtype)
if (values != 0).mean() <= self.sparse_threshold:
Expand All @@ -103,7 +88,7 @@ def _encode_categorical(
):
# We do not do any encoding here as it is handled by tabmat
if drop_rows:
values = values.drop(index=values.index[drop_rows])
values = drop_nulls(values, indices=drop_rows)
return encode_contrasts(
values,
reduced_rank=reduced_rank,
Expand Down Expand Up @@ -665,7 +650,7 @@ def _C(
data,
*,
levels: Optional[Iterable[str]] = None,
missing_method: str = "fail",
missing_method: Optional[str] = None,
missing_name: str = "(MISSING)",
spans_intercept: bool = True,
):
Expand All @@ -685,12 +670,13 @@ def encoder(
model_spec: ModelSpec,
):
if drop_rows:
values = values.drop(index=values.index[drop_rows])
values = drop_nulls(values, indices=drop_rows)
return encode_contrasts(
values,
levels=levels,
reduced_rank=reduced_rank,
missing_method=missing_method,
missing_method=missing_method
or model_spec.materializer_params.get("cat_missing_method", "fail"), # type: ignore
missing_name=missing_name,
_state=encoder_state,
_spec=model_spec,
Expand Down Expand Up @@ -737,14 +723,14 @@ def encode_contrasts(
# - missings are no problem in the other cases
unseen_categories = set(data.unique()) - set(levels)
else:
unseen_categories = set(data.dropna().unique()) - set(levels)
unseen_categories = set(data.drop_nulls().unique()) - set(levels)

if unseen_categories:
raise ValueError(
f"Column {data.name} contains unseen categories: {unseen_categories}."
)

cat = pandas.Categorical(data._values, categories=levels)
cat = pandas.Categorical(data.to_pandas()._values, categories=levels)
_state["categories"] = cat.categories
_state["add_missing_category"] = add_missing_category or (
missing_method == "convert" and cat.isna().any()
Expand Down
Loading

0 comments on commit 70496f3

Please sign in to comment.