Skip to content

Commit

Permalink
KnnPlugin Upgrage with Lucene 10.0.1
Browse files Browse the repository at this point in the history
Signed-off-by: Vikasht34 <[email protected]>
  • Loading branch information
Vikasht34 committed Feb 3, 2025
1 parent d58d133 commit 4de48de
Show file tree
Hide file tree
Showing 53 changed files with 1,050 additions and 482 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325)
* Bump Faiss commit from 1f42e81 to 0cbc2a8 to accelerate hamming distance calculation using _mm512_popcnt_epi64 intrinsic and also add avx512-fp16 instructions to boost performance [#2381](https://github.com/opensearch-project/k-NN/pull/2381)
* Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/)
* Fixing Lucene912Codec Issue with BWC for Lucene 10.0.1 upgrade[#2429](https://github.com/opensearch-project/k-NN/pull/2429)
### Refactoring
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ buildscript {
ext {
// build.version_qualifier parameter applies to knn plugin artifacts only. OpenSearch version must be set
// explicitly as 'opensearch.version' property, for instance opensearch.version=2.0.0-rc1-SNAPSHOT
opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT")
version_qualifier = System.getProperty("build.version_qualifier", "")
opensearch_version = System.getProperty("opensearch.version", "3.0.0-alpha1-SNAPSHOT")
version_qualifier = System.getProperty("build.version_qualifier", "alpha1")
opensearch_group = "org.opensearch"
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
avx2_enabled = System.getProperty("avx2.enabled", "true")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase {
private static final int DELAY_MILLI_SEC = 1000;
private static final int MIN_NUM_OF_MODELS = 2;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static final int NUM_DOCS = 1001;
private static final int NUM_DOCS_TEST_MODEL_INDEX = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_DEFAULT = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_FOR_NON_KNN_INDEX = 100;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.knn.index;

import lombok.SneakyThrows;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
Expand Down Expand Up @@ -38,29 +40,29 @@ public long ramBytesUsed() {
return 0; // unknown
}

@SneakyThrows
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
KnnVectorValues knnVectorValues;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
knnVectorValues = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
knnVectorValues = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
values = DocValues.getBinary(reader, fieldName);
return KNNVectorScriptDocValues.create(knnVectorValues, fieldName, vectorDataType);
}
DocIdSetIterator values = DocValues.getBinary(reader, fieldName);
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;
Expand All @@ -32,9 +33,7 @@ public void setNextDocId(int docId) throws IOException {
if (docId < lastDocID) {
throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId);
}

lastDocID = docId;

int curDocID = vectorValues.docID();
if (lastDocID > curDocID) {
curDocID = vectorValues.advance(docId);
Expand Down Expand Up @@ -81,12 +80,13 @@ public float[] get(int i) {
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
public static KNNVectorScriptDocValues create(Object values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {

if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
Expand All @@ -96,34 +96,53 @@ public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fi

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;
private final KnnVectorValues.DocIndexIterator iterator;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
super(values.iterator(), field, type);
this.values = values;
this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator
? (KnnVectorValues.DocIndexIterator) super.vectorValues
: values.iterator();
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
int docId = this.iterator.index();
if (docId == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) {
throw new IllegalStateException("No more ordinals to retrieve vector values.");
}

// Use the correct method to retrieve the byte vector for the current ordinal
byte[] bytes = values.vectorValue(docId);
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}

}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;
private final KnnVectorValues.DocIndexIterator iterator;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
super(values.iterator(), field, type);
this.values = values;
this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator
? (KnnVectorValues.DocIndexIterator) super.vectorValues
: values.iterator();
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
int ord = iterator.index(); // Fetch ordinal (index of vector)
if (ord == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) {
throw new IllegalStateException("No more ordinals to retrieve vector values.");
}
return values.vectorValue(ord);
}
}

Expand Down
30 changes: 21 additions & 9 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import lombok.Getter;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
Expand All @@ -21,6 +21,7 @@
import org.opensearch.knn.training.FloatTrainingDataConsumer;
import org.opensearch.knn.training.TrainingDataConsumer;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand All @@ -47,14 +48,25 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;

while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++];
if (binaryValue.length % Float.BYTES == 0) {
// ✅ Case 1: Stored as encoded floats (each float takes 4 bytes)
int numFloats = binaryValue.length / Float.BYTES;
float[] vector = new float[numFloats];

ByteBuffer byteBuffer = ByteBuffer.wrap(binaryValue.bytes, binaryValue.offset, binaryValue.length);
for (int i = 0; i < numFloats; i++) {
vector[i] = byteBuffer.getFloat(); // Read as float
}
return vector;
} else {
// ✅ Case 2: Stored as raw bytes (each byte is interpreted as a float)
float[] vector = new float[binaryValue.length];
int i = 0, j = binaryValue.offset;
while (i < binaryValue.length) {
vector[i++] = binaryValue.bytes[j++]; // Direct conversion from byte to float
}
return vector;
}
return vector;
}

@Override
Expand Down Expand Up @@ -100,7 +112,7 @@ public void freeNativeMemory(long memoryAddress) {

@Override
public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) {
return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
return KnnFloatVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN10010Codec;

import lombok.Builder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1
*/

public class KNN10010Codec extends FilterCodec {

private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_1_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN10010Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
* Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec
* and a unique name to this ctor.
*
* @param delegate codec that will perform all operations this codec does not override
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

@Override
public DocValuesFormat docValuesFormat() {
return knnFormatFacade.docValuesFormat();
}

@Override
public CompoundFormat compoundFormat() {
return knnFormatFacade.compoundFormat();
}

@Override
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ public KNN80CompoundFormat(CompoundFormat delegate) {
}

@Override
public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException {
return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si, context), dir);
public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si) throws IOException {
return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si), dir);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.*;

import java.io.IOException;

import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand Down Expand Up @@ -66,6 +59,16 @@ public SortedSetDocValues getSortedSet(FieldInfo field) throws IOException {
return delegate.getSortedSet(field);
}

/**
* @param fieldInfo
* @return
* @throws IOException
*/
@Override
public DocValuesSkipper getSkipper(FieldInfo fieldInfo) throws IOException {
return delegate.getSkipper(fieldInfo);
}

@Override
public void checkIntegrity() throws IOException {
delegate.checkIntegrity();
Expand Down
Loading

0 comments on commit 4de48de

Please sign in to comment.