From 8dc2b0abde79d296cf3389f0303c81cedd262522 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Thu, 23 Jan 2025 12:00:01 -0800 Subject: [PATCH] KnnPlugin Upgrage with Lucene 10.0.1 Signed-off-by: Vikasht34 --- CHANGELOG.md | 1 + build.gradle | 4 +- .../java/org/opensearch/knn/bwc/ModelIT.java | 2 +- .../knn/index/KNNVectorDVLeafFieldData.java | 14 +- .../knn/index/KNNVectorScriptDocValues.java | 43 ++- .../opensearch/knn/index/VectorDataType.java | 30 +- .../codec/KNN10010Codec/KNN10010Codec.java | 63 ++++ .../codec/KNN80Codec/KNN80CompoundFormat.java | 4 +- .../KNN80Codec/KNN80DocValuesProducer.java | 19 +- .../KNN9120BinaryVectorScorer.java | 29 +- .../KNN990QuantizationStateReader.java | 2 +- .../NativeEngines990KnnVectorsFormat.java | 9 + .../NativeEngines990KnnVectorsReader.java | 8 - .../knn/index/codec/KNNCodecVersion.java | 20 +- .../knn/index/mapper/KNNVectorFieldType.java | 4 +- .../opensearch/knn/index/query/KNNScorer.java | 72 ++-- .../opensearch/knn/index/query/KNNWeight.java | 26 +- .../index/query/common/DocAndScoreQuery.java | 151 ++++---- .../vectorvalues/KNNBinaryVectorValues.java | 3 +- .../vectorvalues/KNNByteVectorValues.java | 3 +- .../vectorvalues/KNNFloatVectorValues.java | 3 +- .../vectorvalues/KNNVectorValuesFactory.java | 41 +- .../vectorvalues/KNNVectorValuesIterator.java | 49 ++- .../VectorValueExtractorStrategy.java | 28 +- .../org/opensearch/knn/indices/ModelDao.java | 2 +- .../TrainingJobRouterTransportAction.java | 2 +- .../transport/UpdateModelGraveyardAction.java | 2 +- .../UpdateModelGraveyardRequest.java | 2 +- .../UpdateModelGraveyardTransportAction.java | 2 +- .../transport/UpdateModelMetadataAction.java | 2 +- .../transport/UpdateModelMetadataRequest.java | 2 +- .../UpdateModelMetadataTransportAction.java | 2 +- .../services/org.apache.lucene.codecs.Codec | 1 + .../index/KNNVectorDVLeafFieldDataTests.java | 19 +- .../index/KNNVectorScriptDocValuesTests.java | 179 ++++++--- .../knn/index/VectorDataTypeTests.java | 24 +- .../KNN80Codec/KNN80CompoundFormatTests.java | 4 +- ...NativeEngines990KnnVectorsFormatTests.java | 37 +- .../knn/index/codec/KNNCodecTestCase.java | 5 +- .../knn/index/codec/KNNCodecTestUtil.java | 16 +- .../knn/index/query/KNNWeightTests.java | 12 +- .../knn/index/query/ResultUtilTests.java | 2 +- .../query/common/DocAndScoreQueryTests.java | 29 +- .../ExpandNestedEDocsQueryTests.java | 54 ++- .../NativeEngineKNNVectorQueryTests.java | 351 +++++++++++++++--- .../vectorvalues/KNNVectorValuesTests.java | 9 - .../index/vectorvalues/TestVectorValues.java | 83 ++++- .../opensearch/knn/indices/ModelDaoTests.java | 2 +- .../knn/integ/KNNScriptScoringIT.java | 7 +- .../opensearch/knn/jni/JNIServiceTests.java | 42 +-- .../plugin/script/KNNScoringUtilTests.java | 15 +- ...ateModelGraveyardTransportActionTests.java | 2 +- ...dateModelMetadataTransportActionTests.java | 2 +- 53 files changed, 1056 insertions(+), 483 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d6a2eb79..7cfce8c2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,4 +53,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325) * Bump Faiss commit from 1f42e81 to 0cbc2a8 to accelerate hamming distance calculation using _mm512_popcnt_epi64 intrinsic and also add avx512-fp16 instructions to boost performance [#2381](https://github.com/opensearch-project/k-NN/pull/2381) * Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/) +* Fixing Lucene912Codec Issue with BWC for Lucene 10.0.1 upgrade[#2429](https://github.com/opensearch-project/k-NN/pull/2429) ### Refactoring diff --git a/build.gradle b/build.gradle index 9a3ca673f..2276f03c0 100644 --- a/build.gradle +++ b/build.gradle @@ -15,8 +15,8 @@ buildscript { ext { // build.version_qualifier parameter applies to knn plugin artifacts only. OpenSearch version must be set // explicitly as 'opensearch.version' property, for instance opensearch.version=2.0.0-rc1-SNAPSHOT - opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") - version_qualifier = System.getProperty("build.version_qualifier", "") + opensearch_version = System.getProperty("opensearch.version", "3.0.0-alpha1-SNAPSHOT") + version_qualifier = System.getProperty("build.version_qualifier", "alpha1") opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") avx2_enabled = System.getProperty("avx2.enabled", "true") diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index 29496cff9..c6caff306 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -53,7 +53,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase { private static final int DELAY_MILLI_SEC = 1000; private static final int MIN_NUM_OF_MODELS = 2; private static final int K = 5; - private static final int NUM_DOCS = 10; + private static final int NUM_DOCS = 1001; private static final int NUM_DOCS_TEST_MODEL_INDEX = 100; private static final int NUM_DOCS_TEST_MODEL_INDEX_DEFAULT = 100; private static final int NUM_DOCS_TEST_MODEL_INDEX_FOR_NON_KNN_INDEX = 100; diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 7053e6151..3d33a508d 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index; +import lombok.SneakyThrows; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; @@ -38,6 +40,7 @@ public long ramBytesUsed() { return 0; // unknown } + @SneakyThrows @Override public ScriptDocValues getScriptValues() { try { @@ -45,22 +48,21 @@ public ScriptDocValues getScriptValues() { if (fieldInfo == null) { return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); } - - DocIdSetIterator values; + KnnVectorValues knnVectorValues; if (fieldInfo.hasVectorValues()) { switch (fieldInfo.getVectorEncoding()) { case FLOAT32: - values = reader.getFloatVectorValues(fieldName); + knnVectorValues = reader.getFloatVectorValues(fieldName); break; case BYTE: - values = reader.getByteVectorValues(fieldName); + knnVectorValues = reader.getByteVectorValues(fieldName); break; default: throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); } - } else { - values = DocValues.getBinary(reader, fieldName); + return KNNVectorScriptDocValues.create(knnVectorValues, fieldName, vectorDataType); } + DocIdSetIterator values = DocValues.getBinary(reader, fieldName); return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 55ff65516..d9cb099c6 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; @@ -32,13 +33,15 @@ public void setNextDocId(int docId) throws IOException { if (docId < lastDocID) { throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId); } - lastDocID = docId; - int curDocID = vectorValues.docID(); if (lastDocID > curDocID) { curDocID = vectorValues.advance(docId); } + // 🔹 Ensure the iterator advances correctly + while (vectorValues.docID() < docId) { + vectorValues.nextDoc(); + } docExists = lastDocID == curDocID; } @@ -81,12 +84,13 @@ public float[] get(int i) { * @return A KNNVectorScriptDocValues object based on the type of the values. * @throws IllegalArgumentException If the type of values is unsupported. */ - public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + public static KNNVectorScriptDocValues create(Object values, String fieldName, VectorDataType vectorDataType) { Objects.requireNonNull(values, "values must not be null"); - if (values instanceof ByteVectorValues) { - return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); - } else if (values instanceof FloatVectorValues) { + + if (values instanceof FloatVectorValues) { return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); } else if (values instanceof BinaryDocValues) { return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); } else { @@ -96,34 +100,53 @@ public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fi private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { private final ByteVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { - super(values, field, type); + super(values.iterator(), field, type); this.values = values; + this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator + ? (KnnVectorValues.DocIndexIterator) super.vectorValues + : values.iterator(); } @Override protected float[] doGetValue() throws IOException { - byte[] bytes = values.vectorValue(); + int docId = this.iterator.index(); + if (docId == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { + throw new IllegalStateException("No more ordinals to retrieve vector values."); + } + + // Use the correct method to retrieve the byte vector for the current ordinal + byte[] bytes = values.vectorValue(docId); float[] value = new float[bytes.length]; for (int i = 0; i < bytes.length; i++) { value[i] = (float) bytes[i]; } return value; } + } private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { private final FloatVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { - super(values, field, type); + super(values.iterator(), field, type); this.values = values; + this.iterator = super.vectorValues instanceof KnnVectorValues.DocIndexIterator + ? (KnnVectorValues.DocIndexIterator) super.vectorValues + : values.iterator(); } @Override protected float[] doGetValue() throws IOException { - return values.vectorValue(); + int ord = iterator.index(); // Fetch ordinal (index of vector) + if (ord == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { + throw new IllegalStateException("No more ordinals to retrieve vector values."); + } + return values.vectorValue(ord); } } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index e97bd2dbf..050714f4e 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -9,7 +9,7 @@ import lombok.Getter; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -21,6 +21,7 @@ import org.opensearch.knn.training.FloatTrainingDataConsumer; import org.opensearch.knn.training.TrainingDataConsumer; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Locale; import java.util.Objects; @@ -47,14 +48,25 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc @Override public float[] getVectorFromBytesRef(BytesRef binaryValue) { - float[] vector = new float[binaryValue.length]; - int i = 0; - int j = binaryValue.offset; - - while (i < binaryValue.length) { - vector[i++] = binaryValue.bytes[j++]; + if (binaryValue.length % Float.BYTES == 0) { + // ✅ Case 1: Stored as encoded floats (each float takes 4 bytes) + int numFloats = binaryValue.length / Float.BYTES; + float[] vector = new float[numFloats]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(binaryValue.bytes, binaryValue.offset, binaryValue.length); + for (int i = 0; i < numFloats; i++) { + vector[i] = byteBuffer.getFloat(); // Read as float + } + return vector; + } else { + // ✅ Case 2: Stored as raw bytes (each byte is interpreted as a float) + float[] vector = new float[binaryValue.length]; + int i = 0, j = binaryValue.offset; + while (i < binaryValue.length) { + vector[i++] = binaryValue.bytes[j++]; // Direct conversion from byte to float + } + return vector; } - return vector; } @Override @@ -100,7 +112,7 @@ public void freeNativeMemory(long memoryAddress) { @Override public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) { - return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); + return KnnFloatVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java new file mode 100644 index 000000000..ed607b7c5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010Codec.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import lombok.Builder; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.KNNFormatFacade; + +/** + * KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1 + */ + +public class KNN10010Codec extends FilterCodec { + + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_1_0; + private final KNNFormatFacade knnFormatFacade; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + + /** + * No arg constructor that uses Lucene99 as the delegate + */ + public KNN10010Codec() { + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + } + + /** + * Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec + * and a unique name to this ctor. + * + * @param delegate codec that will perform all operations this codec does not override + * @param knnVectorsFormat per field format for KnnVector + */ + @Builder + protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); + perFieldKnnVectorsFormat = knnVectorsFormat; + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index 24dbfb78b..2d0ee349a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -40,8 +40,8 @@ public KNN80CompoundFormat(CompoundFormat delegate) { } @Override - public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { - return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si, context), dir); + public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si) throws IOException { + return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si), dir); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java index 23c9f3105..1bafa2cfa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java @@ -13,17 +13,10 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.*; import java.io.IOException; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.index.SortedSetDocValues; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -66,6 +59,16 @@ public SortedSetDocValues getSortedSet(FieldInfo field) throws IOException { return delegate.getSortedSet(field); } + /** + * @param fieldInfo + * @return + * @throws IOException + */ + @Override + public DocValuesSkipper getSkipper(FieldInfo fieldInfo) throws IOException { + return delegate.getSkipper(fieldInfo); + } + @Override public void checkIntegrity() throws IOException { delegate.checkIntegrity(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java index 2b3723439..16fd2ad43 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java @@ -8,7 +8,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.opensearch.knn.index.KNNVectorSimilarityFunction; @@ -22,10 +23,10 @@ public class KNN9120BinaryVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues + KnnVectorValues randomAccessVectorValues ) throws IOException { - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { - return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues); + if (randomAccessVectorValues instanceof ByteVectorValues) { + return new BinaryRandomVectorScorerSupplier((ByteVectorValues) randomAccessVectorValues); } throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); } @@ -33,7 +34,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, + KnnVectorValues randomAccessVectorValues, float[] queryVector ) throws IOException { throw new IllegalArgumentException("binary vectors do not support float[] targets"); @@ -42,20 +43,20 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, + KnnVectorValues randomAccessVectorValues, byte[] queryVector ) throws IOException { - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { - return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector); + if (randomAccessVectorValues instanceof ByteVectorValues) { + return new BinaryRandomVectorScorer((ByteVectorValues) randomAccessVectorValues, queryVector); } throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); } static class BinaryRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues.Bytes vectorValues; + private final ByteVectorValues vectorValues; private final byte[] queryVector; - BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + BinaryRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { this.queryVector = query; this.vectorValues = vectorValues; } @@ -82,11 +83,11 @@ public Bits getAcceptOrds(Bits acceptDocs) { } static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - protected final RandomAccessVectorValues.Bytes vectorValues; - protected final RandomAccessVectorValues.Bytes vectorValues1; - protected final RandomAccessVectorValues.Bytes vectorValues2; + protected final ByteVectorValues vectorValues; + protected final ByteVectorValues vectorValues1; + protected final ByteVectorValues vectorValues2; - public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException { + public BinaryRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; this.vectorValues1 = vectorValues.copy(); this.vectorValues2 = vectorValues.copy(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index d9b73d621..c6b6c6268 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -54,7 +54,7 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); - try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { + try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.DEFAULT)) { CodecUtil.retrieveChecksum(input); int numFields = getNumFields(input); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index dd326123e..1095ee026 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -71,6 +71,15 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce return new NativeEngines990KnnVectorsReader(state, flatVectorsFormat.fieldsReader(state)); } + /** + * @param s + * @return + */ + @Override + public int getMaxDimensions(String s) { + return Integer.MAX_VALUE; + } + @Override public String toString() { return "NativeEngines99KnnVectorsFormat(name=" diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index efabc3a70..2366a6d57 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -199,14 +199,6 @@ public void close() throws IOException { } } - /** - * Return the memory usage of this object in bytes. Negative values are illegal. - */ - @Override - public long ramBytesUsed() { - return flatVectorsReader.ramBytesUsed(); - } - private void loadCacheKeyMap() { quantizationStateCacheKeyPerField = new HashMap<>(); for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 4343c845b..23f44ce83 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -13,9 +13,11 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; -import org.apache.lucene.codecs.lucene912.Lucene912Codec; +import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; @@ -128,9 +130,23 @@ public enum KNNCodecVersion { .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), KNN9120Codec::new + ), + V_10_1_0( + "KNN10010Codec", + new Lucene101Codec(), + new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN10010Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .build(), + KNN10010Codec::new ); - private static final KNNCodecVersion CURRENT = V_9_12_0; + private static final KNNCodecVersion CURRENT = V_10_1_0; private final String codecName; private final Codec defaultCodecDelegate; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 461c6f7c8..b0bead693 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -8,7 +8,7 @@ import lombok.Getter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; import org.opensearch.index.fielddata.IndexFieldData; @@ -81,7 +81,7 @@ public String typeName() { @Override public Query existsQuery(QueryShardContext context) { - return new DocValuesFieldExistsQuery(name()); + return new FieldExistsQuery(name()); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 99962d307..26b26ec11 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -28,7 +28,7 @@ public class KNNScorer extends Scorer { private final float boost; public KNNScorer(Weight weight, DocIdSetIterator docIdsIter, Map scores, float boost) { - super(weight); + super(); this.docIdsIter = docIdsIter; this.scores = scores; this.boost = boost; @@ -60,40 +60,44 @@ public int docID() { /** * Returns the Empty Scorer implementation. We use this scorer to short circuit the actual search when it is not * required. - * @param knnWeight {@link KNNWeight} * @return {@link KNNScorer} */ - public static Scorer emptyScorer(KNNWeight knnWeight) { - return new Scorer(knnWeight) { - private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); - - @Override - public DocIdSetIterator iterator() { - return docIdsIter; - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return 0; - } - - @Override - public float score() throws IOException { - assert docID() != DocIdSetIterator.NO_MORE_DOCS; - return 0; - } - - @Override - public int docID() { - return docIdsIter.docID(); - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof Scorer)) return false; - return getWeight().equals(((Scorer) obj).getWeight()); - } - }; - + public static Scorer emptyScorer() { + return EMPTY_SCORER_INSTANCE; } + + private static final Scorer EMPTY_SCORER_INSTANCE = new Scorer() { + private final DocIdSetIterator docIdsIter = DocIdSetIterator.empty(); + + @Override + public DocIdSetIterator iterator() { + return docIdsIter; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return 0; + } + + @Override + public float score() throws IOException { + assert docID() != DocIdSetIterator.NO_MORE_DOCS; + return 0; + } + + @Override + public int docID() { + return docIdsIter.docID(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; // Singleton ensures only one instance exists + } + + @Override + public int hashCode() { + return System.identityHashCode(this); // Consistent hash for singleton + } + }; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad..16d58df71 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilteredDocIdSetIterator; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; @@ -110,13 +111,24 @@ public Explanation explain(LeafReaderContext context, int doc) { } @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult(); - if (docIdToScoreMap.isEmpty()) { - return KNNScorer.emptyScorer(this); - } - final int maxDoc = Collections.max(docIdToScoreMap.keySet()) + 1; - return new KNNScorer(this, ResultUtil.resultMapToDocIds(docIdToScoreMap, maxDoc), docIdToScoreMap, boost); + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult(); + if (docIdToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(); + } + final int maxDoc = Collections.max(docIdToScoreMap.keySet()) + 1; + return new KNNScorer(KNNWeight.this, ResultUtil.resultMapToDocIds(docIdToScoreMap, maxDoc), docIdToScoreMap, boost); + } + + @Override + public long cost() { + // Estimate the cost of the scoring operation, if applicable. + return DocIdSetIterator.NO_MORE_DOCS; + } + }; } /** diff --git a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index f38cc96c6..1ca4424c2 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -13,7 +13,9 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import org.opensearch.knn.index.query.KNNScorer; import java.io.IOException; import java.util.Arrays; @@ -21,9 +23,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -/** - * This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery} - */ final class DocAndScoreQuery extends Query { private final int k; @@ -62,92 +61,104 @@ public int count(LeafReaderContext context) { } @Override - public Scorer scorer(LeafReaderContext context) { - if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { - return null; - } - return new Scorer(this) { - final int lower = segmentStarts[context.ord]; - final int upper = segmentStarts[context.ord + 1]; - int upTo = -1; - + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { @Override - public DocIdSetIterator iterator() { - return new DocIdSetIterator() { + public Scorer get(long leadCost) throws IOException { + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return KNNScorer.emptyScorer(); + } + return new Scorer() { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + @Override - public int docID() { - return docIdNoShadow(); + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; } @Override - public int nextDoc() { - if (upTo == -1) { - upTo = lower; - } else { - ++upTo; + public float getMaxScore(int docId) { + docId += context.docBase; + float maxScore = 0; + for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { + maxScore = Math.max(maxScore, scores[idx]); } - return docIdNoShadow(); + return maxScore * boost; } @Override - public int advance(int target) throws IOException { - return slowAdvance(target); + public float score() { + return scores[upTo] * boost; } @Override - public long cost() { - return upper - lower; + public int advanceShallow(int docid) { + int start = Math.max(upTo, lower); + int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); + if (docidIndex < 0) { + docidIndex = -1 - docidIndex; + } + if (docidIndex >= upper) { + return NO_MORE_DOCS; + } + return docs[docidIndex]; } - }; - } - - @Override - public float getMaxScore(int docId) { - docId += context.docBase; - float maxScore = 0; - for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { - maxScore = Math.max(maxScore, scores[idx]); - } - return maxScore * boost; - } - @Override - public float score() { - return scores[upTo] * boost; - } + /** + * move the implementation of docID() into a differently-named method so we can call it + * from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } - @Override - public int advanceShallow(int docid) { - int start = Math.max(upTo, lower); - int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); - if (docidIndex < 0) { - docidIndex = -1 - docidIndex; - } - if (docidIndex >= upper) { - return NO_MORE_DOCS; - } - return docs[docidIndex]; - } + @Override + public int docID() { + return docIdNoShadow(); + } + }; - /** - * move the implementation of docID() into a differently-named method so we can call it - * from DocIDSetIterator.docID() even though this class is anonymous - * - * @return the current docid - */ - private int docIdNoShadow() { - if (upTo == -1) { - return -1; - } - if (upTo >= upper) { - return NO_MORE_DOCS; - } - return docs[upTo] - context.docBase; } @Override - public int docID() { - return docIdNoShadow(); + public long cost() { + // Estimate the cost of the scoring operation, if applicable. + return DocIdSetIterator.NO_MORE_DOCS; } }; } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index 5da093fd5..b113509cf 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -9,6 +9,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -34,7 +35,7 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index 1ebc50970..374adea20 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -9,6 +9,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -34,7 +35,7 @@ public byte[] getVector() throws IOException { @Override public byte[] conditionalCloneVector() throws IOException { byte[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index dffdd8f0d..ad9f32b77 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -8,6 +8,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import java.io.IOException; import java.util.Arrays; @@ -32,7 +33,7 @@ public float[] getVector() throws IOException { @Override public float[] conditionalCloneVector() throws IOException { float[] vector = getVector(); - if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { + if (vectorValuesIterator.getDocIdSetIterator() instanceof KnnVectorValues.DocIndexIterator) { return Arrays.copyOf(vector, vector.length); } return vector; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e217..79f64c50a 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,11 +5,7 @@ package org.opensearch.knn.index.vectorvalues; -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.DocsWithFieldSet; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.VectorDataType; @@ -26,11 +22,15 @@ public final class KNNVectorValuesFactory { * Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and {@link VectorDataType} * * @param vectorDataType {@link VectorDataType} - * @param docIdSetIterator {@link DocIdSetIterator} + * @param knnVectorValues {@link KnnVectorValues} * @return {@link KNNVectorValues} */ + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final KnnVectorValues knnVectorValues) { + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(knnVectorValues)); + } + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) { - return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator)); + return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator, null)); } /** @@ -57,19 +57,24 @@ public static KNNVectorValues getVectorValues( */ public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, final LeafReader leafReader) throws IOException { final DocIdSetIterator docIdSetIterator; - if (fieldInfo.hasVectorValues()) { - if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { - docIdSetIterator = leafReader.getByteVectorValues(fieldInfo.getName()); - } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { - docIdSetIterator = leafReader.getFloatVectorValues(fieldInfo.getName()); - } else { - throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); - } - } else { + if (!fieldInfo.hasVectorValues()) { docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName()); + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator, null); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getByteVectorValues(fieldInfo.getName())) + ); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + return getVectorValues( + FieldInfoExtractor.extractVectorDataType(fieldInfo), + new KNNVectorValuesIterator.DocIdsIteratorValues(leafReader.getFloatVectorValues(fieldInfo.getName())) + ); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); } - final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); - return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); } @SuppressWarnings("unchecked") diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java index 4f1445c1c..bffe5d96f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesIterator.java @@ -5,19 +5,19 @@ package org.opensearch.knn.index.vectorvalues; +import lombok.Getter; import lombok.NonNull; +import lombok.Setter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.KnnVectorValues; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import java.io.IOException; -import java.util.List; import java.util.Map; -import java.util.function.Function; /** * An abstract class that provides an iterator to iterate over KNNVectors, as KNNVectors are stored as different @@ -71,16 +71,27 @@ public interface KNNVectorValuesIterator { * {@link DocIdSetIterator} interface. Example: {@link BinaryDocValues}, {@link FloatVectorValues} etc. */ class DocIdsIteratorValues implements KNNVectorValuesIterator { - protected DocIdSetIterator docIdSetIterator; - private static final List> VALID_ITERATOR_INSTANCE = List.of( - (itr) -> itr instanceof BinaryDocValues, - (itr) -> itr instanceof FloatVectorValues, - (itr) -> itr instanceof ByteVectorValues - ); - - DocIdsIteratorValues(@NonNull final DocIdSetIterator docIdSetIterator) { - validateIteratorType(docIdSetIterator); + private final DocIdSetIterator docIdSetIterator; + private final KnnVectorValues knnVectorValues; // Added reference to KnnVectorValues + @Getter + @Setter + private int lastOrd = -1; + @Getter + @Setter + private Object lastAccessedVector = null; + + DocIdsIteratorValues(@NonNull final KnnVectorValues knnVectorValues) { + this.docIdSetIterator = knnVectorValues.iterator(); + this.knnVectorValues = knnVectorValues; + } + + DocIdsIteratorValues(final DocIdSetIterator docIdSetIterator, final KnnVectorValues knnVectorValues) { this.docIdSetIterator = docIdSetIterator; + this.knnVectorValues = knnVectorValues; + } + + public KnnVectorValues getKnnVectorValues() { + return knnVectorValues; } @Override @@ -107,7 +118,7 @@ public DocIdSetIterator getDocIdSetIterator() { public long liveDocs() { if (docIdSetIterator instanceof BinaryDocValues) { return KNNCodecUtil.getTotalLiveDocsCount((BinaryDocValues) docIdSetIterator); - } else if (docIdSetIterator instanceof FloatVectorValues || docIdSetIterator instanceof ByteVectorValues) { + } else if (docIdSetIterator instanceof KnnVectorValues.DocIndexIterator) { return docIdSetIterator.cost(); } throw new IllegalArgumentException( @@ -119,18 +130,6 @@ public long liveDocs() { public VectorValueExtractorStrategy getVectorExtractorStrategy() { return new VectorValueExtractorStrategy.DISIVectorExtractor(); } - - private void validateIteratorType(final DocIdSetIterator docIdSetIterator) { - VALID_ITERATOR_INSTANCE.stream() - .map(v -> v.apply(docIdSetIterator)) - .filter(Boolean::booleanValue) - .findFirst() - .orElseThrow( - () -> new IllegalArgumentException( - "DocIdSetIterator present is not of valid type. Valid types are: BinaryDocValues, FloatVectorValues and ByteVectorValues" - ) - ); - } } /** diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java index 07db4e7f6..63364dcdf 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/VectorValueExtractorStrategy.java @@ -8,6 +8,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; @@ -74,8 +75,18 @@ public T extract(final VectorDataType vectorDataType, final KNNVectorValuesI if (docIdSetIterator instanceof BinaryDocValues) { final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; return (T) getFloatVectorFromByteRef(values.binaryValue()); - } else if (docIdSetIterator instanceof FloatVectorValues) { - return (T) ((FloatVectorValues) docIdSetIterator).vectorValue(); + } else if (docIdSetIterator instanceof KnnVectorValues.DocIndexIterator) { + KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues = + (KNNVectorValuesIterator.DocIdsIteratorValues) vectorValuesIterator; + FloatVectorValues knnVectorValues = (FloatVectorValues) docIdsIteratorValues.getKnnVectorValues(); + int ord = ((KnnVectorValues.DocIndexIterator) docIdSetIterator).index(); + // **Check if ord is the same as lastOrd** - return cached vector + if (ord == docIdsIteratorValues.getLastOrd()) { + return (T) docIdsIteratorValues.getLastAccessedVector(); + } + docIdsIteratorValues.setLastOrd(ord); + docIdsIteratorValues.setLastAccessedVector(knnVectorValues.vectorValue(ord)); + return (T) docIdsIteratorValues.getLastAccessedVector(); } throw new IllegalArgumentException( "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and FloatVectorValues" @@ -86,8 +97,17 @@ public T extract(final VectorDataType vectorDataType, final KNNVectorValuesI final BinaryDocValues values = (BinaryDocValues) docIdSetIterator; final BytesRef bytesRef = values.binaryValue(); return (T) ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length); - } else if (docIdSetIterator instanceof ByteVectorValues) { - return (T) ((ByteVectorValues) docIdSetIterator).vectorValue(); + } else if (docIdSetIterator instanceof KnnVectorValues.DocIndexIterator) { + KNNVectorValuesIterator.DocIdsIteratorValues docIdsIteratorValues = + (KNNVectorValuesIterator.DocIdsIteratorValues) vectorValuesIterator; + ByteVectorValues byteVectorValues = (ByteVectorValues) docIdsIteratorValues.getKnnVectorValues(); + int ord = ((KnnVectorValues.DocIndexIterator) docIdSetIterator).index(); + if (ord == docIdsIteratorValues.getLastOrd()) { + return (T) docIdsIteratorValues.getLastAccessedVector(); + } + docIdsIteratorValues.setLastOrd(ord); + docIdsIteratorValues.setLastAccessedVector(byteVectorValues.vectorValue(ord)); + return (T) docIdsIteratorValues.getLastAccessedVector(); } throw new IllegalArgumentException( "VectorValuesIterator is not of a valid type. Valid Types are: BinaryDocValues and ByteVectorValues" diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index d0abe8612..387a23587 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -35,7 +35,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.cluster.health.ClusterHealthStatus; import org.opensearch.cluster.health.ClusterIndexHealth; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 78f3769c5..4ad6227d5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -127,7 +127,7 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques searchSourceBuilder.terminateAfter(DEFAULT_TERMINATE_AFTER); client.search(countRequest, ActionListener.wrap(searchResponse -> { - long trainingVectors = searchResponse.getHits().getTotalHits().value; + long trainingVectors = searchResponse.getHits().getTotalHits().value(); // If there are more docs in the index than what the user wants to use for training, take the min if (trainingModelRequest.getMaximumVectorCount() < trainingVectors) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java index 216efa78e..d374b4610 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java @@ -6,7 +6,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.core.common.io.stream.Writeable; /** diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java index 887f5d7a2..2cbda7b2e 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java @@ -7,7 +7,7 @@ import lombok.Getter; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.action.support.clustermanager.AcknowledgedRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java index 7d5750c2b..d9a26e1e0 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportAction.java @@ -10,7 +10,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.action.support.clustermanager.TransportClusterManagerNodeAction; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateTaskConfig; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java index 756a32575..dbc2d6c7f 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataAction.java @@ -12,7 +12,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.core.common.io.stream.Writeable; /** diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java index af063ad27..56aac2510 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java @@ -12,7 +12,7 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.action.support.clustermanager.AcknowledgedRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.indices.ModelMetadata; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java index ec909f443..01e4cbf36 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java @@ -15,7 +15,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.action.support.clustermanager.TransportClusterManagerNodeAction; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateTaskConfig; diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 7a8916981..e0ed615f7 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -8,4 +8,5 @@ org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec org.opensearch.knn.index.codec.KNN950Codec.KNN950Codec org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec +org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index cbe11dd6b..15a6bceaa 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -5,11 +5,11 @@ package org.opensearch.knn.index; +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; @@ -20,6 +20,7 @@ import org.junit.Before; import java.io.IOException; +import java.nio.ByteBuffer; public class KNNVectorDVLeafFieldDataTests extends KNNTestCase { @@ -42,12 +43,8 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, new float[] { 1.0f, 2.0f }, new FieldType()).binaryValue() - ) - ); + byte[] vectorBinary = encodeVector(new float[] { 1.0f, 2.0f }); + knnDocument.add(new BinaryDocValuesField(MOCK_INDEX_FIELD_NAME, new BytesRef(vectorBinary))); knnDocument.add(new NumericDocValuesField(MOCK_NUMERIC_INDEX_FIELD_NAME, 1000)); writer.addDocument(knnDocument); writer.commit(); @@ -96,4 +93,12 @@ public void testGetBytesValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } + + private byte[] encodeVector(float[] vector) { + ByteBuffer byteBuffer = ByteBuffer.allocate(vector.length * Float.BYTES); + for (float value : vector) { + byteBuffer.putFloat(value); + } + return byteBuffer.array(); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 66e2893c0..dcbfb4286 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -8,32 +8,27 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.*; +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.store.Directory; import org.junit.Assert; import org.junit.Before; +import org.junit.After; +import org.junit.Test; import java.io.IOException; +import java.nio.ByteBuffer; public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; - private KNNVectorScriptDocValues scriptDocValues; + private Directory directory; private DirectoryReader reader; @@ -41,71 +36,116 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); - createKNNVectorDocument(directory, valuesClass); - reader = DirectoryReader.open(directory); - LeafReader leafReader = reader.getContext().leaves().get(0).reader(); - DocIdSetIterator vectorValues; - if (BinaryDocValues.class.equals(valuesClass)) { - vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); - } else if (ByteVectorValues.class.equals(valuesClass)) { - vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); - } else { - vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); - } - - scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { - IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); - IndexWriter writer = new IndexWriter(directory, conf); - Document knnDocument = new Document(); - Field field; - if (BinaryDocValues.class.equals(valuesClass)) { - field = new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ); - } else if (ByteVectorValues.class.equals(valuesClass)) { - field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); - } else { - field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + @After + public void tearDown() throws Exception { + super.tearDown(); + if (reader != null) { + reader.close(); + } + if (directory != null) { + directory.close(); } + } - knnDocument.add(field); - writer.addDocument(knnDocument); - writer.commit(); - writer.close(); + /** Test for Float Vector Values */ + @Test + public void testFloatVectorValues() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + // Separate scriptDocValues instance for this test + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); } - @Override - public void tearDown() throws Exception { - super.tearDown(); - reader.close(); - directory.close(); + /** Test for Byte Vector Values */ + @Test + public void testByteVectorValues() throws IOException { + createKNNVectorDocument(directory, ByteVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.BYTE + ); + + scriptDocValues.setNextDocId(0); + Assert.assertArrayEquals(new float[] { SAMPLE_BYTE_VECTOR_DATA[0], SAMPLE_BYTE_VECTOR_DATA[1] }, scriptDocValues.getValue(), 0.1f); } - public void testGetValue() throws IOException { + /** Test for Binary Vector Values */ + @Test + public void testBinaryVectorValues() throws IOException { + createKNNVectorDocument(directory, BinaryDocValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.BINARY + ); + scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + Assert.assertNotNull(scriptDocValues.getValue()); // Just checking it's non-null } - // Test getValue without calling setNextDocId + /** Ensure getValue() fails without setNextDocId */ + @Test public void testGetValueFails() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + expectThrows(IllegalStateException.class, () -> scriptDocValues.getValue()); } + /** Ensure size() returns expected values */ + @Test public void testSize() throws IOException { + createKNNVectorDocument(directory, FloatVectorValues.class); + reader = DirectoryReader.open(directory); + LeafReader leafReader = reader.leaves().get(0).reader(); + + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); + Assert.assertEquals(0, scriptDocValues.size()); scriptDocValues.setNextDocId(0); Assert.assertEquals(1, scriptDocValues.size()); } - public void testGet() throws IOException { - expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + /** Ensure get() throws UnsupportedOperationException */ + @Test + public void testGet() { + expectThrows(UnsupportedOperationException.class, () -> { + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + scriptDocValues.get(0); + }); } + /** Test unsupported values type */ + @Test public void testUnsupportedValues() throws IOException { expectThrows( IllegalArgumentException.class, @@ -113,10 +153,39 @@ public void testUnsupportedValues() throws IOException { ); } + /** Ensure empty values case */ + @Test public void testEmptyValues() throws IOException { KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); assertEquals(0, values.size()); - scriptDocValues.setNextDocId(0); - assertEquals(0, values.size()); + } + + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + Document knnDocument = new Document(); + Field field; + + if (BinaryDocValues.class.equals(valuesClass)) { + byte[] vectorBinary = encodeVector(SAMPLE_VECTOR_DATA); + field = new BinaryDocValuesField(MOCK_INDEX_FIELD_NAME, new BytesRef(vectorBinary)); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } + + private byte[] encodeVector(float[] vector) { + ByteBuffer byteBuffer = ByteBuffer.allocate(vector.length * Float.BYTES); + for (float value : vector) { + byteBuffer.putFloat(value); + } + return byteBuffer.array(); } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 73af608c1..99b1b86c8 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -8,7 +8,6 @@ import lombok.SneakyThrows; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -20,6 +19,7 @@ import org.opensearch.knn.KNNTestCase; import java.io.IOException; +import java.nio.ByteBuffer; public class VectorDataTypeTests extends KNNTestCase { @@ -82,12 +82,7 @@ private void createKNNFloatVectorDocument(Directory directory) throws IOExceptio IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_FLOAT_INDEX_FIELD_NAME, - new VectorField(MOCK_FLOAT_INDEX_FIELD_NAME, SAMPLE_FLOAT_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + knnDocument.add(new BinaryDocValuesField(MOCK_FLOAT_INDEX_FIELD_NAME, new BytesRef(encodeVector(SAMPLE_FLOAT_VECTOR_DATA)))); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -97,12 +92,7 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - MOCK_BYTE_INDEX_FIELD_NAME, - new VectorField(MOCK_BYTE_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + knnDocument.add(new BinaryDocValuesField(MOCK_BYTE_INDEX_FIELD_NAME, new BytesRef(SAMPLE_BYTE_VECTOR_DATA))); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -114,4 +104,12 @@ public void testGetVectorFromBytesRef_whenBinary_thenException() { BytesRef bytesRef = new BytesRef(vector); assertArrayEquals(expected, VectorDataType.BINARY.getVectorFromBytesRef(bytesRef), 0.01f); } + + private byte[] encodeVector(float[] vector) { + ByteBuffer byteBuffer = ByteBuffer.allocate(vector.length * Float.BYTES); + for (float value : vector) { + byteBuffer.putFloat(value); + } + return byteBuffer.array(); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java index 6001a9729..b9aca7620 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java @@ -47,9 +47,9 @@ public static void closeStaticVariables() throws IOException { public void testGetCompoundReader() throws IOException { CompoundDirectory dir = mock(CompoundDirectory.class); CompoundFormat delegate = mock(CompoundFormat.class); - when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); + when(delegate.getCompoundReader(null, null)).thenReturn(dir); KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); - CompoundDirectory knnDir = knn80CompoundFormat.getCompoundReader(null, null, null); + CompoundDirectory knnDir = knn80CompoundFormat.getCompoundReader(null, null); assertTrue(knnDir instanceof KNN80CompoundDirectory); assertEquals(dir, ((KNN80CompoundDirectory) knnDir).getDelegate()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 90ed18d0d..86ed6b3ae 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -24,21 +24,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.NoMergePolicy; -import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.SerialMergeScheduler; -import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.*; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Sort; import org.apache.lucene.store.Directory; @@ -236,20 +222,24 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc } final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD); - floatVectorValues.nextDoc(); - assertArrayEquals(floatVector, floatVectorValues.vectorValue(), 0.0f); + floatVectorValues.iterator().nextDoc(); + assertArrayEquals(floatVector, floatVectorValues.vectorValue(floatVectorValues.iterator().index()), 0.0f); assertEquals(1, floatVectorValues.size()); assertEquals(3, floatVectorValues.dimension()); final ByteVectorValues byteVectorValues = leafReader.getByteVectorValues(BYTE_VECTOR_FIELD); - byteVectorValues.nextDoc(); - assertArrayEquals(byteVector, byteVectorValues.vectorValue()); + byteVectorValues.iterator().nextDoc(); + assertArrayEquals(byteVector, byteVectorValues.vectorValue(byteVectorValues.iterator().index())); assertEquals(1, byteVectorValues.size()); assertEquals(2, byteVectorValues.dimension()); final FloatVectorValues floatVectorValuesForBinaryQuantization = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); - floatVectorValuesForBinaryQuantization.nextDoc(); - assertArrayEquals(floatVectorForBinaryQuantization_1, floatVectorValuesForBinaryQuantization.vectorValue(), 0.0f); + floatVectorValuesForBinaryQuantization.iterator().nextDoc(); + assertArrayEquals( + floatVectorForBinaryQuantization_1, + floatVectorValuesForBinaryQuantization.vectorValue(floatVectorValuesForBinaryQuantization.iterator().index()), + 0.0f + ); assertEquals(2, floatVectorValuesForBinaryQuantization.size()); assertEquals(8, floatVectorValuesForBinaryQuantization.dimension()); @@ -296,8 +286,9 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce } final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); - floatVectorValues.nextDoc(); - assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(), 0.0f); + KnnVectorValues.DocIndexIterator docIndexIterator = floatVectorValues.iterator(); + docIndexIterator.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(docIndexIterator.index()), 0.0f); assertEquals(1, floatVectorValues.size()); assertEquals(8, floatVectorValues.dimension()); indexReader.close(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 315693a65..18c9e9667 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -9,7 +9,6 @@ import com.google.common.collect.ImmutableSet; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; @@ -411,9 +410,9 @@ public void testKnnVectorIndex( iwc1.setMergeScheduler(new SerialMergeScheduler()); iwc1.setCodec(codec); writer = new RandomIndexWriter(random(), dir, iwc1); - final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); + final FieldType luceneFieldType1 = KnnFloatVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); float[] array1 = { 6.0f, 14.0f }; - KnnVectorField vectorField1 = new KnnVectorField(FIELD_NAME_TWO, array1, luceneFieldType1); + KnnFloatVectorField vectorField1 = new KnnFloatVectorField(FIELD_NAME_TWO, array1, luceneFieldType1); Document doc1 = new Document(); doc1.add(vectorField1); writer.addDocument(doc1); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index d6f22ca7f..64c5371db 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -10,13 +10,7 @@ import lombok.Builder; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.index.*; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; @@ -172,6 +166,8 @@ public FieldInfo build() { storePayloads, indexOptions, docValuesType, + DocValuesSkipIndexType.NONE, + dvGen, attributes, pointDimensionCount, @@ -191,7 +187,7 @@ public static void assertFileInCorrectLocation(SegmentWriteState state, String e } public static void assertValidFooter(Directory dir, String filename) throws IOException { - ChecksumIndexInput indexInput = dir.openChecksumInput(filename, IOContext.DEFAULT); + ChecksumIndexInput indexInput = dir.openChecksumInput(filename); indexInput.seek(indexInput.length() - CodecUtil.footerLength()); CodecUtil.checkFooter(indexInput); indexInput.close(); @@ -205,7 +201,7 @@ public static void assertLoadableByEngine( SpaceType spaceType, int dimension ) { - try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long indexPtr = JNIService.loadIndex( indexInputWithBuffer, @@ -230,7 +226,7 @@ public static void assertBinaryIndexLoadableByEngine( int dimension, VectorDataType vectorDataType ) { - try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long indexPtr = JNIService.loadIndex( indexInputWithBuffer, diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 8011cc08c..7a1da8781 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -99,6 +99,8 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT; public class KNNWeightTests extends KNNTestCase { private static final String FIELD_NAME = "target_field"; @@ -146,6 +148,12 @@ public static void setUpClass() throws Exception { knnSettingsMockedStatic.when(KNNSettings::getCircuitBreakerLimit).thenReturn(v); knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); knnSettingsMockedStatic.when(KNNSettings::isKNNPluginEnabled).thenReturn(true); + ByteSizeValue cacheSize = ByteSizeValue.parseBytesSizeValue("1024kb", QUANTIZATION_STATE_CACHE_SIZE_LIMIT); // Setting 1MB as an + // example + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_SIZE_LIMIT))).thenReturn(cacheSize); + // Mock QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES setting + TimeValue mockTimeValue = TimeValue.timeValueMinutes(10); + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES))).thenReturn(mockTimeValue); nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); @@ -371,7 +379,7 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() // When no knn fields are available , field info for vector field will be null when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); + assertEquals(KNNScorer.emptyScorer(), knnScorer); } @SneakyThrows @@ -415,7 +423,7 @@ public void testEmptyQueryResults() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); + assertEquals(KNNScorer.emptyScorer(), knnScorer); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index a3b8c6989..d75a9a7bc 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -72,7 +72,7 @@ public void testResultMapToTopDocs() { } private void assertResultMapToTopDocs(Map perLeafResults, TopDocs topDocs, int k, int offset) { - assertEquals(k, topDocs.totalHits.value); + assertEquals(k, topDocs.totalHits.value()); float previousScore = Float.MAX_VALUE; for (ScoreDoc scoreDoc : topDocs.scoreDocs) { assertTrue(perLeafResults.containsKey(scoreDoc.doc - offset)); diff --git a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index b32496138..607699b56 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -6,8 +6,12 @@ package org.opensearch.knn.index.query.common; import lombok.SneakyThrows; +import org.apache.lucene.document.Document; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -15,9 +19,14 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.apache.lucene.store.ByteBuffersDirectory; + +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.mockito.Mock; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; + import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; @@ -27,9 +36,6 @@ public class DocAndScoreQueryTests extends OpenSearchTestCase { private LeafReaderContext leaf1; @Mock private IndexSearcher indexSearcher; - @Mock - private IndexReader reader; - @Mock private IndexReaderContext readerContext; private DocAndScoreQuery objectUnderTest; @@ -39,9 +45,9 @@ public void setUp() throws Exception { super.setUp(); openMocks(this); + IndexReader reader = createTestIndexReader(); when(indexSearcher.getIndexReader()).thenReturn(reader); - when(reader.getContext()).thenReturn(readerContext); - when(readerContext.id()).thenReturn(1); + readerContext = reader.getContext(); } // Note: cannot test with multi leaf as there LeafReaderContext is readonly with no getters for some fields to mock @@ -50,7 +56,7 @@ public void testScorer() throws Exception { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); // When Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); @@ -85,7 +91,7 @@ public void testWeight() { Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); // When - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); Explanation explanation = weight.explain(leaf1, 1); @@ -96,4 +102,13 @@ public void testWeight() { assertEquals(expectedExplanation, explanation); assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); } + + private IndexReader createTestIndexReader() throws IOException { + ByteBuffersDirectory directory = new ByteBuffersDirectory(); + IndexWriter writer = new IndexWriter(directory, new IndexWriterConfig(new MockAnalyzer(random()))); + writer.addDocument(new Document()); + writer.close(); + return DirectoryReader.open(directory); + } + } diff --git a/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java index 55a110f6a..ecda53e1b 100644 --- a/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java @@ -7,8 +7,9 @@ import junit.framework.TestCase; import lombok.SneakyThrows; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -17,6 +18,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; import org.apache.lucene.util.Bits; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -49,17 +52,38 @@ public void setUp() throws Exception { @SneakyThrows public void testCreateWeight_whenCalled_thenSucceed() { - LeafReaderContext leafReaderContext1 = mock(LeafReaderContext.class); - LeafReaderContext leafReaderContext2 = mock(LeafReaderContext.class); - List leafReaderContexts = Arrays.asList(leafReaderContext1, leafReaderContext2); + Directory directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + + IndexReader reader = DirectoryReader.open(directory); - IndexReader indexReader = mock(IndexReader.class); - when(indexReader.leaves()).thenReturn(leafReaderContexts); + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + LeafReaderContext leaf1 = leaves.get(0); + LeafReaderContext leaf2 = leaves.get(1); Weight filterWeight = mock(Weight.class); IndexSearcher indexSearcher = mock(IndexSearcher.class); - when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(indexSearcher.getIndexReader()).thenReturn(reader); when(indexSearcher.getTaskExecutor()).thenReturn(taskExecutor); when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE_NO_SCORES), eq(1.0F))).thenReturn(filterWeight); @@ -97,10 +121,10 @@ public void testCreateWeight_whenCalled_thenSucceed() { when(finalQuery.createWeight(indexSearcher, scoreMode, boost)).thenReturn(expectedWeight); QueryUtils queryUtils = mock(QueryUtils.class); - when(queryUtils.doSearch(indexSearcher, leafReaderContexts, queryWeight)).thenReturn(perLeafResults); + when(queryUtils.doSearch(indexSearcher, reader.leaves(), queryWeight)).thenReturn(perLeafResults); when(queryUtils.createBits(any(), any())).thenReturn(queryFilterBits); when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); - when(queryUtils.createDocAndScoreQuery(eq(indexReader), any())).thenReturn(finalQuery); + when(queryUtils.createDocAndScoreQuery(eq(reader), any())).thenReturn(finalQuery); // Run ExpandNestedDocsQuery query = new ExpandNestedDocsQuery(internalQuery, queryUtils); @@ -108,12 +132,12 @@ public void testCreateWeight_whenCalled_thenSucceed() { // Verify assertEquals(expectedWeight, finalWeigh); - verify(queryUtils).createBits(leafReaderContext1, filterWeight); - verify(queryUtils).createBits(leafReaderContext2, filterWeight); - verify(queryUtils).getAllSiblings(leafReaderContext1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); - verify(queryUtils).getAllSiblings(leafReaderContext2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).createBits(leaf1, filterWeight); + verify(queryUtils).createBits(leaf2, filterWeight); + verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); - verify(queryUtils).createDocAndScoreQuery(eq(indexReader), topDocsCaptor.capture()); + verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture()); TopDocs capturedTopDocs = topDocsCaptor.getValue(); assertEquals(topK.totalHits, capturedTopDocs.totalHits); for (int i = 0; i < topK.scoreDocs.length; i++) { diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 87c4a5014..82dc43bf3 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -6,10 +6,9 @@ package org.opensearch.knn.index.query.nativelib; import lombok.SneakyThrows; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexReaderContext; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; @@ -21,6 +20,9 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.util.Bits; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -38,6 +40,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -48,34 +51,26 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.openMocks; public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @Mock private IndexSearcher searcher; - @Mock private IndexReader reader; + private Directory directory; + private DirectoryReader directoryReader; @Mock private KNNQuery knnQuery; @Mock private KNNWeight knnWeight; @Mock private TaskExecutor taskExecutor; - @Mock private IndexReaderContext indexReaderContext; - @Mock private LeafReaderContext leaf1; - @Mock private LeafReaderContext leaf2; - @Mock private LeafReader leafReader1; - @Mock private LeafReader leafReader2; @Mock @@ -90,12 +85,10 @@ public void setUp() throws Exception { super.setUp(); openMocks(this); objectUnderTest = new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, false); - when(leaf1.reader()).thenReturn(leafReader1); - when(leaf2.reader()).thenReturn(leafReader2); - + reader = createTestIndexReader(); + indexReaderContext = reader.getContext(); when(searcher.getIndexReader()).thenReturn(reader); when(knnQuery.createWeight(searcher, scoreMode, 1)).thenReturn(knnWeight); - when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { List> callables = invocationOnMock.getArgument(0); @@ -105,11 +98,7 @@ public void setUp() throws Exception { } return results; }); - - when(reader.getContext()).thenReturn(indexReaderContext); - when(clusterService.state()).thenReturn(mock(ClusterState.class)); // Mock ClusterState - // Set ClusterService in KNNSettings KNNSettings.state().setClusterService(clusterService); when(knnQuery.getQueryVector()).thenReturn(new float[] { 1.0f, 2.0f, 3.0f }); // Example vector @@ -118,29 +107,71 @@ public void setUp() throws Exception { @SneakyThrows public void testMultiLeaf() { - // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + + // Initialize DirectoryReader and IndexSearcher + // Open the real DirectoryReader + DirectoryReader originalReader = DirectoryReader.open(directory); + // Define liveDocs for each segment + Bits liveDocs1 = new Bits() { + @Override + public boolean get(int index) { + return index != 1 && index != 2; // Document 1 and 2 are deleted + } + + @Override + public int length() { + return originalReader.leaves().get(0).reader().maxDoc(); + } + }; + + Bits liveDocs2 = null; // No deletions in the second segment + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); + // Given PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); PerLeafResult leaf2Result = new PerLeafResult(null, new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leaf1Result); when(knnWeight.searchLeaf(leaf2, 4)).thenReturn(leaf2Result); - - // Making sure there is deleted docs in one of the segments - Bits liveDocs = mock(Bits.class); - when(leafReader1.getLiveDocs()).thenReturn(liveDocs); - when(leafReader2.getLiveDocs()).thenReturn(null); - - when(liveDocs.get(anyInt())).thenReturn(true); - when(liveDocs.get(2)).thenReturn(false); - when(liveDocs.get(1)).thenReturn(false); + when(searcher.getIndexReader()).thenReturn(reader); // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves when(knnQuery.getK()).thenReturn(4); - when(indexReaderContext.id()).thenReturn(1); - Map leaf1ResultLive = Map.of(0, 1.2f); TopDocs[] topDocs = { ResultUtil.resultMapToTopDocs(leaf1ResultLive, leaf1.docBase), @@ -157,9 +188,48 @@ public void testMultiLeaf() { @SneakyThrows public void testRescoreWhenShardLevelRescoringEnabled() { - // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); int k = 2; PerLeafResult initialLeaf1Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f))); @@ -176,6 +246,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { when(knnWeight.searchLeaf(leaf2, firstPassK)).thenReturn(initialLeaf2Results); when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + when(searcher.getIndexReader()).thenReturn(reader); try ( MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); @@ -205,11 +276,10 @@ public void testSingleLeaf() { int k = 4; float boost = 1; PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); - List leaves = List.of(leaf1); - when(reader.leaves()).thenReturn(leaves); + List leaves = reader.leaves(); + leaf1 = leaves.get(0); when(knnWeight.searchLeaf(leaf1, k)).thenReturn(leaf1Result); when(knnQuery.getK()).thenReturn(k); - when(indexReaderContext.id()).thenReturn(1); TopDocs expectedTopDocs = ResultUtil.resultMapToTopDocs(leaf1Result.getResult(), leaf1.docBase); // When @@ -223,8 +293,8 @@ public void testSingleLeaf() { @SneakyThrows public void testNoMatch() { // Given - List leaves = List.of(leaf1); - when(reader.leaves()).thenReturn(leaves); + List leaves = reader.leaves(); + leaf1 = leaves.get(0); when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(PerLeafResult.EMPTY_RESULT); when(knnQuery.getK()).thenReturn(4); @@ -238,8 +308,47 @@ public void testNoMatch() { @SneakyThrows public void testRescore() { // Given - List leaves = List.of(leaf1, leaf2); - when(reader.leaves()).thenReturn(leaves); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); int k = 2; int firstPassK = 100; @@ -249,7 +358,6 @@ public void testRescore() { Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(Map.of(1, 20f), 0); TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(Map.of(0, 21f), 4); - when(indexReaderContext.id()).thenReturn(1); when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); when(knnQuery.getK()).thenReturn(k); when(knnWeight.getQuery()).thenReturn(knnQuery); @@ -258,6 +366,7 @@ public void testRescore() { when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); + when(searcher.getIndexReader()).thenReturn(reader); try ( MockedStatic mockedKnnSettings = mockStatic(KNNSettings.class); @@ -286,8 +395,47 @@ public void testRescore() { @SneakyThrows public void testExpandNestedDocs() { - List leafReaderContexts = Arrays.asList(leaf1, leaf2); - when(reader.leaves()).thenReturn(leafReaderContexts); + directory = new ByteBuffersDirectory(); + IndexWriterConfig config = new IndexWriterConfig(); + try (IndexWriter writer = new IndexWriter(directory, config)) { + // Add documents to simulate multiple segments + Document doc1 = new Document(); + doc1.add(new FloatPoint("vector", 1.0f, 2.0f, 3.0f)); + writer.addDocument(doc1); + Document doc2 = new Document(); + doc2.add(new FloatPoint("vector", 4.0f, 5.0f, 6.0f)); + writer.addDocument(doc2); + // Force the creation of a second segment + writer.flush(); + Document doc3 = new Document(); + doc3.add(new FloatPoint("vector", 7.0f, 8.0f, 9.0f)); + writer.addDocument(doc3); + Document doc4 = new Document(); + doc4.add(new FloatPoint("vector", 10.0f, 11.0f, 12.0f)); + writer.addDocument(doc4); + writer.commit(); + } + Bits liveDocs1 = null; + + Bits liveDocs2 = null; + + DirectoryReader originalReader = DirectoryReader.open(directory); + + // Wrap the DirectoryReader to inject custom liveDocs logic + directoryReader = CustomFilterDirectoryReader.wrap(originalReader, liveDocs1, liveDocs2); + + // Set the reader and searcher + reader = directoryReader; + ; + indexReaderContext = reader.getContext(); + // Extract LeafReaderContext + List leaves = reader.leaves(); + assertEquals(2, leaves.size()); // Ensure we have two segments + leaf1 = leaves.get(0); + leaf2 = leaves.get(1); + // Simulate liveDocs for leaf1 (e.g., marking some documents as deleted) + leafReader1 = leaf1.reader(); + leafReader2 = leaf2.reader(); Bits queryFilterBits = mock(Bits.class); PerLeafResult initialLeaf1Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 19f, 1, 20f, 2, 17f, 3, 15f))); PerLeafResult initialLeaf2Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 21f, 1, 18f, 2, 16f, 3, 14f))); @@ -296,9 +444,10 @@ public void testExpandNestedDocs() { Map exactSearchLeaf1Result = new HashMap<>(Map.of(1, 20f)); Map exactSearchLeaf2Result = new HashMap<>(Map.of(0, 21f)); - TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(exactSearchLeaf1Result, 0); - TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(exactSearchLeaf2Result, 0); + TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(exactSearchLeaf1Result, leaf1.docBase); + TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(exactSearchLeaf2Result, leaf2.docBase); TopDocs topK = TopDocs.merge(2, new TopDocs[] { topDocs1, topDocs2 }); + when(searcher.getIndexReader()).thenReturn(reader); int k = 2; when(knnQuery.getRescoreContext()).thenReturn(null); @@ -308,7 +457,8 @@ public void testExpandNestedDocs() { when(knnQuery.getParentsFilter()).thenReturn(parentFilter); when(knnWeight.searchLeaf(leaf1, k)).thenReturn(initialLeaf1Results); when(knnWeight.searchLeaf(leaf2, k)).thenReturn(initialLeaf2Results); - when(knnWeight.exactSearch(any(), any())).thenReturn(exactSearchLeaf1Result, exactSearchLeaf2Result); + when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(exactSearchLeaf1Result); + when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(exactSearchLeaf2Result); Weight filterWeight = mock(Weight.class); when(knnWeight.getFilterWeight()).thenReturn(filterWeight); @@ -350,4 +500,99 @@ public void testExpandNestedDocs() { assertEquals(2, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); assertEquals(DocIdSetIterator.NO_MORE_DOCS, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); } + + private IndexReader createTestIndexReader() throws IOException { + ByteBuffersDirectory directory = new ByteBuffersDirectory(); + IndexWriter writer = new IndexWriter(directory, new IndexWriterConfig(new MockAnalyzer(random()))); + writer.addDocument(new Document()); + writer.close(); + return DirectoryReader.open(directory); + } +} + +class CustomFilterDirectoryReader extends FilterDirectoryReader { + + private final Bits liveDocs1; + private final Bits liveDocs2; + + protected CustomFilterDirectoryReader(DirectoryReader in, Bits liveDocs1, Bits liveDocs2) throws IOException { + super(in, getWrapper(liveDocs1, liveDocs2)); + this.liveDocs1 = liveDocs1; + this.liveDocs2 = liveDocs2; + } + + @Override + protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException { + return new CustomFilterDirectoryReader(in, liveDocs1, liveDocs2); + } + + private static SubReaderWrapper getWrapper(Bits liveDocs1, Bits liveDocs2) { + return new SubReaderWrapper() { + @Override + public LeafReader wrap(LeafReader reader) { + if (reader.getContext().ord == 0) { // First segment + return new FilterLeafReader(reader) { + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + + /** + * @return + */ + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public Bits getLiveDocs() { + return liveDocs1; + } + }; + } else if (reader.getContext().ord == 1) { // Second segment + return new FilterLeafReader(reader) { + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } + + /** + * @return + */ + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public Bits getLiveDocs() { + return liveDocs2; + } + }; + } else { + return reader; // Default case + } + } + }; + } + + // Remove the static modifier to fix the error + public static DirectoryReader wrap(DirectoryReader reader, Bits liveDocs1, Bits liveDocs2) throws IOException { + return new CustomFilterDirectoryReader(reader, liveDocs1, liveDocs2); + } + + /** + * @return + */ + @Override + public CacheHelper getReaderCacheHelper() { + return null; + } } diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java index 0b631ab41..99cb383e5 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java @@ -8,7 +8,6 @@ import lombok.SneakyThrows; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.search.DocIdSetIterator; -import org.junit.Assert; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; @@ -37,7 +36,6 @@ public void testFloatVectorValues_whenValidInput_thenSuccess() { vectorsMap ); new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, floatArray, 8, dimension, false); - final TestVectorValues.PredefinedFloatVectorBinaryDocValues preDefinedFloatVectorValues = new TestVectorValues.PredefinedFloatVectorBinaryDocValues(floatArray); final KNNVectorValues knnFloatVectorValuesBinaryDocValues = KNNVectorValuesFactory.getVectorValues( @@ -101,13 +99,6 @@ public void testBinaryVectorValues_whenValidInput_thenSuccess() { new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, 3, dimension, false); } - public void testDocIdsIteratorValues_whenInvalidDisi_thenThrowException() { - Assert.assertThrows( - IllegalArgumentException.class, - () -> new KNNVectorValuesIterator.DocIdsIteratorValues(new TestVectorValues.NotBinaryDocValues()) - ); - } - private DocsWithFieldSet getDocIdSetIterator(int numberOfDocIds) { final DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); for (int i = 0; i < numberOfDocIds; i++) { diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 0f15d5240..337ab6c48 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -5,14 +5,7 @@ package org.opensearch.knn.index.vectorvalues; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.SortedDocValues; -import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.*; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -116,6 +109,18 @@ public SortedSetDocValues getSortedSet(FieldInfo field) { return null; } + /** + * Returns a {@link DocValuesSkipper} for this field. The returned instance need not be + * thread-safe: it will only be used by a single thread. The return value is undefined if {@link + * FieldInfo#docValuesSkipIndexType()} returns {@link DocValuesSkipIndexType#NONE}. + * + * @param field + */ + @Override + public DocValuesSkipper getSkipper(FieldInfo field) throws IOException { + return null; + } + @Override public void checkIntegrity() { @@ -204,7 +209,6 @@ public int size() { return count; } - @Override public float[] vectorValue() throws IOException { // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an @@ -213,29 +217,49 @@ public float[] vectorValue() throws IOException { return vector; } + @Override + public float[] vectorValue(int ordId) throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(ordId), 0, vector, 0, dimension); + return vector; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + /** + * @return + * @throws IOException + */ + @Override + public FloatVectorValues copy() throws IOException { + return null; + } + @Override public VectorScorer scorer(float[] query) throws IOException { throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); } - @Override public int docID() { if (this.current > this.count) { - return FloatVectorValues.NO_MORE_DOCS; + return DocIndexIterator.NO_MORE_DOCS; } return this.current; } - @Override public int nextDoc() throws IOException { return advance(current + 1); } - @Override public int advance(int target) throws IOException { current = target; if (current >= count) { - current = NO_MORE_DOCS; + current = DocIndexIterator.NO_MORE_DOCS; } return current; } @@ -267,7 +291,6 @@ public int size() { return count; } - @Override public byte[] vectorValue() throws IOException { // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an @@ -276,29 +299,49 @@ public byte[] vectorValue() throws IOException { return vector; } + @Override + public byte[] vectorValue(int ordId) throws IOException { + // since in FloatVectorValues the reference to returned vector doesn't change. This code ensure that we + // are replicating the behavior so that if someone uses this RandomFloatVectorValues they get an + // experience similar to what we get in prod. + System.arraycopy(vectors.get(ordId), 0, vector, 0, dimension); + return vector; + } + + /** + * @return + * @throws IOException + */ + @Override + public ByteVectorValues copy() throws IOException { + return null; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + @Override public VectorScorer scorer(byte[] query) throws IOException { throw new UnsupportedOperationException("scorer not supported with PreDefinedFloatVectorValues"); } - @Override public int docID() { if (this.current > this.count) { - return FloatVectorValues.NO_MORE_DOCS; + return DocIndexIterator.NO_MORE_DOCS; } return this.current; } - @Override public int nextDoc() throws IOException { return advance(current + 1); } - @Override public int advance(int target) throws IOException { current = target; if (current >= count) { - current = NO_MORE_DOCS; + current = DocIndexIterator.NO_MORE_DOCS; } return current; } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 1edb5cff2..e011ba57f 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -29,7 +29,7 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.knn.KNNSingleNodeTestCase; diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 42a4f2b95..d6d6a7afd 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -856,9 +856,8 @@ private void createIndexAndAssertScriptScore( dense, vectorDataType ); - final float[] dummyVector = new float[1]; dataset.forEach((k, v) -> { - final float[] vector = (v != null) ? v.getVector() : dummyVector; + final float[] vector = (v != null) ? v.getVector() : dummyFloatArrayBasedOnDimension(dimensions); ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); }); @@ -886,4 +885,8 @@ private void createIndexAndAssertScriptScore( deleteKNNIndex(INDEX_NAME); } } + + private float[] dummyFloatArrayBasedOnDimension(int dimesion) { + return new float[dimesion]; + } } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index e116ef3c6..37c00a104 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -601,7 +601,7 @@ public void testLoadIndex_faiss_sqfp16_valid() { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -669,7 +669,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -877,7 +877,7 @@ public void testLoadIndex_nmslib_valid() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -938,7 +938,7 @@ public void testLoadIndex_nmslib_valid_with_stream() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -985,7 +985,7 @@ public void testLoadIndex_faiss_valid() throws IOException { ); assertTrue(directory.fileLength(indexFileName1) > 0); - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1024,7 +1024,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1065,7 +1065,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1108,7 +1108,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1137,7 +1137,7 @@ public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOEx assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1174,7 +1174,7 @@ public void testQueryIndex_faiss_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1310,7 +1310,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1425,7 +1425,7 @@ public void testQueryBinaryIndex_faiss_valid() { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1561,7 +1561,7 @@ public void testFree_nmslib_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex( indexInputWithBuffer, @@ -1595,7 +1595,7 @@ public void testFree_faiss_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1800,7 +1800,7 @@ public void createIndexFromTemplate() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); final long pointer; - try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); @@ -1862,7 +1862,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { String indexIVFPQPath = createFaissIVFPQIndex(directory, ivfNlist, pqM, pqCodeSize, SpaceType.L2); final long indexIVFPQIndexTest1; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest1 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest1); @@ -1871,7 +1871,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { throw e; } final long indexIVFPQIndexTest2; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest2 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest2); @@ -1891,7 +1891,7 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { JNIService.free(indexIVFPQIndexTest1, KNNEngine.FAISS); final long indexIVFPQIndexTest3; - try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); indexIVFPQIndexTest3 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, indexIVFPQIndexTest3); @@ -1919,7 +1919,7 @@ public void testIsIndexIVFPQL2() { Path tempDirPath = createTempDir(); try (Directory directory = newFSDirectory(tempDirPath)) { String faissIVFPQL2Index = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.L2); - try (IndexInput indexInput = directory.openInput(faissIVFPQL2Index, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissIVFPQL2Index, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissIVFPQL2Address = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertTrue(JNIService.isSharedIndexStateRequired(faissIVFPQL2Address, KNNEngine.FAISS)); @@ -1927,7 +1927,7 @@ public void testIsIndexIVFPQL2() { } String faissIVFPQIPIndex = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.INNER_PRODUCT); - try (IndexInput indexInput = directory.openInput(faissIVFPQIPIndex, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissIVFPQIPIndex, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissIVFPQIPAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertFalse(JNIService.isSharedIndexStateRequired(faissIVFPQIPAddress, KNNEngine.FAISS)); @@ -1935,7 +1935,7 @@ public void testIsIndexIVFPQL2() { } String faissHNSWIndex = createFaissHNSWIndex(directory, SpaceType.L2); - try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.LOAD)) { + try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long faissHNSWAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertFalse(JNIService.isSharedIndexStateRequired(faissHNSWAddress, KNNEngine.FAISS)); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 2cc20c8f9..18d74b541 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -5,17 +5,18 @@ package org.opensearch.knn.plugin.script; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Locale; + +import org.apache.lucene.util.BytesRef; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.VectorField; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -354,10 +355,18 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam IndexWriter writer = new IndexWriter(directory, conf); conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test Document knnDocument = new Document(); - knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); + knnDocument.add(new BinaryDocValuesField(fieldName, new BytesRef(encodeVector(content)))); writer.addDocument(knnDocument); writer.commit(); writer.close(); } + + private byte[] encodeVector(float[] vector) { + ByteBuffer byteBuffer = ByteBuffer.allocate(vector.length * Float.BYTES); + for (float value : vector) { + byteBuffer.putFloat(value); + } + return byteBuffer.array(); + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index cac5c1b9c..1bbb388fd 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -9,7 +9,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index 48b93653f..f94b661bb 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -13,7 +13,7 @@ import org.opensearch.Version; import org.opensearch.core.action.ActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.io.stream.BytesStreamOutput;