Skip to content

Commit

Permalink
Add IT for testing rescore enabled and disabled
Browse files Browse the repository at this point in the history
Signed-off-by: Ethan Emoto <[email protected]>
  • Loading branch information
e-emoto committed Jan 18, 2025
1 parent 7b2297e commit 9487cf4
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
List<PerLeafResult> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
final int finalK = knnQuery.getK();
if (rescoreContext == null || rescoreContext.isRescoreDisabled()) {
if (rescoreContext == null || !rescoreContext.isRescoreEnabled()) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ public final class RescoreContext {
* Flag to track whether rescoring has been disabled by the query parameters.
*/
@Builder.Default
private boolean rescoreDisabled = false;
private boolean rescoreEnabled = true;

// Rescore context to be used when rescoring should be explicitly disabled
public static final RescoreContext EXPLICITLY_DISABLED_RESCORE_CONTEXT = RescoreContext.builder()
.oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR)
.rescoreDisabled(true)
.rescoreEnabled(false)
.build();

/**
Expand Down
119 changes: 119 additions & 0 deletions src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,125 @@ public void testIndexCreation_whenValid_ThenSucceed() {
}
}

@SneakyThrows
public void testQueryRescoreEnabledAndDisabled() {
XContentBuilder builder;
String mode = Mode.ON_DISK.getName();
String compressionLevel = CompressionLevel.x32.getName();
String indexName = INDEX_NAME + compressionLevel;
builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.field(MODE_PARAMETER, mode)
.field(COMPRESSION_LEVEL_PARAMETER, compressionLevel)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();
validateIndex(indexName, mapping);
logger.info("Compression level {}", compressionLevel);
// Do exact search and gather right scores for the documents
Response exactSearchResponse = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("script_score")
.startObject("query")
.field("match_all")
.startObject()
.endObject()
.endObject()
.startObject("script")
.field("source", "knn_score")
.field("lang", "knn")
.startObject("params")
.field("field", FIELD_NAME)
.field("query_value", TEST_VECTOR)
.field("space_type", SpaceType.L2.getValue())
.endObject()
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(exactSearchResponse);
String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity());
List<Float> exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME);
assertEquals(NUM_DOCS, exactSearchKnnResults.size());
// Search without rescore
Response response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.field(RescoreParser.RESCORE_PARAMETER, false)
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
String responseBody = EntityUtils.toString(response.getEntity());
List<Float> knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertNotEquals(exactSearchKnnResults, knnResults);
// Search with explicit rescore
response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.startObject(RescoreParser.RESCORE_PARAMETER)
.field(RescoreParser.RESCORE_OVERSAMPLE_PARAMETER, 2.0f)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);
// Search with default rescore
response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);
}

@SneakyThrows
public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() {
XContentBuilder builder;
Expand Down

0 comments on commit 9487cf4

Please sign in to comment.