-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor nested cross validation #169
Conversation
… and test all threshold matrix members against that set of params. Still has a failure.
…oesn't give good results making no matches in the test data, so precision is NaN.
…s given to the thresholding eval.
…split used to test all thresholds isn't a good one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me for the most part. It's a big change from how we were doing it before! I requested some changes, but most are on the smaller side. The broad algorithm looks great to me.
I just found the pyspark.ml.tuning
module, and I suspect that we can make use of its CrossValidator
, ParamGridBuilder
, and metrics here. However, our algorithm is a little different than what the Spark logic does. We may only be able to replace the inner cross-validation with CrossValidator
. I'm not sure. CrossValidator
has a parallelism: int
argument which can help you increase the parallelism of testing. This could be really helpful for speeding things up.
Also, we are doing a lot of work to predict on the training data and compute the metrics on the training data. But I don't think that we need to do that anymore. That was a feature of the previous algorithm. With nested cross-validation I think that we just want to compute metrics on the test data. This could also speed things up significantly (I would guess by around a factor of 2, maybe more).
I did not look at the tests, since it sounds like we are really going to need to rework those.
score: float | ||
hyperparams: dict[str, Any] | ||
threshold: float | list[float] | ||
threshold_ratio: float | list[float] | bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it's a bug to store threshold_ratio
as a bool
. It's an optional float/list of floats, so I think that we should store it as float | list[float] | None
. The code that extracts it out of the config file shouldn't make it default to False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to set None .
predict_train_tmp = _get_probability_and_select_pred_columns( | ||
training_data, model, post_transformer, id_a, id_b, dep_var | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to retain this logic? Since we're doing nested cross-validation, I think that we should avoid predicting on the training data. We can just predict on the test data. This may save us a significant amount of work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
|
||
test_pred = predictions_tmp.toPandas() | ||
precision, recall, thresholds_raw = precision_recall_curve( | ||
test_pred[f"{dep_var}"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure that dep_var
is a str
everywhere, so we can just write
test_pred[dep_var]
to simplify things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
config, | ||
training_conf, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These names are pretty confusing. Maybe we can rename them to make it clear that one is the config dictionary and one is the name of the training config section. Or maybe we could just pass the training part of the dictionary to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did both; consolidate the passing of config + training_conf (name) to the functions and instead pull out the training dict and calling it "training_settings". This simplifies accessing the training settings in places to be a bit cleaner and the names are more understandable. Really we need to remove the reliance on the config structure all over the place but that's a bigger change.
# thresholds and model_type are mixed in with the model hyper-parameters | ||
# in the config; this removes them before passing to the model training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great comments, thanks for adding these!
thresholding_predict_train = _get_probability_and_select_pred_columns( | ||
cached_training_data, | ||
thresholding_model, | ||
thresholding_post_transformer, | ||
id_a, | ||
id_b, | ||
dep_var, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that we can drop this logic since it's computing metrics on the training data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can at the least add a flag to not do metrics on the training data and separate the capture of the test results and training data results. Leaving it alone for now until we refactor this part.
predict_train = threshold_core.predict_using_thresholds( | ||
thresholding_predict_train, | ||
this_alpha_threshold, | ||
this_threshold_ratio, | ||
config[training_conf], | ||
config["id_column"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eliminate this since it's on the training data.
outer_fold_count = config[training_conf].get("n_training_iterations", 10) | ||
inner_fold_count = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to rename n_training_iterations
to num_outer_folds
and add a num_inner_folds
attribute to the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely. Let's make that another PR soon.
if outer_fold_count < 3: | ||
raise RuntimeError("You must use at least two training iterations.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message and if statement don't seem to line up here. Do you need at least 2 or at least 3 iterations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's three, not two. Fixed.
@@ -429,7 +830,7 @@ def _save_otd_data( | |||
print("There were no true negatives recorded.") | |||
|
|||
def _create_otd_data(self, id_a: str, id_b: str) -> dict[str, Any] | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good opportunity to rename the "OTD data" to something easier to understand. Maybe "suspicious data" would be clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I replaced "otd" with "suspicious".
Allow manually running CI/CD via on: workflow_dispatch
Run CI/CD on all PRs, not just PRs to main
link_step_train_test_models
to the "nested cross validation" approach._run
method on LinkStepTrainTestModels has been refactored to only do the setup for the run and then the main algo, with the rest factored into other methods. We should probably further refactor so the cross-validation behavior is in its own pure class or module and not rely on class 'self' instance data so much. Most functions only need config information.cache()
andunpersist()
on Spark data frames is likely not optimal yet.NOTE: Currently reporting of the final results isn't totally finished; we get one data frame for each outer fold with all results of every threshold combination in each. These outer-fold results still need to be merged to give a true picture of how each threshold combination produces matches. The old algorithm computed precision mean and recall mean and MCC means in a way that doesn't make sense for the new algorithm.
At this stage I'm primarily concerned with performance.