Skip to content

Commit

Permalink
Merge pull request #155 from ipums/model_exploration_output
Browse files Browse the repository at this point in the history
Improve model_exploration step 2 output
  • Loading branch information
riley-harper authored Oct 10, 2024
2 parents 69307e0 + 4933368 commit da0b4fc
Showing 1 changed file with 121 additions and 44 deletions.
165 changes: 121 additions & 44 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
# https://github.com/ipums/hlink

import itertools
import logging
import math
import re
from time import perf_counter
from typing import Any
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve, auc
from pyspark.ml import Model, Transformer
import pyspark.sql
from pyspark.sql.functions import count, mean

import hlink.linking.core.threshold as threshold_core
import hlink.linking.core.classifier as classifier_core

from hlink.linking.link_step import LinkStep

logger = logging.getLogger(__name__)


class LinkStepTrainTestModels(LinkStep):
def __init__(self, task):
def __init__(self, task) -> None:
super().__init__(
task,
"train test models",
Expand All @@ -35,7 +42,7 @@ def __init__(self, task):
],
)

def _run(self):
def _run(self) -> None:
training_conf = str(self.task.training_conf)
table_prefix = self.task.table_prefix
config = self.task.link_run.config
Expand All @@ -61,7 +68,15 @@ def _run(self):
splits = self._get_splits(prepped_data, id_a, n_training_iterations, seed)

model_parameters = self._get_model_parameters(config)
for run in model_parameters:

logger.info(
f"There are {len(model_parameters)} sets of model parameters to explore; "
f"each of these has {n_training_iterations} train-test splits to test on"
)
for run_index, run in enumerate(model_parameters, 1):
run_start_info = f"Starting run {run_index} of {len(model_parameters)} with these parameters: {run}"
print(run_start_info)
logger.info(run_start_info)
params = run.copy()
model_type = params.pop("type")

Expand All @@ -80,20 +95,31 @@ def _run(self):
threshold_ratio = False

threshold_matrix = _calc_threshold_matrix(alpha_threshold, threshold_ratio)
results_dfs = {}
logger.debug(f"The threshold matrix has {len(threshold_matrix)} entries")

results_dfs: dict[int, pd.DataFrame] = {}
for i in range(len(threshold_matrix)):
results_dfs[i] = _create_results_df()

first = True
for training_data, test_data in splits:
for split_index, (training_data, test_data) in enumerate(splits, 1):
split_start_info = f"Training and testing the model on train-test split {split_index} of {n_training_iterations}"
print(split_start_info)
logger.debug(split_start_info)
training_data.cache()
test_data.cache()

classifier, post_transformer = classifier_core.choose_classifier(
model_type, params, dep_var
)

logger.debug("Training the model on the training data split")
start_train_time = perf_counter()
model = classifier.fit(training_data)
end_train_time = perf_counter()
logger.debug(
f"Successfully trained the model in {end_train_time - start_train_time:.2f}s"
)

