Skip to content

Commit 010f46a

Browse files
committed
[#161, #162] Refactor training step 3 to reduce duplication
I'm still not entirely happy with this, but it's a tricky point in the code because most of the models behave one way, but xgboost and lightgbm are different. Some more refactoring might be in order.
1 parent 5ef0879 commit 010f46a

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

hlink/linking/training/link_step_save_model_metadata.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
# in this project's top-level directory, and also on-line at:
44
# https://github.com/ipums/hlink
55

6+
from pyspark.sql.types import (
7+
FloatType,
8+
IntegerType,
9+
StringType,
10+
StructField,
11+
StructType,
12+
)
13+
614
from hlink.linking.link_step import LinkStep
715

816

@@ -100,20 +108,20 @@ def _run(self):
100108
gains = [raw_gains.get(key, 0.0) for key in keys]
101109
label = "Feature importances (weights and gains)"
102110

103-
features_df = self.task.spark.createDataFrame(
104-
zip(true_column_names, true_categories, weights, gains),
105-
"feature_name: string, category: int, weight: double, gain: double",
106-
).sort("feature_name", "category")
111+
importance_columns = [
112+
(StructField("weight", FloatType(), nullable=False), weights),
113+
(StructField("gain", FloatType(), nullable=False), gains),
114+
]
107115
elif model_type == "lightgbm":
108116
# The "weight" of a feature is the number of splits it causes.
109117
weights = model.getFeatureImportances("split")
110118
gains = model.getFeatureImportances("gain")
111119
label = "Feature importances (weights and gains)"
112120

113-
features_df = self.task.spark.createDataFrame(
114-
zip(true_column_names, true_categories, weights, gains),
115-
"feature_name: string, category: int, weight: double, gain: double",
116-
).sort("feature_name", "category")
121+
importance_columns = [
122+
(StructField("weight", FloatType(), nullable=False), weights),
123+
(StructField("gain", FloatType(), nullable=False), gains),
124+
]
117125
else:
118126
try:
119127
feature_imp = model.coefficients
@@ -135,13 +143,27 @@ def _run(self):
135143
feature_importances = [
136144
float(importance) for importance in feature_imp.toArray()
137145
]
138-
features_df = self.task.spark.createDataFrame(
139-
zip(
140-
true_column_names, true_categories, feature_importances, strict=True
146+
147+
importance_columns = [
148+
(
149+
StructField(
150+
"coefficient_or_importance", FloatType(), nullable=False
151+
),
152+
feature_importances,
141153
),
142-
"feature_name: string, category: int, coefficient_or_importance: double",
143-
).sort("feature_name", "category")
154+
]
144155

156+
importance_schema, importance_data = zip(*importance_columns)
157+
features_df = self.task.spark.createDataFrame(
158+
zip(true_column_names, true_categories, *importance_data, strict=True),
159+
StructType(
160+
[
161+
StructField("feature_name", StringType(), nullable=False),
162+
StructField("category", IntegerType(), nullable=True),
163+
*importance_schema,
164+
]
165+
),
166+
).sort("feature_name", "category")
145167
feature_importances_table = (
146168
f"{self.task.table_prefix}training_feature_importances"
147169
)

0 commit comments

Comments
 (0)