Skip to content

Commit 358683d

Browse files
committed
Add VectorScorerFactory, ScoreMode and unit tests for exact search scoring
Introduce VectorScorerFactory to create VectorScorer instances based on the underlying vector storage format (BinaryDocValues, FloatVectorValues, ByteVectorValues). Add ScoreMode interface to abstract scorer vs rescorer creation. - Add ScoreMode with SCORE and RESCORE strategies - Add VectorScorerFactory with float and byte target overloads - Add ADC scorer with TODO to remove once ByteVectorValues.scorer() is implemented Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
1 parent 6c4d427 commit 358683d

File tree

6 files changed

+615
-1
lines changed

6 files changed

+615
-1
lines changed

β€ŽCHANGELOG.mdβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3232
* Add Prefetch functionality to prefetch vectors during ANN Search for MemoryOptimizedSearch. [#3173](https://github.com/opensearch-project/k-NN/pull/3173)
3333
* Optimize ByteVectorIdsExactKNNIterator by moving array conversion to constructor [#3171](https://github.com/opensearch-project/k-NN/pull/3171)
3434
* Add VectorScorers for BinaryDocValues and nested best child scoring [#3179](https://github.com/opensearch-project/k-NN/pull/3179)
35+
* Add VectorScorerFactory, ScoreMode and unit tests for exact search scoring [#3183](https://github.com/opensearch-project/k-NN/pull/3183)

β€Žsrc/main/java/org/opensearch/knn/index/query/scorers/KNNBinaryDocValuesScorer.javaβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
* {@link BytesRef}.</li>
3434
* </ul>
3535
*/
36-
public class KNNBinaryDocValuesScorer implements VectorScorer {
36+
class KNNBinaryDocValuesScorer implements VectorScorer {
3737

3838
private final BinaryDocValues binaryDocValues;
3939
private final ScoreFunction scoreFunction;
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.query.scorers;
7+
8+
import org.apache.lucene.index.ByteVectorValues;
9+
import org.apache.lucene.index.FloatVectorValues;
10+
import org.apache.lucene.search.VectorScorer;
11+
12+
import java.io.IOException;
13+
14+
/**
15+
* Strategy for creating a {@link VectorScorer} from Lucene vector values.
16+
*
17+
* <p>This interface abstracts the choice between primary scoring and rescoring so that
18+
* {@link VectorScorers} can delegate scorer creation without knowing which mode is
19+
* in effect.
20+
*
21+
* <h2>Provided Implementations</h2>
22+
* <ul>
23+
* <li>{@link #SCORE} β€” creates a scorer via {@code vectorValues.scorer(target)},
24+
* which computes the primary similarity score (e.g. dot product, L2).</li>
25+
* <li>{@link #RESCORE} β€” creates a scorer via {@code vectorValues.rescorer(target)},
26+
* which recomputes a higher-fidelity score from the original (unquantized) vectors.
27+
* This is typically used after an initial approximate search with quantized vectors.</li>
28+
* </ul>
29+
*
30+
* @see VectorScorers
31+
*/
32+
public interface VectorScorerMode {
33+
34+
/** Creates a scorer using the primary similarity function. */
35+
VectorScorerMode SCORE = new VectorScorerMode() {
36+
@Override
37+
public VectorScorer createScorer(FloatVectorValues vectorValues, float[] target) throws IOException {
38+
return vectorValues.scorer(target);
39+
}
40+
41+
@Override
42+
public VectorScorer createScorer(ByteVectorValues vectorValues, byte[] target) throws IOException {
43+
return vectorValues.scorer(target);
44+
}
45+
};
46+
47+
/** Creates a scorer that recomputes a higher-fidelity score from unquantized vectors. */
48+
VectorScorerMode RESCORE = new VectorScorerMode() {
49+
@Override
50+
public VectorScorer createScorer(FloatVectorValues vectorValues, float[] target) throws IOException {
51+
return vectorValues.rescorer(target);
52+
}
53+
54+
@Override
55+
public VectorScorer createScorer(ByteVectorValues vectorValues, byte[] target) throws IOException {
56+
return vectorValues.rescorer(target);
57+
}
58+
};
59+
60+
/**
61+
* Creates a {@link VectorScorer} for a float query vector against {@link FloatVectorValues}.
62+
*
63+
* @param vectorValues the float vector values for the segment
64+
* @param target the float query vector
65+
* @return a scorer positioned over the given vector values
66+
* @throws IOException if an I/O error occurs
67+
*/
68+
VectorScorer createScorer(FloatVectorValues vectorValues, float[] target) throws IOException;
69+
70+
/**
71+
* Creates a {@link VectorScorer} for a byte query vector against {@link ByteVectorValues}.
72+
*
73+
* @param vectorValues the byte vector values for the segment
74+
* @param target the byte query vector
75+
* @return a scorer positioned over the given vector values
76+
* @throws IOException if an I/O error occurs
77+
*/
78+
VectorScorer createScorer(ByteVectorValues vectorValues, byte[] target) throws IOException;
79+
}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.query.scorers;
7+
8+
import lombok.AccessLevel;
9+
import lombok.NoArgsConstructor;
10+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
11+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
12+
import org.apache.lucene.index.BinaryDocValues;
13+
import org.apache.lucene.index.ByteVectorValues;
14+
import org.apache.lucene.index.FloatVectorValues;
15+
import org.apache.lucene.index.KnnVectorValues;
16+
import org.apache.lucene.search.DocIdSetIterator;
17+
import org.apache.lucene.search.VectorScorer;
18+
import org.apache.lucene.util.BitSet;
19+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
20+
import org.opensearch.common.Nullable;
21+
import org.opensearch.knn.index.SpaceType;
22+
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator;
23+
import org.opensearch.knn.memoryoptsearch.faiss.FlatVectorsScorerProvider;
24+
25+
import java.io.IOException;
26+
27+
/**
28+
* Static factory for creating {@link VectorScorer} instances from {@link KNNVectorValuesIterator.DocIdsIteratorValues}.
29+
*
30+
* <p>{@code VectorScorers} inspects the underlying iterator and vector values to select the appropriate
31+
* scoring strategy:
32+
* <ul>
33+
* <li>{@link BinaryDocValues} β†’ delegates to {@link KNNBinaryDocValuesScorer}</li>
34+
* <li>{@link FloatVectorValues} β†’ uses the provided {@link VectorScorerMode} (score or rescore)</li>
35+
* <li>{@link ByteVectorValues} with float target β†’ ADC (Asymmetric Distance Computation) scoring</li>
36+
* <li>{@link ByteVectorValues} with byte target β†’ uses the provided {@link VectorScorerMode}</li>
37+
* </ul>
38+
*/
39+
@NoArgsConstructor(access = AccessLevel.PRIVATE)
40+
public final class VectorScorers {
41+
42+
/**
43+
* Creates a {@link VectorScorer} for the given float query vector.
44+
*
45+
* @param docIdsIteratorValues wraps the {@link DocIdSetIterator} and {@link KnnVectorValues}
46+
* for the segment being scored
47+
* @param target the float query vector
48+
* @param vectorScorerMode determines whether to use scoring or rescoring
49+
* @param spaceType the space type defining the similarity function
50+
* @return a {@link VectorScorer} appropriate for the underlying vector storage format
51+
* @throws IOException if an I/O error occurs
52+
*/
53+
public static VectorScorer createScorer(
54+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
55+
final float[] target,
56+
final VectorScorerMode vectorScorerMode,
57+
final SpaceType spaceType
58+
) throws IOException {
59+
return createScorer(docIdsIteratorValues, target, vectorScorerMode, spaceType, null, null);
60+
}
61+
62+
/**
63+
* Creates a {@link VectorScorer} for the given float query vector, wrapping with
64+
* {@link NestedBestChildVectorScorer} when nested search is required.
65+
*
66+
* @param docIdsIteratorValues wraps the {@link DocIdSetIterator} and {@link KnnVectorValues}
67+
* for the segment being scored
68+
* @param target the float query vector
69+
* @param vectorScorerMode determines whether to use scoring or rescoring
70+
* @param spaceType the space type defining the similarity function
71+
* @param filteredIdsIterator iterator over accepted child documents, or null if not nested
72+
* @param parentBitSet bit set identifying parent documents, or null if not nested
73+
* @return a {@link VectorScorer} appropriate for the underlying vector storage format
74+
* @throws IOException if an I/O error occurs
75+
*/
76+
public static VectorScorer createScorer(
77+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
78+
final float[] target,
79+
final VectorScorerMode vectorScorerMode,
80+
final SpaceType spaceType,
81+
@Nullable final DocIdSetIterator filteredIdsIterator,
82+
@Nullable final BitSet parentBitSet
83+
) throws IOException {
84+
final VectorScorer scorer = getBaseScorer(docIdsIteratorValues, target, vectorScorerMode, spaceType);
85+
return maybeWrapWithNestedScorer(scorer, filteredIdsIterator, parentBitSet);
86+
}
87+
88+
/**
89+
* Creates a {@link VectorScorer} for the given byte query vector.
90+
*
91+
* @param docIdsIteratorValues wraps the {@link DocIdSetIterator} and {@link KnnVectorValues}
92+
* for the segment being scored
93+
* @param target the byte query vector
94+
* @param vectorScorerMode determines whether to use scoring or rescoring
95+
* @param spaceType the space type defining the similarity function
96+
* @return a {@link VectorScorer} appropriate for the underlying vector storage format
97+
* @throws IOException if an I/O error occurs
98+
*/
99+
public static VectorScorer createScorer(
100+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
101+
final byte[] target,
102+
final VectorScorerMode vectorScorerMode,
103+
final SpaceType spaceType
104+
) throws IOException {
105+
return createScorer(docIdsIteratorValues, target, vectorScorerMode, spaceType, null, null);
106+
}
107+
108+
/**
109+
* Creates a {@link VectorScorer} for the given byte query vector, wrapping with
110+
* {@link NestedBestChildVectorScorer} when nested search is required.
111+
*
112+
* @param docIdsIteratorValues wraps the {@link DocIdSetIterator} and {@link KnnVectorValues}
113+
* for the segment being scored
114+
* @param target the byte query vector
115+
* @param vectorScorerMode determines whether to use scoring or rescoring
116+
* @param spaceType the space type defining the similarity function
117+
* @param acceptedChildrenIterator iterator over accepted child documents, or null if not nested
118+
* @param parentBitSet bit set identifying parent documents, or null if not nested
119+
* @return a {@link VectorScorer} appropriate for the underlying vector storage format
120+
* @throws IOException if an I/O error occurs
121+
*/
122+
public static VectorScorer createScorer(
123+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
124+
final byte[] target,
125+
final VectorScorerMode vectorScorerMode,
126+
final SpaceType spaceType,
127+
@Nullable final DocIdSetIterator acceptedChildrenIterator,
128+
@Nullable final BitSet parentBitSet
129+
) throws IOException {
130+
final VectorScorer scorer = getBaseScorer(docIdsIteratorValues, target, vectorScorerMode, spaceType);
131+
return maybeWrapWithNestedScorer(scorer, acceptedChildrenIterator, parentBitSet);
132+
}
133+
134+
private static VectorScorer getBaseScorer(
135+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
136+
final float[] target,
137+
final VectorScorerMode vectorScorerMode,
138+
final SpaceType spaceType
139+
) throws IOException {
140+
final DocIdSetIterator docIdSetIterator = docIdsIteratorValues.getDocIdSetIterator();
141+
142+
// ignore score mode, for BinaryDocValues since we do not support BinaryDocValues with quantization
143+
if (docIdSetIterator instanceof BinaryDocValues binaryDocValues) {
144+
return KNNBinaryDocValuesScorer.create(target, binaryDocValues, spaceType);
145+
}
146+
147+
final KnnVectorValues knnVectorValues = docIdsIteratorValues.getKnnVectorValues();
148+
if (knnVectorValues instanceof FloatVectorValues floatVectorValues) {
149+
return vectorScorerMode.createScorer(floatVectorValues, target);
150+
}
151+
if (knnVectorValues instanceof ByteVectorValues byteVectorValues) { // ADC case
152+
return createADCScorer(byteVectorValues, target, spaceType);
153+
}
154+
throw new IllegalArgumentException("Unsupported KnnVectorValues type: " + knnVectorValues.getClass().getSimpleName());
155+
}
156+
157+
private static VectorScorer getBaseScorer(
158+
final KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues,
159+
final byte[] target,
160+
final VectorScorerMode vectorScorerMode,
161+
final SpaceType spaceType
162+
) throws IOException {
163+
final DocIdSetIterator docIdSetIterator = docIdsIteratorValues.getDocIdSetIterator();
164+
165+
// ignore score mode, for BinaryDocValues since we do not support BinaryDocValues with quantization
166+
if (docIdSetIterator instanceof BinaryDocValues binaryDocValues) {
167+
return KNNBinaryDocValuesScorer.create(target, binaryDocValues, spaceType);
168+
}
169+
170+
final KnnVectorValues knnVectorValues = docIdsIteratorValues.getKnnVectorValues();
171+
if (knnVectorValues instanceof ByteVectorValues byteVectorValues) {
172+
return vectorScorerMode.createScorer(byteVectorValues, target);
173+
}
174+
throw new IllegalArgumentException("Byte target requires ByteVectorValues but got " + knnVectorValues.getClass().getSimpleName());
175+
}
176+
177+
private static VectorScorer maybeWrapWithNestedScorer(
178+
final VectorScorer scorer,
179+
@Nullable final DocIdSetIterator acceptedChildrenIterator,
180+
@Nullable final BitSet parentBitSet
181+
) {
182+
if (parentBitSet == null) {
183+
return scorer;
184+
}
185+
return new NestedBestChildVectorScorer(acceptedChildrenIterator, parentBitSet, scorer);
186+
}
187+
188+
/**
189+
* Creates an ADC (Asymmetric Distance Computation) {@link VectorScorer} that scores a float query vector
190+
* against quantized byte document vectors.
191+
*/
192+
// TODO: Remove once ByteVectorValues.scorer() is implemented to return the appropriate
193+
// VectorScorer based on ADC/quantization. At that point, VectorScorerMode.createScorer() will
194+
// handle this case and this method will no longer be needed.
195+
private static VectorScorer createADCScorer(final ByteVectorValues byteVectorValues, final float[] target, final SpaceType spaceType)
196+
throws IOException {
197+
198+
final FlatVectorsScorer adcFlatVectorsScorer = FlatVectorsScorerProvider.getFlatVectorsScorer(
199+
spaceType.getKnnVectorSimilarityFunction(),
200+
true,
201+
spaceType,
202+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
203+
);
204+
final RandomVectorScorer randomVectorScorer = adcFlatVectorsScorer.getRandomVectorScorer(
205+
spaceType.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(),
206+
byteVectorValues,
207+
target
208+
);
209+
return new VectorScorer() {
210+
@Override
211+
public float score() throws IOException {
212+
return randomVectorScorer.score(byteVectorValues.iterator().docID());
213+
}
214+
215+
@Override
216+
public DocIdSetIterator iterator() {
217+
return byteVectorValues.iterator();
218+
}
219+
};
220+
}
221+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.query.scorers;
7+
8+
import junit.framework.TestCase;
9+
import lombok.SneakyThrows;
10+
import org.apache.lucene.index.ByteVectorValues;
11+
import org.apache.lucene.index.FloatVectorValues;
12+
import org.apache.lucene.search.VectorScorer;
13+
14+
import static org.mockito.Mockito.mock;
15+
import static org.mockito.Mockito.verify;
16+
import static org.mockito.Mockito.when;
17+
18+
public class VectorScorerModeTests extends TestCase {
19+
20+
// ──────────────────────────────────────────────
21+
// SCORE mode
22+
// ──────────────────────────────────────────────
23+
24+
@SneakyThrows
25+
public void testScore_withFloatVectorValues_delegatesToScorer() {
26+
float[] target = { 1.0f, 2.0f };
27+
FloatVectorValues floatVectorValues = mock(FloatVectorValues.class);
28+
VectorScorer expected = mock(VectorScorer.class);
29+
when(floatVectorValues.scorer(target)).thenReturn(expected);
30+
31+
VectorScorer result = VectorScorerMode.SCORE.createScorer(floatVectorValues, target);
32+
33+
assertSame(expected, result);
34+
verify(floatVectorValues).scorer(target);
35+
}
36+
37+
@SneakyThrows
38+
public void testScore_withByteVectorValues_delegatesToScorer() {
39+
byte[] target = { 1, 2 };
40+
ByteVectorValues byteVectorValues = mock(ByteVectorValues.class);
41+
VectorScorer expected = mock(VectorScorer.class);
42+
when(byteVectorValues.scorer(target)).thenReturn(expected);
43+
44+
VectorScorer result = VectorScorerMode.SCORE.createScorer(byteVectorValues, target);
45+
46+
assertSame(expected, result);
47+
verify(byteVectorValues).scorer(target);
48+
}
49+
50+
// ──────────────────────────────────────────────
51+
// RESCORE mode
52+
// ──────────────────────────────────────────────
53+
54+
@SneakyThrows
55+
public void testRescore_withFloatVectorValues_delegatesToRescorer() {
56+
float[] target = { 1.0f, 2.0f };
57+
FloatVectorValues floatVectorValues = mock(FloatVectorValues.class);
58+
VectorScorer expected = mock(VectorScorer.class);
59+
when(floatVectorValues.rescorer(target)).thenReturn(expected);
60+
61+
VectorScorer result = VectorScorerMode.RESCORE.createScorer(floatVectorValues, target);
62+
63+
assertSame(expected, result);
64+
verify(floatVectorValues).rescorer(target);
65+
}
66+
67+
@SneakyThrows
68+
public void testRescore_withByteVectorValues_delegatesToRescorer() {
69+
byte[] target = { 1, 2 };
70+
ByteVectorValues byteVectorValues = mock(ByteVectorValues.class);
71+
VectorScorer expected = mock(VectorScorer.class);
72+
when(byteVectorValues.rescorer(target)).thenReturn(expected);
73+
74+
VectorScorer result = VectorScorerMode.RESCORE.createScorer(byteVectorValues, target);
75+
76+
assertSame(expected, result);
77+
verify(byteVectorValues).rescorer(target);
78+
}
79+
}

0 commit comments

Comments
Β (0)