Skip to content

Commit

Permalink
[#161] Rename a variable in training step 3
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Nov 18, 2024
1 parent ffba81a commit 0277d7d
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions hlink/linking/training/link_step_save_model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ def _run(self):

raise new_error from e

# The pipeline model has three stages: vector assembler, classifier, post
# transformer.
# The pipeline model has three stages: vector assembler, model, and post transformer.
vector_assembler = pipeline_model.stages[0]
classifier = pipeline_model.stages[1]
model = pipeline_model.stages[1]

column_names = vector_assembler.getInputCols()
tf_prepped = self.task.spark.table(f"{table_prefix}training_features_prepped")
Expand Down Expand Up @@ -93,24 +92,24 @@ def _run(self):
print("Retrieving model feature importances or coefficients...")

if model_type == "xgboost":
raw_weights = classifier.get_feature_importances("weight")
raw_gains = classifier.get_feature_importances("gain")
raw_weights = model.get_feature_importances("weight")
raw_gains = model.get_feature_importances("gain")
keys = [f"f{index}" for index in range(len(true_cols))]

weights = [raw_weights.get(key, 0.0) for key in keys]
gains = [raw_gains.get(key, 0.0) for key in keys]
label = "Feature importances (weights and gain)"
label = "Feature importances (weights and gains)"

features_df = self.task.spark.createDataFrame(
zip(true_column_names, true_categories, weights, gains),
"feature_name: string, category: int, weight: double, average_gain_per_split: double",
).sort("feature_name", "category")
else:
try:
feature_imp = classifier.coefficients
feature_imp = model.coefficients
except:
try:
feature_imp = classifier.featureImportances
feature_imp = model.featureImportances
except:
print(
"This model doesn't contain a coefficient or feature importances parameter -- check chosen model type."
Expand Down

0 comments on commit 0277d7d

Please sign in to comment.