diff --git a/hlink/tests/core/classifier_test.py b/hlink/tests/core/classifier_test.py new file mode 100644 index 0000000..1010262 --- /dev/null +++ b/hlink/tests/core/classifier_test.py @@ -0,0 +1,23 @@ +import pytest + +from hlink.linking.core.classifier import choose_classifier + +try: + import xgboost +except ModuleNotFoundError: + xgboost_available = False +else: + xgboost_available = True + +@pytest.mark.skipif(not xgboost_available, reason="requires the xgboost library") +def test_choose_classifier_supports_xgboost(): + """ + If the xgboost module is installed, then choose_classifier() supports a model + type of "xgboost". + """ + params = { + "max_depth": 2, + "eta": 0.5, + } + classifier = choose_classifier("xgboost", params, "match") + assert classifier.getLabelCol() == "match"