Skip to content

Commit

Permalink
[#161] Add an integration test for xgboost, set the post-transformer
Browse files Browse the repository at this point in the history
Like some of the other models, xgboost returns an array of probabilities like
[probability_no, probability_yes]. So we extract just probability_yes as our
probability for hlink purposes.
  • Loading branch information
riley-harper committed Nov 15, 2024
1 parent c64cf43 commit 88d7199
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
4 changes: 4 additions & 0 deletions hlink/linking/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def choose_classifier(model_type, params, dep_var):
**params_without_threshold,
features_col=features_vector,
label_col=dep_var,
probability_col="probability_array",
)
post_transformer = SQLTransformer(
statement="SELECT *, parseProbVector(probability_array, 1) as probability FROM __THIS__"
)
else:
raise ValueError(
Expand Down
96 changes: 96 additions & 0 deletions hlink/tests/integration_score_with_trained_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from hlink.tests.markers import requires_xgboost


def test_apply_chosen_model_RF(
spark,
Expand Down Expand Up @@ -859,6 +861,100 @@ def test_step_3_apply_chosen_model_boosted_trees(
)


@requires_xgboost
def test_apply_chosen_model_xgboost(
spark,
training,
matching,
training_conf,
datasource_training_input,
potential_matches_path,
state_dist_path,
spark_test_tmp_dir_path,
):
training_data_path, prepped_df_a_path, prepped_df_b_path = datasource_training_input
training_conf["comparison_features"] = [
{
"alias": "regionf",
"column_name": "region",
"comparison_type": "fetch_a",
"categorical": True,
},
{
"alias": "namelast_jw",
"column_name": "namelast",
"comparison_type": "jaro_winkler",
},
{
"alias": "state_distance",
"key_count": 1,
"column_name": "bpl",
"comparison_type": "geo_distance",
"loc_a": "statecode1",
"loc_b": "statecode2",
"distance_col": "dist",
"table_name": "state_distances_lookup",
"distances_file": state_dist_path,
},
]

training_conf["training"]["dataset"] = training_data_path
training_conf["training"]["dependent_var"] = "match"
training_conf["training"]["independent_vars"] = [
"namelast_jw",
"regionf",
"state_distance",
]
training_conf["training"]["chosen_model"] = {
"type": "xgboost",
"max_depth": 5,
"eta": 0.5,
"threshold": 0.5,
"threshold_ratio": 1.3,
}
training_conf["training"]["score_with_model"] = True
training_conf["spark_tmp_dir"] = spark_test_tmp_dir_path
training_conf["drop_data_from_scored_matches"] = True

prepped_df_a = spark.read.csv(prepped_df_a_path, header=True, inferSchema=True)
prepped_df_b = spark.read.csv(prepped_df_b_path, header=True, inferSchema=True)
potential_matches = spark.read.csv(
potential_matches_path, header=True, inferSchema=True
)
prepped_df_a.write.mode("overwrite").saveAsTable("prepped_df_a")
prepped_df_b.write.mode("overwrite").saveAsTable("prepped_df_b")
potential_matches.write.mode("overwrite").saveAsTable("potential_matches")

training.run_all_steps()
matching.run_step(2)

potential_matches_df = spark.table("scored_potential_matches").toPandas()

# Check one case that we expect to be a match and one case that we expect not
# to be a match.
should_be_match = potential_matches_df.query(
"id_a == '0202928A-AC3E-48BB-8568-3372067F35C7'"
)
assert (
should_be_match.shape[0] == 1
), "expected exactly one potential match for 0202928A"
assert should_be_match["probability"].iloc[0] >= 0.5
assert should_be_match["prediction"].iloc[0] == 1

# In the real world, this would probably be a match, depending on how the
# additional features looked. But we've included so few training features
# for our test model that small differences in names can really hurt a
# potential match's chances of being classified as a match.
should_not_be_match = potential_matches_df.query(
"id_b == '033FD0FA-C523-42B5-976A-751E830F7021'"
)
assert (
should_not_be_match.shape[0] == 1
), "expected exactly one potential match for 033FD0FA"
assert should_not_be_match["probability"].iloc[0] <= 0.5
assert should_not_be_match["prediction"].iloc[0] == 0


def test_step_3_apply_chosen_model_RF_threshold(
spark,
training_conf,
Expand Down

0 comments on commit 88d7199

Please sign in to comment.