Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain API changes #2403

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class KNNConstants {
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String EXACT_SEARCH = "Exact";
public static final String ANN_SEARCH = "Approximate-NN";
public static final String RADIAL_SEARCH = "Radial";
public static final String DISK_BASED_SEARCH = "Disk-based";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public float scoreTranslation(final float rawScore) {
throw new IllegalStateException("Unsupported method");
}

@Override
public String explainScoreTranslation(float rawScore) {
throw new IllegalStateException("Unsupported method");
}

@Override
public void validateVectorDataType(VectorDataType vectorDataType) {
throw new IllegalStateException("Unsupported method");
Expand All @@ -46,6 +51,11 @@ public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this as private static final string.

}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.EUCLIDEAN;
Expand Down Expand Up @@ -77,6 +87,11 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down Expand Up @@ -105,12 +120,22 @@ public void validateVector(float[] vector) {
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}
},
LINF("linf") {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}
},
INNER_PRODUCT("innerproduct") {
/**
Expand All @@ -129,6 +154,14 @@ public float scoreTranslation(float rawScore) {
return -rawScore + 1;
}

@Override
public String explainScoreTranslation(float rawScore) {
if (rawScore >= 0) {
return "`1 / (1 + rawScore)`";
}
return "`-rawScore + 1`";
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
Expand All @@ -140,6 +173,11 @@ public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}

@Override
public String explainScoreTranslation(float rawScore) {
return "`1 / (1 + rawScore)`";
}

@Override
public void validateVectorDataType(VectorDataType vectorDataType) {
if (VectorDataType.BINARY != vectorDataType) {
Expand Down Expand Up @@ -177,6 +215,8 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {

public abstract float scoreTranslation(float rawScore);

public abstract String explainScoreTranslation(float rawScore);

/**
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class KNNQuery extends Query {
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
@Setter
@Getter
private boolean explain;

public KNNQuery(
final String field,
Expand Down
188 changes: 182 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder;
import org.opensearch.knn.index.query.explain.KnnExplanation;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -73,6 +74,7 @@ public class KNNWeight extends Weight {

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private final KnnExplanation knnExplanation;

public KNNWeight(KNNQuery query, float boost) {
super(query);
Expand All @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) {
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
Expand All @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -105,8 +109,157 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {
}

@Override
// This method is called in case of Radial-Search
public Explanation explain(LeafReaderContext context, int doc) {
neetikasinghal marked this conversation as resolved.
Show resolved Hide resolved
return Explanation.match(1.0f, "No Explanation");
return explain(context, doc, 0, null);
}

// This method is called for ANN/Exact/Disk-based/Efficient-filtering search
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) {
knnQuery.setExplain(true);
try {
knnScorer = getOrCreateKnnScorer(context, knnScorer);
float knnScore = getKnnScore(knnScorer, doc);

if (score == 0) {
score = knnScore;
}
Comment on lines +124 to +126
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be the first check we do when we enter the if condition.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, knnScore should be computed for it to be assigned to the score.

assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score";
} catch (IOException e) {
throw new RuntimeException("Error while explaining KNN score", e);
}

final String highLevelExplanation = getHighLevelExplanation();
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context);

final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}
final SpaceType spaceType = getSpaceType(fieldInfo);
leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue());

final Float rawScore = knnExplanation.getRawScores().get(doc);
Explanation rawScoreDetail = null;
if (rawScore != null && knnQuery.getRescoreContext() == null) {
leafLevelExplanation.append(" where score is computed as ")
.append(spaceType.explainScoreTranslation(rawScore))
.append(" from:");
rawScoreDetail = Explanation.match(
rawScore,
"rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library"
);
}

return rawScoreDetail != null
? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail))
: Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}

private StringBuilder getLeafLevelExplanation(LeafReaderContext context) {
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
int cardinality = knnExplanation.getCardinality();
StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was ");
if (filterWeight != null) {
if (isFilterIdCountLessThanK(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filteredIds = ")
.append(cardinality)
.append(" is less than or equal to K = ")
.append(knnQuery.getK());
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filtered threshold value = ")
.append(filterThresholdValue)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
} else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since max distance computation = ")
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
}
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) == 0
&& isMissingNativeEngineFiles(context)) {
sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available");
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since the number of documents returned are less than K = ")
.append(knnQuery.getK())
.append(" and there are more than K filtered Ids = ")
.append(cardinality);
}
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) > 0
&& !isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) {
sb.append(KNNConstants.ANN_SEARCH);
}
sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType());
return sb;
}

private SpaceType getSpaceType(FieldInfo fieldInfo) {
try {
return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
} catch (IllegalArgumentException e) {
return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT;
}
}

private String getHighLevelExplanation() {
StringBuilder sb = new StringBuilder("the type of knn search executed was ");
if (knnQuery.getRescoreContext() != null) {
sb.append(buildDiskBasedSearchExplanation());
} else if (knnQuery.getRadius() != null) {
sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius());
} else {
sb.append(KNNConstants.ANN_SEARCH);
}
return sb.toString();
}

private String buildDiskBasedSearchExplanation() {
StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH);
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension);
sb.append(" and the first pass k was ")
.append(firstPassK)
.append(" with vector dimension of ")
.append(dimension)
.append(", over sampling factor of ")
.append(knnQuery.getRescoreContext().getOversampleFactor());
if (isShardLevelRescoringDisabled) {
sb.append(", shard level rescoring disabled");
} else {
sb.append(", shard level rescoring enabled");
}
return sb.toString();
}

private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException {
if (existingScorer != null) {
return existingScorer;
}

KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context);
if (cachedScorer != null) {
return cachedScorer;
}

KNNScorer newScorer = (KNNScorer) scorer(context);
knnExplanation.getKnnScorerPerLeaf().put(context, newScorer);
return newScorer;
}

private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException {
return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0;
}

@Override
Expand Down Expand Up @@ -137,6 +290,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
if (filterWeight != null && cardinality == 0) {
return PerLeafResult.EMPTY_RESULT;
}
if (knnQuery.isExplain()) {
knnExplanation.setCardinality(cardinality);
}
/*
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
Expand All @@ -153,7 +309,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

if (knnQuery.isExplain()) {
knnExplanation.getAnnResultPerLeaf().put(context.id(), docIdsToScoreMap.size());
}
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
Expand Down Expand Up @@ -383,6 +541,15 @@ private Map<Integer, Float> doANNSearch(
log.debug("[KNN] Query yielded 0 results");
return Collections.emptyMap();
}
if (knnQuery.isExplain()) {
Arrays.stream(results).forEach(result -> {
if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) {
knnExplanation.getRawScores().put(result.getId(), -1 * result.getScore());
} else {
knnExplanation.getRawScores().put(result.getId(), result.getScore());
}
});
}

if (quantizedVector != null) {
return Arrays.stream(results)
Expand Down Expand Up @@ -425,24 +592,33 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
);
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
return true;
}
if (isFilterIdCountLessThanK(filterIdsCount)) return true;
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
if (filterThresholdValue >= filterIdsCount) {
return true;
}
return false;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return isMDCGreaterThanFilterIdCnt(filterIdsCount);
}

private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) {
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
}

private boolean isFilterIdCountLessThanK(int filterIdsCount) {
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK();
}

/**
* This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This
* is done by validating if the setting value is equal to the default value.
Expand Down
Loading
Loading