diff --git a/src/tabmat/formula.py b/src/tabmat/formula.py index dc8ce6b6..ed7bd2dc 100644 --- a/src/tabmat/formula.py +++ b/src/tabmat/formula.py @@ -10,7 +10,7 @@ import pandas from formulaic import ModelMatrix, ModelSpec from formulaic.errors import FactorEncodingError -from formulaic.materializers import NarwhalsMaterializer +from formulaic.materializers import FormulaMaterializer from formulaic.materializers.types import FactorValues, ScopedTerm from formulaic.parser.types import Term from formulaic.transforms import stateful_transform @@ -31,7 +31,7 @@ from formulaic.materializers.types.formula_materializer import EncodedTermStructure -class TabmatMaterializer(NarwhalsMaterializer): +class TabmatMaterializer(FormulaMaterializer): """Materializer for pandas input and tabmat output.""" REGISTER_NAME = "tabmat" @@ -66,6 +66,13 @@ def _init(self): def data_context(self): return self.__data_context + @override + def _is_categorical(self, values: Any) -> bool: + if nw.dependencies.is_narwhals_series(values): + if not values.dtype.is_numeric(): + return True + return super()._is_categorical(values) + @override def _encode_constant(self, value, metadata, encoder_state, spec, drop_rows): series = value * numpy.ones(self.nrows - len(drop_rows))