Skip to content

Commit 3b22f14

Browse files
committed
Clean up output
1 parent 2facf41 commit 3b22f14

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

hlink/linking/model_exploration/link_step_train_test_models.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _collect_train_test_splits(
168168
cached_test_data = test_data.cache()
169169

170170
split_start_info = f"Training and testing the model on train-test split {split_index} of {len(splits)}"
171-
print(split_start_info)
171+
# print(split_start_info)
172172
logger.debug(split_start_info)
173173
prauc = self._train_model(
174174
cached_training_data,
@@ -199,7 +199,7 @@ def _evaluate_hyperparam_combinations(
199199
results = []
200200
for index, params_combo in enumerate(all_model_parameter_combos, 1):
201201
eval_start_info = f"Starting run {index} of {len(all_model_parameter_combos)} with these parameters: {params_combo}"
202-
print(eval_start_info)
202+
# print(eval_start_info)
203203
logger.info(eval_start_info)
204204
# Copy because the params combo will get stripped of extra key-values
205205
# so only the hyperparams remain.
@@ -266,15 +266,15 @@ def _choose_best_training_results(self, evals: list[ModelEval]) -> ModelEval:
266266
raise RuntimeError(
267267
"No model evaluations provided, cannot choose the best one."
268268
)
269-
print("\n**************************************************")
269+
print("\n\n**************************************************")
270270
print(" All Model - hyper-parameter combinations")
271271
print("**************************************************\n")
272272
best_eval = evals[0]
273273
for e in evals:
274274
print(e)
275275
if best_eval.score < e.score:
276276
best_eval = e
277-
print("--------------------------------------------------\n")
277+
print("--------------------------------------------------\n\n")
278278
return best_eval
279279

280280
def _evaluate_threshold_combinations(
@@ -295,9 +295,9 @@ def _evaluate_threshold_combinations(
295295
# but for now it's a single ModelEval instance -- the one with the highest score.
296296
best_results = self._choose_best_training_results(hyperparam_evaluation_results)
297297

298-
print(f"======== Best Model and Parameters =========")
299-
print(f"{best_results}")
300-
print("==============================================================")
298+
print(f"\n======== Best Model and Parameters ========\n")
299+
print(f"\t{best_results}\n")
300+
print("=============================================\n]\n")
301301

302302
# TODO check if we should make a different split, like starting from a different seed?
303303
# or just not re-using one we used in making the PR_AUC mean value?
@@ -306,6 +306,9 @@ def _evaluate_threshold_combinations(
306306
# thresholding_test_data = splits_for_thresholding_eval[1].cache()
307307
threshold_matrix = best_results.make_threshold_matrix()
308308
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")
309+
print(
310+
f"Testing the best model + parameters against all {len(threshold_matrix)} threshold combinations."
311+
)
309312
results_dfs: dict[int, pd.DataFrame] = {}
310313
for i in range(len(threshold_matrix)):
311314
results_dfs[i] = _create_results_df()
@@ -367,10 +370,6 @@ def _evaluate_threshold_combinations(
367370
config["id_column"],
368371
)
369372

370-
print(
371-
f"Capture results for threshold matrix entry {threshold_index} and split index {split_index}"
372-
)
373-
374373
results_dfs[i] = self._capture_results(
375374
predictions,
376375
predict_train,
@@ -535,12 +534,12 @@ def _capture_results(
535534
# write to sql tables for testing
536535
predictions.createOrReplaceTempView(f"{table_prefix}predictions")
537536
predict_train.createOrReplaceTempView(f"{table_prefix}predict_train")
538-
print("------------------------------------------------------------")
539-
print(f"Capturing predictions:")
540-
predictions.show()
541-
print(f"Capturing predict_train:")
542-
predict_train.show()
543-
print("------------------------------------------------------------")
537+
# print("------------------------------------------------------------")
538+
# print(f"Capturing predictions:")
539+
# predictions.show()
540+
# print(f"Capturing predict_train:")
541+
# predict_train.show()
542+
# print("------------------------------------------------------------")
544543

545544
(
546545
test_TP_count,
@@ -769,19 +768,19 @@ def _get_confusion_matrix(
769768
FP = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 1))
770769
FP_count = FP.count()
771770

772-
print(
773-
f"Confusion matrix -- true positives and false positivesTP {TP_count} FP {FP_count}"
774-
)
771+
# print(
772+
# f"Confusion matrix -- true positives and false positivesTP {TP_count} FP {FP_count}"
773+
# )
775774

776775
FN = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 0))
777776
FN_count = FN.count()
778777

779778
TN = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 0))
780779
TN_count = TN.count()
781780

782-
print(
783-
f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}"
784-
)
781+
# print(
782+
# f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}"
783+
# )
785784

786785
if otd_data:
787786
id_a = otd_data["id_a"]
@@ -829,7 +828,7 @@ def _get_aggregate_metrics(
829828
else:
830829
recall = TP_count / (TP_count + FN_count)
831830
mcc = _calc_mcc(TP_count, TN_count, FP_count, FN_count)
832-
print(f"XX Aggregates precision {precision} recall {recall}")
831+
# print(f"XX Aggregates precision {precision} recall {recall}")
833832
return precision, recall, mcc
834833

835834

0 commit comments

Comments
 (0)