diff --git a/hlink/linking/core/classifier.py b/hlink/linking/core/classifier.py index d9543ed..2acd2c4 100644 --- a/hlink/linking/core/classifier.py +++ b/hlink/linking/core/classifier.py @@ -3,6 +3,8 @@ # in this project's top-level directory, and also on-line at: # https://github.com/ipums/hlink +from typing import Any + from pyspark.ml.feature import SQLTransformer from pyspark.ml.regression import GeneralizedLinearRegression from pyspark.ml.classification import ( @@ -28,22 +30,32 @@ _xgboost_available = True -def choose_classifier(model_type, params, dep_var): - """Returns a classifier and a post_classification transformer given model type and params. +def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str): + """Given a model type and hyper-parameters for the model, return a + classifier of that type with those hyper-parameters, along with a + post-classification transformer to run after classification. + + The post-classification transformer standardizes the output of the + classifier for further processing. For example, some classifiers create + models that output a probability array of [P(dep_var=0), P(dep_var=1)], and + the post-classification transformer extracts the single float P(dep_var=1) + as the probability for these models. Parameters ---------- - model_type: string - name of model - params: dictionary - dictionary of parameters for model - dep_var: string - the dependent variable for the model + model_type + the type of model, which may be random_forest, probit, + logistic_regression, decision_tree, gradient_boosted_trees, lightgbm + (requires the 'lightgbm' extra), or xgboost (requires the 'xgboost' + extra) + params + a dictionary of hyper-parameters for the model + dep_var + the dependent variable for the model, sometimes also called the "label" Returns ------- - The classifer and a transformer to be used after classification. - + The classifier and a transformer to be used after classification, as a tuple. """ post_transformer = SQLTransformer(statement="SELECT * FROM __THIS__") features_vector = "features_vector"