From 9755f73c3f95557a765e599ff6b2f6ae831dd81d Mon Sep 17 00:00:00 2001 From: rileyh Date: Tue, 10 Dec 2024 14:08:12 -0600 Subject: [PATCH] [#176] Add a unit test for _get_confusion_matrix() --- hlink/tests/model_exploration_test.py | 32 ++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/hlink/tests/model_exploration_test.py b/hlink/tests/model_exploration_test.py index 46166c5..7414ef4 100644 --- a/hlink/tests/model_exploration_test.py +++ b/hlink/tests/model_exploration_test.py @@ -3,14 +3,16 @@ # https://github.com/ipums/hlink from collections import Counter -import pytest import pandas as pd +from pyspark.sql import SparkSession +import pytest import hlink.linking.core.threshold as threshold_core from hlink.linking.model_exploration.link_step_train_test_models import ( LinkStepTrainTestModels, _custom_param_grid_builder, _get_model_parameters, + _get_confusion_matrix, ) @@ -985,3 +987,31 @@ def test_step_2_split_by_id_a( assert splits[1][1].toPandas()["id_a"].unique().tolist() == ["30"] main.do_drop_all("") + + +def test_get_confusion_matrix(spark: SparkSession) -> None: + # 1 true negative (0, 0) + # 2 false negatives (1, 0) + # 3 false postives (0, 1) + # 4 true positives (1, 1) + rows = [ + (0, 0), + (1, 0), + (0, 1), + (1, 0), + (0, 1), + (1, 1), + (0, 1), + (1, 1), + (1, 1), + (1, 1), + ] + predictions = spark.createDataFrame(rows, schema=["match", "prediction"]) + true_positives, false_positives, false_negatives, true_negatives = ( + _get_confusion_matrix(predictions, "match") + ) + + assert true_positives == 4 + assert false_positives == 3 + assert false_negatives == 2 + assert true_negatives == 1