From a7b0c37f164ea1af76ebf1a8881c3738982cc2ce Mon Sep 17 00:00:00 2001 From: rileyh Date: Fri, 15 Nov 2024 09:11:29 -0600 Subject: [PATCH] [#161] Add a test that runs the whole training task with an xgboost model This test is failing right now because we also need pyarrow>=4 when using xgboost. We should add this as a dependency in the xgboost extra. If xgboost isn't installed, this test skips itself. --- hlink/tests/training_test.py | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/hlink/tests/training_test.py b/hlink/tests/training_test.py index 0fbdb0a..94be4f4 100644 --- a/hlink/tests/training_test.py +++ b/hlink/tests/training_test.py @@ -7,6 +7,13 @@ from pyspark.ml import Pipeline import hlink.linking.core.pipeline as pipeline_core +try: + import xgboost # noqa: F401 +except ModuleNotFoundError: + xgboost_available = False +else: + xgboost_available = True + @pytest.mark.quickcheck def test_all_steps( @@ -432,6 +439,50 @@ def test_step_3_with_probit_model( ) +@pytest.mark.skipif(not xgboost_available, reason="requires the xgboost library") +def test_step_3_with_xgboost_model( + spark, training, training_conf, datasource_training_input +): + training_data_path, prepped_df_a_path, prepped_df_b_path = datasource_training_input + training_conf["comparison_features"] = [ + { + "alias": "regionf", + "column_name": "region", + "comparison_type": "fetch_a", + "categorical": True, + }, + { + "alias": "namelast_jw", + "column_name": "namelast", + "comparison_type": "jaro_winkler", + }, + ] + training_conf["training"]["dataset"] = training_data_path + training_conf["training"]["dependent_var"] = "match" + training_conf["training"]["independent_vars"] = ["namelast_jw", "regionf"] + training_conf["training"]["chosen_model"] = { + "type": "xgboost", + "max_depth": 2, + "eta": 0.5, + "threshold": 0.7, + "threshold_ratio": 1.3, + } + training_conf["training"]["score_with_model"] = True + training_conf["training"]["feature_importances"] = True + + spark.read.csv(prepped_df_a_path, header=True, inferSchema=True).write.mode( + "overwrite" + ).saveAsTable("prepped_df_a") + spark.read.csv(prepped_df_b_path, header=True, inferSchema=True).write.mode( + "overwrite" + ).saveAsTable("prepped_df_b") + + training.run_step(0) + training.run_step(1) + training.run_step(2) + training.run_step(3) + + def test_step_3_requires_table(training_conf, training): training_conf["training"]["feature_importances"] = True with pytest.raises(RuntimeError, match="Missing input tables"):