Skip to content

Commit

Permalink
[#161, #162] Unify feature importances for XGBoost and LightGBM
Browse files Browse the repository at this point in the history
We now compute two feature importances for each model.

- weight: the number of splits that each feature causes
- gain: the total gain across all of each feature's splits
  • Loading branch information
riley-harper committed Nov 21, 2024
1 parent 7f7afe7 commit 5ef0879
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
15 changes: 8 additions & 7 deletions hlink/linking/training/link_step_save_model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _run(self):

if model_type == "xgboost":
raw_weights = model.get_feature_importances("weight")
raw_gains = model.get_feature_importances("gain")
raw_gains = model.get_feature_importances("total_gain")
keys = [f"f{index}" for index in range(len(true_cols))]

weights = [raw_weights.get(key, 0.0) for key in keys]
Expand All @@ -102,16 +102,17 @@ def _run(self):

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, average_gain_per_split: double",
"feature_name: string, category: int, weight: double, gain: double",
).sort("feature_name", "category")
elif model_type == "lightgbm":
num_splits = model.getFeatureImportances("split")
total_gains = model.getFeatureImportances("gain")
label = "Feature importances (number of splits and total gains)"
# The "weight" of a feature is the number of splits it causes.
weights = model.getFeatureImportances("split")
gains = model.getFeatureImportances("gain")
label = "Feature importances (weights and gains)"

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, num_splits, total_gains),
"feature_name: string, category: int, num_splits: double, total_gain: double",
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, gain: double",
).sort("feature_name", "category")
else:
try:
Expand Down
10 changes: 5 additions & 5 deletions hlink/tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ def test_step_3_with_lightgbm_model(
assert importances_df.columns == [
"feature_name",
"category",
"num_splits",
"total_gain",
"weight",
"gain",
]


Expand Down Expand Up @@ -548,8 +548,8 @@ def test_lightgbm_with_interacted_features(
assert importances_df.columns == [
"feature_name",
"category",
"num_splits",
"total_gain",
"weight",
"gain",
]


Expand Down Expand Up @@ -601,7 +601,7 @@ def test_step_3_with_xgboost_model(
"feature_name",
"category",
"weight",
"average_gain_per_split",
"gain",
]


Expand Down

0 comments on commit 5ef0879

Please sign in to comment.