Skip to content

Commit 7a2c310

Browse files
committed
Introduce NativeEngines990KnnVectorsScorer to decouple native SIMD scoring selection from FaissMemoryOptimizedSearcher
Previously, FaissMemoryOptimizedSearcher was responsible for detecting MMap-backed vectors and choosing between NativeRandomVectorScorer and the default Java scorer. This logic is now extracted into a dedicated FlatVectorsScorer decorator that sits in the codec scoring chain, making native SIMD acceleration available to all scoring paths (including HNSW graph traversal with prefetch support). By placing NativeEngines990KnnVectorsScorer inside PrefetchableFlatVectorScorer in the decorator chain, the native scorer now benefits from prefetch-enabled bulk scoring during HNSW graph traversal. Previously, NativeRandomVectorScorer was instantiated directly in FaissMemoryOptimizedSearcher, bypassing the prefetch layer entirely. With this change, PrefetchableFlatVectorScorer can wrap the NativeRandomVectorScorer returned by NativeEngines990KnnVectorsScorer, issuing prefetch hints for memory-mapped vector data before native SIMD scoring, reducing memory access latency during graph neighbor evaluation. Key changes: - Add NativeEngines990KnnVectorsScorer which wraps a FlatVectorsScorer delegate and transparently returns NativeRandomVectorScorer when the bottom-level FloatVectorValues implements MMapVectorValues and the similarity function is EUCLIDEAN or MAXIMUM_INNER_PRODUCT - Simplify FaissMemoryOptimizedSearcher by removing the native scoring branch and determineNativeFunctionType(); it now unconditionally delegates to the FlatVectorsScorer - Refactor NativeRandomVectorScorer to extend AbstractRandomVectorScorer, removing manual maxOrd/ordToDoc/getAcceptOrds overrides, enabling it to be wrapped by PrefetchableFlatVectorScorer for prefetch-enabled bulk scoring during HNSW graph traversal - Wire the new scorer into NativeEngines990KnnVectorsFormat between the Lucene99 flat vectors scorer and PrefetchableFlatVectorScorer - Add unit tests covering all routing branches Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
1 parent 6c4d427 commit 7a2c310

File tree

6 files changed

+245
-86
lines changed

6 files changed

