From e60603d979a4b9c07046e76a8d013e0cf4e80c61 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 27 Jan 2025 19:43:07 -0600 Subject: [PATCH] Fix Faiss byte vector efficient filter bug (#2448) Signed-off-by: Naveen Tatikonda (cherry picked from commit 02fdd70f1a10e9faa290c5a92d2c2dc5db3fe31e) --- .../opensearch/knn/index/query/KNNWeight.java | 21 +++- .../knn/index/VectorDataTypeIT.java | 98 +++++++++++++++++++ .../org/opensearch/knn/KNNRestTestCase.java | 4 +- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad..d5912d758 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -45,6 +45,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -438,9 +439,23 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { * TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index * is cheaper than computation cost for non binary vector */ - return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT - ? knnQuery.getQueryVector().length - : knnQuery.getByteQueryVector().length); + return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * getQueryVectorLength(); + } + + /** + * Returns the length of query vector based on the query vector data type + * @return length of query vector + */ + private int getQueryVectorLength() { + if (knnQuery.getVectorDataType() == VectorDataType.FLOAT || knnQuery.getVectorDataType() == VectorDataType.BYTE) { + return knnQuery.getQueryVector().length; + } + if (knnQuery.getVectorDataType() == VectorDataType.BINARY) { + return knnQuery.getByteQueryVector().length; + } + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] datatype is not supported for k-NN query vector", knnQuery.getVectorDataType().getValue()) + ); } /** diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 6e0f954a7..4a295a136 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index; +import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.After; @@ -17,6 +18,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -25,11 +27,13 @@ import org.opensearch.script.Script; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -62,6 +66,7 @@ public class VectorDataTypeIT extends KNNRestTestCase { private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final int EF_CONSTRUCTION = 128; private static final int M = 16; + private static final String COLOR_FIELD_NAME = "color"; private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder(); @After @@ -666,6 +671,99 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() { deleteModel(modelId); } + @SneakyThrows + public void testQueryWithFilterFaissByteVector_withDifferentCombination_thenSuccess() { + setupKNNFaissByteIndexForFilterQuery(); + final Byte[] searchVector = { 6, 6, 4 }; + // K > filteredResults + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList("1", "3"); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kGreaterThanFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "red") + ), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + // K Limits Filter results + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = List.of("1"); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "red") + ), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + + // Empty filter docIds + int k = 10; + final Response emptyFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present") + ), + k + ); + final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity()); + final List emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME); + + assertEquals(0, emptyKNNFilteredResultsFromResponse.size()); + } + + protected void setupKNNFaissByteIndexForFilterQuery() throws Exception { + // Create Mappings + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", 3) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping); + + addKnnDocWithAttributes(INDEX_NAME, "1", FIELD_NAME, new Byte[] { 6, 7, 3 }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(INDEX_NAME, "2", FIELD_NAME, new Byte[] { 3, 2, 4 }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(INDEX_NAME, "3", FIELD_NAME, new Byte[] { 4, 5, 7 }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshIndex(INDEX_NAME); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 7dd1ec237..f2ddf3b4b 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1907,11 +1907,11 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map void addKnnDocWithAttributes( String indexName, String docId, String vectorFieldName, - float[] vector, + T vector, Map fieldValues ) throws IOException { Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");