-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain API changes #2403
base: main
Are you sure you want to change the base?
Explain API changes #2403
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,11 @@ public float scoreTranslation(final float rawScore) { | |
throw new IllegalStateException("Unsupported method"); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
throw new IllegalStateException("Unsupported method"); | ||
} | ||
|
||
@Override | ||
public void validateVectorDataType(VectorDataType vectorDataType) { | ||
throw new IllegalStateException("Unsupported method"); | ||
|
@@ -46,6 +51,11 @@ public float scoreTranslation(float rawScore) { | |
return 1 / (1 + rawScore); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
return "`1 / (1 + rawScore)`"; | ||
} | ||
|
||
@Override | ||
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { | ||
return KNNVectorSimilarityFunction.EUCLIDEAN; | ||
|
@@ -77,6 +87,11 @@ public float scoreTranslation(float rawScore) { | |
return Math.max((2.0F - rawScore) / 2.0F, 0.0F); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
return "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
} | ||
|
||
@Override | ||
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { | ||
return KNNVectorSimilarityFunction.COSINE; | ||
|
@@ -105,12 +120,22 @@ public void validateVector(float[] vector) { | |
public float scoreTranslation(float rawScore) { | ||
return 1 / (1 + rawScore); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
return "`1 / (1 + rawScore)`"; | ||
} | ||
}, | ||
LINF("linf") { | ||
@Override | ||
public float scoreTranslation(float rawScore) { | ||
return 1 / (1 + rawScore); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
return "`1 / (1 + rawScore)`"; | ||
} | ||
}, | ||
INNER_PRODUCT("innerproduct") { | ||
/** | ||
|
@@ -129,6 +154,14 @@ public float scoreTranslation(float rawScore) { | |
return -rawScore + 1; | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
if (rawScore >= 0) { | ||
return "`1 / (1 + rawScore)`"; | ||
} | ||
return "`-rawScore + 1`"; | ||
} | ||
|
||
@Override | ||
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { | ||
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; | ||
|
@@ -140,6 +173,11 @@ public float scoreTranslation(float rawScore) { | |
return 1 / (1 + rawScore); | ||
} | ||
|
||
@Override | ||
public String explainScoreTranslation(float rawScore) { | ||
return "`1 / (1 + rawScore)`"; | ||
} | ||
|
||
@Override | ||
public void validateVectorDataType(VectorDataType vectorDataType) { | ||
if (VectorDataType.BINARY != vectorDataType) { | ||
|
@@ -177,6 +215,8 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { | |
|
||
public abstract float scoreTranslation(float rawScore); | ||
|
||
public abstract String explainScoreTranslation(float rawScore); | ||
|
||
/** | ||
* Get KNNVectorSimilarityFunction that maps to this SpaceType | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
import org.opensearch.knn.index.engine.KNNEngine; | ||
import org.opensearch.knn.index.quantizationservice.QuantizationService; | ||
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; | ||
import org.opensearch.knn.index.query.explain.KnnExplanation; | ||
import org.opensearch.knn.indices.ModelDao; | ||
import org.opensearch.knn.indices.ModelMetadata; | ||
import org.opensearch.knn.indices.ModelUtil; | ||
|
@@ -73,6 +74,7 @@ public class KNNWeight extends Weight { | |
|
||
private static ExactSearcher DEFAULT_EXACT_SEARCHER; | ||
private final QuantizationService quantizationService; | ||
private final KnnExplanation knnExplanation; | ||
|
||
public KNNWeight(KNNQuery query, float boost) { | ||
super(query); | ||
|
@@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) { | |
this.filterWeight = null; | ||
this.exactSearcher = DEFAULT_EXACT_SEARCHER; | ||
this.quantizationService = QuantizationService.getInstance(); | ||
this.knnExplanation = new KnnExplanation(); | ||
} | ||
|
||
public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { | ||
|
@@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { | |
this.filterWeight = filterWeight; | ||
this.exactSearcher = DEFAULT_EXACT_SEARCHER; | ||
this.quantizationService = QuantizationService.getInstance(); | ||
this.knnExplanation = new KnnExplanation(); | ||
} | ||
|
||
public static void initialize(ModelDao modelDao) { | ||
|
@@ -105,8 +109,157 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { | |
} | ||
|
||
@Override | ||
// This method is called in case of Radial-Search | ||
public Explanation explain(LeafReaderContext context, int doc) { | ||
neetikasinghal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Explanation.match(1.0f, "No Explanation"); | ||
return explain(context, doc, 0, null); | ||
} | ||
|
||
// This method is called for ANN/Exact/Disk-based/Efficient-filtering search | ||
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) { | ||
knnQuery.setExplain(true); | ||
try { | ||
knnScorer = getOrCreateKnnScorer(context, knnScorer); | ||
float knnScore = getKnnScore(knnScorer, doc); | ||
|
||
if (score == 0) { | ||
score = knnScore; | ||
} | ||
Comment on lines
+124
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be the first check we do when we enter the if condition. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really, knnScore should be computed for it to be assigned to the score. |
||
assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score"; | ||
} catch (IOException e) { | ||
throw new RuntimeException("Error while explaining KNN score", e); | ||
} | ||
|
||
final String highLevelExplanation = getHighLevelExplanation(); | ||
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context); | ||
|
||
final SegmentReader reader = Lucene.segmentReader(context.reader()); | ||
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); | ||
if (fieldInfo == null) { | ||
return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); | ||
} | ||
final SpaceType spaceType = getSpaceType(fieldInfo); | ||
leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue()); | ||
|
||
final Float rawScore = knnExplanation.getRawScores().get(doc); | ||
Explanation rawScoreDetail = null; | ||
if (rawScore != null && knnQuery.getRescoreContext() == null) { | ||
leafLevelExplanation.append(" where score is computed as ") | ||
.append(spaceType.explainScoreTranslation(rawScore)) | ||
.append(" from:"); | ||
rawScoreDetail = Explanation.match( | ||
rawScore, | ||
"rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library" | ||
); | ||
} | ||
|
||
return rawScoreDetail != null | ||
? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail)) | ||
: Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); | ||
} | ||
|
||
private StringBuilder getLeafLevelExplanation(LeafReaderContext context) { | ||
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); | ||
int cardinality = knnExplanation.getCardinality(); | ||
StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was "); | ||
if (filterWeight != null) { | ||
if (isFilterIdCountLessThanK(cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since filteredIds = ") | ||
.append(cardinality) | ||
.append(" is less than or equal to K = ") | ||
.append(knnQuery.getK()); | ||
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since filtered threshold value = ") | ||
.append(filterThresholdValue) | ||
.append(" is greater than or equal to cardinality = ") | ||
.append(cardinality); | ||
} else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since max distance computation = ") | ||
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS) | ||
.append(" is greater than or equal to cardinality = ") | ||
.append(cardinality); | ||
} | ||
} | ||
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null | ||
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) == 0 | ||
&& isMissingNativeEngineFiles(context)) { | ||
sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available"); | ||
} | ||
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null | ||
&& isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since the number of documents returned are less than K = ") | ||
.append(knnQuery.getK()) | ||
.append(" and there are more than K filtered Ids = ") | ||
.append(cardinality); | ||
} | ||
if (knnExplanation.getAnnResultPerLeaf().get(context.id()) != null | ||
&& knnExplanation.getAnnResultPerLeaf().get(context.id()) > 0 | ||
&& !isFilteredExactSearchRequireAfterANNSearch(cardinality, knnExplanation.getAnnResultPerLeaf().get(context.id()))) { | ||
sb.append(KNNConstants.ANN_SEARCH); | ||
} | ||
sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType()); | ||
return sb; | ||
} | ||
|
||
private SpaceType getSpaceType(FieldInfo fieldInfo) { | ||
try { | ||
return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); | ||
} catch (IllegalArgumentException e) { | ||
return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT; | ||
} | ||
} | ||
|
||
private String getHighLevelExplanation() { | ||
StringBuilder sb = new StringBuilder("the type of knn search executed was "); | ||
if (knnQuery.getRescoreContext() != null) { | ||
sb.append(buildDiskBasedSearchExplanation()); | ||
} else if (knnQuery.getRadius() != null) { | ||
sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius()); | ||
} else { | ||
sb.append(KNNConstants.ANN_SEARCH); | ||
} | ||
return sb.toString(); | ||
} | ||
|
||
private String buildDiskBasedSearchExplanation() { | ||
StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH); | ||
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); | ||
int dimension = knnQuery.getQueryVector().length; | ||
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension); | ||
sb.append(" and the first pass k was ") | ||
.append(firstPassK) | ||
.append(" with vector dimension of ") | ||
.append(dimension) | ||
.append(", over sampling factor of ") | ||
.append(knnQuery.getRescoreContext().getOversampleFactor()); | ||
if (isShardLevelRescoringDisabled) { | ||
sb.append(", shard level rescoring disabled"); | ||
} else { | ||
sb.append(", shard level rescoring enabled"); | ||
} | ||
return sb.toString(); | ||
} | ||
|
||
private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException { | ||
if (existingScorer != null) { | ||
return existingScorer; | ||
} | ||
|
||
KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context); | ||
if (cachedScorer != null) { | ||
return cachedScorer; | ||
} | ||
|
||
KNNScorer newScorer = (KNNScorer) scorer(context); | ||
knnExplanation.getKnnScorerPerLeaf().put(context, newScorer); | ||
return newScorer; | ||
} | ||
|
||
private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException { | ||
return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0; | ||
} | ||
|
||
@Override | ||
|
@@ -137,6 +290,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep | |
if (filterWeight != null && cardinality == 0) { | ||
return PerLeafResult.EMPTY_RESULT; | ||
} | ||
if (knnQuery.isExplain()) { | ||
knnExplanation.setCardinality(cardinality); | ||
} | ||
/* | ||
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph | ||
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search. | ||
|
@@ -153,7 +309,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep | |
*/ | ||
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; | ||
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); | ||
|
||
if (knnQuery.isExplain()) { | ||
knnExplanation.getAnnResultPerLeaf().put(context.id(), docIdsToScoreMap.size()); | ||
} | ||
// See whether we have to perform exact search based on approx search results | ||
// This is required if there are no native engine files or if approximate search returned | ||
// results less than K, though we have more than k filtered docs | ||
|
@@ -383,6 +541,15 @@ private Map<Integer, Float> doANNSearch( | |
log.debug("[KNN] Query yielded 0 results"); | ||
return Collections.emptyMap(); | ||
} | ||
if (knnQuery.isExplain()) { | ||
Arrays.stream(results).forEach(result -> { | ||
if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) { | ||
knnExplanation.getRawScores().put(result.getId(), -1 * result.getScore()); | ||
} else { | ||
knnExplanation.getRawScores().put(result.getId(), result.getScore()); | ||
} | ||
}); | ||
} | ||
|
||
if (quantizedVector != null) { | ||
return Arrays.stream(results) | ||
|
@@ -425,24 +592,33 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { | |
); | ||
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); | ||
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic | ||
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { | ||
return true; | ||
} | ||
if (isFilterIdCountLessThanK(filterIdsCount)) return true; | ||
// See user has defined Exact Search filtered threshold. if yes, then use that setting. | ||
if (isExactSearchThresholdSettingSet(filterThresholdValue)) { | ||
return filterThresholdValue >= filterIdsCount; | ||
if (filterThresholdValue >= filterIdsCount) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
// if no setting is set, then use the default max distance computation value to see if we can do exact search. | ||
/** | ||
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index | ||
* is cheaper than computation cost for non binary vector | ||
*/ | ||
return isMDCGreaterThanFilterIdCnt(filterIdsCount); | ||
} | ||
|
||
private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) { | ||
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT | ||
? knnQuery.getQueryVector().length | ||
: knnQuery.getByteQueryVector().length); | ||
} | ||
|
||
private boolean isFilterIdCountLessThanK(int filterIdsCount) { | ||
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK(); | ||
} | ||
|
||
/** | ||
* This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This | ||
* is done by validating if the setting value is equal to the default value. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this as private static final string.