predictions_tmp = _get_probability_and_select_pred_columns(
test_data, model, post_transformer, id_a, id_b, dep_var
Expand All @@ -113,7 +139,7 @@ def _run(self):
param_text = np.full(precision.shape, f"{model_type}_{params}")

pr_auc = auc(recall, precision)
print(f"Area under PR curve: {pr_auc}")
print(f"The area under the precision-recall curve is {pr_auc}")

if first:
prc = pd.DataFrame(
Expand All @@ -134,18 +160,24 @@ def _run(self):
first = False

i = 0
for at, tr in threshold_matrix:
for threshold_index, (alpha_threshold, threshold_ratio) in enumerate(
threshold_matrix, 1
):
logger.debug(
f"Predicting with threshold matrix entry {threshold_index} of {len(threshold_matrix)}: "
f"{alpha_threshold=} and {threshold_ratio=}"
)
predictions = threshold_core.predict_using_thresholds(
predictions_tmp,
at,
tr,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
predict_train = threshold_core.predict_using_thresholds(
predict_train_tmp,
at,
tr,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
)
Expand All @@ -157,8 +189,8 @@ def _run(self):
model,
results_dfs[i],
otd_data,
at,
tr,
alpha_threshold,
threshold_ratio,
pr_auc,
)
i += 1
Expand All @@ -175,7 +207,19 @@ def _run(self):
self._save_otd_data(otd_data, self.task.spark)
self.task.spark.sql("set spark.sql.shuffle.partitions=200")

def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):
def _get_splits(
self,
prepped_data: pyspark.sql.DataFrame,
id_a: str,
n_training_iterations: int,
seed: int,
) -> list[list[pyspark.sql.DataFrame]]:
"""
Get a list of random splits of the prepped_data into two DataFrames.
There are n_training_iterations elements in the list. Each element is
itself a list of two DataFrames which are the splits of prepped_data.
The split DataFrames are roughly equal in size.
"""
if self.task.link_run.config[f"{self.task.training_conf}"].get(
"split_by_id_a", False
):
Expand All @@ -200,7 +244,7 @@ def _get_splits(self, prepped_data, id_a, n_training_iterations, seed):

return splits

def _custom_param_grid_builder(self, conf):
def _custom_param_grid_builder(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
print("Building param grid for models")
given_parameters = conf[f"{self.task.training_conf}"]["model_parameters"]
new_params = []
Expand Down Expand Up @@ -231,19 +275,18 @@ def _custom_param_grid_builder(self, conf):

def _capture_results(
self,
predictions,
predict_train,
dep_var,
model,
results_df,
otd_data,
at,
tr,
pr_auc,
):
predictions: pyspark.sql.DataFrame,
predict_train: pyspark.sql.DataFrame,
dep_var: str,
model: Model,
results_df: pd.DataFrame,
otd_data: dict[str, Any] | None,
alpha_threshold: float,
threshold_ratio: float,
pr_auc: float,
) -> pd.DataFrame:
table_prefix = self.task.table_prefix

print("Evaluating model performance...")
# write to sql tables for testing
predictions.createOrReplaceTempView(f"{table_prefix}predictions")
predict_train.createOrReplaceTempView(f"{table_prefix}predict_train")
Expand Down Expand Up @@ -278,13 +321,13 @@ def _capture_results(
"test_mcc": [test_mcc],
"train_mcc": [train_mcc],
"model_id": [model],
"alpha_threshold": [at],
"threshold_ratio": [tr],
"alpha_threshold": [alpha_threshold],
"threshold_ratio": [threshold_ratio],
},
)
return pd.concat([results_df, new_results], ignore_index=True)

def _get_model_parameters(self, conf):
def _get_model_parameters(self, conf: dict[str, Any]) -> list[dict[str, Any]]:
training_conf = str(self.task.training_conf)

model_parameters = conf[training_conf]["model_parameters"]
Expand All @@ -296,7 +339,9 @@ def _get_model_parameters(self, conf):
)
return model_parameters

def _save_training_results(self, desc_df, spark):
def _save_training_results(
self, desc_df: pd.DataFrame, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if desc_df.empty:
Expand All @@ -310,7 +355,9 @@ def _save_training_results(self, desc_df, spark):
f"Training results saved to Spark table '{table_prefix}training_results'."
)

def _prepare_otd_table(self, spark, df, id_a, id_b):
def _prepare_otd_table(
self, spark: pyspark.sql.SparkSession, df: pd.DataFrame, id_a: str, id_b: str
) -> pyspark.sql.DataFrame:
spark_df = spark.createDataFrame(df)
counted = (
spark_df.groupby(id_a, id_b)
Expand All @@ -323,7 +370,9 @@ def _prepare_otd_table(self, spark, df, id_a, id_b):
)
return counted

def _save_otd_data(self, otd_data, spark):
def _save_otd_data(
self, otd_data: dict[str, Any] | None, spark: pyspark.sql.SparkSession
) -> None:
table_prefix = self.task.table_prefix

if otd_data is None:
Expand Down Expand Up @@ -379,7 +428,7 @@ def _save_otd_data(self, otd_data, spark):
else:
print("There were no true negatives recorded.")

