33# in this project's top-level directory, and also on-line at:
44# https://github.com/ipums/hlink
55
6+ from pyspark .sql .types import (
7+ FloatType ,
8+ IntegerType ,
9+ StringType ,
10+ StructField ,
11+ StructType ,
12+ )
13+
614from hlink .linking .link_step import LinkStep
715
816
@@ -100,20 +108,20 @@ def _run(self):
100108 gains = [raw_gains .get (key , 0.0 ) for key in keys ]
101109 label = "Feature importances (weights and gains)"
102110
103- features_df = self . task . spark . createDataFrame (
104- zip ( true_column_names , true_categories , weights , gains ),
105- "feature_name: string, category: int, weight: double, gain: double" ,
106- ). sort ( "feature_name" , "category" )
111+ importance_columns = [
112+ ( StructField ( "weight" , FloatType (), nullable = False ), weights ),
113+ ( StructField ( "gain" , FloatType (), nullable = False ), gains ) ,
114+ ]
107115 elif model_type == "lightgbm" :
108116 # The "weight" of a feature is the number of splits it causes.
109117 weights = model .getFeatureImportances ("split" )
110118 gains = model .getFeatureImportances ("gain" )
111119 label = "Feature importances (weights and gains)"
112120
113- features_df = self . task . spark . createDataFrame (
114- zip ( true_column_names , true_categories , weights , gains ),
115- "feature_name: string, category: int, weight: double, gain: double" ,
116- ). sort ( "feature_name" , "category" )
121+ importance_columns = [
122+ ( StructField ( "weight" , FloatType (), nullable = False ), weights ),
123+ ( StructField ( "gain" , FloatType (), nullable = False ), gains ) ,
124+ ]
117125 else :
118126 try :
119127 feature_imp = model .coefficients
@@ -135,13 +143,27 @@ def _run(self):
135143 feature_importances = [
136144 float (importance ) for importance in feature_imp .toArray ()
137145 ]
138- features_df = self .task .spark .createDataFrame (
139- zip (
140- true_column_names , true_categories , feature_importances , strict = True
146+
147+ importance_columns = [
148+ (
149+ StructField (
150+ "coefficient_or_importance" , FloatType (), nullable = False
151+ ),
152+ feature_importances ,
141153 ),
142- "feature_name: string, category: int, coefficient_or_importance: double" ,
143- ).sort ("feature_name" , "category" )
154+ ]
144155
156+ importance_schema , importance_data = zip (* importance_columns )
157+ features_df = self .task .spark .createDataFrame (
158+ zip (true_column_names , true_categories , * importance_data , strict = True ),
159+ StructType (
160+ [
161+ StructField ("feature_name" , StringType (), nullable = False ),
162+ StructField ("category" , IntegerType (), nullable = True ),
163+ * importance_schema ,
164+ ]
165+ ),
166+ ).sort ("feature_name" , "category" )
145167 feature_importances_table = (
146168 f"{ self .task .table_prefix } training_feature_importances"
147169 )
0 commit comments