Skip to content

Commit

Permalink
[#172] Add type hints and docs to linking.core.classifier
Browse files Browse the repository at this point in the history
The output type of choose_classifier() is really hard to write down
precisely because of the way PySpark types are set up. It's something
like tuple["Classifier", "Transformer"], but for some reason
SQLTransformer is not a subtype of Transformer.
  • Loading branch information
riley-harper committed Dec 5, 2024
1 parent 9542800 commit e57dad6
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions hlink/linking/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from typing import Any

from pyspark.ml.feature import SQLTransformer
from pyspark.ml.regression import GeneralizedLinearRegression
from pyspark.ml.classification import (
Expand All @@ -28,22 +30,32 @@
_xgboost_available = True


def choose_classifier(model_type, params, dep_var):
"""Returns a classifier and a post_classification transformer given model type and params.
def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"""Given a model type and hyper-parameters for the model, return a
classifier of that type with those hyper-parameters, along with a
post-classification transformer to run after classification.
The post-classification transformer standardizes the output of the
classifier for further processing. For example, some classifiers create
models that output a probability array of [P(dep_var=0), P(dep_var=1)], and
the post-classification transformer extracts the single float P(dep_var=1)
as the probability for these models.
Parameters
----------
model_type: string
name of model
params: dictionary
dictionary of parameters for model
dep_var: string
the dependent variable for the model
model_type
the type of model, which may be random_forest, probit,
logistic_regression, decision_tree, gradient_boosted_trees, lightgbm
(requires the 'lightgbm' extra), or xgboost (requires the 'xgboost'
extra)
params
a dictionary of hyper-parameters for the model
dep_var
the dependent variable for the model, sometimes also called the "label"
Returns
-------
The classifer and a transformer to be used after classification.
The classifier and a transformer to be used after classification, as a tuple.
"""
post_transformer = SQLTransformer(statement="SELECT * FROM __THIS__")
features_vector = "features_vector"
Expand Down

0 comments on commit e57dad6

Please sign in to comment.