Skip to content

Commit 5e3d623

Browse files
Stiching Faiss104ScalarQuantizedKnnVectorsWriter to MemOptimizedScalarQuantizedIndexBuildStrategy
Signed-off-by: Dooyong Kim <kdooyong@amazon.com>
1 parent 95ab621 commit 5e3d623

30 files changed

+1624
-633
lines changed

jni/include/bbq/faiss_bbq_flat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ namespace knn_jni {
196196
quantizedVectorsAndCorrectionFactors(_numVectors * oneElementSize),
197197
dimension(_dimension) {
198198

199-
// Just changing the size, not shrinking.
199+
// Just changing the size, not shrinking, thus allocated memory capacity remains the same.
200+
// This is to avoid reallocations when adding elements later on since we know the exact required memory space upfront.
200201
quantizedVectorsAndCorrectionFactors.resize(0);
201202
// Rewriting code_size to the full element size so that hnsw_add_vertices
202203
// strides correctly through the packed buffer when computing:

src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.knn.indices.ModelMetadata;
1818
import org.opensearch.knn.indices.ModelUtil;
1919

20+
import static org.opensearch.knn.common.KNNConstants.FAISS_BBQ_CONFIG;
2021
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
2122
import static org.opensearch.knn.indices.ModelUtil.getModelMetadata;
2223
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
@@ -128,4 +129,14 @@ private static SpaceType getSpaceTypeFromModel(final ModelDao modelDao, final St
128129
public static @Nullable FieldInfo getFieldInfo(final LeafReader leafReader, final String fieldName) {
129130
return leafReader.getFieldInfos().fieldInfo(fieldName);
130131
}
132+
133+
/**
134+
* Check if the field is configured for Faiss BBQ.
135+
*
136+
* @param fieldInfo {@link FieldInfo}
137+
* @return true if the field has faiss_bbq_config attribute
138+
*/
139+
public static boolean isFaissBBQ(final FieldInfo fieldInfo) {
140+
return StringUtils.isNotEmpty(fieldInfo.getAttribute(FAISS_BBQ_CONFIG));
141+
}
131142
}

src/main/java/org/opensearch/knn/common/KNNConstants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ public class KNNConstants {
7676
public static final String SEARCH_SIZE_PARAMETER = "search_size";
7777

7878
public static final String QFRAMEWORK_CONFIG = "qframe_config";
79+
public static final String FAISS_BBQ_CONFIG = "faiss_bbq_config";
7980

8081
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
8182
public static final String EXPAND_NESTED = "expand_nested_docs";
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.knn.index.codec;
7+
8+
import lombok.extern.log4j.Log4j2;
9+
import org.apache.lucene.codecs.KnnVectorsWriter;
10+
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
11+
import org.apache.lucene.index.FieldInfo;
12+
import org.apache.lucene.index.MergeState;
13+
import org.apache.lucene.index.SegmentWriteState;
14+
import org.apache.lucene.search.DocIdSetIterator;
15+
import org.opensearch.common.Nullable;
16+
import org.opensearch.common.StopWatch;
17+
import org.opensearch.common.TriFunction;
18+
import org.opensearch.knn.index.VectorDataType;
19+
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory;
20+
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
21+
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
22+
import org.opensearch.knn.plugin.stats.KNNGraphValue;
23+
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
24+
25+
import java.io.IOException;
26+
import java.util.List;
27+
import java.util.Map;
28+
import java.util.function.Supplier;
29+
30+
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
31+
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge;
32+
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValuesSupplier;
33+
34+
@Log4j2
35+
public abstract class AbstractNativeEnginesKnnVectorsWriter extends KnnVectorsWriter {
36+
protected <T> void doFlush(
37+
final FieldInfo fieldInfo,
38+
final FlatFieldVectorsWriter<?> fieldWriter,
39+
final T vectors,
40+
@Nullable final TriFunction<FieldInfo, Supplier<KNNVectorValues<?>>, Integer, QuantizationState> quantizationStateSupplier,
41+
final Integer approximateThreshold,
42+
final SegmentWriteState segmentWriteState,
43+
final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory,
44+
final Map<String, Object> buildStrategyParameters
45+
) throws IOException {
46+
// Check total live docs first to avoid unnecessary supplier creation for empty fields
47+
final int totalLiveDocs;
48+
if (vectors instanceof Map) {
49+
totalLiveDocs = ((Map) vectors).size();
50+
} else {
51+
totalLiveDocs = ((List) vectors).size();
52+
}
53+
54+
if (totalLiveDocs == 0) {
55+
log.debug("[Flush] No live docs for field {}", fieldInfo.getName());
56+
return;
57+
}
58+
59+
// Get vector values supplier
60+
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
61+
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier;
62+
if (vectors instanceof Map) {
63+
knnVectorValuesSupplier = getVectorValuesSupplier(vectorDataType, fieldWriter.getDocsWithFieldSet(), (Map) vectors);
64+
} else {
65+
knnVectorValuesSupplier = getVectorValuesSupplier(vectorDataType, fieldWriter.getDocsWithFieldSet(), (List) vectors);
66+
}
67+
68+
QuantizationState quantizationState = null;
69+
if (quantizationStateSupplier != null) {
70+
// should skip graph building only for non quantization use case and if threshold is met
71+
quantizationState = quantizationStateSupplier.apply(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
72+
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs, approximateThreshold)) {
73+
log.debug(
74+
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
75+
fieldInfo.name,
76+
totalLiveDocs,
77+
approximateThreshold
78+
);
79+
return;
80+
}
81+
}
82+
83+
final NativeIndexWriter writer = NativeIndexWriter.getWriter(
84+
fieldInfo,
85+
segmentWriteState,
86+
quantizationState,
87+
nativeIndexBuildStrategyFactory,
88+
buildStrategyParameters
89+
);
90+
91+
final StopWatch stopWatch = new StopWatch().start();
92+
writer.flushIndex(knnVectorValuesSupplier, totalLiveDocs);
93+
final long time_in_millis = stopWatch.stop().totalTime().millis();
94+
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
95+
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
96+
}
97+
98+
protected void doMergeOneField(
99+
final FieldInfo fieldInfo,
100+
final MergeState mergeState,
101+
@Nullable final TriFunction<FieldInfo, Supplier<KNNVectorValues<?>>, Integer, QuantizationState> quantizationStateSupplier,
102+
final Integer approximateThreshold,
103+
final SegmentWriteState segmentWriteState,
104+
final NativeIndexBuildStrategyFactory nativeIndexBuildStrategyFactory,
105+
final Map<String, Object> buildStrategyParameters
106+
) throws IOException {
107+
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
108+
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = getKNNVectorValuesSupplierForMerge(
109+
vectorDataType,
110+
fieldInfo,
111+
mergeState
112+
);
113+
final int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
114+
if (totalLiveDocs == 0) {
115+
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
116+
return;
117+
}
118+
119+
QuantizationState quantizationState = null;
120+
if (quantizationStateSupplier != null) {
121+
quantizationState = quantizationStateSupplier.apply(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
122+
// should skip graph building only for non quantization use case and if threshold is met
123+
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs, approximateThreshold)) {
124+
log.debug(
125+
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
126+
fieldInfo.name,
127+
totalLiveDocs,
128+
approximateThreshold
129+
);
130+
return;
131+
}
132+
}
133+
134+
final NativeIndexWriter writer = NativeIndexWriter.getWriter(
135+
fieldInfo,
136+
segmentWriteState,
137+
quantizationState,
138+
nativeIndexBuildStrategyFactory,
139+
buildStrategyParameters
140+
);
141+
142+
final StopWatch stopWatch = new StopWatch().start();
143+
144+
writer.mergeIndex(knnVectorValuesSupplier, totalLiveDocs);
145+
146+
final long time_in_millis = stopWatch.stop().totalTime().millis();
147+
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
148+
log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
149+
}
150+
151+
public static boolean shouldSkipBuildingVectorDataStructure(final long docCount, final int approximateThreshold) {
152+
if (approximateThreshold < 0) {
153+
return true;
154+
}
155+
return docCount < approximateThreshold;
156+
}
157+
158+
/**
159+
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
160+
* vectorsValues object which you plan to use later
161+
*/
162+
public static int getLiveDocs(final KNNVectorValues<?> vectorValues) throws IOException {
163+
// Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues,
164+
// and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting
165+
// the total live docs here.
166+
int liveDocs = 0;
167+
while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
168+
liveDocs++;
169+
}
170+
return liveDocs;
171+
}
172+
}

