Skip to content
Merged
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ New Features

* GITHUB#14300: Add support JDK 24 to the Panama Vectorization Provider. (Chris Hegarty)

* GITHUB#13470: Added `TopDocs#rrf` to combine multiple TopDocs instances using
reciprocal rank fusion. (Haren Lin, Adrien Grand)

Improvements
---------------------

Expand Down
84 changes: 84 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/TopDocs.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
*/
package org.apache.lucene.search;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.PriorityQueue;

/** Represents hits returned by {@link IndexSearcher#search(Query,int)}. */
Expand Down Expand Up @@ -350,4 +354,84 @@ private static TopDocs mergeAux(
return new TopFieldDocs(totalHits, hits, sort.getSort());
}
}

private record ShardIndexAndDoc(int shardIndex, int doc) {}

/**
* Reciprocal Rank Fusion method.
*
* <p>This method combines different search results into a single ranked list by combining their
* ranks. This is especially well suited when combining hits computed via different methods, whose
* score distributions are hardly comparable.
*
* @param topN the top N results to be returned
* @param k a constant determines how much influence documents in individual rankings have on the
* final result. A higher value gives lower rank documents more influence. k should be greater
* than or equal to 1.
* @param hits a list of TopDocs to apply RRF on
* @return a TopDocs contains the top N ranked results.
*/
public static TopDocs rrf(int topN, int k, TopDocs[] hits) {
if (topN < 1) {
throw new IllegalArgumentException("topN must be >= 1, got " + topN);
}
if (k < 1) {
throw new IllegalArgumentException("k must be >= 1, got " + k);
}

Boolean shardIndexSet = null;
for (TopDocs topDocs : hits) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
boolean thisShardIndexSet = scoreDoc.shardIndex != -1;
if (shardIndexSet == null) {
shardIndexSet = thisShardIndexSet;
} else if (shardIndexSet.booleanValue() != thisShardIndexSet) {
throw new IllegalArgumentException(
"All hits must either have their ScoreDoc#shardIndex set, or unset (-1), not a mix of both.");
}
}
}

// Compute the rrf score as a double to reduce accuracy loss due to floating-point arithmetic.
Map<ShardIndexAndDoc, Double> rrfScore = new HashMap<>();
long totalHitCount = 0;
for (TopDocs topDoc : hits) {
// A document is a hit globally if it is a hit for any of the top docs, so we compute the
// total hit count as the max total hit count.
totalHitCount = Math.max(totalHitCount, topDoc.totalHits.value());
for (int i = 0; i < topDoc.scoreDocs.length; ++i) {
ScoreDoc scoreDoc = topDoc.scoreDocs[i];
int rank = i + 1;
double rrfScoreContribution = 1d / Math.addExact(k, rank);
rrfScore.compute(
new ShardIndexAndDoc(scoreDoc.shardIndex, scoreDoc.doc),
(key, score) -> (score == null ? 0 : score) + rrfScoreContribution);
}
}

List<Map.Entry<ShardIndexAndDoc, Double>> rrfScoreRank = new ArrayList<>(rrfScore.entrySet());
rrfScoreRank.sort(
// Sort by descending score
Map.Entry.<ShardIndexAndDoc, Double>comparingByValue()
.reversed()
// Tie-break by doc ID, then shard index (like TopDocs#merge)
.thenComparing(
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
Comparator.comparingInt(ShardIndexAndDoc::doc)))
.thenComparing(
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
Comparator.comparingInt(ShardIndexAndDoc::shardIndex))));

ScoreDoc[] rrfScoreDocs = new ScoreDoc[Math.min(topN, rrfScoreRank.size())];
for (int i = 0; i < rrfScoreDocs.length; i++) {
Map.Entry<ShardIndexAndDoc, Double> entry = rrfScoreRank.get(i);
int doc = entry.getKey().doc;
int shardIndex = entry.getKey().shardIndex();
float score = entry.getValue().floatValue();
rrfScoreDocs[i] = new ScoreDoc(doc, score, shardIndex);
}

TotalHits totalHits = new TotalHits(totalHitCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
return new TopDocs(totalHits, rrfScoreDocs);
}
}
128 changes: 128 additions & 0 deletions lucene/core/src/test/org/apache/lucene/search/TestTopDocsRRF.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import org.apache.lucene.tests.util.LuceneTestCase;

public class TestTopDocsRRF extends LuceneTestCase {

public void testBasics() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(42, 10f), new ScoreDoc(10, 5f), new ScoreDoc(20, 3f)});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f), new ScoreDoc(20, 5f)});

TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);

ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
assertEquals(3, rrfScoreDocs.length);

assertEquals(10, rrfScoreDocs[0].doc);
assertEquals(-1, rrfScoreDocs[0].shardIndex);
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);

assertEquals(20, rrfScoreDocs[1].doc);
assertEquals(-1, rrfScoreDocs[1].shardIndex);
assertEquals((float) (1d / (20 + 3) + 1d / (20 + 2)), rrfScoreDocs[1].score, 0f);

assertEquals(42, rrfScoreDocs[2].doc);
assertEquals(-1, rrfScoreDocs[2].shardIndex);
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[2].score, 0f);
}

public void testShardIndex() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f, 1), new ScoreDoc(20, 5f, 1)});

TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);

ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
assertEquals(3, rrfScoreDocs.length);

assertEquals(10, rrfScoreDocs[0].doc);
assertEquals(1, rrfScoreDocs[0].shardIndex);
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);

assertEquals(42, rrfScoreDocs[1].doc);
assertEquals(0, rrfScoreDocs[1].shardIndex);
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[1].score, 0f);

assertEquals(20, rrfScoreDocs[2].doc);
assertEquals(1, rrfScoreDocs[2].shardIndex);
assertEquals((float) (1d / (20 + 2)), rrfScoreDocs[2].score, 0f);
}

public void testInconsistentShardIndex() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f, -1), new ScoreDoc(20, 5f, -1)});

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(3, 20, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("shardIndex"));
}

public void testInvalidTopN() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(0, 20, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("topN"));
}

public void testInvalidK() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(3, 0, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("k"));
}
}