Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Manual backport 2.x]Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter #2426

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import lombok.Getter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
Expand Down Expand Up @@ -43,9 +42,8 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
@Getter
private final Map<Integer, T> vectors;
private int lastDocID = -1;
@Getter
private final DocsWithFieldSet docsWithField;
private final InfoStream infoStream;
@Getter
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -75,7 +73,6 @@ private NativeEngineFieldVectorsWriter(
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
this.docsWithField = new DocsWithFieldSet();
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
}

Expand All @@ -101,7 +98,6 @@ public void addValue(int docID, T vectorValue) throws IOException {
// ensuring that vector is provided to flatFieldWriter.
flatFieldVectorsWriter.addValue(docID, vectorValue);
vectors.put(docID, vectorValue);
docsWithField.add(docID);
lastDocID = docID;
}

Expand All @@ -121,10 +117,9 @@ public T copyValue(T vectorValue) {
*/
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance(
Integer.class
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter
.ramBytesUsed();
return SHALLOW_SIZE + flatFieldVectorsWriter.getDocsWithFieldSet().ramBytesUsed() + (long) this.vectors.size()
* (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size()
* RamUsageEstimator.shallowSizeOfInstance(Integer.class) + (long) vectors.size() * fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter.ramBytesUsed();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
}
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getFlatFieldVectorsWriter().getDocsWithFieldSet(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.InfoStream;
Expand Down Expand Up @@ -115,6 +116,7 @@ public void testRamByteUsed_whenValidInput_thenSuccess() {
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L);
Mockito.when(mockedFlatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(new DocsWithFieldSet());
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public void testFlush() {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -250,7 +250,7 @@ public void testFlush_WithQuantization() {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -352,7 +352,7 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -429,7 +429,7 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -507,7 +507,7 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -593,7 +593,7 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -683,7 +683,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -786,7 +786,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -848,11 +848,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<

private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
return fieldVectorsWriter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<

private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
return fieldVectorsWriter;
}
}
Loading