src/main/java/org/opensearch/knn/index/codec/KNN1040Codec/Faiss104ScalarQuantizedKnnVectorsFormat.java

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
1414
import org.apache.lucene.index.SegmentReadState;
1515
import org.apache.lucene.index.SegmentWriteState;
16-
import org.opensearch.knn.index.KNNSettings;
1716
import org.opensearch.knn.index.engine.KNNEngine;
1817

1918
import java.io.IOException;
@@ -29,39 +28,26 @@
2928
* {@code "encoder": {"name": "faiss_bbq"}}. See {@code FaissCodecFormatResolver} for the
3029
* routing logic in {@code BasePerFieldKnnVectorsFormat.getKnnVectorsFormatForField}.
3130
*
32-
* <p>Uses Lucene's max dimension limit since the flat storage is Lucene-managed.
33-
*
3431
* @see Faiss104ScalarQuantizedKnnVectorsWriter
3532
* @see Faiss104ScalarQuantizedKnnVectorsReader
3633
*/
3734
@Log4j2
3835
public class Faiss104ScalarQuantizedKnnVectorsFormat extends KnnVectorsFormat {
3936

40-
private static final String FORMAT_NAME = "Faiss104ScalarQuantizedKnnVectorsFormat";
37+
private static final String FORMAT_NAME = Faiss104ScalarQuantizedKnnVectorsFormat.class.getSimpleName();
4138

4239
// Shared across all format instances; Lucene104ScalarQuantizedVectorsFormat is stateless.
4340
private static final Lucene104ScalarQuantizedVectorsFormat bbqFlatFormat = new Lucene104ScalarQuantizedVectorsFormat(
4441
ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE
4542
);
4643

47-
private final int approximateThreshold;
48-
4944
public Faiss104ScalarQuantizedKnnVectorsFormat() {
50-
this(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE);
51-
}
52-
53-
/**
54-
* @param approximateThreshold if the number of vectors in a segment is below this threshold,
55-
* HNSW graph building is skipped. A negative value always skips.
56-
*/
57-
public Faiss104ScalarQuantizedKnnVectorsFormat(int approximateThreshold) {
5845
super(FORMAT_NAME);
59-
this.approximateThreshold = approximateThreshold;
6046
}
6147

6248
@Override
6349
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
64-
return new Faiss104ScalarQuantizedKnnVectorsWriter(state, bbqFlatFormat.fieldsWriter(state), approximateThreshold);
50+
return new Faiss104ScalarQuantizedKnnVectorsWriter(state, bbqFlatFormat.fieldsWriter(state), bbqFlatFormat::fieldsReader);
6551
}
6652

6753
@Override
@@ -79,11 +65,6 @@ public int getMaxDimensions(String fieldName) {
7965

8066
@Override
8167
public String toString() {
82-
return this.getClass().getSimpleName()
83-
+ "(name="
84-
+ this.getClass().getSimpleName()
85-
+ ", approximateThreshold="
86-
+ approximateThreshold
87-
+ ")";
68+
return this.getClass().getSimpleName() + "(name=" + this.getClass().getSimpleName() + ")";
8869
}
8970
}

0 commit comments

Comments
 (0)