Skip to content

Commit 8cb5790

Browse files
committed
test:condorcet aggregator
1 parent b4b7d74 commit 8cb5790

File tree

2 files changed

+196
-3
lines changed

2 files changed

+196
-3
lines changed

hakeem/aggregation/aggregators/condorcet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,21 @@ def __init__(
1818
self.upper_reliability_bound = upper_reliability_bound
1919

2020
def compute_weights(self, annotations: pd.DataFrame) -> pd.Series:
21+
reliabilities = self._compute_reliabilities(annotations)
22+
assert np.all(
23+
(reliabilities > 0) & (reliabilities < 1)
24+
), "Reliabilities must be in (0, 1)."
25+
26+
return np.log(reliabilities / (1 - reliabilities))
27+
28+
def _compute_reliabilities(self, annotations: pd.DataFrame) -> pd.Series:
2129
vote_size = annotations.sum(axis=1)
30+
31+
assert len(annotations.columns) > 2, "At least 3 labels are required currently."
2232
reliabilities = (len(annotations.columns) - vote_size - 1) / (
2333
len(annotations.columns) - 2
2434
)
2535
reliabilities = reliabilities.clip(
2636
self.lower_reliability_bound, self.upper_reliability_bound
2737
)
28-
weights = np.log(reliabilities / (1 - reliabilities))
29-
30-
return weights
38+
return reliabilities
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import pandas as pd
2+
import pytest
3+
4+
5+
@pytest.mark.ut
6+
@pytest.mark.parametrize(
7+
[
8+
"lower_reliability_bound",
9+
"upper_reliability_bound",
10+
"task_column",
11+
"worker_column",
12+
],
13+
[(0.1, 0.9, "task", "worker"), (0.2, 0.8, "question", "voter")],
14+
)
15+
def test_CondorcetAggregator_init(
16+
lower_reliability_bound: float,
17+
upper_reliability_bound: float,
18+
task_column: str,
19+
worker_column: str,
20+
) -> None:
21+
# Given
22+
from hakeem.aggregation.aggregators.condorcet import CondorcetAggregator
23+
24+
# When
25+
result = CondorcetAggregator(
26+
lower_reliability_bound=lower_reliability_bound,
27+
upper_reliability_bound=upper_reliability_bound,
28+
task_column=task_column,
29+
worker_column=worker_column,
30+
)
31+
32+
# Then
33+
assert result.lower_reliability_bound == lower_reliability_bound
34+
assert result.upper_reliability_bound == upper_reliability_bound
35+
assert result.task_column == task_column
36+
assert result.worker_column == worker_column
37+
38+
39+
@pytest.mark.ut
40+
@pytest.mark.parametrize(
41+
[
42+
"annotations",
43+
"lower_reliability_bound",
44+
"upper_reliability_bound",
45+
"expected_result",
46+
],
47+
[
48+
(
49+
pd.DataFrame(
50+
{
51+
"task": ["q1", "q1", "q2"],
52+
"worker": ["v1", "v2", "v1"],
53+
"a": [1, 1, 0],
54+
"b": [0, 1, 1],
55+
"c": [0, 0, 1],
56+
"d": [0, 0, 1],
57+
}
58+
).set_index(["task", "worker"]),
59+
0.1,
60+
0.9,
61+
pd.Series(
62+
[0.9, 0.5, 0.1],
63+
index=pd.MultiIndex.from_tuples(
64+
[("q1", "v1"), ("q1", "v2"), ("q2", "v1")], names=["task", "worker"]
65+
),
66+
),
67+
),
68+
(
69+
pd.DataFrame(
70+
{
71+
"task": ["q1", "q1", "q2", "q3"],
72+
"worker": ["v1", "v2", "v1", "v3"],
73+
"a": [1, 1, 0, 1],
74+
"b": [0, 1, 1, 1],
75+
"c": [0, 0, 1, 1],
76+
"d": [0, 0, 1, 1],
77+
"e": [0, 0, 0, 0],
78+
}
79+
).set_index(["task", "worker"]),
80+
0.01,
81+
0.99,
82+
pd.Series(
83+
[0.99, 2 / 3, 1 / 3, 0.01],
84+
index=pd.MultiIndex.from_tuples(
85+
[("q1", "v1"), ("q1", "v2"), ("q2", "v1"), ("q3", "v3")],
86+
names=["task", "worker"],
87+
),
88+
),
89+
),
90+
],
91+
)
92+
def test_CondorcetAggregator__compute_reliabilities(
93+
annotations: pd.DataFrame,
94+
lower_reliability_bound: float,
95+
upper_reliability_bound: float,
96+
expected_result: pd.Series,
97+
) -> None:
98+
# Given
99+
from hakeem.aggregation.aggregators.condorcet import CondorcetAggregator
100+
101+
# When
102+
result = CondorcetAggregator(
103+
lower_reliability_bound=lower_reliability_bound,
104+
upper_reliability_bound=upper_reliability_bound,
105+
)._compute_reliabilities(annotations)
106+
107+
# Then
108+
pd.testing.assert_series_equal(expected_result, result)
109+
110+
111+
@pytest.mark.ut
112+
@pytest.mark.parametrize(
113+
[
114+
"annotations",
115+
"lower_reliability_bound",
116+
"upper_reliability_bound",
117+
"expected_result",
118+
],
119+
[
120+
(
121+
pd.DataFrame(
122+
{
123+
"task": ["q1", "q1", "q2"],
124+
"worker": ["v1", "v2", "v1"],
125+
"a": [1, 1, 0],
126+
"b": [0, 1, 1],
127+
"c": [0, 0, 1],
128+
"d": [0, 0, 1],
129+
}
130+
).set_index(["task", "worker"]),
131+
0.1,
132+
0.9,
133+
pd.Series(
134+
[2.1972245773362196, 0.0, -2.197224577336219],
135+
index=pd.MultiIndex.from_tuples(
136+
[("q1", "v1"), ("q1", "v2"), ("q2", "v1")], names=["task", "worker"]
137+
),
138+
),
139+
),
140+
(
141+
pd.DataFrame(
142+
{
143+
"task": ["q1", "q1", "q2", "q3"],
144+
"worker": ["v1", "v2", "v1", "v3"],
145+
"a": [1, 1, 0, 1],
146+
"b": [0, 1, 1, 1],
147+
"c": [0, 0, 1, 1],
148+
"d": [0, 0, 1, 1],
149+
"e": [0, 0, 0, 0],
150+
}
151+
).set_index(["task", "worker"]),
152+
0.01,
153+
0.99,
154+
pd.Series(
155+
[
156+
4.595119850134589,
157+
0.6931471805599452,
158+
-0.6931471805599454,
159+
-4.59511985013459,
160+
],
161+
index=pd.MultiIndex.from_tuples(
162+
[("q1", "v1"), ("q1", "v2"), ("q2", "v1"), ("q3", "v3")],
163+
names=["task", "worker"],
164+
),
165+
),
166+
),
167+
],
168+
)
169+
def test_CondorcetAggregator_compute_weights(
170+
annotations: pd.DataFrame,
171+
lower_reliability_bound: float,
172+
upper_reliability_bound: float,
173+
expected_result: pd.Series,
174+
) -> None:
175+
# Given
176+
from hakeem.aggregation.aggregators.condorcet import CondorcetAggregator
177+
178+
# When
179+
result = CondorcetAggregator(
180+
lower_reliability_bound=lower_reliability_bound,
181+
upper_reliability_bound=upper_reliability_bound,
182+
).compute_weights(annotations)
183+
184+
# Then
185+
pd.testing.assert_series_equal(expected_result, result, rtol=1e-5)

0 commit comments

Comments
 (0)