Skip to content

Commit

Permalink
[#162] Add two training tests for lightgbm
Browse files Browse the repository at this point in the history
One of these is failing because there's a bug where LightGBM throws an error on
interacted features.
  • Loading branch information
riley-harper committed Nov 19, 2024
1 parent 1aef721 commit 72fd83c
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions hlink/tests/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pyspark.ml import Pipeline
import hlink.linking.core.pipeline as pipeline_core
from hlink.tests.markers import requires_lightgbm


@pytest.mark.quickcheck
Expand Down Expand Up @@ -432,6 +433,111 @@ def test_step_3_with_probit_model(
)


@requires_lightgbm
def test_step_3_with_lightgbm_model(
spark, training, training_conf, datasource_training_input, state_dist_path
):
training_data_path, prepped_df_a_path, prepped_df_b_path = datasource_training_input
training_conf["comparison_features"] = [
{
"alias": "regionf",
"column_name": "region",
"comparison_type": "fetch_a",
"categorical": True,
},
{
"alias": "namelast_jw",
"column_name": "namelast",
"comparison_type": "jaro_winkler",
},
{
"alias": "state_distance",
"key_count": 1,
"column_name": "bpl",
"comparison_type": "geo_distance",
"loc_a": "statecode1",
"loc_b": "statecode2",
"distance_col": "dist",
"table_name": "state_distances_lookup",
"distances_file": state_dist_path,
},
]
training_conf["training"]["dataset"] = training_data_path
training_conf["training"]["dependent_var"] = "match"
training_conf["training"]["independent_vars"] = [
"namelast_jw",
"regionf",
"state_distance",
]
training_conf["training"]["chosen_model"] = {
"type": "lightgbm",
"maxDepth": 7,
"numIterations": 5,
"minDataInLeaf": 1,
"threshold": 0.5,
}
training_conf["training"]["score_with_model"] = True
training_conf["training"]["feature_importances"] = True

prepped_df_a = spark.read.csv(prepped_df_a_path, header=True, inferSchema=True)
prepped_df_b = spark.read.csv(prepped_df_b_path, header=True, inferSchema=True)

prepped_df_a.write.mode("overwrite").saveAsTable("prepped_df_a")
prepped_df_b.write.mode("overwrite").saveAsTable("prepped_df_b")

training.run_all_steps()


@requires_lightgbm
def test_lightgbm_with_interacted_features(
spark, training, training_conf, datasource_training_input
):
training_data_path, prepped_df_a_path, prepped_df_b_path = datasource_training_input
training_conf["comparison_features"] = [
{
"alias": "regionf",
"column_name": "region",
"comparison_type": "fetch_a",
"categorical": True,
},
{
"alias": "namelast_jw",
"column_name": "namelast",
"comparison_type": "jaro_winkler",
},
]
training_conf["pipeline_features"] = [
{
"input_columns": ["regionf", "namelast_jw"],
"output_column": "regionf_interacted_namelast_jw",
"transformer_type": "interaction",
}
]
training_conf["training"]["dataset"] = training_data_path
training_conf["training"]["dependent_var"] = "match"
training_conf["training"]["independent_vars"] = [
"namelast_jw",
"regionf",
"regionf_interacted_namelast_jw",
]
training_conf["training"]["chosen_model"] = {
"type": "lightgbm",
"maxDepth": 7,
"numIterations": 5,
"minDataInLeaf": 1,
"threshold": 0.5,
}
training_conf["training"]["score_with_model"] = True
training_conf["training"]["feature_importances"] = True
prepped_df_a = spark.read.csv(prepped_df_a_path, header=True, inferSchema=True)
prepped_df_b = spark.read.csv(prepped_df_b_path, header=True, inferSchema=True)

prepped_df_a.write.mode("overwrite").saveAsTable("prepped_df_a")
prepped_df_b.write.mode("overwrite").saveAsTable("prepped_df_b")

training.run_all_steps()


def test_step_3_requires_table(training_conf, training):
training_conf["training"]["feature_importances"] = True
with pytest.raises(RuntimeError, match="Missing input tables"):
Expand Down

0 comments on commit 72fd83c

Please sign in to comment.