From 0277d7d8a4d06e3c4fd2ab9fad495e3e0dd35f0d Mon Sep 17 00:00:00 2001 From: rileyh Date: Mon, 18 Nov 2024 12:53:03 -0600 Subject: [PATCH] [#161] Rename a variable in training step 3 --- .../training/link_step_save_model_metadata.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/hlink/linking/training/link_step_save_model_metadata.py b/hlink/linking/training/link_step_save_model_metadata.py index ed98274..88f82a4 100644 --- a/hlink/linking/training/link_step_save_model_metadata.py +++ b/hlink/linking/training/link_step_save_model_metadata.py @@ -58,10 +58,9 @@ def _run(self): raise new_error from e - # The pipeline model has three stages: vector assembler, classifier, post - # transformer. + # The pipeline model has three stages: vector assembler, model, and post transformer. vector_assembler = pipeline_model.stages[0] - classifier = pipeline_model.stages[1] + model = pipeline_model.stages[1] column_names = vector_assembler.getInputCols() tf_prepped = self.task.spark.table(f"{table_prefix}training_features_prepped") @@ -93,13 +92,13 @@ def _run(self): print("Retrieving model feature importances or coefficients...") if model_type == "xgboost": - raw_weights = classifier.get_feature_importances("weight") - raw_gains = classifier.get_feature_importances("gain") + raw_weights = model.get_feature_importances("weight") + raw_gains = model.get_feature_importances("gain") keys = [f"f{index}" for index in range(len(true_cols))] weights = [raw_weights.get(key, 0.0) for key in keys] gains = [raw_gains.get(key, 0.0) for key in keys] - label = "Feature importances (weights and gain)" + label = "Feature importances (weights and gains)" features_df = self.task.spark.createDataFrame( zip(true_column_names, true_categories, weights, gains), @@ -107,10 +106,10 @@ def _run(self): ).sort("feature_name", "category") else: try: - feature_imp = classifier.coefficients + feature_imp = model.coefficients except: try: - feature_imp = classifier.featureImportances + feature_imp = model.featureImportances except: print( "This model doesn't contain a coefficient or feature importances parameter -- check chosen model type."