Skip to content

Commit

Permalink
[#161, #162] Refactor training step 3 to reduce duplication
Browse files Browse the repository at this point in the history
I'm still not entirely happy with this, but it's a tricky point in the code
because most of the models behave one way, but xgboost and lightgbm are
different. Some more refactoring might be in order.
  • Loading branch information
riley-harper committed Nov 21, 2024
1 parent 5ef0879 commit 010f46a
Showing 1 changed file with 35 additions and 13 deletions.
48 changes: 35 additions & 13 deletions hlink/linking/training/link_step_save_model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from pyspark.sql.types import (
FloatType,
IntegerType,
StringType,
StructField,
StructType,
)

from hlink.linking.link_step import LinkStep


Expand Down Expand Up @@ -100,20 +108,20 @@ def _run(self):
gains = [raw_gains.get(key, 0.0) for key in keys]
label = "Feature importances (weights and gains)"

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, gain: double",
).sort("feature_name", "category")
importance_columns = [
(StructField("weight", FloatType(), nullable=False), weights),
(StructField("gain", FloatType(), nullable=False), gains),
]
elif model_type == "lightgbm":
# The "weight" of a feature is the number of splits it causes.
weights = model.getFeatureImportances("split")
gains = model.getFeatureImportances("gain")
label = "Feature importances (weights and gains)"

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, gain: double",
).sort("feature_name", "category")
importance_columns = [
(StructField("weight", FloatType(), nullable=False), weights),
(StructField("gain", FloatType(), nullable=False), gains),
]
else:
try:
feature_imp = model.coefficients
Expand All @@ -135,13 +143,27 @@ def _run(self):
feature_importances = [
float(importance) for importance in feature_imp.toArray()
]
features_df = self.task.spark.createDataFrame(
zip(
true_column_names, true_categories, feature_importances, strict=True

importance_columns = [
(
StructField(
"coefficient_or_importance", FloatType(), nullable=False
),
feature_importances,
),
"feature_name: string, category: int, coefficient_or_importance: double",
).sort("feature_name", "category")
]

importance_schema, importance_data = zip(*importance_columns)
features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, *importance_data, strict=True),
StructType(
[
StructField("feature_name", StringType(), nullable=False),
StructField("category", IntegerType(), nullable=True),
*importance_schema,
]
),
).sort("feature_name", "category")
feature_importances_table = (
f"{self.task.table_prefix}training_feature_importances"
)
Expand Down

0 comments on commit 010f46a

Please sign in to comment.