-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[#162] Add a rough draft of a RenameVectorAttributes transformer
Usually we don't care about the names of the vector attributes. But LightGBM uses them as feature names and disallows some characters in the names. Unfortunately, one of these characters is :, and Spark's Interaction names the output of an interaction between A and B "A:B". I looked through the Spark code and didn't see any way to configure the names of these output features. So I think the easiest way forward here is to make a transformer that renames the attributes of a vector by removing some characters and replacing them with another.
- Loading branch information
1 parent
72fd83c
commit 062ad63
Showing
2 changed files
with
41 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from pyspark.ml import Transformer | ||
from pyspark.sql import DataFrame | ||
|
||
|
||
class RenameVectorAttributes(Transformer): | ||
""" | ||
A custom transformer which renames the attributes or "slot names" of a | ||
given input column of type vector. This is helpful when you don't have | ||
complete control over the names of the attributes, but you need them to | ||
look a certain way. | ||
For example, LightGBM can't handle vector attributes with colons in their | ||
names. But the Spark Interaction class creates vector attributes named with | ||
colons. So we need to rename the attributes and remove the colons before | ||
passing the feature vector to LightGBM for training. | ||
""" | ||
|
||
def __init__(self) -> None: ... | ||
|
||
def _transform(self, dataset: DataFrame) -> DataFrame: | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from pyspark.sql import SparkSession | ||
from pyspark.ml.feature import VectorAssembler | ||
|
||
from hlink.linking.transformers.rename_vector_attributes import RenameVectorAttributes | ||
|
||
|
||
def test_rename_vector_attributes(spark: SparkSession) -> None: | ||
df = spark.createDataFrame( | ||
[[0.0, 1.0], [1.0, 2.0], [3.0, 4.0]], schema=["A", "regionf_0:namelast_jw"] | ||
) | ||
|
||
assembler = VectorAssembler( | ||
inputCols=["A", "regionf_0:namelast_jw"], outputCol="vectorized" | ||
) | ||
remove_colons = RenameVectorAttributes() | ||
transformed = remove_colons.transform(assembler.transform(df)) | ||
|
||
attrs = transformed.schema["vectorized"].metadata["ml_attr"]["attrs"]["numeric"] | ||
attr_names = [attr["name"] for attr in attrs] | ||
assert attr_names == ["A", "regionf_0_namelast_jw"] |