Skip to content

Commit 1ae2655

Browse files
harenlinhack4changjpountz
authored
Reciprocal Rank Fusion (RRF) in TopDocs (#13470)
Co-authored-by: SuperSonicVox <hackchang0715@gmail.com> Co-authored-by: Adrien Grand <jpountz@gmail.com>
1 parent be3e39d commit 1ae2655

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ New Features
9191

9292
* GITHUB#14300: Add support JDK 24 to the Panama Vectorization Provider. (Chris Hegarty)
9393

94+
* GITHUB#13470: Added `TopDocs#rrf` to combine multiple TopDocs instances using
95+
reciprocal rank fusion. (Haren Lin, Adrien Grand)
96+
9497
Improvements
9598
---------------------
9699

lucene/core/src/java/org/apache/lucene/search/TopDocs.java

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
*/
1717
package org.apache.lucene.search;
1818

19+
import java.util.ArrayList;
1920
import java.util.Comparator;
21+
import java.util.HashMap;
22+
import java.util.List;
23+
import java.util.Map;
2024
import org.apache.lucene.util.PriorityQueue;
2125

2226
/** Represents hits returned by {@link IndexSearcher#search(Query,int)}. */
@@ -350,4 +354,84 @@ private static TopDocs mergeAux(
350354
return new TopFieldDocs(totalHits, hits, sort.getSort());
351355
}
352356
}
357+
358+
private record ShardIndexAndDoc(int shardIndex, int doc) {}
359+
360+
/**
361+
* Reciprocal Rank Fusion method.
362+
*
363+
* <p>This method combines different search results into a single ranked list by combining their
364+
* ranks. This is especially well suited when combining hits computed via different methods, whose
365+
* score distributions are hardly comparable.
366+
*
367+
* @param topN the top N results to be returned
368+
* @param k a constant determines how much influence documents in individual rankings have on the
369+
* final result. A higher value gives lower rank documents more influence. k should be greater
370+
* than or equal to 1.
371+
* @param hits a list of TopDocs to apply RRF on
372+
* @return a TopDocs contains the top N ranked results.
373+
*/
374+
public static TopDocs rrf(int topN, int k, TopDocs[] hits) {
375+
if (topN < 1) {
376+
throw new IllegalArgumentException("topN must be >= 1, got " + topN);
377+
}
378+
if (k < 1) {
379+
throw new IllegalArgumentException("k must be >= 1, got " + k);
380+
}
381+
382+
Boolean shardIndexSet = null;
383+
for (TopDocs topDocs : hits) {
384+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
385+
boolean thisShardIndexSet = scoreDoc.shardIndex != -1;
386+
if (shardIndexSet == null) {
387+
shardIndexSet = thisShardIndexSet;
388+
} else if (shardIndexSet.booleanValue() != thisShardIndexSet) {
389+
throw new IllegalArgumentException(
390+
"All hits must either have their ScoreDoc#shardIndex set, or unset (-1), not a mix of both.");
391+
}
392+
}
393+
}
394+
395+
// Compute the rrf score as a double to reduce accuracy loss due to floating-point arithmetic.
396+
Map<ShardIndexAndDoc, Double> rrfScore = new HashMap<>();
397+
long totalHitCount = 0;
398+
for (TopDocs topDoc : hits) {
399+
// A document is a hit globally if it is a hit for any of the top docs, so we compute the
400+
// total hit count as the max total hit count.
401+
totalHitCount = Math.max(totalHitCount, topDoc.totalHits.value());
402+
for (int i = 0; i < topDoc.scoreDocs.length; ++i) {
403+
ScoreDoc scoreDoc = topDoc.scoreDocs[i];
404+
int rank = i + 1;
405+
double rrfScoreContribution = 1d / Math.addExact(k, rank);
406+
rrfScore.compute(
407+
new ShardIndexAndDoc(scoreDoc.shardIndex, scoreDoc.doc),
408+
(key, score) -> (score == null ? 0 : score) + rrfScoreContribution);
409+
}
410+
}
411+
412+
List<Map.Entry<ShardIndexAndDoc, Double>> rrfScoreRank = new ArrayList<>(rrfScore.entrySet());
413+
rrfScoreRank.sort(
414+
// Sort by descending score
415+
Map.Entry.<ShardIndexAndDoc, Double>comparingByValue()
416+
.reversed()
417+
// Tie-break by doc ID, then shard index (like TopDocs#merge)
418+
.thenComparing(
419+
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
420+
Comparator.comparingInt(ShardIndexAndDoc::doc)))
421+
.thenComparing(
422+
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
423+
Comparator.comparingInt(ShardIndexAndDoc::shardIndex))));
424+
425+
ScoreDoc[] rrfScoreDocs = new ScoreDoc[Math.min(topN, rrfScoreRank.size())];
426+
for (int i = 0; i < rrfScoreDocs.length; i++) {
427+
Map.Entry<ShardIndexAndDoc, Double> entry = rrfScoreRank.get(i);
428+
int doc = entry.getKey().doc;
429+
int shardIndex = entry.getKey().shardIndex();
430+
float score = entry.getValue().floatValue();
431+
rrfScoreDocs[i] = new ScoreDoc(doc, score, shardIndex);
432+
}
433+
434+
TotalHits totalHits = new TotalHits(totalHitCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
435+
return new TopDocs(totalHits, rrfScoreDocs);
436+
}
353437
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.search;
18+
19+
import org.apache.lucene.tests.util.LuceneTestCase;
20+
21+
public class TestTopDocsRRF extends LuceneTestCase {
22+
23+
public void testBasics() {
24+
TopDocs td1 =
25+
new TopDocs(
26+
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
27+
new ScoreDoc[] {new ScoreDoc(42, 10f), new ScoreDoc(10, 5f), new ScoreDoc(20, 3f)});
28+
TopDocs td2 =
29+
new TopDocs(
30+
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
31+
new ScoreDoc[] {new ScoreDoc(10, 10f), new ScoreDoc(20, 5f)});
32+
33+
TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
34+
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);
35+
36+
ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
37+
assertEquals(3, rrfScoreDocs.length);
38+
39+
assertEquals(10, rrfScoreDocs[0].doc);
40+
assertEquals(-1, rrfScoreDocs[0].shardIndex);
41+
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);
42+
43+
assertEquals(20, rrfScoreDocs[1].doc);
44+
assertEquals(-1, rrfScoreDocs[1].shardIndex);
45+
assertEquals((float) (1d / (20 + 3) + 1d / (20 + 2)), rrfScoreDocs[1].score, 0f);
46+
47+
assertEquals(42, rrfScoreDocs[2].doc);
48+
assertEquals(-1, rrfScoreDocs[2].shardIndex);
49+
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[2].score, 0f);
50+
}
51+
52+
public void testShardIndex() {
53+
TopDocs td1 =
54+
new TopDocs(
55+
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
56+
new ScoreDoc[] {
57+
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
58+
});
59+
TopDocs td2 =
60+
new TopDocs(
61+
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
62+
new ScoreDoc[] {new ScoreDoc(10, 10f, 1), new ScoreDoc(20, 5f, 1)});
63+
64+
TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
65+
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);
66+
67+
ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
68+
assertEquals(3, rrfScoreDocs.length);
69+
70+
assertEquals(10, rrfScoreDocs[0].doc);
71+
assertEquals(1, rrfScoreDocs[0].shardIndex);
72+
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);
73+
74+
assertEquals(42, rrfScoreDocs[1].doc);
75+
assertEquals(0, rrfScoreDocs[1].shardIndex);
76+
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[1].score, 0f);
77+
78+
assertEquals(20, rrfScoreDocs[2].doc);
79+
assertEquals(1, rrfScoreDocs[2].shardIndex);
80+
assertEquals((float) (1d / (20 + 2)), rrfScoreDocs[2].score, 0f);
81+
}
82+
83+
public void testInconsistentShardIndex() {
84+
TopDocs td1 =
85+
new TopDocs(
86+
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
87+
new ScoreDoc[] {
88+
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
89+
});
90+
TopDocs td2 =
91+
new TopDocs(
92+
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
93+
new ScoreDoc[] {new ScoreDoc(10, 10f, -1), new ScoreDoc(20, 5f, -1)});
94+
95+
IllegalArgumentException e =
96+
expectThrows(
97+
IllegalArgumentException.class, () -> TopDocs.rrf(3, 20, new TopDocs[] {td1, td2}));
98+
assertTrue(e.getMessage().contains("shardIndex"));
99+
}
100+
101+
public void testInvalidTopN() {
102+
TopDocs td1 =
103+
new TopDocs(
104+
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
105+
TopDocs td2 =
106+
new TopDocs(
107+
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
108+
109+
IllegalArgumentException e =
110+
expectThrows(
111+
IllegalArgumentException.class, () -> TopDocs.rrf(0, 20, new TopDocs[] {td1, td2}));
112+
assertTrue(e.getMessage().contains("topN"));
113+
}
114+
115+
public void testInvalidK() {
116+
TopDocs td1 =
117+
new TopDocs(
118+
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
119+
TopDocs td2 =
120+
new TopDocs(
121+
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
122+
123+
IllegalArgumentException e =
124+
expectThrows(
125+
IllegalArgumentException.class, () -> TopDocs.rrf(3, 0, new TopDocs[] {td1, td2}));
126+
assertTrue(e.getMessage().contains("k"));
127+
}
128+
}

0 commit comments

Comments
 (0)