Skip to content

Commit

Permalink
[#161] Add a test that runs the whole training task with an xgboost m…
Browse files Browse the repository at this point in the history
…odel

This test is failing right now because we also need pyarrow>=4 when using
xgboost. We should add this as a dependency in the xgboost extra. If xgboost
isn't installed, this test skips itself.
  • Loading branch information
riley-harper committed Nov 15, 2024
1 parent 287912e commit a7b0c37
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions hlink/tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from pyspark.ml import Pipeline
import hlink.linking.core.pipeline as pipeline_core

try:
import xgboost # noqa: F401
except ModuleNotFoundError:
xgboost_available = False
else:
xgboost_available = True


@pytest.mark.quickcheck
def test_all_steps(
Expand Down Expand Up @@ -432,6 +439,50 @@ def test_step_3_with_probit_model(
)


@pytest.mark.skipif(not xgboost_available, reason="requires the xgboost library")
def test_step_3_with_xgboost_model(
spark, training, training_conf, datasource_training_input
):
training_data_path, prepped_df_a_path, prepped_df_b_path = datasource_training_input
training_conf["comparison_features"] = [
{
"alias": "regionf",
"column_name": "region",
"comparison_type": "fetch_a",
"categorical": True,
},
{
"alias": "namelast_jw",
"column_name": "namelast",
"comparison_type": "jaro_winkler",
},
]
training_conf["training"]["dataset"] = training_data_path
training_conf["training"]["dependent_var"] = "match"
training_conf["training"]["independent_vars"] = ["namelast_jw", "regionf"]
training_conf["training"]["chosen_model"] = {
"type": "xgboost",
"max_depth": 2,
"eta": 0.5,
"threshold": 0.7,
"threshold_ratio": 1.3,
}
training_conf["training"]["score_with_model"] = True
training_conf["training"]["feature_importances"] = True

spark.read.csv(prepped_df_a_path, header=True, inferSchema=True).write.mode(
"overwrite"
).saveAsTable("prepped_df_a")
spark.read.csv(prepped_df_b_path, header=True, inferSchema=True).write.mode(
"overwrite"
).saveAsTable("prepped_df_b")

training.run_step(0)
training.run_step(1)
training.run_step(2)
training.run_step(3)


def test_step_3_requires_table(training_conf, training):
training_conf["training"]["feature_importances"] = True
with pytest.raises(RuntimeError, match="Missing input tables"):
Expand Down

0 comments on commit a7b0c37

Please sign in to comment.