Skip to content

Commit

Permalink
[#162] Integrate LightGBM with training step 3
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Nov 21, 2024
1 parent 34b1a26 commit 7f7afe7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions hlink/linking/training/link_step_save_model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def _run(self):
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, average_gain_per_split: double",
).sort("feature_name", "category")
elif model_type == "lightgbm":
num_splits = model.getFeatureImportances("split")
total_gains = model.getFeatureImportances("gain")
label = "Feature importances (number of splits and total gains)"

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, num_splits, total_gains),
"feature_name: string, category: int, num_splits: double, total_gain: double",
).sort("feature_name", "category")
else:
try:
feature_imp = model.coefficients
Expand Down
16 changes: 16 additions & 0 deletions hlink/tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,14 @@ def test_step_3_with_lightgbm_model(

training.run_all_steps()

importances_df = spark.table("training_feature_importances")
assert importances_df.columns == [
"feature_name",
"category",
"num_splits",
"total_gain",
]


@requires_lightgbm
def test_lightgbm_with_interacted_features(
Expand Down Expand Up @@ -536,6 +544,14 @@ def test_lightgbm_with_interacted_features(

training.run_all_steps()

importances_df = spark.table("training_feature_importances")
assert importances_df.columns == [
"feature_name",
"category",
"num_splits",
"total_gain",
]


@requires_xgboost
def test_step_3_with_xgboost_model(
Expand Down

0 comments on commit 7f7afe7

Please sign in to comment.