Skip to content
Merged
89 changes: 89 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/TopDocs.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
*/
package org.apache.lucene.search;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.PriorityQueue;

/** Represents hits returned by {@link IndexSearcher#search(Query,int)}. */
Expand Down Expand Up @@ -350,4 +354,89 @@ private static TopDocs mergeAux(
return new TopFieldDocs(totalHits, hits, sort.getSort());
}
}

private record ShardIndexAndDoc(int shardIndex, int doc) {}

/**
* Reciprocal Rank Fusion method.
*
* <p>This method combines different search results into a single ranked list by combining their
* ranks. This is especially well suited when combining hits computed via different methods, whose
* score distributions are hardly comparable.
*
* @param topN the top N results to be returned
* @param k a constant determines how much influence documents in individual rankings have on the
* final result. A higher value gives lower rank documents more influence. k should be greater
* than or equal to 1.
* @param hits a list of TopDocs to apply RRF on
* @return a TopDocs contains the top N ranked results.
*/
public static TopDocs rrf(int topN, int k, TopDocs[] hits) {
if (topN < 1) {
throw new IllegalArgumentException("topN must be >= 1, got " + topN);
}
if (k < 1) {
throw new IllegalArgumentException("k must be >= 1, got " + k);
}

boolean shardIndexSet = false;
outer:
for (TopDocs topDocs : hits) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
shardIndexSet = scoreDoc.shardIndex != -1;
break outer;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the purpose here to only check the first scoreDoc of every TopDocs instance provided in the array? Should we try and rewrite this to be more readable and not use goto ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it look better now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thank you!

}
}
for (TopDocs topDocs : hits) {
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
boolean thisShardIndexSet = scoreDoc.shardIndex != -1;
if (shardIndexSet != thisShardIndexSet) {
throw new IllegalArgumentException(
"All hits must either have their ScoreDoc#shardIndex set, or unset (-1), not a mix of both.");
}
}
}

// Compute the rrf score as a double to reduce accuracy loss due to floating-point arithmetic.
Map<ShardIndexAndDoc, Double> rrfScore = new HashMap<>();
long totalHitCount = 0;
for (TopDocs topDoc : hits) {
// A document is a hit globally if it is a hit for any of the top docs, so we compute the
// total hit count as the max total hit count.
totalHitCount = Math.max(totalHitCount, topDoc.totalHits.value());
for (int i = 0; i < topDoc.scoreDocs.length; ++i) {
ScoreDoc scoreDoc = topDoc.scoreDocs[i];
int rank = i + 1;
double rrfScoreContribution = 1d / Math.addExact(k, rank);
rrfScore.compute(
new ShardIndexAndDoc(scoreDoc.shardIndex, scoreDoc.doc),
(key, score) -> (score == null ? 0 : score) + rrfScoreContribution);
}
}

List<Map.Entry<ShardIndexAndDoc, Double>> rrfScoreRank = new ArrayList<>(rrfScore.entrySet());
rrfScoreRank.sort(
// Sort by descending score
Map.Entry.<ShardIndexAndDoc, Double>comparingByValue()
.reversed()
// Tie-break by doc ID, then shard index (like TopDocs#merge)
.thenComparing(
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
Comparator.comparingInt(ShardIndexAndDoc::doc)))
.thenComparing(
Map.Entry.<ShardIndexAndDoc, Double>comparingByKey(
Comparator.comparingInt(ShardIndexAndDoc::shardIndex))));

ScoreDoc[] rrfScoreDocs = new ScoreDoc[Math.min(topN, rrfScoreRank.size())];
for (int i = 0; i < rrfScoreDocs.length; i++) {
Map.Entry<ShardIndexAndDoc, Double> entry = rrfScoreRank.get(i);
int doc = entry.getKey().doc;
int shardIndex = entry.getKey().shardIndex();
float score = entry.getValue().floatValue();
rrfScoreDocs[i] = new ScoreDoc(doc, score, shardIndex);
}

TotalHits totalHits = new TotalHits(totalHitCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
return new TopDocs(totalHits, rrfScoreDocs);
}
}
128 changes: 128 additions & 0 deletions lucene/core/src/test/org/apache/lucene/search/TestTopDocsRRF.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import org.apache.lucene.tests.util.LuceneTestCase;

public class TestTopDocsRRF extends LuceneTestCase {

public void testBasics() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(42, 10f), new ScoreDoc(10, 5f), new ScoreDoc(20, 3f)});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f), new ScoreDoc(20, 5f)});

TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);

ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
assertEquals(3, rrfScoreDocs.length);

assertEquals(10, rrfScoreDocs[0].doc);
assertEquals(-1, rrfScoreDocs[0].shardIndex);
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);

assertEquals(20, rrfScoreDocs[1].doc);
assertEquals(-1, rrfScoreDocs[1].shardIndex);
assertEquals((float) (1d / (20 + 3) + 1d / (20 + 2)), rrfScoreDocs[1].score, 0f);

assertEquals(42, rrfScoreDocs[2].doc);
assertEquals(-1, rrfScoreDocs[2].shardIndex);
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[2].score, 0f);
}

public void testShardIndex() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f, 1), new ScoreDoc(20, 5f, 1)});

TopDocs rrfTd = TopDocs.rrf(3, 20, new TopDocs[] {td1, td2});
assertEquals(new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), rrfTd.totalHits);

ScoreDoc[] rrfScoreDocs = rrfTd.scoreDocs;
assertEquals(3, rrfScoreDocs.length);

assertEquals(10, rrfScoreDocs[0].doc);
assertEquals(1, rrfScoreDocs[0].shardIndex);
assertEquals((float) (1d / (20 + 2) + 1d / (20 + 1)), rrfScoreDocs[0].score, 0f);

assertEquals(42, rrfScoreDocs[1].doc);
assertEquals(0, rrfScoreDocs[1].shardIndex);
assertEquals((float) (1d / (20 + 1)), rrfScoreDocs[1].score, 0f);

assertEquals(20, rrfScoreDocs[2].doc);
assertEquals(1, rrfScoreDocs[2].shardIndex);
assertEquals((float) (1d / (20 + 2)), rrfScoreDocs[2].score, 0f);
}

public void testInconsistentShardIndex() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {
new ScoreDoc(42, 10f, 0), new ScoreDoc(10, 5f, 1), new ScoreDoc(20, 3f, 0)
});
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
new ScoreDoc[] {new ScoreDoc(10, 10f, -1), new ScoreDoc(20, 5f, -1)});

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(3, 20, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("shardIndex"));
}

public void testInvalidTopN() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(0, 20, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("topN"));
}

public void testInvalidK() {
TopDocs td1 =
new TopDocs(
new TotalHits(100, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);
TopDocs td2 =
new TopDocs(
new TotalHits(80, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]);

IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class, () -> TopDocs.rrf(3, 0, new TopDocs[] {td1, td2}));
assertTrue(e.getMessage().contains("k"));
}
}