Skip to content

Commit 0d05a10

Browse files
authored
Use pre-quantized vectors for ADC (#3113)
* Use pre-quantized vectors for ADC Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com> * Add unit test Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com> * Remove constructor Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com> --------- Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
1 parent 20295a0 commit 0d05a10

File tree

11 files changed

+331
-215
lines changed

11 files changed

+331
-215
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
* Make Merge in nativeEngine can Abort [#2529](https://github.com/opensearch-project/k-NN/pull/2529)
2020
* Use pre-quantized vectors from native engines for exact search [#3095](https://github.com/opensearch-project/k-NN/pull/3095)
2121
* Use right Vector Scorer when segments are initialized using SPI and also corrected the maxConn for MOS [#3117](https://github.com/opensearch-project/k-NN/pull/3117)
22-
22+
* Use pre-quantized vectors for ADC [#3113](https://github.com/opensearch-project/k-NN/pull/3113)

src/main/java/org/opensearch/knn/index/query/exactsearch/BinaryVectorIdsExactKNNIterator.java

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.common.Nullable;
1010
import org.opensearch.knn.index.SpaceType;
1111
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
12+
import org.opensearch.knn.plugin.script.KNNScoringUtil;
1213

1314
import java.io.IOException;
1415

@@ -20,33 +21,46 @@
2021
*/
2122
class BinaryVectorIdsExactKNNIterator implements ExactKNNIterator {
2223
protected final DocIdSetIterator docIdSetIterator;
23-
protected final byte[] queryVector;
24+
protected final byte[] byteQueryVector;
25+
protected final float[] floatQueryVector;
2426
protected final KNNBinaryVectorValues binaryVectorValues;
2527
protected final SpaceType spaceType;
2628
protected float currentScore = Float.NEGATIVE_INFINITY;
2729
protected int docId;
2830

2931
public BinaryVectorIdsExactKNNIterator(
3032
@Nullable final DocIdSetIterator docIdSetIterator,
31-
final byte[] queryVector,
33+
final byte[] byteQueryVector,
3234
final KNNBinaryVectorValues binaryVectorValues,
3335
final SpaceType spaceType
3436
) throws IOException {
35-
this.docIdSetIterator = docIdSetIterator;
36-
this.queryVector = queryVector;
37-
this.binaryVectorValues = binaryVectorValues;
38-
this.spaceType = spaceType;
39-
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
40-
// nextDoc should already be referring to next knnVectorValues
41-
this.docId = getNextDocId();
37+
this(docIdSetIterator, byteQueryVector, null, binaryVectorValues, spaceType);
4238
}
4339

4440
public BinaryVectorIdsExactKNNIterator(
45-
final byte[] queryVector,
41+
@Nullable final DocIdSetIterator docIdSetIterator,
42+
final float[] floatQueryVector,
43+
final KNNBinaryVectorValues binaryVectorValues,
44+
final SpaceType spaceType
45+
) throws IOException {
46+
this(docIdSetIterator, null, floatQueryVector, binaryVectorValues, spaceType);
47+
}
48+
49+
private BinaryVectorIdsExactKNNIterator(
50+
@Nullable final DocIdSetIterator docIdSetIterator,
51+
final byte[] byteQueryVector,
52+
final float[] floatQueryVector,
4653
final KNNBinaryVectorValues binaryVectorValues,
4754
final SpaceType spaceType
4855
) throws IOException {
49-
this(null, queryVector, binaryVectorValues, spaceType);
56+
assert (floatQueryVector == null) != (byteQueryVector == null)
57+
: "Exactly one of byteQueryVector or floatQueryVector must be non-null";
58+
this.docIdSetIterator = docIdSetIterator;
59+
this.byteQueryVector = byteQueryVector;
60+
this.floatQueryVector = floatQueryVector;
61+
this.binaryVectorValues = binaryVectorValues;
62+
this.spaceType = spaceType;
63+
this.docId = getNextDocId();
5064
}
5165

5266
/**
@@ -73,10 +87,11 @@ public float score() {
7387
}
7488

7589
protected float computeScore() throws IOException {
76-
final byte[] vector = binaryVectorValues.getVector();
77-
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
78-
// scores correspond to closer vectors.
79-
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
90+
final byte[] documentVector = binaryVectorValues.getVector();
91+
if (floatQueryVector != null) {
92+
return KNNScoringUtil.scoreWithADC(floatQueryVector, documentVector, spaceType);
93+
}
94+
return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, documentVector);
8095
}
8196

8297
protected int getNextDocId() throws IOException {

src/main/java/org/opensearch/knn/index/query/exactsearch/ExactSearcher.java

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -219,64 +219,81 @@ private ExactKNNIterator getKNNIterator(LeafReaderContext leafReaderContext, Exa
219219
spaceType
220220
);
221221
}
222-
byte[] quantizedQueryVector = null;
223-
SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = null;
224-
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
225-
// Build Segment Level Quantization info.
226-
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, exactSearcherContext.getField());
227-
// Quantize the Query Vector Once. Or transform it in the case of ADC.
228-
if (SegmentLevelQuantizationUtil.isAdcEnabled(segmentLevelQuantizationInfo)) {
229-
SegmentLevelQuantizationUtil.transformVectorWithADC(
230-
exactSearcherContext.getFloatQueryVector(),
231-
segmentLevelQuantizationInfo,
232-
spaceType
233-
);
234-
} else {
235-
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(
222+
// Build Segment Level Quantization info.
223+
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(
224+
reader,
225+
fieldInfo,
226+
exactSearcherContext.getField()
227+
);
228+
// For FP32 vectors, there are two execution paths:
229+
// 1. Full precision path: Used during rescoring or when quantization is not available.
230+
// Loads original float32 vectors and performs exact search using the configured distance metric (L2, Cosine, etc.).
231+
// 2. Quantized path: Used during approximate search when quantization is enabled.
232+
// Loads quantized byte vectors from segment and performs search using either:
233+
// a) ADC (Asymmetric Distance Computation): Transforms query vector and compares against quantized doc vectors
234+
// b) Symmetric quantization: Quantizes query vector and uses Hamming distance for comparison
235+
if (segmentLevelQuantizationInfo == null || !exactSearcherContext.isUseQuantizedVectorsForSearch()) {
236+
final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
237+
if (isNestedRequired) {
238+
return new NestedVectorIdsExactKNNIterator(
239+
matchedDocs,
236240
exactSearcherContext.getFloatQueryVector(),
237-
segmentLevelQuantizationInfo
241+
(KNNFloatVectorValues) vectorValues,
242+
spaceType,
243+
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext)
238244
);
239245
}
246+
return new VectorIdsExactKNNIterator(
247+
matchedDocs,
248+
exactSearcherContext.getFloatQueryVector(),
249+
(KNNFloatVectorValues) vectorValues,
250+
spaceType
251+
);
240252
}
241-
// Quantized search path: retrieve quantized byte vectors from reader as KNNBinaryVectorValues and perform exact search
242-
// using Hamming distance.
243-
if (quantizedQueryVector != null) {
244-
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader, true);
253+
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader, true);
254+
// For ADC, we will transform float vector -> ADC's float vector
255+
if (SegmentLevelQuantizationUtil.isAdcEnabled(segmentLevelQuantizationInfo)) {
256+
SegmentLevelQuantizationUtil.transformVectorWithADC(
257+
exactSearcherContext.getFloatQueryVector(),
258+
segmentLevelQuantizationInfo,
259+
spaceType
260+
);
245261
if (isNestedRequired) {
246262
return new NestedBinaryVectorIdsExactKNNIterator(
247263
matchedDocs,
248-
quantizedQueryVector,
264+
exactSearcherContext.getFloatQueryVector(),
249265
(KNNBinaryVectorValues) vectorValues,
250-
SpaceType.HAMMING,
266+
spaceType,
251267
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext)
252268
);
253269
}
254270
return new BinaryVectorIdsExactKNNIterator(
255271
matchedDocs,
256-
quantizedQueryVector,
272+
exactSearcherContext.getFloatQueryVector(),
257273
(KNNBinaryVectorValues) vectorValues,
258-
SpaceType.HAMMING
274+
spaceType
259275
);
260276
}
261-
final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
277+
final byte[] quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(
278+
exactSearcherContext.getFloatQueryVector(),
279+
segmentLevelQuantizationInfo
280+
);
281+
// Quantized search path: retrieve quantized byte vectors from reader as KNNBinaryVectorValues and perform exact search
282+
// using Hamming distance.
262283
if (isNestedRequired) {
263-
return new NestedVectorIdsExactKNNIterator(
284+
return new NestedBinaryVectorIdsExactKNNIterator(
264285
matchedDocs,
265-
exactSearcherContext.getFloatQueryVector(),
266-
(KNNFloatVectorValues) vectorValues,
267-
spaceType,
268-
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext),
269286
quantizedQueryVector,
270-
segmentLevelQuantizationInfo
287+
(KNNBinaryVectorValues) vectorValues,
288+
SpaceType.HAMMING,
289+
exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext)
271290
);
272291
}
273-
return new VectorIdsExactKNNIterator(
292+
return new BinaryVectorIdsExactKNNIterator(
274293
matchedDocs,
275-
exactSearcherContext.getFloatQueryVector(),
276-
(KNNFloatVectorValues) vectorValues,
277-
spaceType,
278294
quantizedQueryVector,
279-
segmentLevelQuantizationInfo
295+
(KNNBinaryVectorValues) vectorValues,
296+
SpaceType.HAMMING
280297
);
281298
}
282299

src/main/java/org/opensearch/knn/index/query/exactsearch/NestedBinaryVectorIdsExactKNNIterator.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ public NestedBinaryVectorIdsExactKNNIterator(
3232
this.parentBitSet = parentBitSet;
3333
}
3434

35+
public NestedBinaryVectorIdsExactKNNIterator(
36+
@Nullable final DocIdSetIterator filterIdsIterator,
37+
final float[] queryVector,
38+
final KNNBinaryVectorValues binaryVectorValues,
39+
final SpaceType spaceType,
40+
final BitSet parentBitSet
41+
) throws IOException {
42+
super(filterIdsIterator, queryVector, binaryVectorValues, spaceType);
43+
this.parentBitSet = parentBitSet;
44+
}
45+
3546
public NestedBinaryVectorIdsExactKNNIterator(
3647
final byte[] queryVector,
3748
final KNNBinaryVectorValues binaryVectorValues,

src/main/java/org/opensearch/knn/index/query/exactsearch/VectorIdsExactKNNIterator.java

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
import org.opensearch.common.Nullable;
1010
import org.opensearch.knn.index.SpaceType;
1111
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
12-
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
1312
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
1413

15-
import org.opensearch.knn.plugin.script.KNNScoringUtil;
16-
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
17-
1814
import java.io.IOException;
1915

2016
/**
@@ -26,12 +22,10 @@
2622
class VectorIdsExactKNNIterator implements ExactKNNIterator {
2723
protected final DocIdSetIterator filterIdsIterator;
2824
protected final float[] queryVector;
29-
private final byte[] quantizedQueryVector;
3025
protected final KNNFloatVectorValues knnFloatVectorValues;
3126
protected final SpaceType spaceType;
3227
protected float currentScore = Float.NEGATIVE_INFINITY;
3328
protected int docId;
34-
private final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
3529

3630
public VectorIdsExactKNNIterator(
3731
@Nullable final DocIdSetIterator filterIdsIterator,
@@ -44,7 +38,7 @@ public VectorIdsExactKNNIterator(
4438

4539
public VectorIdsExactKNNIterator(final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType)
4640
throws IOException {
47-
this(null, queryVector, knnFloatVectorValues, spaceType, null, null);
41+
this(null, queryVector, knnFloatVectorValues, spaceType);
4842
}
4943

5044
public VectorIdsExactKNNIterator(
@@ -62,8 +56,6 @@ public VectorIdsExactKNNIterator(
6256
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
6357
// nextDoc should already be referring to next knnVectorValues
6458
this.docId = getNextDocId();
65-
this.quantizedQueryVector = quantizedQueryVector;
66-
this.segmentLevelQuantizationInfo = segmentLevelQuantizationInfo;
6759
}
6860

6961
/**
@@ -91,16 +83,7 @@ public float score() {
9183

9284
protected float computeScore() throws IOException {
9385
final float[] vector = knnFloatVectorValues.getVector();
94-
if (segmentLevelQuantizationInfo == null) {
95-
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
96-
}
97-
98-
byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(vector, segmentLevelQuantizationInfo);
99-
if (quantizedQueryVector == null) {
100-
// in ExactSearcher::getKnnIterator we don't set quantizedQueryVector if adc is enabled. So at this point adc is enabled.
101-
return scoreWithADC(queryVector, quantizedVector, spaceType);
102-
}
103-
return SpaceType.HAMMING.getKnnVectorSimilarityFunction().compare(quantizedQueryVector, quantizedVector);
86+
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
10487
}
10588

10689
protected int getNextDocId() throws IOException {
@@ -114,36 +97,4 @@ protected int getNextDocId() throws IOException {
11497
}
11598
return nextDocID;
11699
}
117-
118-
/*
119-
protected for testing.
120-
Logic:
121-
- segmentLevelQuantizationInfo is null -> should not score with ADC
122-
- quantizationParams is not ScalarQuantizationParams -> should not score with ADC
123-
- quantizationParams is ScalarQuantizationParams -> defer to isEnableADC() to determine if should score with ADC.
124-
*/
125-
protected boolean shouldScoreWithADC(SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) {
126-
if (segmentLevelQuantizationInfo == null) {
127-
return false;
128-
}
129-
130-
if (segmentLevelQuantizationInfo.getQuantizationParams() instanceof ScalarQuantizationParams scalarQuantizationParams) {
131-
return scalarQuantizationParams.isEnableADC();
132-
}
133-
return false;
134-
}
135-
136-
// protected for testing. scoreWithADC is used in exact searcher.
137-
protected float scoreWithADC(float[] queryVector, byte[] documentVector, SpaceType spaceType) {
138-
// NOTE: the prescore translations come from Faiss.java::SCORE_TRANSLATIONS.
139-
if (spaceType.equals(SpaceType.L2)) {
140-
return SpaceType.L2.scoreTranslation(KNNScoringUtil.l2SquaredADC(queryVector, documentVector));
141-
} else if (spaceType.equals(SpaceType.INNER_PRODUCT)) {
142-
return SpaceType.INNER_PRODUCT.scoreTranslation((-1 * KNNScoringUtil.innerProductADC(queryVector, documentVector)));
143-
} else if (spaceType.equals(SpaceType.COSINESIMIL)) {
144-
return SpaceType.COSINESIMIL.scoreTranslation(1 - KNNScoringUtil.innerProductADC(queryVector, documentVector));
145-
}
146-
147-
throw new UnsupportedOperationException("Space type " + spaceType.getValue() + " is not supported for ADC");
148-
}
149100
}

src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,19 @@ public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo,
110110
}
111111

112112
/**
113-
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}
113+
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}.
114+
* When shouldRetrieveQuantizedVectors is true, retrieves quantized byte vectors for FLOAT32 encoded fields.
114115
*
115116
* @param fieldInfo {@link FieldInfo}
116117
* @param leafReader {@link LeafReader}
118+
* @param shouldRetrieveQuantizedVectors if true, retrieves quantized byte vectors for FLOAT32 fields
117119
* @return {@link KNNVectorValues}
120+
* @throws IOException if an I/O error occurs
118121
*/
119122
public static <T> KNNVectorValues<T> getVectorValues(
120123
final FieldInfo fieldInfo,
121124
final SegmentReader leafReader,
122-
boolean isQueryVectorQuantized
125+
boolean shouldRetrieveQuantizedVectors
123126
) throws IOException {
124127
if (!fieldInfo.hasVectorValues()) {
125128
final DocIdSetIterator docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName());
@@ -134,7 +137,7 @@ public static <T> KNNVectorValues<T> getVectorValues(
134137
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
135138
final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(fieldInfo.getName());
136139
// Quantized search path: retrieve quantized byte vectors from codec.
137-
if (isQueryVectorQuantized) {
140+
if (shouldRetrieveQuantizedVectors) {
138141
// Bypasses leafReader.getByteVectorValues() which enforces BYTE encoding check.
139142
// This will call getByteVectorValues from NativeEngines990KnnVectorsReader at the end.
140143
final ByteVectorValues byteVectorValues = leafReader.getVectorReader().getByteVectorValues(fieldInfo.getName());

src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,27 @@ public static float innerProductADC(float[] queryVector, byte[] inputVector) {
174174
return score;
175175
}
176176

177+
/**
178+
* Computes the similarity score between a float query vector and a quantized byte document vector using ADC.
179+
* This method applies the appropriate distance metric based on the space type and returns a similarity score.
180+
*
181+
* @param queryVector the full-precision float query vector
182+
* @param documentVector the quantized byte document vector
183+
* @param spaceType the distance metric to use
184+
* @return the similarity score
185+
* @throws UnsupportedOperationException if the space type is not supported for ADC
186+
*/
187+
public static float scoreWithADC(float[] queryVector, byte[] documentVector, SpaceType spaceType) {
188+
if (spaceType.equals(SpaceType.L2)) {
189+
return SpaceType.L2.scoreTranslation(l2SquaredADC(queryVector, documentVector));
190+
} else if (spaceType.equals(SpaceType.INNER_PRODUCT)) {
191+
return SpaceType.INNER_PRODUCT.scoreTranslation((-1 * innerProductADC(queryVector, documentVector)));
192+
} else if (spaceType.equals(SpaceType.COSINESIMIL)) {
193+
return SpaceType.COSINESIMIL.scoreTranslation(1 - innerProductADC(queryVector, documentVector));
194+
}
195+
throw new UnsupportedOperationException("Space type " + spaceType.getValue() + " is not supported for ADC");
196+
}
197+
177198
private static float[] toFloat(final List<Number> inputVector, final VectorDataType vectorDataType) {
178199
Objects.requireNonNull(inputVector);
179200
float[] value = new float[inputVector.size()];

0 commit comments

Comments
 (0)