Skip to content

Commit

Permalink
Explain API changes for Exact/ANN/Radial/Disk based KNN search
Browse files Browse the repository at this point in the history
Signed-off-by: Neetika Singhal <[email protected]>
  • Loading branch information
neetikasinghal committed Feb 4, 2025
1 parent fa70fc8 commit 7c4f425
Show file tree
Hide file tree
Showing 13 changed files with 1,341 additions and 158 deletions.
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class KNNConstants {
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String EXACT_SEARCH = "Exact";
public static final String ANN_SEARCH = "Approximate-NN";
public static final String RADIAL_SEARCH = "Radial";
public static final String DISK_BASED_SEARCH = "Disk-based";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public float scoreTranslation(final float rawScore) {
throw new IllegalStateException("Unsupported method");
}

@Override
public String explainScoreTranslation(float rawScore) {
throw new IllegalStateException("Unsupported method");
}

@Override
public void validateVectorDataType(VectorDataType vectorDataType) {
throw new IllegalStateException("Unsupported method");
Expand All @@ -46,6 +51,11 @@ public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.EUCLIDEAN;
Expand Down Expand Up @@ -77,6 +87,11 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`";
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down Expand Up @@ -105,12 +120,22 @@ public void validateVector(float[] vector) {
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}
},
LINF("linf") {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}
},
INNER_PRODUCT("innerproduct") {
/**
Expand All @@ -129,6 +154,14 @@ public float scoreTranslation(float rawScore) {
return -rawScore + 1;
}

@Override
public String explainScoreTranslation(float rawScore) {
if (rawScore >= 0) {
return "`1 / (1 + rawScore)`";
}
return "`-rawScore + 1`";
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
Expand All @@ -140,6 +173,11 @@ public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}

@Override
public void validateVectorDataType(VectorDataType vectorDataType) {
if (VectorDataType.BINARY != vectorDataType) {
Expand Down Expand Up @@ -177,6 +215,8 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {

public abstract float scoreTranslation(float rawScore);

public abstract String explainScoreTranslation(float rawScore);

/**
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class KNNQuery extends Query {
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
@Setter
@Getter
private boolean explain;

public KNNQuery(
final String field,
Expand Down
178 changes: 172 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder;
import org.opensearch.knn.index.query.explain.KnnExplanation;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -73,6 +74,7 @@ public class KNNWeight extends Weight {

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private final KnnExplanation knnExplanation;

public KNNWeight(KNNQuery query, float boost) {
super(query);
Expand All @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) {
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
Expand All @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -105,8 +109,147 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {
}

@Override
// This method is called in case of Radial-Search
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
return explain(context, doc, 0, null);
}

// This method is called for ANN/Exact/Disk-based/Efficient-filtering search
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) {
knnQuery.setExplain(true);
if (knnScorer == null) {
float knnScore;
try {
knnScore = getKnnScore(context, doc);
} catch (IOException e) {
throw new RuntimeException("Error while getting KNN score during explanation", e);
}
if (score == 0) {
score = knnScore;
}
assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score";
}
final String highLevelExplanation = getHighLevelExplanation();
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context);

final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}
final SpaceType spaceType = getSpaceType(fieldInfo);
leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue());

final Float rawScore = knnExplanation.getRawScores().get(doc);
Explanation rawScoreDetail = null;
if (rawScore != null && knnQuery.getRescoreContext() == null) {
leafLevelExplanation.append(" where score is computed as ")
.append(spaceType.explainScoreTranslation(rawScore))
.append(" from:");
rawScoreDetail = Explanation.match(
rawScore,
"rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library"
);
}

return rawScoreDetail != null
? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail))
: Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}

private StringBuilder getLeafLevelExplanation(LeafReaderContext context) {
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
int cardinality = knnExplanation.getCardinality();
StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was ");
if (filterWeight != null) {
if (isFilterIdCountLessThanK(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filteredIds = ")
.append(cardinality)
.append(" is less than or equal to K = ")
.append(knnQuery.getK());
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filtered threshold value = ")
.append(filterThresholdValue)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
} else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since max distance computation = ")
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
}
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) == 0
&& isMissingNativeEngineFiles(context)) {
sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available");
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since the number of documents returned are less than K = ")
.append(knnQuery.getK())
.append(" and there are more than K filtered Ids = ")
.append(cardinality);
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) > 0
&& !isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) {
sb.append(KNNConstants.ANN_SEARCH);
}
sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType());
return sb;
}

private SpaceType getSpaceType(FieldInfo fieldInfo) {
try {
return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
} catch (IllegalArgumentException e) {
return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT;
}
}

private String getHighLevelExplanation() {
StringBuilder sb = new StringBuilder("the type of knn search executed was ");
if (knnQuery.getRescoreContext() != null) {
sb.append(buildDiskBasedSearchExplanation());
} else if (knnQuery.getRadius() != null) {
sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius());
} else {
sb.append(KNNConstants.ANN_SEARCH);
}
return sb.toString();
}

private String buildDiskBasedSearchExplanation() {
StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH);
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension);
sb.append(" and the first pass k was ")
.append(firstPassK)
.append(" with vector dimension of ")
.append(dimension)
.append(", over sampling factor of ")
.append(knnQuery.getRescoreContext().getOversampleFactor());
if (isShardLevelRescoringDisabled) {
sb.append(", shard level rescoring disabled");
} else {
sb.append(", shard level rescoring enabled");
}
return sb.toString();
}

private float getKnnScore(LeafReaderContext context, int doc) throws IOException {
KNNScorer knnScorer = (KNNScorer) scorer(context);
int resDoc = knnScorer.iterator().advance(doc);
if (resDoc == doc) {
return knnScorer.score();
}
return 0;
}

@Override
Expand Down Expand Up @@ -137,6 +280,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
if (filterWeight != null && cardinality == 0) {
return PerLeafResult.EMPTY_RESULT;
}
if (knnQuery.isExplain()) {
knnExplanation.setCardinality(cardinality);
}
/*
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
Expand All @@ -153,7 +299,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

if (knnQuery.isExplain()) {
knnExplanation.getAnnResultPerLeaf().put(context.id(), docIdsToScoreMap.size());
}
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
Expand Down Expand Up @@ -383,6 +531,15 @@ private Map<Integer, Float> doANNSearch(
log.debug("[KNN] Query yielded 0 results");
return Collections.emptyMap();
}
if (knnQuery.isExplain()) {
Arrays.stream(results).forEach(result -> {
if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) {
knnExplanation.getRawScores().put(result.getId(), -1 * result.getScore());
} else {
knnExplanation.getRawScores().put(result.getId(), result.getScore());
}
});
}

if (quantizedVector != null) {
return Arrays.stream(results)
Expand Down Expand Up @@ -425,24 +582,33 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
);
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
return true;
}
if (isFilterIdCountLessThanK(filterIdsCount)) return true;
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
if (filterThresholdValue >= filterIdsCount) {
return true;
}
return false;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return isMDCGreaterThanFilterIdCnt(filterIdsCount);
}

private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) {
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
}

private boolean isFilterIdCountLessThanK(int filterIdsCount) {
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK();
}

/**
* This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This
* is done by validating if the setting value is equal to the default value.
Expand Down
Loading

0 comments on commit 7c4f425

Please sign in to comment.