+245
-86
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3232
* Add Prefetch functionality to prefetch vectors during ANN Search for MemoryOptimizedSearch. [#3173](https://github.com/opensearch-project/k-NN/pull/3173)
3333
* Optimize ByteVectorIdsExactKNNIterator by moving array conversion to constructor [#3171](https://github.com/opensearch-project/k-NN/pull/3171)
3434
* Add VectorScorers for BinaryDocValues and nested best child scoring [#3179](https://github.com/opensearch-project/k-NN/pull/3179)
35+
* Introduce NativeEngines990KnnVectorsScorer to decouple native SIMD scoring selection from FaissMemoryOptimizedSearcher [#3184](https://github.com/opensearch-project/k-NN/pull/3184)

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.index.SegmentWriteState;
2323
import org.opensearch.knn.index.KNNSettings;
2424
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
25+
import org.opensearch.knn.index.codec.scorer.NativeEngines990KnnVectorsScorer;
2526
import org.opensearch.knn.index.codec.scorer.PrefetchableFlatVectorScorer;
2627
import org.opensearch.knn.index.engine.KNNEngine;
2728

@@ -35,7 +36,7 @@
3536
public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat {
3637
/** The format for storing, reading, merging vectors on disk */
3738
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(
38-
new PrefetchableFlatVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
39+
new PrefetchableFlatVectorScorer(new NativeEngines990KnnVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()))
3940
);
4041
private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat";
4142
private final int approximateThreshold;
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.codec.scorer;
7+
8+
import lombok.RequiredArgsConstructor;
9+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
10+
import org.apache.lucene.index.FloatVectorValues;
11+
import org.apache.lucene.index.KnnVectorValues;
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
14+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
15+
import org.opensearch.knn.jni.SimdVectorComputeService;
16+
import org.opensearch.knn.memoryoptsearch.faiss.MMapVectorValues;
17+
import org.opensearch.knn.memoryoptsearch.faiss.NativeRandomVectorScorer;
18+
import org.opensearch.knn.memoryoptsearch.faiss.WrappedFloatVectorValues;
19+
20+
import java.io.IOException;
21+
22+
/**
23+
* A {@link FlatVectorsScorer} that transparently selects between native SIMD-optimized scoring
24+
* and pure-Java scoring based on whether the underlying vector values are memory-mapped.
25+
*
26+
* <p>When the bottom-level {@link FloatVectorValues} implements {@link MMapVectorValues},
27+
* this scorer returns a {@link NativeRandomVectorScorer} for hardware-accelerated computation.
28+
* Otherwise, it delegates to the wrapped {@link FlatVectorsScorer}.
29+
*/
30+
@RequiredArgsConstructor
31+
public class NativeEngines990KnnVectorsScorer implements FlatVectorsScorer {
32+
private final FlatVectorsScorer delegate;
33+
34+
@Override
35+
public RandomVectorScorer getRandomVectorScorer(
36+
VectorSimilarityFunction similarityFunction,
37+
KnnVectorValues vectorValues,
38+
float[] target
39+
) throws IOException {
40+
final SimdVectorComputeService.SimilarityFunctionType nativeType = getNativeFunctionType(similarityFunction);
41+
if (nativeType != null) {
42+
final FloatVectorValues bottomValues = WrappedFloatVectorValues.getBottomFloatVectorValues(vectorValues);
43+
if (bottomValues instanceof MMapVectorValues mmapValues) {
44+
return new NativeRandomVectorScorer(target, vectorValues, mmapValues, nativeType);
45+
}
46+
}
47+
return delegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
48+
}
49+
50+
@Override
51+
public RandomVectorScorer getRandomVectorScorer(
52+
VectorSimilarityFunction similarityFunction,
53+
KnnVectorValues vectorValues,
54+
byte[] target
55+
) throws IOException {
56+
return delegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
57+
}
58+
59+
@Override
60+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
61+
VectorSimilarityFunction similarityFunction,
62+
KnnVectorValues vectorValues
63+
) throws IOException {
64+
return delegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
65+
}
66+
67+
private static SimdVectorComputeService.SimilarityFunctionType getNativeFunctionType(
68+
final VectorSimilarityFunction similarityFunction
69+
) {
70+
return switch (similarityFunction) {
71+
case MAXIMUM_INNER_PRODUCT -> SimdVectorComputeService.SimilarityFunctionType.FP16_MAXIMUM_INNER_PRODUCT;
72+
case EUCLIDEAN -> SimdVectorComputeService.SimilarityFunctionType.FP16_L2;
73+
default -> null;
74+
};
75+
}
76+
}

src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
1010
import org.apache.lucene.index.ByteVectorValues;
1111
import org.apache.lucene.index.FieldInfo;
12-
import org.apache.lucene.index.FloatVectorValues;
1312
import org.apache.lucene.index.KnnVectorValues;
1413
import org.apache.lucene.index.VectorEncoding;
1514
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -29,7 +28,6 @@
2928
import org.opensearch.knn.index.KNNVectorSimilarityFunction;
3029
import org.opensearch.knn.index.SpaceType;
3130
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
32-
import org.opensearch.knn.jni.SimdVectorComputeService;
3331
import org.opensearch.knn.memoryoptsearch.VectorSearcher;
3432
import org.opensearch.knn.memoryoptsearch.faiss.cagra.FaissCagraHNSW;
3533

@@ -58,7 +56,6 @@ public WarmupInitializationException(String message) {
5856
private final VectorSimilarityFunction vectorSimilarityFunction;
5957
private final long fileSize;
6058
private boolean isAdc;
61-
private SimdVectorComputeService.SimilarityFunctionType nativeSimilarityFunctionType;
6259

6360
public FaissMemoryOptimizedSearcher(final IndexInput indexInput, final FieldInfo fieldInfo, final FlatVectorsScorer flatVectorsScorer)
6461
throws IOException {
@@ -90,7 +87,6 @@ public FaissMemoryOptimizedSearcher(final IndexInput indexInput, final FieldInfo
9087
);
9188

9289
this.hnsw = extractFaissHnsw(faissIndex);
93-
this.nativeSimilarityFunctionType = determineNativeFunctionType();
9490
}
9591

9692
private static FaissHNSW extractFaissHnsw(final FaissIndex faissIndex) {
@@ -106,35 +102,13 @@ public void search(float[] target, KnnCollector knnCollector, AcceptDocs acceptD
106102
final KnnVectorValues knnVectorValues = isAdc
107103
? faissIndex.getByteValues(getSlicedIndexInput())
108104
: faissIndex.getFloatValues(getSlicedIndexInput());
109-
final FloatVectorValues bottomKnnVectorValues = WrappedFloatVectorValues.getBottomFloatVectorValues(knnVectorValues);
110-
final boolean useNativeScoring = bottomKnnVectorValues instanceof MMapVectorValues;
111-
final IOSupplier<RandomVectorScorer> scorerSupplier;
112-
113-
if (useNativeScoring) {
114-
// We can use native scoring.
115-
scorerSupplier = () -> new NativeRandomVectorScorer(
116-
target,
117-
knnVectorValues,
118-
(MMapVectorValues) bottomKnnVectorValues,
119-
nativeSimilarityFunctionType
120-
);
121-
} else {
122-
// Falling back to default scoring using pure Java.
123-
scorerSupplier = () -> flatVectorsScorer.getRandomVectorScorer(vectorSimilarityFunction, knnVectorValues, target);
124-
}
125105

126-
search(VectorEncoding.FLOAT32, scorerSupplier, knnCollector, acceptDocs);
127-
}
128-
129-
private SimdVectorComputeService.SimilarityFunctionType determineNativeFunctionType() {
130-
if (vectorSimilarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
131-
return SimdVectorComputeService.SimilarityFunctionType.FP16_MAXIMUM_INNER_PRODUCT;
132-
} else if (vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
133-
return SimdVectorComputeService.SimilarityFunctionType.FP16_L2;
134-
}
135-
136-
// At the moment, we only support FP16, it's fine to return null.
137-
return null;
106+
search(
107+
VectorEncoding.FLOAT32,
108+
() -> flatVectorsScorer.getRandomVectorScorer(vectorSimilarityFunction, knnVectorValues, target),
109+
knnCollector,
110+
acceptDocs
111+
);
138112
}
139113

140114
@Override

src/main/java/org/opensearch/knn/memoryoptsearch/faiss/NativeRandomVectorScorer.java

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
package org.opensearch.knn.memoryoptsearch.faiss;
77

8-
import lombok.NonNull;
98
import org.apache.lucene.index.KnnVectorValues;
10-
import org.apache.lucene.util.Bits;
119
import org.apache.lucene.util.hnsw.RandomVectorScorer;
1210
import org.opensearch.knn.jni.SimdVectorComputeService;
1311

@@ -21,21 +19,12 @@
2119
* memory-mapped vector chunks, and delegates all similarity scoring operations to the
2220
* {@link SimdVectorComputeService}. The underlying native library is expected to
2321
* leverage SIMD instructions (e.g., AVX, AVX512, or NEON) to accelerate computations.
22+
* <p>
23+
* Extends {@link AbstractRandomVectorScorer} so that it can be wrapped by
24+
* {@link org.opensearch.knn.index.codec.scorer.PrefetchableFlatVectorScorer} for
25+
* prefetch-enabled bulk scoring during HNSW graph traversal.
2426
*/
25-
public class NativeRandomVectorScorer implements RandomVectorScorer {
26-
27-
// Backing {@link KnnVectorValues} used for document–vector association.
28-
@NonNull
29-
private final KnnVectorValues knnVectorValues;
30-
31-
// Array of address–size pairs describing memory-mapped vector chunks.
32-
private long[] addressAndSize;
33-
34-
// Maximum vector id available for scoring.
35-
private int maxOrd;
36-
37-
// Index value of the native similarity function type.
38-
private int nativeFunctionTypeOrd;
27+
public class NativeRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
3928

4029
/**
4130
* Constructs a native-backed scorer for computing similarity between the given query
@@ -52,11 +41,8 @@ public NativeRandomVectorScorer(
5241
final MMapVectorValues mmapVectorValues,
5342
final SimdVectorComputeService.SimilarityFunctionType similarityFunctionType
5443
) {
55-
this.knnVectorValues = knnVectorValues;
56-
this.addressAndSize = mmapVectorValues.getAddressAndSize();
57-
this.maxOrd = knnVectorValues.size();
58-
this.nativeFunctionTypeOrd = similarityFunctionType.ordinal();
59-
SimdVectorComputeService.saveSearchContext(query, addressAndSize, nativeFunctionTypeOrd);
44+
super(knnVectorValues);
45+
SimdVectorComputeService.saveSearchContext(query, mmapVectorValues.getAddressAndSize(), similarityFunctionType.ordinal());
6046
}
6147

6248
/**
@@ -87,36 +73,4 @@ public float bulkScore(final int[] internalVectorIds, final float[] scores, fina
8773
public float score(final int internalVectorId) throws IOException {
8874
return SimdVectorComputeService.scoreSimilarity(internalVectorId);
8975
}
90-
91-
/**
92-
* Returns the maximum vector id for scoring.
93-
*
94-
* @return the maximum vector id
95-
*/
96-
@Override
97-
public int maxOrd() {
98-
return maxOrd;
99-
}
100-
101-
/**
102-
* Maps an internal vector ordinal to its corresponding document ID.
103-
*
104-
* @param ord the internal vector id
105-
* @return the document ID associated with the given vector id
106-
*/
107-
@Override
108-
public int ordToDoc(int ord) {
109-
return knnVectorValues.ordToDoc(ord);
110-
}
111-
112-
/**
113-
* Returns a filtered {@link Bits} view representing accepted documents.
114-
*
115-
* @param acceptDocs the bit set of accepted documents
116-
* @return a {@link Bits} object describing acceptable vector ids for scoring
117-
*/
118-
@Override
119-
public Bits getAcceptOrds(Bits acceptDocs) {
120-
return knnVectorValues.getAcceptOrds(acceptDocs);
121-
}
12276
}

0 commit comments

Comments
 (0)