Skip to content

Commit

Permalink
[#162] Add a rough draft of a RenameVectorAttributes transformer
Browse files Browse the repository at this point in the history
Usually we don't care about the names of the vector attributes. But LightGBM
uses them as feature names and disallows some characters in the names.
Unfortunately, one of these characters is :, and Spark's Interaction names the
output of an interaction between A and B "A:B". I looked through the Spark code
and didn't see any way to configure the names of these output features. So I
think the easiest way forward here is to make a transformer that renames the
attributes of a vector by removing some characters and replacing them with
another.
  • Loading branch information
riley-harper committed Nov 20, 2024
1 parent 72fd83c commit 062ad63
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
21 changes: 21 additions & 0 deletions hlink/linking/transformers/rename_vector_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pyspark.ml import Transformer
from pyspark.sql import DataFrame


class RenameVectorAttributes(Transformer):
"""
A custom transformer which renames the attributes or "slot names" of a
given input column of type vector. This is helpful when you don't have
complete control over the names of the attributes, but you need them to
look a certain way.
For example, LightGBM can't handle vector attributes with colons in their
names. But the Spark Interaction class creates vector attributes named with
colons. So we need to rename the attributes and remove the colons before
passing the feature vector to LightGBM for training.
"""

def __init__(self) -> None: ...

def _transform(self, dataset: DataFrame) -> DataFrame:
return dataset
20 changes: 20 additions & 0 deletions hlink/tests/transformers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler

from hlink.linking.transformers.rename_vector_attributes import RenameVectorAttributes


def test_rename_vector_attributes(spark: SparkSession) -> None:
df = spark.createDataFrame(
[[0.0, 1.0], [1.0, 2.0], [3.0, 4.0]], schema=["A", "regionf_0:namelast_jw"]
)

assembler = VectorAssembler(
inputCols=["A", "regionf_0:namelast_jw"], outputCol="vectorized"
)
remove_colons = RenameVectorAttributes()
transformed = remove_colons.transform(assembler.transform(df))

attrs = transformed.schema["vectorized"].metadata["ml_attr"]["attrs"]["numeric"]
attr_names = [attr["name"] for attr in attrs]
assert attr_names == ["A", "regionf_0_namelast_jw"]

0 comments on commit 062ad63

Please sign in to comment.