diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 4479099e8..d86d91ab6 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -22,6 +22,10 @@ public class KNNConstants { public static final String PATH = "path"; public static final String QUERY = "query"; public static final String KNN = "knn"; + public static final String EXACT_SEARCH = "Exact"; + public static final String ANN_SEARCH = "Approximate-NN"; + public static final String RADIAL_SEARCH = "Radial"; + public static final String DISK_BASED_SEARCH = "Disk-based"; public static final String VECTOR = "vector"; public static final String K = "k"; public static final String TYPE_KNN_VECTOR = "knn_vector"; diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 5d90071e8..18224d0fb 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -35,6 +35,11 @@ public float scoreTranslation(final float rawScore) { throw new IllegalStateException("Unsupported method"); } + @Override + public String explainScoreTranslation(float rawScore) { + throw new IllegalStateException("Unsupported method"); + } + @Override public void validateVectorDataType(VectorDataType vectorDataType) { throw new IllegalStateException("Unsupported method"); @@ -46,6 +51,11 @@ public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } + @Override + public String explainScoreTranslation(float rawScore) { + return "`1 / (1 + rawScore)`"; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.EUCLIDEAN; @@ -77,6 +87,11 @@ public float scoreTranslation(float rawScore) { return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } + @Override + public String explainScoreTranslation(float rawScore) { + return "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`"; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.COSINE; @@ -105,12 +120,22 @@ public void validateVector(float[] vector) { public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } + + @Override + public String explainScoreTranslation(float rawScore) { + return "`1 / (1 + rawScore)`"; + } }, LINF("linf") { @Override public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } + + @Override + public String explainScoreTranslation(float rawScore) { + return "`1 / (1 + rawScore)`"; + } }, INNER_PRODUCT("innerproduct") { /** @@ -129,6 +154,14 @@ public float scoreTranslation(float rawScore) { return -rawScore + 1; } + @Override + public String explainScoreTranslation(float rawScore) { + if (rawScore >= 0) { + return "`1 / (1 + rawScore)`"; + } + return "`-rawScore + 1`"; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; @@ -140,6 +173,11 @@ public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } + @Override + public String explainScoreTranslation(float rawScore) { + return "`1 / (1 + rawScore)`"; + } + @Override public void validateVectorDataType(VectorDataType vectorDataType) { if (VectorDataType.BINARY != vectorDataType) { @@ -177,6 +215,8 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { public abstract float scoreTranslation(float rawScore); + public abstract String explainScoreTranslation(float rawScore); + /** * Get KNNVectorSimilarityFunction that maps to this SpaceType * diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 1a03f4b99..03eb67f82 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -52,6 +52,9 @@ public class KNNQuery extends Query { private BitSetProducer parentsFilter; private Float radius; private Context context; + @Setter + @Getter + private boolean explain; public KNNQuery( final String field, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad..5e51c654b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -35,6 +35,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; +import org.opensearch.knn.index.query.explain.KnnExplanation; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -73,6 +74,7 @@ public class KNNWeight extends Weight { private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService; + private final KnnExplanation knnExplanation; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) { this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + this.knnExplanation = new KnnExplanation(); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + this.knnExplanation = new KnnExplanation(); } public static void initialize(ModelDao modelDao) { @@ -105,8 +109,157 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { } @Override + // This method is called in case of Radial-Search public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.match(1.0f, "No Explanation"); + return explain(context, doc, 0, null); + } + + // This method is called for ANN/Exact/Disk-based/Efficient-filtering search + public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) { + knnQuery.setExplain(true); + try { + knnScorer = getOrCreateKnnScorer(context, knnScorer); + float knnScore = getKnnScore(knnScorer, doc); + + if (score == 0) { + score = knnScore; + } + assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score"; + } catch (IOException e) { + throw new RuntimeException("Error while explaining KNN score", e); + } + + final String highLevelExplanation = getHighLevelExplanation(); + final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context); + + final SegmentReader reader = Lucene.segmentReader(context.reader()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); + if (fieldInfo == null) { + return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); + } + final SpaceType spaceType = getSpaceType(fieldInfo); + leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue()); + + final Float rawScore = knnExplanation.getRawScores().get(doc); + Explanation rawScoreDetail = null; + if (rawScore != null && knnQuery.getRescoreContext() == null) { + leafLevelExplanation.append(" where score is computed as ") + .append(spaceType.explainScoreTranslation(rawScore)) + .append(" from:"); + rawScoreDetail = Explanation.match( + rawScore, + "rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library" + ); + } + + return rawScoreDetail != null + ? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail)) + : Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); + } + + private StringBuilder getLeafLevelExplanation(LeafReaderContext context) { + int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); + int cardinality = knnExplanation.getCardinality(); + StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was "); + if (filterWeight != null) { + if (isFilterIdCountLessThanK(cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since filteredIds = ") + .append(cardinality) + .append(" is less than or equal to K = ") + .append(knnQuery.getK()); + } else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since filtered threshold value = ") + .append(filterThresholdValue) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since max distance computation = ") + .append(KNNConstants.MAX_DISTANCE_COMPUTATIONS) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } + } + if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null + && knnExplanation.getAnnResultPerLeaf().get(context.id()) == 0 + && isMissingNativeEngineFiles(context)) { + sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available"); + } + if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null + && isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since the number of documents returned are less than K = ") + .append(knnQuery.getK()) + .append(" and there are more than K filtered Ids = ") + .append(cardinality); + } + if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null + && knnExplanation.getAnnResultPerLeaf().get(context.id()) > 0 + && !isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) { + sb.append(KNNConstants.ANN_SEARCH); + } + sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType()); + return sb; + } + + private SpaceType getSpaceType(FieldInfo fieldInfo) { + try { + return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + } catch (IllegalArgumentException e) { + return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT; + } + } + + private String getHighLevelExplanation() { + StringBuilder sb = new StringBuilder("the type of knn search executed was "); + if (knnQuery.getRescoreContext() != null) { + sb.append(buildDiskBasedSearchExplanation()); + } else if (knnQuery.getRadius() != null) { + sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius()); + } else { + sb.append(KNNConstants.ANN_SEARCH); + } + return sb.toString(); + } + + private String buildDiskBasedSearchExplanation() { + StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); + int dimension = knnQuery.getQueryVector().length; + int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension); + sb.append(" and the first pass k was ") + .append(firstPassK) + .append(" with vector dimension of ") + .append(dimension) + .append(", over sampling factor of ") + .append(knnQuery.getRescoreContext().getOversampleFactor()); + if (isShardLevelRescoringDisabled) { + sb.append(", shard level rescoring disabled"); + } else { + sb.append(", shard level rescoring enabled"); + } + return sb.toString(); + } + + private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException { + if (existingScorer != null) { + return existingScorer; + } + + KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context); + if (cachedScorer != null) { + return cachedScorer; + } + + KNNScorer newScorer = (KNNScorer) scorer(context); + knnExplanation.getKnnScorerPerLeaf().put(context, newScorer); + return newScorer; + } + + private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException { + return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0; } @Override @@ -137,6 +290,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep if (filterWeight != null && cardinality == 0) { return PerLeafResult.EMPTY_RESULT; } + if (knnQuery.isExplain()) { + knnExplanation.setCardinality(cardinality); + } /* * The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. @@ -153,7 +309,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep */ final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); - + if (knnQuery.isExplain()) { + knnExplanation.getAnnResultPerLeaf().put(context.id(), docIdsToScoreMap.size()); + } // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs @@ -383,6 +541,15 @@ private Map doANNSearch( log.debug("[KNN] Query yielded 0 results"); return Collections.emptyMap(); } + if (knnQuery.isExplain()) { + Arrays.stream(results).forEach(result -> { + if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) { + knnExplanation.getRawScores().put(result.getId(), -1 * result.getScore()); + } else { + knnExplanation.getRawScores().put(result.getId(), result.getScore()); + } + }); + } if (quantizedVector != null) { return Arrays.stream(results) @@ -425,12 +592,13 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { ); int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { - return true; - } + if (isFilterIdCountLessThanK(filterIdsCount)) return true; // See user has defined Exact Search filtered threshold. if yes, then use that setting. if (isExactSearchThresholdSettingSet(filterThresholdValue)) { - return filterThresholdValue >= filterIdsCount; + if (filterThresholdValue >= filterIdsCount) { + return true; + } + return false; } // if no setting is set, then use the default max distance computation value to see if we can do exact search. @@ -438,11 +606,19 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { * TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index * is cheaper than computation cost for non binary vector */ + return isMDCGreaterThanFilterIdCnt(filterIdsCount); + } + + private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) { return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT ? knnQuery.getQueryVector().length : knnQuery.getByteQueryVector().length); } + private boolean isFilterIdCountLessThanK(int filterIdsCount) { + return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK(); + } + /** * This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This * is done by validating if the setting value is equal to the default value. diff --git a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index f38cc96c6..e59923e3a 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.opensearch.knn.index.query.KNNWeight; import java.io.IOException; import java.util.Arrays; @@ -31,13 +32,15 @@ final class DocAndScoreQuery extends Query { private final float[] scores; private final int[] segmentStarts; private final Object contextIdentity; + private final KNNWeight knnWeight; - public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity, KNNWeight knnWeight) { this.k = k; this.docs = docs; this.scores = scores; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; + this.knnWeight = knnWeight; } @Override @@ -53,7 +56,19 @@ public Explanation explain(LeafReaderContext context, int doc) { if (found < 0) { return Explanation.noMatch("not in top " + k); } - return Explanation.match(scores[found] * boost, "within top " + k); + float score = 0; + try { + final Scorer scorer = scorer(context); + assert scorer != null; + int resDoc = scorer.iterator().advance(doc); + if (resDoc == doc) { + score = scorer.score(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return knnWeight.explain(context, doc, score, null); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java index 5fc0fb077..ce823c229 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java +++ b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java @@ -18,6 +18,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator; import java.io.IOException; @@ -46,6 +47,10 @@ public class QueryUtils { * @return a query representing the given TopDocs */ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs) { + return createDocAndScoreQuery(reader, topDocs, null); + } + + public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs, final KNNWeight knnWeight) { int len = topDocs.scoreDocs.length; Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc)); int[] docs = new int[len]; @@ -55,7 +60,7 @@ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topD scores[i] = topDocs.scoreDocs[i].score; } int[] segmentStarts = findSegmentStarts(reader, docs); - return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id()); + return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id(), knnWeight); } private int[] findSegmentStarts(final IndexReader reader, final int[] docs) { diff --git a/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java new file mode 100644 index 000000000..ce594b1f1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.explain; + +import lombok.Getter; +import lombok.Setter; +import org.opensearch.knn.index.query.KNNScorer; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class captures details around knn explain queries that is used + * by explain API to generate explanation for knn queries + */ +public class KnnExplanation { + + @Getter + private final Map annResultPerLeaf; + + @Getter + private final Map rawScores; + + @Getter + private final Map knnScorerPerLeaf; + + @Setter + @Getter + private int cardinality; + + public KnnExplanation() { + this.annResultPerLeaf = new ConcurrentHashMap<>(); + this.rawScores = new ConcurrentHashMap<>(); + this.knnScorerPerLeaf = new ConcurrentHashMap<>(); + this.cardinality = 0; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 1ffaa804d..ba16cd501 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -98,7 +98,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); } - return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + return queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost); } /** diff --git a/src/test/java/org/opensearch/knn/index/query/ExplainTests.java b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java new file mode 100644 index 000000000..b4701e5bd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java @@ -0,0 +1,838 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.google.common.collect.Comparators; +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.ANN_SEARCH; +import static org.opensearch.knn.common.KNNConstants.EXACT_SEARCH; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.RADIAL_SEARCH; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; + +public class ExplainTests extends KNNWeightTestCase { + + @Mock + private Weight filterQueryWeight; + @Mock + private LeafReaderContext leafReaderContext; + + private void setupTest(final int[] filterDocIds, final Map attributesMap) throws IOException { + setupTest(filterDocIds, attributesMap, filterDocIds != null ? filterDocIds.length : 0, SpaceType.L2, true, null, null, null); + } + + private void setupTest( + final int[] filterDocIds, + final Map attributesMap, + final int maxDoc, + final SpaceType spaceType, + final boolean isCompoundFile, + final byte[] byteVector, + final float[] floatVector, + final MockedStatic vectorValuesFactoryMockedStatic + ) throws IOException { + + final Scorer filterScorer = mock(Scorer.class); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + + Bits liveDocsBits = null; + if (filterDocIds != null) { + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + liveDocsBits = mock(Bits.class); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); + + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); + } + final SegmentReader reader = mockSegmentReader(isCompoundFile); + when(reader.maxDoc()).thenReturn(maxDoc); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + when(leafReaderContext.reader()).thenReturn(reader); + when(leafReaderContext.id()).thenReturn(new Object()); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + + if (floatVector != null) { + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + when(binaryDocValues.advance(0)).thenReturn(0); + BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(floatVector)); + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + } + + if (byteVector != null) { + final KNNBinaryVectorValues knnBinaryVectorValues = mock(KNNBinaryVectorValues.class); + vectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)) + .thenReturn(knnBinaryVectorValues); + when(knnBinaryVectorValues.advance(0)).thenReturn(0); + when(knnBinaryVectorValues.getVector()).thenReturn(byteVector); + } + } + + private void assertExplanation(Explanation explanation, float expectedScore, String topSearch, String... leafDescription) { + assertNotNull(explanation); + assertTrue(explanation.isMatch()); + assertEquals(expectedScore, explanation.getValue().floatValue(), 0.01f); + assertTrue(explanation.getDescription().contains(topSearch)); + assertEquals(1, explanation.getDetails().length); + Explanation explanationDetail = explanation.getDetails()[0]; + assertEquals(expectedScore, explanation.getValue().floatValue(), 0.01f); + for (String description : leafDescription) { + assertTrue(explanationDetail.getDescription().contains(description)); + } + } + + private void assertDiskSearchExplanation(Explanation explanation, String[] topSearchDesc, String... leafDescription) { + assertNotNull(explanation); + assertTrue(explanation.isMatch()); + for (String description : topSearchDesc) { + assertTrue(explanation.getDescription().contains(description)); + } + assertEquals(1, explanation.getDetails().length); + Explanation explanationDetail = explanation.getDetails()[0]; + for (String description : leafDescription) { + assertTrue(explanationDetail.getDescription().contains(description)); + } + } + + @SneakyThrows + public void testDiskBasedSearchWithShardRescoringEnabledANN() { + int k = 3; + knnSettingsMockedStatic.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME)).thenReturn(false); + + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build(); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .rescoreContext(rescoreContext) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + String[] expectedTopDescription = new String[] { + KNNConstants.DISK_BASED_SEARCH, + "the first pass k was " + rescoreContext.getFirstPassK(k, false, QUERY_VECTOR.length), + "over sampling factor of " + rescoreContext.getOversampleFactor(), + "with vector dimension of " + QUERY_VECTOR.length, + "shard level rescoring enabled" }; + assertDiskSearchExplanation( + explanation, + expectedTopDescription, + KNNConstants.ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue() + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testDiskBasedSearchWithShardRescoringDisabledExact() { + knnSettingsMockedStatic.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME)).thenReturn(true); + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR - 1).build(); + + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + KNNWeight.initialize(null, mockedExactSearcher); + + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 1, spaceType, false, null, null, null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .rescoreContext(rescoreContext) + .explain(true) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + String[] expectedTopDescription = new String[] { + KNNConstants.DISK_BASED_SEARCH, + "the first pass k was " + rescoreContext.getFirstPassK(0, true, queryVector.length), + "over sampling factor of " + rescoreContext.getOversampleFactor(), + "with vector dimension of " + queryVector.length, + "shard level rescoring disabled" }; + assertDiskSearchExplanation( + explanation, + expectedTopDescription, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + spaceType.getValue(), + "no native engine files" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + + @SneakyThrows + public void testDefaultANNSearch() { + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + KNNConstants.ANN_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(DOC_ID_TO_SCORES.get(docId)) + ); + Explanation nestedDetail = explanation.getDetails()[0].getDetails()[0]; + assertTrue(nestedDetail.getDescription().contains(KNNEngine.FAISS.name())); + assertEquals(DOC_ID_TO_SCORES.get(docId), nestedDetail.getValue().floatValue(), 0.01f); + assertEquals(score, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANN_FilteredExactSearchAfterANN() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + KNNWeight.initialize(null, mockedExactSearcher); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(translatedScores); + // Given + int k = 4; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + KNNConstants.ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since the number of documents returned are less than K", + "there are more than K filtered Ids" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANN_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 1, spaceType, false, null, null, null); + + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId); + assertEquals(score, knnScorer.score(), 0.00000001f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + spaceType.getValue(), + "no native engine files" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenFTVGreaterThanFilterId() { + + KNNWeight.initialize(null); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10); + byte[] vector = new byte[] { 1, 3 }; + int k = 1; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.HAMMING.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "BHNSW32") + ); + + try (MockedStatic vectorValuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + setupTest(filterDocIds, attributesMap, 100, SpaceType.HAMMING, true, vector, null, vectorValuesFactoryMockedStatic); + final KNNQuery query = new KNNQuery( + FIELD_NAME, + BYTE_QUERY_VECTOR, + k, + INDEX_NAME, + FILTER_QUERY, + null, + VectorDataType.BINARY, + null + ); + + query.setExplain(true); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.BINARY.name(), + SpaceType.HAMMING.getValue(), + "is greater than or equal to cardinality", + "since filtered threshold value" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + } + + @SneakyThrows + public void testANNWithFilterQuery_whenMDCGreaterThanFilterId() { + ModelDao modelDao = mock(ModelDao.class); + KNNWeight.initialize(modelDao); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(-1); + float[] vector = new float[] { 0.1f, 0.3f }; + int k = 1; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(filterDocIds, attributesMap, 100, SpaceType.L2, true, null, vector, null); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, null); + query.setExplain(true); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since max distance computation", + "is greater than or equal to cardinality" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenFilterIdLessThanK() { + ModelDao modelDao = mock(ModelDao.class); + KNNWeight.initialize(modelDao); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(-1); + float[] vector = new float[] { 0.1f, 0.3f }; + final int[] filterDocIds = new int[] { 0 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(filterDocIds, attributesMap, 100, SpaceType.L2, true, null, vector, null); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null, null); + query.setExplain(true); + + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since filteredIds", + "is less than or equal to K" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testRadialANNSearch() { + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ).thenReturn(getKNNQueryResults()); + + Map attributesMap = Map.of( + SPACE_TYPE, + SpaceType.L2.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap); + + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .explain(true) + .vectorDataType(VectorDataType.FLOAT) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + jniServiceMockedStatic.verify( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation( + explanation, + score, + RADIAL_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(DOC_ID_TO_SCORES.get(docId)) + ); + Explanation nestedDetail = explanation.getDetails()[0].getDetails()[0]; + assertTrue(nestedDetail.getDescription().contains(KNNEngine.FAISS.name())); + assertEquals(DOC_ID_TO_SCORES.get(docId), nestedDetail.getValue().floatValue(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testRadialExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ).thenReturn(getKNNQueryResults()); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 0, spaceType, false, null, null, null); + + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .explain(true) + .vectorDataType(VectorDataType.FLOAT) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score, knnScorer); + assertExplanation(explanation, score, RADIAL_SEARCH, EXACT_SEARCH, VectorDataType.FLOAT.name(), spaceType.getValue()); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java new file mode 100644 index 000000000..bf21bad12 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import org.apache.lucene.index.SegmentCommitInfo; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.MockedStatic; +import org.opensearch.common.io.PathUtils; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.knn.jni.JNIService; + +import java.nio.file.Path; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; + +public class KNNWeightTestCase extends KNNTestCase { + + protected static final String FIELD_NAME = "target_field"; + protected static final float[] QUERY_VECTOR = new float[] { 1.8f, 2.4f }; + protected static final byte[] BYTE_QUERY_VECTOR = new byte[] { 1, 2 }; + protected static final String SEGMENT_NAME = "0"; + protected static final int K = 5; + protected static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); + protected static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); + protected static final Set SEGMENT_FILES_DEFAULT = SEGMENT_FILES_FAISS; + protected static final Set SEGMENT_MULTI_FIELD_FILES_FAISS = Set.of( + "_0.cfe", + "_0_2011_target_field.faissc", + "_0_2011_long_target_field.faissc" + ); + protected static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; + protected static final Integer EF_SEARCH = 10; + protected static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); + protected static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); + protected static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); + protected static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); + protected static final Map BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.5f); + protected static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); + protected static MockedStatic nativeMemoryCacheManagerMockedStatic; + protected static MockedStatic jniServiceMockedStatic; + + protected static MockedStatic knnSettingsMockedStatic; + + @BeforeClass + public static void setUpClass() throws Exception { + final KNNSettings knnSettings = mock(KNNSettings.class); + knnSettingsMockedStatic = mockStatic(KNNSettings.class); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES))).thenReturn(TimeValue.timeValueMinutes(10)); + + final ByteSizeValue v = ByteSizeValue.parseBytesSizeValue( + CIRCUIT_BREAKER_LIMIT_100KB, + KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT + ); + knnSettingsMockedStatic.when(KNNSettings::getCircuitBreakerLimit).thenReturn(v); + knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + knnSettingsMockedStatic.when(KNNSettings::isKNNPluginEnabled).thenReturn(true); + + nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); + + final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); + final NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); + when(nativeMemoryCacheManager.get(any(), anyBoolean())).thenReturn(nativeMemoryAllocation); + + nativeMemoryCacheManagerMockedStatic.when(NativeMemoryCacheManager::getInstance).thenReturn(nativeMemoryCacheManager); + + final MockedStatic pathUtilsMockedStatic = mockStatic(PathUtils.class); + final Path indexPath = mock(Path.class); + when(indexPath.toString()).thenReturn("/mydrive/myfolder"); + pathUtilsMockedStatic.when(() -> PathUtils.get(anyString(), anyString())).thenReturn(indexPath); + } + + @Before + public void setupBeforeTest() { + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); + jniServiceMockedStatic = mockStatic(JNIService.class); + } + + @After + public void tearDownAfterTest() { + jniServiceMockedStatic.close(); + } + + protected Map getTranslatedScores(Function scoreTranslator) { + return DOC_ID_TO_SCORES.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreTranslator.apply(entry.getValue()))); + } + + protected KNNQueryResult[] getKNNQueryResults() { + return DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } + + protected KNNQueryResult[] getFilteredKNNQueryResults() { + return FILTERED_DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } + + protected SegmentReader mockSegmentReader() { + return mockSegmentReader(true); + } + + protected SegmentReader mockSegmentReader(boolean isCompoundFile) { + Path path = mock(Path.class); + + FSDirectory directory = mock(FSDirectory.class); + when(directory.getDirectory()).thenReturn(path); + + SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + isCompoundFile, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + + SegmentReader reader = mock(SegmentReader.class); + when(reader.directory()).thenReturn(directory); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + return reader; + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 8011cc08c..2a6238961 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -15,12 +15,9 @@ import org.apache.lucene.index.SegmentCommitInfo; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; -import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.store.FSDirectory; @@ -29,27 +26,18 @@ import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; import org.mockito.Mockito; -import org.opensearch.common.io.PathUtils; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; -import org.opensearch.knn.index.memory.NativeMemoryAllocation; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; @@ -80,7 +68,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyFloat; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -95,82 +82,11 @@ import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; -public class KNNWeightTests extends KNNTestCase { - private static final String FIELD_NAME = "target_field"; - private static final float[] QUERY_VECTOR = new float[] { 1.8f, 2.4f }; - private static final byte[] BYTE_QUERY_VECTOR = new byte[] { 1, 2 }; - private static final String SEGMENT_NAME = "0"; - private static final int K = 5; - private static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); - private static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); - private static final Set SEGMENT_FILES_DEFAULT = SEGMENT_FILES_FAISS; - private static final Set SEGMENT_MULTI_FIELD_FILES_FAISS = Set.of( - "_0.cfe", - "_0_2011_target_field.faissc", - "_0_2011_long_target_field.faissc" - ); - private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; - private static final Integer EF_SEARCH = 10; - private static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); - - private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); - private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); - private static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); - private static final Map BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.5f); - - private static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); - - private static MockedStatic nativeMemoryCacheManagerMockedStatic; - private static MockedStatic jniServiceMockedStatic; - - private static MockedStatic knnSettingsMockedStatic; - - @BeforeClass - public static void setUpClass() throws Exception { - final KNNSettings knnSettings = mock(KNNSettings.class); - knnSettingsMockedStatic = mockStatic(KNNSettings.class); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES))).thenReturn(TimeValue.timeValueMinutes(10)); - - final ByteSizeValue v = ByteSizeValue.parseBytesSizeValue( - CIRCUIT_BREAKER_LIMIT_100KB, - KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT - ); - knnSettingsMockedStatic.when(KNNSettings::getCircuitBreakerLimit).thenReturn(v); - knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); - knnSettingsMockedStatic.when(KNNSettings::isKNNPluginEnabled).thenReturn(true); - - nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); - - final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); - final NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); - when(nativeMemoryCacheManager.get(any(), anyBoolean())).thenReturn(nativeMemoryAllocation); - - nativeMemoryCacheManagerMockedStatic.when(NativeMemoryCacheManager::getInstance).thenReturn(nativeMemoryCacheManager); - - final MockedStatic pathUtilsMockedStatic = mockStatic(PathUtils.class); - final Path indexPath = mock(Path.class); - when(indexPath.toString()).thenReturn("/mydrive/myfolder"); - pathUtilsMockedStatic.when(() -> PathUtils.get(anyString(), anyString())).thenReturn(indexPath); - } - - @Before - public void setupBeforeTest() { - knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); - jniServiceMockedStatic = mockStatic(JNIService.class); - } - - @After - public void tearDownAfterTest() { - jniServiceMockedStatic.close(); - } +public class KNNWeightTests extends KNNWeightTestCase { @SneakyThrows public void testQueryResultScoreNmslib() { @@ -840,35 +756,6 @@ public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } - private SegmentReader mockSegmentReader() { - Path path = mock(Path.class); - - FSDirectory directory = mock(FSDirectory.class); - when(directory.getDirectory()).thenReturn(path); - - SegmentInfo segmentInfo = new SegmentInfo( - directory, - Version.LATEST, - Version.LATEST, - SEGMENT_NAME, - 100, - true, - false, - KNNCodecVersion.current().getDefaultCodecDelegate(), - Map.of(), - new byte[StringHelper.ID_LENGTH], - Map.of(), - Sort.RELEVANCE - ); - segmentInfo.setFiles(SEGMENT_FILES_FAISS); - SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); - - SegmentReader reader = mock(SegmentReader.class); - when(reader.directory()).thenReturn(directory); - when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); - return reader; - } - @SneakyThrows public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { validateANNWithFilterQuery_whenExactSearch_thenSuccess(false); @@ -1612,28 +1499,6 @@ private void testQueryScore( assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } - private Map getTranslatedScores(Function scoreTranslator) { - return DOC_ID_TO_SCORES.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreTranslator.apply(entry.getValue()))); - } - - private KNNQueryResult[] getKNNQueryResults() { - return DOC_ID_TO_SCORES.entrySet() - .stream() - .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()) - .toArray(new KNNQueryResult[0]); - } - - private KNNQueryResult[] getFilteredKNNQueryResults() { - return FILTERED_DOC_ID_TO_SCORES.entrySet() - .stream() - .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()) - .toArray(new KNNQueryResult[0]); - } - @SneakyThrows public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { diff --git a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index b32496138..3a17be7f9 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -10,14 +10,16 @@ import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.mockito.Mock; +import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.test.OpenSearchTestCase; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; @@ -50,7 +52,7 @@ public void testScorer() throws Exception { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, null); // When Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); @@ -82,18 +84,17 @@ public void testWeight() { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); // When - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + KNNWeight knnWeight = mock(KNNWeight.class); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, knnWeight); Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); - Explanation explanation = weight.explain(leaf1, 1); + weight.explain(leaf1, 1); // Then assertEquals(objectUnderTest, weight.getQuery()); assertTrue(weight.isCacheable(leaf1)); assertEquals(2, weight.count(leaf1)); - assertEquals(expectedExplanation, explanation); - assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); + verify(knnWeight).explain(leaf1, 1, 1.2f, null); } } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 87c4a5014..4a0578e6c 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -155,6 +155,45 @@ public void testMultiLeaf() { assertEquals(expected, actual.getQuery()); } + @SneakyThrows + public void testExplain() { + // Given + List leaves = List.of(leaf1); + when(reader.leaves()).thenReturn(leaves); + + PerLeafResult leafResult = new PerLeafResult(null, new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); + + when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leafResult); + + Bits liveDocs = mock(Bits.class); + when(leafReader1.getLiveDocs()).thenReturn(null); + + when(liveDocs.get(anyInt())).thenReturn(true); + when(liveDocs.get(2)).thenReturn(false); + when(liveDocs.get(1)).thenReturn(false); + + // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves + when(knnQuery.getK()).thenReturn(4); + when(indexReaderContext.id()).thenReturn(1); + + TopDocs[] topDocs = { ResultUtil.resultMapToTopDocs(leafResult.getResult(), leaf1.docBase) }; + TopDocs expectedTopDocs = TopDocs.merge(4, topDocs); + + // When + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); + + // Then + Query expected = QueryUtils.INSTANCE.createDocAndScoreQuery(reader, expectedTopDocs); + assertEquals(expected, actual.getQuery()); + for (ScoreDoc scoreDoc : expectedTopDocs.scoreDocs) { + int docId = scoreDoc.doc; + if (docId == 0) continue; + float score = scoreDoc.score; + actual.explain(leaf1, docId); + verify(knnWeight).explain(leaf1, docId, score, null); + } + } + @SneakyThrows public void testRescoreWhenShardLevelRescoringEnabled() { // Given @@ -321,7 +360,7 @@ public void testExpandNestedDocs() { QueryUtils queryUtils = mock(QueryUtils.class); when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); - when(queryUtils.createDocAndScoreQuery(eq(reader), any())).thenReturn(finalQuery); + when(queryUtils.createDocAndScoreQuery(eq(reader), any(), eq(knnWeight))).thenReturn(finalQuery); // Run NativeEngineKnnVectorQuery query = new NativeEngineKnnVectorQuery(knnQuery, queryUtils, true); @@ -332,7 +371,7 @@ public void testExpandNestedDocs() { verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); - verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture()); + verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture(), eq(knnWeight)); TopDocs capturedTopDocs = topDocsCaptor.getValue(); assertEquals(topK.totalHits, capturedTopDocs.totalHits); for (int i = 0; i < topK.scoreDocs.length; i++) {