Skip to content

Commit a371736

Browse files
committed
topk_queue::finalize sorts by both score and ID (#508)
The final sorting order is now by score (descending) and docid (ascending). Furthermore, `std::push_heap` is replaced with our own implementation to maintain consistency across standard libraries.
1 parent f8739ab commit a371736

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

include/pisa/topk_queue.hpp

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ namespace pisa {
1616
/// min element. Because it is a binary heap, the elements are not sorted;
1717
/// use `finalize()` member function to sort it before accessing it with
1818
/// `topk()`.
19+
///
20+
/// Note that `finalize()` breaks ties between entries with equal scores by
21+
/// sorting them by document ID. However, this is not true for insertion, and
22+
/// thus it may happen that a document with identical score as the k-th
23+
/// document in the heap but with a lower document ID is _not_ in the top-k
24+
/// results. In such case, which entry makes it to top k is determined by the
25+
/// order in which elements are pushed to the heap.
1926
struct topk_queue {
2027
using entry_type = std::pair<Score, DocId>;
2128

@@ -50,7 +57,7 @@ struct topk_queue {
5057
}
5158
m_q.emplace_back(score, docid);
5259
if (PISA_UNLIKELY(m_q.size() <= m_k)) {
53-
std::push_heap(m_q.begin(), m_q.end(), min_heap_order);
60+
push_heap(m_q.begin(), m_q.end());
5461
if (PISA_UNLIKELY(m_q.size() == m_k)) {
5562
m_effective_threshold = m_q.front().first;
5663
}
@@ -69,17 +76,25 @@ struct topk_queue {
6976

7077
/// Sorts the results in the heap container in the descending score order.
7178
///
72-
/// After calling this function, the heap should be no longer modified, as
73-
/// the heap order will not be preserved.
79+
/// If multiple entries have equal score, they are sorted by document ID. Notice that this only
80+
/// happens in the finalization step; due to performance considerations, inserting is done with
81+
/// score-only order. After calling this function, the heap should be no longer modified, as the
82+
/// heap order will not be preserved.
7483
void finalize()
7584
{
76-
std::sort_heap(m_q.begin(), m_q.end(), min_heap_order);
77-
size_t size = std::lower_bound(
78-
m_q.begin(),
79-
m_q.end(),
80-
0,
81-
[](std::pair<Score, DocId> l, Score r) { return l.first > r; })
82-
- m_q.begin();
85+
auto sort_order = [](auto const& lhs, auto const& rhs) {
86+
if (lhs.first == rhs.first) {
87+
return lhs.second < rhs.second;
88+
}
89+
return lhs.first > rhs.first;
90+
};
91+
// We have to do a full sort because it is not the exact same ordering as when pushing
92+
// elements to the heap.
93+
std::sort(m_q.begin(), m_q.end(), sort_order);
94+
95+
auto search_order = [](auto entry, auto score) { return entry.first > score; };
96+
auto first_zero_score = std::lower_bound(m_q.begin(), m_q.end(), 0, search_order);
97+
std::size_t size = std::distance(m_q.begin(), first_zero_score);
8398
m_q.resize(size);
8499
}
85100

@@ -127,12 +142,6 @@ struct topk_queue {
127142
[[nodiscard]] auto size() const noexcept -> std::size_t { return m_q.size(); }
128143

129144
private:
130-
[[nodiscard]] constexpr static auto
131-
min_heap_order(entry_type const& lhs, entry_type const& rhs) noexcept -> bool
132-
{
133-
return lhs.first > rhs.first;
134-
}
135-
136145
using entry_iterator_type = typename std::vector<entry_type>::iterator;
137146

138147
/// Sifts down the top element of the heap in `[first, last)`.
@@ -172,6 +181,26 @@ struct topk_queue {
172181
}
173182
}
174183

184+
// We use our own function (as opposed to `std::heap_push`), to ensure that
185+
// heap implementation is consistent across standard libraries.
186+
void push_heap(entry_iterator_type first, entry_iterator_type last)
187+
{
188+
std::size_t hole_idx = std::distance(first, last) - 1;
189+
std::size_t top_idx = 0;
190+
auto cmp = [](entry_iterator_type const& lhs, entry_type const& rhs) {
191+
return lhs->first > rhs.first;
192+
};
193+
auto value = *std::next(first, hole_idx);
194+
195+
auto parent = (hole_idx - 1) / 2;
196+
while (hole_idx > top_idx && cmp(first + parent, value)) {
197+
*(first + hole_idx) = *(first + parent);
198+
hole_idx = parent;
199+
parent = (hole_idx - 1) / 2;
200+
}
201+
*(first + hole_idx) = value;
202+
}
203+
175204
std::size_t m_k;
176205
float m_initial_threshold;
177206
std::vector<entry_type> m_q;

test/test_topk_queue.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,28 @@ auto kth(std::vector<float> scores, int k) -> float
4949
return scores.at(k - 1);
5050
}
5151

52+
auto sort_order = [](auto const& lhs, auto const& rhs) {
53+
if (lhs.first == rhs.first) {
54+
return lhs.second < rhs.second;
55+
}
56+
return lhs.first > rhs.first;
57+
};
58+
59+
TEST_CASE("Top-k ordering", "[topk_queue][prop]")
60+
{
61+
SECTION("Final elements are always sorted")
62+
{
63+
check([] {
64+
auto [scores, docids] = *gen_postings(10, 1000);
65+
66+
pisa::topk_queue topk(10);
67+
accumulate(topk, scores, docids);
68+
topk.finalize();
69+
REQUIRE(std::is_sorted(topk.topk().begin(), topk.topk().end(), sort_order));
70+
});
71+
}
72+
}
73+
5274
TEST_CASE("Threshold", "[topk_queue][prop]")
5375
{
5476
SECTION("When initial = 0.0, the final threshold is the k-th score")

0 commit comments

Comments
 (0)