Skip to content

Commit

Permalink
[#161] Create a SparkXGBClassifier in choose_classifier() for model_t…
Browse files Browse the repository at this point in the history
…ype xgboost

This is only possible when we have the xgboost module, so raise an error if
that is not present.
  • Loading branch information
riley-harper committed Nov 14, 2024
1 parent a865825 commit 287912e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion hlink/linking/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
)
import hlink.linking.transformers.rename_prob_column

try:
import xgboost.spark
except ModuleNotFoundError:
_xgboost_available = False
else:
_xgboost_available = True


def choose_classifier(model_type, params, dep_var):
"""Returns a classifier and a post_classification transformer given model type and params.
Expand Down Expand Up @@ -96,7 +103,21 @@ def choose_classifier(model_type, params, dep_var):
post_transformer = (
hlink.linking.transformers.rename_prob_column.RenameProbColumn()
)

elif model_type == "xgboost":
if not _xgboost_available:
raise ModuleNotFoundError(
"model_type 'xgboost' requires the xgboost library"
)
params_without_threshold = {
key: val
for key, val in params.items()
if key not in {"threshold", "threshold_ratio"}
}
classifier = xgboost.spark.SparkXGBClassifier(
**params_without_threshold,
features_col=features_vector,
label_col=dep_var,
)
else:
raise ValueError(
"Model type not recognized! Please check your config, reload, and try again."
Expand Down
2 changes: 1 addition & 1 deletion hlink/tests/core/classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def test_choose_classifier_supports_xgboost():
"max_depth": 2,
"eta": 0.5,
}
classifier = choose_classifier("xgboost", params, "match")
classifier, _post_transformer = choose_classifier("xgboost", params, "match")
assert classifier.getLabelCol() == "match"

0 comments on commit 287912e

Please sign in to comment.