Skip to content

Commit 5ef0879

Browse files
committed
[#161, #162] Unify feature importances for XGBoost and LightGBM
We now compute two feature importances for each model. - weight: the number of splits that each feature causes - gain: the total gain across all of each feature's splits
1 parent 7f7afe7 commit 5ef0879

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

hlink/linking/training/link_step_save_model_metadata.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _run(self):
9393

9494
if model_type == "xgboost":
9595
raw_weights = model.get_feature_importances("weight")
96-
raw_gains = model.get_feature_importances("gain")
96+
raw_gains = model.get_feature_importances("total_gain")
9797
keys = [f"f{index}" for index in range(len(true_cols))]
9898

9999
weights = [raw_weights.get(key, 0.0) for key in keys]
@@ -102,16 +102,17 @@ def _run(self):
102102

103103
features_df = self.task.spark.createDataFrame(
104104
zip(true_column_names, true_categories, weights, gains),
105-
"feature_name: string, category: int, weight: double, average_gain_per_split: double",
105+
"feature_name: string, category: int, weight: double, gain: double",
106106
).sort("feature_name", "category")
107107
elif model_type == "lightgbm":
108-
num_splits = model.getFeatureImportances("split")
109-
total_gains = model.getFeatureImportances("gain")
110-
label = "Feature importances (number of splits and total gains)"
108+
# The "weight" of a feature is the number of splits it causes.
109+
weights = model.getFeatureImportances("split")
110+
gains = model.getFeatureImportances("gain")
111+
label = "Feature importances (weights and gains)"
111112

112113
features_df = self.task.spark.createDataFrame(
113-
zip(true_column_names, true_categories, num_splits, total_gains),
114-
"feature_name: string, category: int, num_splits: double, total_gain: double",
114+
zip(true_column_names, true_categories, weights, gains),
115+
"feature_name: string, category: int, weight: double, gain: double",
115116
).sort("feature_name", "category")
116117
else:
117118
try:

hlink/tests/training_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,8 @@ def test_step_3_with_lightgbm_model(
490490
assert importances_df.columns == [
491491
"feature_name",
492492
"category",
493-
"num_splits",
494-
"total_gain",
493+
"weight",
494+
"gain",
495495
]
496496

497497

@@ -548,8 +548,8 @@ def test_lightgbm_with_interacted_features(
548548
assert importances_df.columns == [
549549
"feature_name",
550550
"category",
551-
"num_splits",
552-
"total_gain",
551+
"weight",
552+
"gain",
553553
]
554554

555555

@@ -601,7 +601,7 @@ def test_step_3_with_xgboost_model(
601601
"feature_name",
602602
"category",
603603
"weight",
604-
"average_gain_per_split",
604+
"gain",
605605
]
606606

607607

0 commit comments

Comments
 (0)