Skip to content

Commit 4619d63

Browse files
committed
[#158] Add some unit tests for core.comparison.get_comparison_leaves()
The logic in this function that deals with the secondary, threshold_a, and threshold_b attributes isn't tested or documented, and I'm not sure exactly what it's supposed to do.
1 parent c1713e5 commit 4619d63

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

hlink/tests/core/comparison_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
3+
from hlink.linking.core.comparison import get_comparison_leaves
4+
5+
6+
def test_get_comparison_leaves_base_case() -> None:
7+
"""
8+
A comparison with no sub-comparisons (comp_a and comp_b) is itself the only leaf.
9+
"""
10+
comparison = {
11+
"comparison_type": "threshold",
12+
"feature_name": "namefrst_jw",
13+
"threshold": 0.79,
14+
}
15+
leaves = get_comparison_leaves(comparison)
16+
assert leaves == [comparison]
17+
18+
19+
@pytest.mark.parametrize("operator", ["AND", "OR"])
20+
def test_get_comparison_leaves_one_level(operator: str) -> None:
21+
"""
22+
When there are comp_a and comp_b subcomparisons, get_comparison_leaves()
23+
extracts them as the leaves.
24+
"""
25+
comparison_a = {
26+
"comparison_type": "threshold",
27+
"feature_name": "namefrst_jw",
28+
"threshold": 0.79,
29+
}
30+
comparison_b = {
31+
"comparison_type": "threshold",
32+
"feature_name": "namelast_jw",
33+
"threshold": 0.84,
34+
}
35+
comparisons = {
36+
"operator": operator,
37+
"comp_a": comparison_a,
38+
"comp_b": comparison_b,
39+
}
40+
leaves = get_comparison_leaves(comparisons)
41+
assert leaves == [comparison_a, comparison_b]
42+
43+
44+
@pytest.mark.parametrize("operator1", ["AND", "OR"])
45+
@pytest.mark.parametrize("operator2", ["AND", "OR"])
46+
def test_get_comparison_leaves_nested(operator1: str, operator2: str) -> None:
47+
"""
48+
get_comparison_leaves() recurses through the tree to find leaves when there
49+
are multiple nested levels.
50+
"""
51+
comparison_a = {
52+
"comparison_type": "threshold",
53+
"feature_name": "namefrst_jw",
54+
"threshold": 0.79,
55+
}
56+
comparison_b_a = {
57+
"comparison_type": "threshold",
58+
"feature_name": "namelast_jw",
59+
"threshold": 0.84,
60+
}
61+
comparison_b_b = {
62+
"comparison_type": "threshold",
63+
"feature_name": "marst_flag",
64+
"threshold_expr": ">0.5",
65+
}
66+
67+
comparisons = {
68+
"operator": operator1,
69+
"comp_a": comparison_a,
70+
"comp_b": {
71+
"operator": operator2,
72+
"comp_a": comparison_b_a,
73+
"comp_b": comparison_b_b,
74+
},
75+
}
76+
77+
leaves = get_comparison_leaves(comparisons)
78+
assert leaves == [comparison_a, comparison_b_a, comparison_b_b]

0 commit comments

Comments
 (0)