Skip to content

Commit

Permalink
[#176] Add a unit test for _get_confusion_matrix()
Browse files Browse the repository at this point in the history
  • Loading branch information
riley-harper committed Dec 10, 2024
1 parent b7f821c commit 9755f73
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion hlink/tests/model_exploration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# https://github.com/ipums/hlink
from collections import Counter

import pytest
import pandas as pd
from pyspark.sql import SparkSession
import pytest

import hlink.linking.core.threshold as threshold_core
from hlink.linking.model_exploration.link_step_train_test_models import (
LinkStepTrainTestModels,
_custom_param_grid_builder,
_get_model_parameters,
_get_confusion_matrix,
)


Expand Down Expand Up @@ -985,3 +987,31 @@ def test_step_2_split_by_id_a(
assert splits[1][1].toPandas()["id_a"].unique().tolist() == ["30"]

main.do_drop_all("")


def test_get_confusion_matrix(spark: SparkSession) -> None:
# 1 true negative (0, 0)
# 2 false negatives (1, 0)
# 3 false postives (0, 1)
# 4 true positives (1, 1)
rows = [
(0, 0),
(1, 0),
(0, 1),
(1, 0),
(0, 1),
(1, 1),
(0, 1),
(1, 1),
(1, 1),
(1, 1),
]
predictions = spark.createDataFrame(rows, schema=["match", "prediction"])
true_positives, false_positives, false_negatives, true_negatives = (
_get_confusion_matrix(predictions, "match")
)

assert true_positives == 4
assert false_positives == 3
assert false_negatives == 2
assert true_negatives == 1

0 comments on commit 9755f73

Please sign in to comment.