def _create_otd_data(self, id_a, id_b):
def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None:
"""Output Suspicous Data (OTD): used to check config to see if you should find sketchy training data that the models routinely mis-classify"""
training_conf = str(self.task.training_conf)
config = self.task.link_run.config
Expand All @@ -400,7 +449,12 @@ def _create_otd_data(self, id_a, id_b):
return None


def _calc_mcc(TP, TN, FP, FN):
def _calc_mcc(TP: int, TN: int, FP: int, FN: int) -> float:
"""
Given the counts of true positives (TP), true negatives (TN), false
positives (FP), and false negatives (FN) for a model run, compute the
Matthews Correlation Coefficient (MCC).
"""
if (math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) != 0:
mcc = ((TP * TN) - (FP * FN)) / (
math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
Expand All @@ -410,7 +464,9 @@ def _calc_mcc(TP, TN, FP, FN):
return mcc


def _calc_threshold_matrix(alpha_threshold, threshold_ratio):
def _calc_threshold_matrix(
alpha_threshold: float | list[float], threshold_ratio: float | list[float]
) -> list[list[float]]:
if alpha_threshold and type(alpha_threshold) != list:
alpha_threshold = [alpha_threshold]

Expand All @@ -426,8 +482,13 @@ def _calc_threshold_matrix(alpha_threshold, threshold_ratio):


def _get_probability_and_select_pred_columns(
pred_df, model, post_transformer, id_a, id_b, dep_var
):
pred_df: pyspark.sql.DataFrame,
model: Model,
post_transformer: Transformer,
id_a: str,
id_b: str,
dep_var: str,
) -> pyspark.sql.DataFrame:
all_prediction_cols = set(
[
f"{id_a}",
Expand All @@ -446,7 +507,9 @@ def _get_probability_and_select_pred_columns(
return required_col_df


def _get_confusion_matrix(predictions, dep_var, otd_data):
def _get_confusion_matrix(
predictions: pyspark.sql.DataFrame, dep_var: str, otd_data: dict[str, Any] | None
) -> tuple[int, int, int, int]:
TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
TP_count = TP.count()

Expand Down Expand Up @@ -486,7 +549,16 @@ def _get_confusion_matrix(predictions, dep_var, otd_data):
return TP_count, FP_count, FN_count, TN_count


def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
def _get_aggregate_metrics(
TP_count: int, FP_count: int, FN_count: int, TN_count: int
) -> tuple[float, float, float]:
"""
Given the counts of true positives, false positivies, false negatives, and
true negatives for a model run, compute several metrics to evaluate the
model's quality.
Return a tuple of (precision, recall, Matthews Correlation Coefficient).
"""
if (TP_count + FP_count) == 0:
precision = np.nan
else:
Expand All @@ -499,7 +571,7 @@ def _get_aggregate_metrics(TP_count, FP_count, FN_count, TN_count):
return precision, recall, mcc


def _create_results_df():
def _create_results_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"precision_test",
Expand All @@ -516,7 +588,12 @@ def _create_results_df():
)


def _append_results(desc_df, results_df, model_type, params):
def _append_results(
desc_df: pd.DataFrame,
results_df: pd.DataFrame,
model_type: str,
params: dict[str, Any],
) -> pd.DataFrame:
# run.pop("type")
print(results_df)

Expand Down Expand Up @@ -548,7 +625,7 @@ def _append_results(desc_df, results_df, model_type, params):
return desc_df


def _print_desc_df(desc_df):
def _print_desc_df(desc_df: pd.DataFrame) -> None:
pd.set_option("display.max_colwidth", None)
print(
desc_df.drop(
Expand All @@ -564,7 +641,7 @@ def _print_desc_df(desc_df):
print("\n")


def _load_desc_df_params(desc_df):
def _load_desc_df_params(desc_df: pd.DataFrame) -> pd.DataFrame:
params = [
"maxDepth",
"numTrees",
Expand All @@ -591,7 +668,7 @@ def _load_desc_df_params(desc_df):
return desc_df


def _create_desc_df():
def _create_desc_df() -> pd.DataFrame:
return pd.DataFrame(
columns=[
"model",
Expand Down

0 comments on commit da0b4fc

Please sign in to comment.