3
3
# in this project's top-level directory, and also on-line at:
4
4
# https://github.com/ipums/hlink
5
5
6
+ from pyspark .sql .types import (
7
+ FloatType ,
8
+ IntegerType ,
9
+ StringType ,
10
+ StructField ,
11
+ StructType ,
12
+ )
13
+
6
14
from hlink .linking .link_step import LinkStep
7
15
8
16
@@ -100,20 +108,20 @@ def _run(self):
100
108
gains = [raw_gains .get (key , 0.0 ) for key in keys ]
101
109
label = "Feature importances (weights and gains)"
102
110
103
- features_df = self . task . spark . createDataFrame (
104
- zip ( true_column_names , true_categories , weights , gains ),
105
- "feature_name: string, category: int, weight: double, gain: double" ,
106
- ). sort ( "feature_name" , "category" )
111
+ importance_columns = [
112
+ ( StructField ( "weight" , FloatType (), nullable = False ), weights ),
113
+ ( StructField ( "gain" , FloatType (), nullable = False ), gains ) ,
114
+ ]
107
115
elif model_type == "lightgbm" :
108
116
# The "weight" of a feature is the number of splits it causes.
109
117
weights = model .getFeatureImportances ("split" )
110
118
gains = model .getFeatureImportances ("gain" )
111
119
label = "Feature importances (weights and gains)"
112
120
113
- features_df = self . task . spark . createDataFrame (
114
- zip ( true_column_names , true_categories , weights , gains ),
115
- "feature_name: string, category: int, weight: double, gain: double" ,
116
- ). sort ( "feature_name" , "category" )
121
+ importance_columns = [
122
+ ( StructField ( "weight" , FloatType (), nullable = False ), weights ),
123
+ ( StructField ( "gain" , FloatType (), nullable = False ), gains ) ,
124
+ ]
117
125
else :
118
126
try :
119
127
feature_imp = model .coefficients
@@ -135,13 +143,27 @@ def _run(self):
135
143
feature_importances = [
136
144
float (importance ) for importance in feature_imp .toArray ()
137
145
]
138
- features_df = self .task .spark .createDataFrame (
139
- zip (
140
- true_column_names , true_categories , feature_importances , strict = True
146
+
147
+ importance_columns = [
148
+ (
149
+ StructField (
150
+ "coefficient_or_importance" , FloatType (), nullable = False
151
+ ),
152
+ feature_importances ,
141
153
),
142
- "feature_name: string, category: int, coefficient_or_importance: double" ,
143
- ).sort ("feature_name" , "category" )
154
+ ]
144
155
156
+ importance_schema , importance_data = zip (* importance_columns )
157
+ features_df = self .task .spark .createDataFrame (
158
+ zip (true_column_names , true_categories , * importance_data , strict = True ),
159
+ StructType (
160
+ [
161
+ StructField ("feature_name" , StringType (), nullable = False ),
162
+ StructField ("category" , IntegerType (), nullable = True ),
163
+ * importance_schema ,
164
+ ]
165
+ ),
166
+ ).sort ("feature_name" , "category" )
145
167
feature_importances_table = (
146
168
f"{ self .task .table_prefix } training_feature_importances"
147
169
)
0 commit comments