diff --git a/CHANGELOG.md b/CHANGELOG.md index d79ba8d2..3c13e01d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Breaking Changes ### Features +* Added `expandCoverage` parameter to LLM judgment API for hybrid document pooling, improving judgment coverage for Hybrid Optimizer experiments ([#400](https://github.com/opensearch-project/search-relevance/pull/400)) * Introduced dynamic percentile-based relevance thresholding for binary-dependent metrics (Precision, MAP) to replace hard-coded `j > 0` mapping ([#394](https://github.com/opensearch-project/search-relevance/pull/394)) ### Enhancements diff --git a/build.gradle b/build.gradle index 474ca500..0849fe94 100644 --- a/build.gradle +++ b/build.gradle @@ -217,6 +217,9 @@ dependencies { api "org.opensearch:opensearch:${opensearch_version}" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly("org.opensearch:opensearch-neural-search:${opensearch_build}") { + transitive = false + } implementation group: 'com.google.guava', name: 'guava', version:'33.4.8-jre' compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.20.0' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.20.0' diff --git a/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java b/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java index d4a800ec..4fe88bf4 100644 --- a/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java +++ b/src/main/java/org/opensearch/searchrelevance/dao/JudgmentCacheDao.java @@ -12,13 +12,14 @@ import static org.opensearch.searchrelevance.model.JudgmentCache.DOCUMENT_ID; import static org.opensearch.searchrelevance.model.JudgmentCache.PROMPT_TEMPLATE_ID; import static org.opensearch.searchrelevance.model.JudgmentCache.QUERY_TEXT; +import static org.opensearch.searchrelevance.model.JudgmentCache.TIME_STAMP; import static org.opensearch.searchrelevance.utils.ParserUtils.convertListToSortedStr; import java.io.IOException; +import java.time.Instant; import java.util.List; +import java.util.concurrent.TimeUnit; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.StepListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.xcontent.XContentFactory; @@ -28,20 +29,33 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.indices.SearchRelevanceIndicesManager; import org.opensearch.searchrelevance.model.JudgmentCache; +import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class JudgmentCacheDao { - private static final Logger LOGGER = LogManager.getLogger(JudgmentCacheDao.class); private final SearchRelevanceIndicesManager searchRelevanceIndicesManager; + private volatile SearchRelevanceSettingsAccessor settingsAccessor; public JudgmentCacheDao(SearchRelevanceIndicesManager searchRelevanceIndicesManager) { this.searchRelevanceIndicesManager = searchRelevanceIndicesManager; } + /** + * Sets the settings accessor for reading cache TTL configuration. + * Called during plugin initialization after both DAO and settings accessor are created. + */ + public void setSettingsAccessor(SearchRelevanceSettingsAccessor settingsAccessor) { + this.settingsAccessor = settingsAccessor; + } + /** * Create judgment cache index if not exists * @param stepListener - step listener for async operation @@ -89,14 +103,14 @@ public void upsertJudgmentCache(final JudgmentCache judgmentCache, final ActionL // Use updateDoc which will create or update the document searchRelevanceIndicesManager.updateDoc(judgmentCache.id(), content, JUDGMENT_CACHE, ActionListener.wrap(response -> { - LOGGER.debug( + log.debug( "Successfully upserted judgment cache for queryText: {} and documentId: {}", judgmentCache.queryText(), judgmentCache.documentId() ); listener.onResponse(response); }, e -> { - LOGGER.error( + log.error( "Failed to upsert judgment cache for queryText: {} and documentId: {}", judgmentCache.queryText(), judgmentCache.documentId(), @@ -111,6 +125,53 @@ public void upsertJudgmentCache(final JudgmentCache judgmentCache, final ActionL } } + /** + * Cleanup stale cache entries older than the specified TTL. + * This is a fire-and-forget operation — failures are logged but do not block callers. + * @param ttlDays number of days after which cache entries are considered stale + */ + public void cleanupStaleEntries(final long ttlDays) { + long cutoffMillis = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(ttlDays); + String cutoffDate = Instant.ofEpochMilli(cutoffMillis).toString(); + + log.info("Starting judgment cache cleanup for entries older than {} days (before {})", ttlDays, cutoffDate); + + searchRelevanceIndicesManager.deleteByQuery( + QueryBuilders.rangeQuery(TIME_STAMP).lt(cutoffDate), + JUDGMENT_CACHE, + ActionListener.wrap((BulkByScrollResponse response) -> { + long deleted = response.getDeleted(); + if (deleted > 0) { + log.info("Judgment cache cleanup completed: deleted {} stale entries older than {} days", deleted, ttlDays); + } else { + log.debug("Judgment cache cleanup completed: no stale entries found older than {} days", ttlDays); + } + }, e -> log.warn("Judgment cache cleanup failed - continuing without cleanup", e)) + ); + } + + /** + * Cleanup stale cache entries based on the configured TTL setting. + * When TTL is disabled (-1, the default), this method is a no-op. + * This is a fire-and-forget operation — failures are logged but do not block callers. + */ + public void cleanupStaleEntries() { + if (settingsAccessor == null) { + log.debug("Settings accessor not set, skipping cache cleanup"); + return; + } + long ttlMillis = settingsAccessor.getJudgmentCacheTtl().millis(); + if (ttlMillis < 0) { + log.debug("Judgment cache TTL is disabled (-1), skipping cleanup"); + return; + } + long ttlDays = TimeUnit.MILLISECONDS.toDays(ttlMillis); + if (ttlDays < 1) { + ttlDays = 1; // minimum 1 day + } + cleanupStaleEntries(ttlDays); + } + /** * Get judgment cache by queryText and documentId * @param queryText - queryText to be searched @@ -129,7 +190,7 @@ public SearchResponse getJudgmentCache( SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); String contextFieldsStr = contextFields != null ? convertListToSortedStr(contextFields) : ""; - LOGGER.debug( + log.debug( "Building cache search query - queryText: '{}', documentId: '{}', contextFields: '{}', promptTemplateCode: '{}'", queryText, documentId, @@ -157,7 +218,7 @@ public SearchResponse getJudgmentCache( } listener.onResponse(response); }, e -> { - LOGGER.debug("Cache lookup failed for docId: {} - continuing without cache", documentId); + log.debug("Cache lookup failed for docId: {} - continuing without cache", documentId); listener.onFailure(e); }); diff --git a/src/main/java/org/opensearch/searchrelevance/experiment/QuerySourceUtil.java b/src/main/java/org/opensearch/searchrelevance/experiment/QuerySourceUtil.java index 4b066a80..deb8a363 100644 --- a/src/main/java/org/opensearch/searchrelevance/experiment/QuerySourceUtil.java +++ b/src/main/java/org/opensearch/searchrelevance/experiment/QuerySourceUtil.java @@ -13,21 +13,35 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.searchrelevance.model.ExperimentVariant; +import org.opensearch.searchrelevance.model.builder.SearchRequestBuilder; + +import lombok.extern.log4j.Log4j2; /** * Utility class for a query source */ +@Log4j2 public class QuerySourceUtil { public static final int NUMBER_OF_SUBQUERIES_IN_HYBRID_QUERY = 2; + public static final String POOL_NORMALIZATION = "min_max"; + public static final String POOL_COMBINATION = "arithmetic_mean"; + /** * Creates a definition of a temporary search pipeline for hybrid search. * @param experimentVariant sub-experiment to create the pipeline for @@ -61,21 +75,208 @@ public static Map createDefinitionOfTemporarySearchPipeline(fina } /** - * Validate that the query in the search configuration is a hybrid query with two sub-queries. - * @param fullQueryMap - * @throws IOException + * Checks if the query is a hybrid query with exactly {@link #NUMBER_OF_SUBQUERIES_IN_HYBRID_QUERY} sub-queries. + * Non-throwing variant of {@link #validateHybridQuery(Map)}. + * Used by the HYBRID_OPTIMIZER experiment which requires exactly 2 sub-queries (lexical + neural). + * For expandCoverage (which supports any sub-query count), use {@link #isHybridQueryAnySize(Map)} instead. + * @param fullQueryMap the parsed query body + * @return true if the query is a valid hybrid query with the required sub-query count + */ + public static boolean isHybridQuery(final Map fullQueryMap) { + try { + validateHybridQuery(fullQueryMap); + return true; + } catch (Exception e) { + log.debug("Query is not a valid hybrid query: {}", e.getMessage()); + return false; + } + } + + /** + * Checks if the query in the search configuration is a hybrid query with any number of sub-queries (≥ 1). + * @param fullQueryMap the parsed query body + * @return true if the query is a valid hybrid query with at least 1 sub-query + */ + public static boolean isHybridQueryAnySize(final Map fullQueryMap) { + try { + return getSubQueryCount(fullQueryMap) >= 1; + } catch (Exception e) { + log.debug("Query is not a valid hybrid query: {}", e.getMessage()); + return false; + } + } + + /** + * Extracts the number of sub-queries from a hybrid query. + * Uses typed parsing via {@link HybridQueryBuilder#fromXContent(XContentParser)} at runtime, + * falls back to map-based inspection in unit tests without registry. + * + * @param fullQueryMap the parsed query body + * @return number of sub-queries in the hybrid query + * @throws IllegalArgumentException if the query is not a valid hybrid query + */ + public static int getSubQueryCount(final Map fullQueryMap) { + if (Objects.isNull(fullQueryMap) || !fullQueryMap.containsKey("query") || !(fullQueryMap.get("query") instanceof Map)) { + throw new IllegalArgumentException("search configuration must have at least one query"); + } + Map queryMap = (Map) fullQueryMap.get("query"); + if (!queryMap.containsKey(HybridQueryBuilder.NAME) || !(queryMap.get(HybridQueryBuilder.NAME) instanceof Map)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "query in search configuration must be of type [%s]", HybridQueryBuilder.NAME) + ); + } + + if (Objects.nonNull(SearchRequestBuilder.getNamedXContentRegistry())) { + try { + return getSubQueryCountViaTypedParsing(queryMap, SearchRequestBuilder.getNamedXContentRegistry()); + } catch (Exception e) { + log.debug("Typed hybrid query parsing failed, falling back to map-based parsing: {}", e.getMessage()); + } + } + return getSubQueryCountFromMap(queryMap); + } + + /** + * Parses the hybrid query using {@link HybridQueryBuilder#fromXContent(XContentParser)} and returns + * the sub-query count via the typed {@link HybridQueryBuilder#queries()} accessor. + */ + private static int getSubQueryCountViaTypedParsing(final Map queryMap, final NamedXContentRegistry registry) { + Map hybridSection = (Map) queryMap.get(HybridQueryBuilder.NAME); + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.startObject(); + builder.field(HybridQueryBuilder.NAME); + builder.map(hybridSection); + builder.endObject(); + + try ( + XContentParser parser = JsonXContent.jsonXContent.createParser( + registry, + DeprecationHandler.IGNORE_DEPRECATIONS, + builder.toString() + ) + ) { + parser.nextToken(); + parser.nextToken(); + parser.nextToken(); + + HybridQueryBuilder hybridQueryBuilder = HybridQueryBuilder.fromXContent(parser); + int count = hybridQueryBuilder.queries().size(); + if (count < 1) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] query must have at least one sub-query", HybridQueryBuilder.NAME) + ); + } + return count; + } + } catch (IllegalArgumentException e) { + log.error("Invalid hybrid query structure: {}", e.getMessage(), e); + throw e; + } catch (IOException e) { + log.error("Failed to parse hybrid query: {}", e.getMessage(), e); + throw new IllegalArgumentException(String.format(Locale.ROOT, "failed to parse [%s] query", HybridQueryBuilder.NAME), e); + } + } + + /** + * Extracts sub-query count from the raw map structure. + * Used when {@link NamedXContentRegistry} is not available (e.g. in unit tests). */ - public static void validateHybridQuery(Map fullQueryMap) throws IOException { + private static int getSubQueryCountFromMap(final Map queryMap) { + Map hybridMap = (Map) queryMap.get(HybridQueryBuilder.NAME); + if (!hybridMap.containsKey("queries") || !(hybridMap.get("queries") instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] query in search configuration does not have sub-queries", HybridQueryBuilder.NAME) + ); + } + List queries = (List) hybridMap.get("queries"); + if (queries.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] query must have at least one sub-query", HybridQueryBuilder.NAME) + ); + } + return queries.size(); + } + + /** + * Generates pooling weight configurations for expandCoverage. + * Produces N+1 configurations: 1 equal-weight + N one-hot configurations. + * + * @param numSubQueries number of sub-queries in the hybrid query (must be ≥ 1) + * @return unmodifiable list of unmodifiable weight configurations + */ + public static List> generatePoolingWeights(final int numSubQueries) { + if (numSubQueries < 1) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "numSubQueries must be at least 1, got [%d]", numSubQueries)); + } + + List> weightConfigs = new ArrayList<>(numSubQueries + 1); + + // Equal weights: null signals that the "parameters" key should be omitted from the pipeline, + // allowing the hybrid query processor to apply its default equal-weight behavior. + weightConfigs.add(null); + + for (int i = 0; i < numSubQueries; i++) { + List oneHotConfig = new ArrayList<>(numSubQueries); + for (int j = 0; j < numSubQueries; j++) { + oneHotConfig.add((i == j) ? 1.0f : 0.0f); + } + weightConfigs.add(Collections.unmodifiableList(oneHotConfig)); + } + + return Collections.unmodifiableList(weightConfigs); + } + + /** + * Creates a temporary search pipeline definition for hybrid search pooling with specified weights. + * @param weights list of weights, one per sub-query + * @param normalization normalization technique name + * @param combination combination technique name + * @return definition of a temporary search pipeline + */ + public static Map createPoolingSearchPipeline( + final List weights, + final String normalization, + final String combination + ) { + Map normalizationConfig = new HashMap<>(Map.of("technique", normalization)); + Map combinationConfig = new HashMap<>(Map.of("technique", combination)); + // When weights is null, omit the "parameters" key entirely — this lets the hybrid query + // processor apply its default equal-weight behavior (1.0 per sub-query). + if (weights != null) { + List weightsList = new ArrayList<>(weights.size()); + for (Float w : weights) { + weightsList.add(w.doubleValue()); + } + combinationConfig.put("parameters", new HashMap<>(Map.of("weights", weightsList))); + } + + Map processorConfig = new HashMap<>(Map.of("normalization", normalizationConfig, "combination", combinationConfig)); + Map phaseProcessor = new HashMap<>(Map.of("normalization-processor", processorConfig)); + Map pipeline = new HashMap<>(); + pipeline.put("phase_results_processors", List.of(phaseProcessor)); + return pipeline; + } + + /** + * Validate that the query in the search configuration is a hybrid query with exactly two sub-queries. + * @param fullQueryMap the parsed query body + * @throws IOException if the query cannot be parsed + */ + public static void validateHybridQuery(final Map fullQueryMap) throws IOException { if (fullQueryMap.containsKey("query") == false || fullQueryMap.get("query") instanceof Map == false) { throw new IllegalArgumentException("search configuration must have at least one query"); } Map queryMap = (Map) fullQueryMap.get("query"); - if (queryMap.containsKey("hybrid") == false || queryMap.get("hybrid") instanceof Map == false) { - throw new IllegalArgumentException("query in search configuration must be of type hybrid"); + if (queryMap.containsKey(HybridQueryBuilder.NAME) == false || queryMap.get(HybridQueryBuilder.NAME) instanceof Map == false) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "query in search configuration must be of type [%s]", HybridQueryBuilder.NAME) + ); } - Map hybridMap = (Map) queryMap.get("hybrid"); + Map hybridMap = (Map) queryMap.get(HybridQueryBuilder.NAME); if (hybridMap.containsKey("queries") == false || hybridMap.get("queries") instanceof List == false) { - throw new IllegalArgumentException("hybrid query in search configuration does not have sub-queries"); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] query in search configuration does not have sub-queries", HybridQueryBuilder.NAME) + ); } List queriesMap = (List) hybridMap.get("queries"); if (queriesMap.size() != NUMBER_OF_SUBQUERIES_IN_HYBRID_QUERY) { diff --git a/src/main/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManager.java b/src/main/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManager.java index 355f1850..41bc8493 100644 --- a/src/main/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManager.java +++ b/src/main/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManager.java @@ -573,6 +573,25 @@ public void deleteByQuery( client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, listener); } + /** + * Delete by query using a custom query builder + * @param query - query builder for matching documents to delete + * @param index - index on which delete operation has to be performed + * @param listener - action listener for async action + */ + public void deleteByQuery( + final org.opensearch.index.query.QueryBuilder query, + final SearchRelevanceIndices index, + final ActionListener listener + ) { + DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(index.getIndexName()); + deleteByQueryRequest.setConflicts(PROCEED); + deleteByQueryRequest.setBatchSize(BATCH_SIZE_FOR_DELETE_BY_QUERY); + deleteByQueryRequest.setQuery(query); + + client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, listener); + } + /** * Gets index mapping JSON content from the classpath * diff --git a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java index be99c792..f0f8d2e9 100644 --- a/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java +++ b/src/main/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessor.java @@ -10,6 +10,7 @@ import static org.opensearch.searchrelevance.common.MLConstants.LLM_JUDGMENT_RATING_TYPE; import static org.opensearch.searchrelevance.common.MLConstants.OVERWRITE_CACHE; import static org.opensearch.searchrelevance.common.MLConstants.PROMPT_TEMPLATE; +import static org.opensearch.searchrelevance.model.builder.SearchRequestBuilder.buildRequestForHybridSearch; import static org.opensearch.searchrelevance.model.builder.SearchRequestBuilder.buildSearchRequest; import static org.opensearch.searchrelevance.utils.ParserUtils.combinedIndexAndDocId; import static org.opensearch.searchrelevance.utils.ParserUtils.generatePromptTemplateCode; @@ -42,6 +43,7 @@ import org.opensearch.searchrelevance.dao.SearchConfigurationDao; import org.opensearch.searchrelevance.exception.SearchRelevanceException; import org.opensearch.searchrelevance.executors.LlmJudgmentTaskManager; +import org.opensearch.searchrelevance.experiment.QuerySourceUtil; import org.opensearch.searchrelevance.ml.ChunkResult; import org.opensearch.searchrelevance.ml.MLAccessor; import org.opensearch.searchrelevance.model.JudgmentCache; @@ -121,6 +123,7 @@ private void generateJudgmentRatingInternal(Map metadata, Action log.debug("No ratingType provided, defaulting to SCORE0_1"); } boolean overwriteCache = (boolean) metadata.get(OVERWRITE_CACHE); + boolean expandCoverage = Boolean.TRUE.equals(metadata.get("expandCoverage")); QuerySet querySet = querySetDao.getQuerySetSync(querySetId); List searchConfigurations = searchConfigurationList.stream() @@ -138,6 +141,7 @@ private void generateJudgmentRatingInternal(Map metadata, Action promptTemplate, ratingType, overwriteCache, + expandCoverage, listener ); } catch (Exception e) { @@ -157,6 +161,7 @@ private void generateLLMJudgmentsAsync( String promptTemplate, LLMJudgmentRatingType ratingType, boolean overwriteCache, + boolean expandCoverage, ActionListener>> listener ) { List queryTextsWithCustomInput = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList()); @@ -164,6 +169,13 @@ private void generateLLMJudgmentsAsync( log.info("Starting LLM judgment generation for {} total queries", totalQueries); + // Fire-and-forget cleanup of stale cache entries (older than 90 days) + try { + judgmentCacheDao.cleanupStaleEntries(); + } catch (Exception e) { + log.warn("Failed to trigger judgment cache cleanup - continuing without cleanup", e); + } + // Create judgment cache index upfront to prevent concurrent creation attempts StepListener cacheIndexListener = new StepListener<>(); judgmentCacheDao.createIndexIfAbsent(cacheIndexListener); @@ -182,7 +194,8 @@ private void generateLLMJudgmentsAsync( ignoreFailure, promptTemplate, ratingType, - overwriteCache + overwriteCache, + expandCoverage ); } catch (Exception e) { if (ignoreFailure) { @@ -229,7 +242,8 @@ private void generateLLMJudgmentsAsync( ignoreFailure, promptTemplate, ratingType, - overwriteCache + overwriteCache, + expandCoverage ); } catch (Exception e) { if (ignoreFailure) { @@ -274,7 +288,8 @@ private Map processQueryTextAsync( boolean ignoreFailure, String promptTemplate, LLMJudgmentRatingType ratingType, - boolean overwriteCache + boolean overwriteCache, + boolean expandCoverage ) { log.info("Processing query text judgment: {}", queryTextWithCustomInput); @@ -284,7 +299,7 @@ private Map processQueryTextAsync( try { // Step 1: Execute searches concurrently within this query text task - processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure); + processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure, expandCoverage); // Step 2: Deduplicate from cache (skip if overwriteCache is true) List docIds = new ArrayList<>(allHits.keySet()); @@ -333,33 +348,75 @@ private Map processQueryTextAsync( } } - private void processSearchConfigurationsAsync( + // Package-private for unit testing (expandCoverage pooling verification) + void processSearchConfigurationsAsync( List searchConfigurations, String queryText, int size, ConcurrentMap allHits, - boolean ignoreFailure + boolean ignoreFailure, + boolean expandCoverage ) throws Exception { - List> searchFutures = searchConfigurations.stream().map(config -> { - CompletableFuture future = new CompletableFuture<>(); - SearchRequest searchRequest = buildSearchRequest(config.index(), config.query(), queryText, config.searchPipeline(), size); - client.search(searchRequest, ActionListener.wrap(future::complete, future::completeExceptionally)); - - return future.thenAccept(response -> { - if (response.getHits().getTotalHits().value() > 0) { - for (SearchHit hit : response.getHits().getHits()) { - allHits.put(hit.getId(), hit); - } - log.debug("Collected {} hits from index: {}", response.getHits().getHits().length, config.index()); + List> searchFutures = new ArrayList<>(); + + for (SearchConfiguration config : searchConfigurations) { + if (expandCoverage) { + // Validate hybrid query and dynamically generate pooling weight configurations + Map queryMap = OBJECT_MAPPER.readValue(config.query(), new TypeReference>() { + }); + if (!QuerySourceUtil.isHybridQueryAnySize(queryMap)) { + throw new IllegalArgumentException("expandCoverage requires a hybrid search query with at least 1 sub-query."); } - }).exceptionally(e -> { - log.warn("Search failed for index: {}, continuing with other searches", config.index(), e); - return null; // Continue processing other searches - }); - }).toList(); + int numSubQueries = QuerySourceUtil.getSubQueryCount(queryMap); + List> poolWeights = QuerySourceUtil.generatePoolingWeights(numSubQueries); + log.info( + "expandCoverage enabled: executing {} pooling searches for query: {} ({} sub-queries)", + poolWeights.size(), + queryText, + numSubQueries + ); + for (List weights : poolWeights) { + Map pipeline = QuerySourceUtil.createPoolingSearchPipeline( + weights, + QuerySourceUtil.POOL_NORMALIZATION, + QuerySourceUtil.POOL_COMBINATION + ); + SearchRequest searchRequest = buildRequestForHybridSearch(config.index(), config.query(), pipeline, queryText, size); + CompletableFuture future = new CompletableFuture<>(); + client.search(searchRequest, ActionListener.wrap(future::complete, future::completeExceptionally)); + searchFutures.add(future.thenAccept(response -> { + if (response.getHits().getTotalHits().value() > 0) { + for (SearchHit hit : response.getHits().getHits()) { + allHits.putIfAbsent(hit.getId(), hit); + } + log.debug("Pooling: collected {} hits with weights {}", response.getHits().getHits().length, weights); + } + }).exceptionally(e -> { + log.warn("Pooling search failed for weights {}, continuing", weights, e); + return null; + })); + } + } else { + // Existing behavior: single search with config's pipeline + CompletableFuture future = new CompletableFuture<>(); + SearchRequest searchRequest = buildSearchRequest(config.index(), config.query(), queryText, config.searchPipeline(), size); + client.search(searchRequest, ActionListener.wrap(future::complete, future::completeExceptionally)); + searchFutures.add(future.thenAccept(response -> { + if (response.getHits().getTotalHits().value() > 0) { + for (SearchHit hit : response.getHits().getHits()) { + allHits.put(hit.getId(), hit); + } + log.debug("Collected {} hits from index: {}", response.getHits().getHits().length, config.index()); + } + }).exceptionally(e -> { + log.warn("Search failed for index: {}, continuing with other searches", config.index(), e); + return null; + })); + } + } CompletableFuture.allOf(searchFutures.toArray(new CompletableFuture[0])).join(); - log.info("Search phase completed. Total hits collected: {}", allHits.size()); + log.info("Search phase completed. Total hits collected: {} (expandCoverage={})", allHits.size(), expandCoverage); } private List deduplicateFromCache( diff --git a/src/main/java/org/opensearch/searchrelevance/model/builder/SearchRequestBuilder.java b/src/main/java/org/opensearch/searchrelevance/model/builder/SearchRequestBuilder.java index fc5458c1..5bd85b76 100644 --- a/src/main/java/org/opensearch/searchrelevance/model/builder/SearchRequestBuilder.java +++ b/src/main/java/org/opensearch/searchrelevance/model/builder/SearchRequestBuilder.java @@ -48,6 +48,14 @@ public static void initialize(NamedXContentRegistry registry) { log.debug("SearchRequestBuilder initialized with NamedXContentRegistry"); } + /** + * Returns the NamedXContentRegistry used for parsing query types. + * @return the initialized registry, or null if not yet initialized + */ + public static NamedXContentRegistry getNamedXContentRegistry() { + return NAMED_XCONTENT_REGISTRY; + } + private static XContentParser newParserWithRegistry(String json) throws IOException { if (NAMED_XCONTENT_REGISTRY == null) { throw new IllegalStateException( diff --git a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java index 569fe1f9..07eb7d42 100644 --- a/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java +++ b/src/main/java/org/opensearch/searchrelevance/plugin/SearchRelevancePlugin.java @@ -10,6 +10,7 @@ import static org.opensearch.searchrelevance.common.PluginConstants.EXPERIMENT_INDEX; import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENT_CACHE_INDEX; import static org.opensearch.searchrelevance.common.PluginConstants.SCHEDULED_JOBS_INDEX; +import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_QUERY_SET_MAX_LIMIT; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_ENABLED; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL; @@ -222,6 +223,7 @@ public Collection createComponents( ); this.metricsHelper = new MetricsHelper(clusterService, client, judgmentDao, evaluationResultDao, experimentVariantDao); this.settingsAccessor = new SearchRelevanceSettingsAccessor(clusterService, environment.settings()); + this.judgmentCacheDao.setSettingsAccessor(this.settingsAccessor); this.clusterUtil = new ClusterUtil(clusterService); this.cronUtil = new CronUtil(settingsAccessor); this.infoStatsManager = new InfoStatsManager(settingsAccessor); @@ -405,7 +407,8 @@ public List> getSettings() { SEARCH_RELEVANCE_QUERY_SET_MAX_LIMIT, SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_ENABLED, SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_TIMEOUT, - SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL + SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL, + SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL ); } diff --git a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java index 1676d677..7009d7c9 100644 --- a/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java +++ b/src/main/java/org/opensearch/searchrelevance/rest/RestPutJudgmentAction.java @@ -37,8 +37,6 @@ import java.util.Map; import java.util.Optional; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.index.IndexResponse; import org.opensearch.core.action.ActionListener; @@ -64,13 +62,14 @@ import org.opensearch.transport.client.node.NodeClient; import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; /** * Rest Action to facilitate requests to create a judgment. */ @AllArgsConstructor +@Log4j2 public class RestPutJudgmentAction extends BaseRestHandler { - private static final Logger LOGGER = LogManager.getLogger(RestPutJudgmentAction.class); private static final String PUT_JUDGMENT_ACTION = "put_judgment_action"; private SearchRelevanceSettingsAccessor settingsAccessor; @@ -166,6 +165,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } boolean overwriteCache = Optional.ofNullable((Boolean) source.get(OVERWRITE_CACHE)).orElse(Boolean.FALSE); + boolean expandCoverage = Optional.ofNullable((Boolean) source.get("expandCoverage")).orElse(Boolean.FALSE); createRequest = new PutLlmJudgmentRequest( type, @@ -180,7 +180,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ignoreFailure, promptTemplate, llmJudgmentRatingType, - overwriteCache + overwriteCache, + expandCoverage ); } case UBI_JUDGMENT -> { @@ -237,7 +238,7 @@ public void onFailure(Exception e) { try { channel.sendResponse(new BytesRestResponse(channel, ExceptionsHelper.status(e), e)); } catch (IOException ex) { - LOGGER.error("Failed to send error response", ex); + log.error("Failed to send error response", ex); } } }); diff --git a/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettings.java b/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettings.java index 6ce1bc30..587ee44b 100644 --- a/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettings.java +++ b/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettings.java @@ -85,6 +85,19 @@ public class SearchRelevanceSettings { * times the jobs are run will be greater than or equal to the minimum * interval defined here. */ + /** + * Judgment cache TTL. When set to a positive value, cache entries older than this TTL + * are deleted lazily at the start of each judgment generation. + * Default: -1 (disabled — infinite retention, no data deletion). + */ + public static final Setting SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL = Setting.timeSetting( + "plugins.search_relevance.judgment_cache.ttl", + TimeValue.MINUS_ONE, + TimeValue.MINUS_ONE, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL = Setting.positiveTimeSetting( "plugins.search_relevance.scheduled_experiments_minimum_interval", TimeValue.timeValueSeconds(1), diff --git a/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettingsAccessor.java b/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettingsAccessor.java index 930b6603..49f8ac7c 100644 --- a/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettingsAccessor.java +++ b/src/main/java/org/opensearch/searchrelevance/settings/SearchRelevanceSettingsAccessor.java @@ -29,6 +29,8 @@ public class SearchRelevanceSettingsAccessor { private volatile boolean isScheduledExperimentsEnabled; @Getter private volatile TimeValue scheduledExperimentsTimeout; + + private volatile TimeValue judgmentCacheTtl; @Getter private volatile TimeValue scheduledExperimentsMinimumInterval; @@ -45,6 +47,7 @@ public SearchRelevanceSettingsAccessor(ClusterService clusterService, Settings s isScheduledExperimentsEnabled = SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_ENABLED.get(settings); scheduledExperimentsTimeout = SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_TIMEOUT.get(settings); scheduledExperimentsMinimumInterval = SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL.get(settings); + judgmentCacheTtl = SearchRelevanceSettings.SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(settings); registerSettingsCallbacks(clusterService); } @@ -80,5 +83,14 @@ private void registerSettingsCallbacks(ClusterService clusterService) { .addSettingsUpdateConsumer(SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL, value -> { scheduledExperimentsMinimumInterval = value; }); + + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(SearchRelevanceSettings.SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL, value -> { + judgmentCacheTtl = value; + }); + } + + public TimeValue getJudgmentCacheTtl() { + return judgmentCacheTtl; } } diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java index 2a0ea226..9b6225be 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutJudgmentTransportAction.java @@ -19,8 +19,6 @@ import java.util.Map; import java.util.UUID; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -38,13 +36,14 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class PutJudgmentTransportAction extends HandledTransportAction { private final ClusterService clusterService; private final JudgmentDao judgmentDao; private final JudgmentsProcessorFactory judgmentsProcessorFactory; - private static final Logger LOGGER = LogManager.getLogger(PutJudgmentTransportAction.class); - @Inject public PutJudgmentTransportAction( ClusterService clusterService, @@ -84,12 +83,12 @@ protected void doExecute(Task task, PutJudgmentRequest request, ActionListener { - LOGGER.error("Failed to create initial judgment", e); + log.error("Failed to create initial judgment", e); listener.onFailure(new SearchRelevanceException("Failed to create initial judgment", e, RestStatus.INTERNAL_SERVER_ERROR)); })); } catch (Exception e) { - LOGGER.error("Failed to process judgment request", e); + log.error("Failed to process judgment request", e); listener.onFailure(new SearchRelevanceException("Failed to process judgment request", e, RestStatus.INTERNAL_SERVER_ERROR)); } } @@ -109,6 +108,7 @@ private Map buildMetadata(PutJudgmentRequest request) { metadata.put(PROMPT_TEMPLATE, llmRequest.getPromptTemplate()); metadata.put(LLM_JUDGMENT_RATING_TYPE, llmRequest.getLlmJudgmentRatingType()); metadata.put(OVERWRITE_CACHE, llmRequest.isOverwriteCache()); + metadata.put("expandCoverage", llmRequest.isExpandCoverage()); } case UBI_JUDGMENT -> { PutUbiJudgmentRequest ubiRequest = (PutUbiJudgmentRequest) request; @@ -130,11 +130,11 @@ private Map buildMetadata(PutJudgmentRequest request) { } private void triggerAsyncProcessing(String judgmentId, PutJudgmentRequest request, Map metadata) { - LOGGER.info("Starting async processing for judgment: {}, type: {}, metadata: {}", judgmentId, request.getType(), metadata); + log.info("Starting async processing for judgment: {}, type: {}, metadata: {}", judgmentId, request.getType(), metadata); BaseJudgmentsProcessor processor = judgmentsProcessorFactory.getProcessor(request.getType()); processor.generateJudgmentRating(metadata, ActionListener.wrap(judgmentRatings -> { - LOGGER.info( + log.info( "Generated judgment ratings for {}, ratings size: {}", judgmentId, judgmentRatings != null ? judgmentRatings.size() : 0 @@ -162,14 +162,14 @@ private void updateFinalJudgment( judgmentDao.updateJudgment( finalJudgment, ActionListener.wrap( - response -> LOGGER.debug("Updated final judgment: {}", judgmentId), + response -> log.debug("Updated final judgment: {}", judgmentId), error -> handleAsyncFailure(judgmentId, request, "Failed to update final judgment", error) ) ); } private void handleAsyncFailure(String judgmentId, PutJudgmentRequest request, String message, Exception error) { - LOGGER.error(message + " for judgment: " + judgmentId, error); + log.error(message + " for judgment: " + judgmentId, error); Judgment errorJudgment = new Judgment( judgmentId, @@ -184,8 +184,8 @@ private void handleAsyncFailure(String judgmentId, PutJudgmentRequest request, S judgmentDao.updateJudgment( errorJudgment, ActionListener.wrap( - response -> LOGGER.info("Updated judgment {} status to ERROR", judgmentId), - e -> LOGGER.error("Failed to update error status for judgment: " + judgmentId, e) + response -> log.info("Updated judgment {} status to ERROR", judgmentId), + e -> log.error("Failed to update error status for judgment: " + judgmentId, e) ) ); } diff --git a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java index 24328e9b..7f226d69 100644 --- a/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java +++ b/src/main/java/org/opensearch/searchrelevance/transport/judgment/PutLlmJudgmentRequest.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.util.List; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.searchrelevance.model.JudgmentType; @@ -57,6 +58,13 @@ public class PutLlmJudgmentRequest extends PutJudgmentRequest { */ private boolean overwriteCache; + /** + * Flag to enable expanded document coverage for hybrid search optimization. + * When true, pools documents from 3 hybrid weight configurations [0.0, 0.5, 1.0] + * instead of executing a single search. Requires a hybrid search configuration. + */ + private boolean expandCoverage; + public PutLlmJudgmentRequest( @NonNull JudgmentType type, @NonNull String name, @@ -70,7 +78,8 @@ public PutLlmJudgmentRequest( boolean ignoreFailure, String promptTemplate, LLMJudgmentRatingType llmJudgmentRatingType, - boolean overwriteCache + boolean overwriteCache, + boolean expandCoverage ) { super(type, name, description); this.modelId = modelId; @@ -83,6 +92,7 @@ public PutLlmJudgmentRequest( this.promptTemplate = promptTemplate; this.llmJudgmentRatingType = llmJudgmentRatingType; this.overwriteCache = overwriteCache; + this.expandCoverage = expandCoverage; } public PutLlmJudgmentRequest(StreamInput in) throws IOException { @@ -97,6 +107,11 @@ public PutLlmJudgmentRequest(StreamInput in) throws IOException { this.promptTemplate = in.readOptionalString(); this.llmJudgmentRatingType = in.readOptionalWriteable(LLMJudgmentRatingType::readFromStream); this.overwriteCache = Boolean.TRUE.equals(in.readOptionalBoolean()); + if (in.getVersion().onOrAfter(Version.V_3_6_0)) { + this.expandCoverage = Boolean.TRUE.equals(in.readOptionalBoolean()); + } else { + this.expandCoverage = false; + } } @Override @@ -112,6 +127,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(promptTemplate); out.writeOptionalWriteable(llmJudgmentRatingType); out.writeOptionalBoolean(overwriteCache); + if (out.getVersion().onOrAfter(Version.V_3_6_0)) { + out.writeOptionalBoolean(expandCoverage); + } } public String getModelId() { @@ -154,4 +172,8 @@ public boolean isOverwriteCache() { return overwriteCache; } + public boolean isExpandCoverage() { + return expandCoverage; + } + } diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentExpandCoverageIT.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentExpandCoverageIT.java new file mode 100644 index 00000000..d7cdc4e1 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/LlmJudgmentExpandCoverageIT.java @@ -0,0 +1,235 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.action.judgment; + +import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENTS_URL; +import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENT_INDEX; +import static org.opensearch.searchrelevance.common.PluginConstants.QUERYSETS_URL; +import static org.opensearch.searchrelevance.common.PluginConstants.SEARCH_CONFIGURATIONS_URL; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.Response; +import org.opensearch.rest.RestRequest; +import org.opensearch.searchrelevance.BaseSearchRelevanceIT; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import com.google.common.collect.ImmutableList; + +import lombok.SneakyThrows; + +/** + * Integration tests for LLM Judgment expandCoverage functionality. + * Tests the expandCoverage flag with hybrid and non-hybrid search configurations. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE) +public class LlmJudgmentExpandCoverageIT extends BaseSearchRelevanceIT { + + private static final String TEST_INDEX = "test_expand_coverage_products"; + + /** + * Helper: creates test index, ingests documents, creates query set and returns querySetId. + */ + @SneakyThrows + private String setupTestIndexAndQuerySet() { + // Create test index + String indexConfig = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateTestIndex.json").toURI())); + createIndexWithConfiguration(TEST_INDEX, indexConfig); + + // Bulk ingest test documents + String bulkData = Files.readString(Path.of(classLoader.getResource("llmjudgment/BulkIngestProducts.json").toURI())); + bulkIngest(TEST_INDEX, bulkData); + + // Create query set + String querySetBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateQuerySetSimple.json").toURI())); + Response querySetResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + QUERYSETS_URL, + null, + toHttpEntity(querySetBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map querySetResult = entityAsMap(querySetResponse); + return querySetResult.get("query_set_id").toString(); + } + + /** + * Helper: creates a search configuration and returns searchConfigId. + */ + @SneakyThrows + private String createSearchConfig(String resourcePath) { + String searchConfigBody = Files.readString(Path.of(classLoader.getResource(resourcePath).toURI())); + searchConfigBody = replacePlaceholders(searchConfigBody, Map.of("index", TEST_INDEX)); + Response searchConfigResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + SEARCH_CONFIGURATIONS_URL, + null, + toHttpEntity(searchConfigBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map searchConfigResult = entityAsMap(searchConfigResponse); + return searchConfigResult.get("search_configuration_id").toString(); + } + + @SneakyThrows + public void testExpandCoverageWithHybridConfig_thenSuccessful() { + // Setup + String querySetId = setupTestIndexAndQuerySet(); + String searchConfigId = createSearchConfig("llmjudgment/CreateSearchConfigurationHybrid.json"); + + // Create LLM judgment with expandCoverage=true + hybrid config + String llmJudgmentBody = Files.readString( + Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentExpandCoverage.json").toURI()) + ); + llmJudgmentBody = replacePlaceholders(llmJudgmentBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response llmJudgmentResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(llmJudgmentBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map llmJudgmentResult = entityAsMap(llmJudgmentResponse); + String judgmentId = llmJudgmentResult.get("judgment_id").toString(); + assertNotNull("Judgment ID should not be null", judgmentId); + + // Wait for judgment processing + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify the judgment was created with expandCoverage in metadata + String getJudgmentUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId); + Response getJudgmentResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentDoc = entityAsMap(getJudgmentResponse); + assertNotNull(judgmentDoc); + assertEquals(judgmentId, judgmentDoc.get("_id")); + + Map source = (Map) judgmentDoc.get("_source"); + assertNotNull(source); + assertEquals("LLM_JUDGMENT", source.get("type")); + + // Verify metadata contains expandCoverage=true + Map metadata = (Map) source.get("metadata"); + assertNotNull(metadata); + assertEquals(true, metadata.get("expandCoverage")); + } + + @SneakyThrows + public void testWithoutExpandCoverage_thenExistingBehaviorUnchanged() { + // Setup + String querySetId = setupTestIndexAndQuerySet(); + String searchConfigId = createSearchConfig("llmjudgment/CreateSearchConfiguration.json"); + + // Create LLM judgment WITHOUT expandCoverage (standard non-hybrid config) + String llmJudgmentBody = Files.readString(Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentMinimal.json").toURI())); + llmJudgmentBody = replacePlaceholders(llmJudgmentBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + Response llmJudgmentResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(llmJudgmentBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map llmJudgmentResult = entityAsMap(llmJudgmentResponse); + String judgmentId = llmJudgmentResult.get("judgment_id").toString(); + assertNotNull("Judgment ID should not be null", judgmentId); + + // Wait for judgment processing + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify the judgment was created without expandCoverage + String getJudgmentUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId); + Response getJudgmentResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentDoc = entityAsMap(getJudgmentResponse); + Map source = (Map) judgmentDoc.get("_source"); + Map metadata = (Map) source.get("metadata"); + + // expandCoverage should be false or absent when not provided + Object expandCoverage = metadata.get("expandCoverage"); + assertTrue("expandCoverage should be false or null when not provided", expandCoverage == null || expandCoverage.equals(false)); + } + + @SneakyThrows + public void testExpandCoverageWithNonHybridConfig_thenValidationError() { + // Setup + String querySetId = setupTestIndexAndQuerySet(); + // Use standard (non-hybrid) search config + String searchConfigId = createSearchConfig("llmjudgment/CreateSearchConfiguration.json"); + + // Create LLM judgment with expandCoverage=true + non-hybrid config + String llmJudgmentBody = Files.readString( + Path.of(classLoader.getResource("llmjudgment/CreateLlmJudgmentExpandCoverage.json").toURI()) + ); + llmJudgmentBody = replacePlaceholders(llmJudgmentBody, Map.of("querySetId", querySetId, "searchConfigId", searchConfigId)); + + // The request should succeed (validation is async during processing), + // but the judgment should end up with an error status + Response llmJudgmentResponse = makeRequest( + client(), + RestRequest.Method.PUT.name(), + JUDGMENTS_URL, + null, + toHttpEntity(llmJudgmentBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map llmJudgmentResult = entityAsMap(llmJudgmentResponse); + String judgmentId = llmJudgmentResult.get("judgment_id").toString(); + assertNotNull("Judgment ID should not be null", judgmentId); + + // Wait for judgment processing (will fail during async processing) + Thread.sleep(DEFAULT_INTERVAL_MS); + + // Verify the judgment was created — status may indicate failure + // because the non-hybrid query will fail validation during processing + String getJudgmentUrl = String.join("/", JUDGMENT_INDEX, "_doc", judgmentId); + Response getJudgmentResponse = makeRequest( + adminClient(), + RestRequest.Method.GET.name(), + getJudgmentUrl, + null, + null, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map judgmentDoc = entityAsMap(getJudgmentResponse); + Map source = (Map) judgmentDoc.get("_source"); + assertNotNull(source); + + // The judgment should have been created with expandCoverage=true in metadata + Map metadata = (Map) source.get("metadata"); + assertEquals(true, metadata.get("expandCoverage")); + + // Since the search config is non-hybrid, the processing should fail, + // resulting in either FAILED status or empty ratings + String status = (String) source.get("status"); + // Status could be FAILED or COMPLETED with empty ratings depending on ignoreFailure + assertNotNull("Judgment should have a status", status); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java index e7c67476..76b93768 100644 --- a/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java +++ b/src/test/java/org/opensearch/searchrelevance/action/judgment/PutJudgmentActionTests.java @@ -93,7 +93,8 @@ public void testLlmJudgmentRequestStreams() throws IOException { false, "test_prompt_template", LLMJudgmentRatingType.SCORE0_1, - true + true, + false // expandCoverage ); BytesStreamOutput output = new BytesStreamOutput(); @@ -114,6 +115,36 @@ public void testLlmJudgmentRequestStreams() throws IOException { assertEquals("test_prompt_template", serialized.getPromptTemplate()); assertEquals(LLMJudgmentRatingType.SCORE0_1, serialized.getLlmJudgmentRatingType()); assertEquals(true, serialized.isOverwriteCache()); + assertEquals(false, serialized.isExpandCoverage()); + } + + public void testLlmJudgmentRequestStreamsWithExpandCoverageTrue() throws IOException { + PutLlmJudgmentRequest request = new PutLlmJudgmentRequest( + JudgmentType.LLM_JUDGMENT, + "test_name", + "test_description", + "test_model_id", + "test_query_set_id", + List.of("config1"), + 10, + 1000, + List.of("field1"), + false, + null, + null, + false, + true // expandCoverage = true + ); + + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + StreamInput in = StreamInput.wrap(output.bytes().toBytesRef().bytes); + PutLlmJudgmentRequest serialized = new PutLlmJudgmentRequest(in); + + assertEquals("test_name", serialized.getName()); + assertEquals(JudgmentType.LLM_JUDGMENT, serialized.getType()); + assertEquals(true, serialized.isExpandCoverage()); + assertEquals(false, serialized.isOverwriteCache()); } public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOException { @@ -130,7 +161,8 @@ public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOExcep true, null, null, - false + false, + false // expandCoverage ); BytesStreamOutput output = new BytesStreamOutput(); @@ -144,6 +176,7 @@ public void testLlmJudgmentRequestStreamsWithNullOptionalFields() throws IOExcep assertNull(serialized.getPromptTemplate()); assertNull(serialized.getLlmJudgmentRatingType()); assertEquals(false, serialized.isOverwriteCache()); + assertEquals(false, serialized.isExpandCoverage()); } public void testUbiJudgmentWithCustomIndexes() throws IOException { diff --git a/src/test/java/org/opensearch/searchrelevance/dao/JudgmentCacheDaoCleanupTests.java b/src/test/java/org/opensearch/searchrelevance/dao/JudgmentCacheDaoCleanupTests.java new file mode 100644 index 00000000..e12fb838 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/dao/JudgmentCacheDaoCleanupTests.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.dao; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.searchrelevance.indices.SearchRelevanceIndices; +import org.opensearch.searchrelevance.indices.SearchRelevanceIndicesManager; +import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; +import org.opensearch.test.OpenSearchTestCase; + +import lombok.SneakyThrows; + +/** + * Tests for JudgmentCacheDao cleanup logic with the TTL setting. + */ +public class JudgmentCacheDaoCleanupTests extends OpenSearchTestCase { + + private SearchRelevanceIndicesManager mockIndicesManager; + private SearchRelevanceSettingsAccessor mockSettingsAccessor; + private JudgmentCacheDao dao; + + @Override + @SneakyThrows + public void setUp() { + super.setUp(); + mockIndicesManager = mock(SearchRelevanceIndicesManager.class); + mockSettingsAccessor = mock(SearchRelevanceSettingsAccessor.class); + dao = new JudgmentCacheDao(mockIndicesManager); + } + + public void testCleanup_NoSettingsAccessor_IsNoOp() { + // Don't call setSettingsAccessor + dao.cleanupStaleEntries(); + verify(mockIndicesManager, never()).deleteByQuery(any(), any(), any()); + } + + public void testCleanup_TtlDisabled_IsNoOp() { + dao.setSettingsAccessor(mockSettingsAccessor); + when(mockSettingsAccessor.getJudgmentCacheTtl()).thenReturn(TimeValue.MINUS_ONE); + + dao.cleanupStaleEntries(); + verify(mockIndicesManager, never()).deleteByQuery(any(), any(), any()); + } + + public void testCleanup_TtlEnabled_TriggersDeleteByQuery() { + dao.setSettingsAccessor(mockSettingsAccessor); + when(mockSettingsAccessor.getJudgmentCacheTtl()).thenReturn(TimeValue.timeValueDays(90)); + + dao.cleanupStaleEntries(); + verify(mockIndicesManager).deleteByQuery(any(QueryBuilder.class), eq(SearchRelevanceIndices.JUDGMENT_CACHE), any()); + } + + public void testCleanup_VeryShortTtl_UsesMinimum1Day() { + dao.setSettingsAccessor(mockSettingsAccessor); + when(mockSettingsAccessor.getJudgmentCacheTtl()).thenReturn(TimeValue.timeValueHours(1)); + + dao.cleanupStaleEntries(); + // Should still trigger cleanup (with 1 day minimum) + verify(mockIndicesManager).deleteByQuery(any(QueryBuilder.class), eq(SearchRelevanceIndices.JUDGMENT_CACHE), any()); + } + + public void testCleanupExplicitDays_AlwaysTriggersDeleteByQuery() { + // No settings accessor set + dao.cleanupStaleEntries(30); + verify(mockIndicesManager).deleteByQuery(any(QueryBuilder.class), eq(SearchRelevanceIndices.JUDGMENT_CACHE), any()); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/experiment/QuerySourceUtilTests.java b/src/test/java/org/opensearch/searchrelevance/experiment/QuerySourceUtilTests.java index 1cb2620e..2f31f3c9 100644 --- a/src/test/java/org/opensearch/searchrelevance/experiment/QuerySourceUtilTests.java +++ b/src/test/java/org/opensearch/searchrelevance/experiment/QuerySourceUtilTests.java @@ -9,7 +9,10 @@ import static org.opensearch.searchrelevance.experiment.ExperimentOptionsForHybridSearch.EXPERIMENT_OPTION_COMBINATION_TECHNIQUE; import static org.opensearch.searchrelevance.experiment.ExperimentOptionsForHybridSearch.EXPERIMENT_OPTION_NORMALIZATION_TECHNIQUE; +import static org.opensearch.searchrelevance.experiment.QuerySourceUtil.POOL_COMBINATION; +import static org.opensearch.searchrelevance.experiment.QuerySourceUtil.POOL_NORMALIZATION; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -101,7 +104,7 @@ public void testValidateHybridQuery_MissingHybrid() { IllegalArgumentException.class, () -> QuerySourceUtil.validateHybridQuery(fullQuery) ); - assertEquals("query in search configuration must be of type hybrid", exception.getMessage()); + assertEquals("query in search configuration must be of type [hybrid]", exception.getMessage()); } public void testValidateHybridQuery_InvalidHybridType() { @@ -114,7 +117,7 @@ public void testValidateHybridQuery_InvalidHybridType() { IllegalArgumentException.class, () -> QuerySourceUtil.validateHybridQuery(fullQuery) ); - assertEquals("query in search configuration must be of type hybrid", exception.getMessage()); + assertEquals("query in search configuration must be of type [hybrid]", exception.getMessage()); } public void testValidateHybridQuery_MissingQueries() { @@ -128,7 +131,7 @@ public void testValidateHybridQuery_MissingQueries() { IllegalArgumentException.class, () -> QuerySourceUtil.validateHybridQuery(fullQuery) ); - assertEquals("hybrid query in search configuration does not have sub-queries", exception.getMessage()); + assertEquals("[hybrid] query in search configuration does not have sub-queries", exception.getMessage()); } public void testValidateHybridQuery_InvalidQueriesType() { @@ -143,7 +146,7 @@ public void testValidateHybridQuery_InvalidQueriesType() { IllegalArgumentException.class, () -> QuerySourceUtil.validateHybridQuery(fullQuery) ); - assertEquals("hybrid query in search configuration does not have sub-queries", exception.getMessage()); + assertEquals("[hybrid] query in search configuration does not have sub-queries", exception.getMessage()); } public void testValidateHybridQuery_whenOneSubquery_thenFail() { @@ -176,4 +179,258 @@ public void testValidateHybridQuery_whenThreeSubqueries_thenFail() { ); assertEquals("invalid hybrid query: expected exactly [2] sub-queries but found [3]", exception.getMessage()); } + + public void testIsHybridQuery_ValidHybrid_ReturnsTrue() { + Map hybridQueries = new HashMap<>(); + hybridQueries.put("queries", Arrays.asList(new HashMap<>(), new HashMap<>())); + Map hybrid = new HashMap<>(); + hybrid.put("hybrid", hybridQueries); + Map fullQuery = new HashMap<>(); + fullQuery.put("query", hybrid); + + assertTrue(QuerySourceUtil.isHybridQuery(fullQuery)); + } + + public void testIsHybridQuery_NonHybridQuery_ReturnsFalse() { + Map matchQuery = new HashMap<>(); + matchQuery.put("match", Map.of("title", "test")); + Map fullQuery = new HashMap<>(); + fullQuery.put("query", matchQuery); + + assertFalse(QuerySourceUtil.isHybridQuery(fullQuery)); + } + + public void testIsHybridQuery_EmptyMap_ReturnsFalse() { + Map emptyMap = new HashMap<>(); + assertFalse(QuerySourceUtil.isHybridQuery(emptyMap)); + } + + public void testIsHybridQuery_NullMap_ReturnsFalse() { + assertFalse(QuerySourceUtil.isHybridQuery(null)); + } + + public void testIsHybridQuery_WrongSubqueryCount_ReturnsFalse() { + // isHybridQuery is strict: requires exactly 2 sub-queries + Map hybridQueries = new HashMap<>(); + hybridQueries.put("queries", Collections.singletonList(new HashMap<>())); + Map hybrid = new HashMap<>(); + hybrid.put("hybrid", hybridQueries); + Map fullQuery = new HashMap<>(); + fullQuery.put("query", hybrid); + + assertFalse(QuerySourceUtil.isHybridQuery(fullQuery)); + } + + public void testIsHybridQueryAnySize_1SubQuery_ReturnsTrue() { + Map fullQuery = buildHybridQuery(1); + assertTrue(QuerySourceUtil.isHybridQueryAnySize(fullQuery)); + } + + public void testIsHybridQueryAnySize_2SubQueries_ReturnsTrue() { + Map fullQuery = buildHybridQuery(2); + assertTrue(QuerySourceUtil.isHybridQueryAnySize(fullQuery)); + } + + public void testIsHybridQueryAnySize_3SubQueries_ReturnsTrue() { + Map fullQuery = buildHybridQuery(3); + assertTrue(QuerySourceUtil.isHybridQueryAnySize(fullQuery)); + } + + public void testIsHybridQueryAnySize_5SubQueries_ReturnsTrue() { + Map fullQuery = buildHybridQuery(5); + assertTrue(QuerySourceUtil.isHybridQueryAnySize(fullQuery)); + } + + public void testIsHybridQueryAnySize_NonHybrid_ReturnsFalse() { + Map matchQuery = new HashMap<>(); + matchQuery.put("match", Map.of("title", "test")); + Map fullQuery = new HashMap<>(); + fullQuery.put("query", matchQuery); + + assertFalse(QuerySourceUtil.isHybridQueryAnySize(fullQuery)); + } + + public void testIsHybridQueryAnySize_Null_ReturnsFalse() { + assertFalse(QuerySourceUtil.isHybridQueryAnySize(null)); + } + + public void testGetSubQueryCount_2SubQueries() { + Map fullQuery = buildHybridQuery(2); + assertEquals(2, QuerySourceUtil.getSubQueryCount(fullQuery)); + } + + public void testGetSubQueryCount_3SubQueries() { + Map fullQuery = buildHybridQuery(3); + assertEquals(3, QuerySourceUtil.getSubQueryCount(fullQuery)); + } + + public void testGetSubQueryCount_5SubQueries() { + Map fullQuery = buildHybridQuery(5); + assertEquals(5, QuerySourceUtil.getSubQueryCount(fullQuery)); + } + + public void testGetSubQueryCount_NonHybrid_Throws() { + Map fullQuery = new HashMap<>(); + fullQuery.put("query", Map.of("match", Map.of("title", "test"))); + assertThrows(IllegalArgumentException.class, () -> QuerySourceUtil.getSubQueryCount(fullQuery)); + } + + public void testGetSubQueryCount_Null_Throws() { + assertThrows(IllegalArgumentException.class, () -> QuerySourceUtil.getSubQueryCount(null)); + } + + public void testGeneratePoolingWeights_1SubQuery() { + List> weights = QuerySourceUtil.generatePoolingWeights(1); + assertEquals(2, weights.size()); // 1 equal + 1 one-hot + // Equal weights: null (omit parameters to trigger default equal-weight behavior) + assertNull(weights.get(0)); + // One-hot: [1.0] + assertEquals(1.0f, weights.get(1).get(0), 0.001); + } + + public void testGeneratePoolingWeights_2SubQueries() { + List> weights = QuerySourceUtil.generatePoolingWeights(2); + assertEquals(3, weights.size()); // 1 equal + 2 one-hot + + // Equal weights: null (omit parameters to trigger default equal-weight behavior) + assertNull(weights.get(0)); + // One-hot: [1, 0] + assertEquals(1.0f, weights.get(1).get(0), 0.001); + assertEquals(0.0f, weights.get(1).get(1), 0.001); + // One-hot: [0, 1] + assertEquals(0.0f, weights.get(2).get(0), 0.001); + assertEquals(1.0f, weights.get(2).get(1), 0.001); + } + + public void testGeneratePoolingWeights_3SubQueries() { + List> weights = QuerySourceUtil.generatePoolingWeights(3); + assertEquals(4, weights.size()); // 1 equal + 3 one-hot + + // Equal weights: null (omit parameters to trigger default equal-weight behavior) + assertNull(weights.get(0)); + // One-hot: [1, 0, 0] + assertEquals(1.0f, weights.get(1).get(0), 0.001); + assertEquals(0.0f, weights.get(1).get(1), 0.001); + assertEquals(0.0f, weights.get(1).get(2), 0.001); + // One-hot: [0, 1, 0] + assertEquals(0.0f, weights.get(2).get(0), 0.001); + assertEquals(1.0f, weights.get(2).get(1), 0.001); + assertEquals(0.0f, weights.get(2).get(2), 0.001); + // One-hot: [0, 0, 1] + assertEquals(0.0f, weights.get(3).get(0), 0.001); + assertEquals(0.0f, weights.get(3).get(1), 0.001); + assertEquals(1.0f, weights.get(3).get(2), 0.001); + } + + public void testGeneratePoolingWeights_5SubQueries() { + List> weights = QuerySourceUtil.generatePoolingWeights(5); + assertEquals(6, weights.size()); // 1 equal + 5 one-hot + + // Equal weights: null (omit parameters to trigger default equal-weight behavior) + assertNull(weights.get(0)); + // Verify each one-hot config + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 5; j++) { + float expected = (i == j) ? 1.0f : 0.0f; + assertEquals(expected, weights.get(i + 1).get(j), 0.001); + } + } + } + + public void testGeneratePoolingWeights_ZeroThrows() { + assertThrows(IllegalArgumentException.class, () -> QuerySourceUtil.generatePoolingWeights(0)); + } + + public void testGeneratePoolingWeights_NegativeThrows() { + assertThrows(IllegalArgumentException.class, () -> QuerySourceUtil.generatePoolingWeights(-1)); + } + + public void testGeneratePoolingWeights_ResultIsUnmodifiable() { + List> weights = QuerySourceUtil.generatePoolingWeights(2); + assertThrows(UnsupportedOperationException.class, () -> weights.add(List.of(1.0f))); + // weights.get(0) is null (equal weights), so test unmodifiable on a one-hot config + assertThrows(UnsupportedOperationException.class, () -> weights.get(1).add(1.0f)); + } + + public void testCreatePoolingSearchPipeline_ReturnsCorrectStructure() { + Map result = QuerySourceUtil.createPoolingSearchPipeline(List.of(0.7f, 0.3f), "min_max", "arithmetic_mean"); + + assertNotNull(result); + assertTrue(result.containsKey("phase_results_processors")); + + List processors = (List) result.get("phase_results_processors"); + assertEquals(1, processors.size()); + + Map processorObject = (Map) processors.get(0); + assertTrue(processorObject.containsKey("normalization-processor")); + + Map normalizationProcessor = (Map) processorObject.get("normalization-processor"); + Map normalization = (Map) normalizationProcessor.get("normalization"); + Map combination = (Map) normalizationProcessor.get("combination"); + + assertEquals("min_max", normalization.get("technique")); + assertEquals("arithmetic_mean", combination.get("technique")); + + Map parameters = (Map) combination.get("parameters"); + assertNotNull(parameters); + List weights = (List) parameters.get("weights"); + assertNotNull(weights); + assertEquals(2, weights.size()); + assertEquals(0.7, (double) weights.get(0), 0.001); + assertEquals(0.3, (double) weights.get(1), 0.001); + } + + public void testCreatePoolingSearchPipeline_3Weights() { + Map result = QuerySourceUtil.createPoolingSearchPipeline( + List.of(1.0f, 0.0f, 0.0f), + POOL_NORMALIZATION, + POOL_COMBINATION + ); + + List processors = (List) result.get("phase_results_processors"); + Map processorObject = (Map) processors.get(0); + Map normalizationProcessor = (Map) processorObject.get("normalization-processor"); + Map combination = (Map) normalizationProcessor.get("combination"); + Map parameters = (Map) combination.get("parameters"); + List weights = (List) parameters.get("weights"); + + assertEquals(3, weights.size()); + assertEquals(1.0, (double) weights.get(0), 0.001); + assertEquals(0.0, (double) weights.get(1), 0.001); + assertEquals(0.0, (double) weights.get(2), 0.001); + } + + public void testCreatePoolingSearchPipeline_5EqualWeights() { + Map result = QuerySourceUtil.createPoolingSearchPipeline( + List.of(0.2f, 0.2f, 0.2f, 0.2f, 0.2f), + POOL_NORMALIZATION, + POOL_COMBINATION + ); + + List processors = (List) result.get("phase_results_processors"); + Map processorObject = (Map) processors.get(0); + Map normalizationProcessor = (Map) processorObject.get("normalization-processor"); + Map combination = (Map) normalizationProcessor.get("combination"); + Map parameters = (Map) combination.get("parameters"); + List weights = (List) parameters.get("weights"); + + assertEquals(5, weights.size()); + for (int i = 0; i < 5; i++) { + assertEquals(0.2, (double) weights.get(i), 0.001); + } + } + + private Map buildHybridQuery(int numSubQueries) { + List> queries = new ArrayList<>(); + for (int i = 0; i < numSubQueries; i++) { + queries.add(new HashMap<>()); + } + Map hybridMap = new HashMap<>(); + hybridMap.put("queries", queries); + Map queryWrapper = new HashMap<>(); + queryWrapper.put("hybrid", hybridMap); + Map fullQuery = new HashMap<>(); + fullQuery.put("query", queryWrapper); + return fullQuery; + } } diff --git a/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorExpandCoverageTests.java b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorExpandCoverageTests.java new file mode 100644 index 00000000..7d6cf1d5 --- /dev/null +++ b/src/test/java/org/opensearch/searchrelevance/judgments/LlmJudgmentsProcessorExpandCoverageTests.java @@ -0,0 +1,221 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.searchrelevance.judgments; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.lucene.search.TotalHits; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.searchrelevance.dao.JudgmentCacheDao; +import org.opensearch.searchrelevance.dao.QuerySetDao; +import org.opensearch.searchrelevance.dao.SearchConfigurationDao; +import org.opensearch.searchrelevance.ml.MLAccessor; +import org.opensearch.searchrelevance.model.SearchConfiguration; +import org.opensearch.searchrelevance.model.builder.SearchRequestBuilder; +import org.opensearch.searchrelevance.settings.SearchRelevanceSettingsAccessor; +import org.opensearch.searchrelevance.stats.events.EventStatsManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; + +import lombok.SneakyThrows; + +/** + * Unit tests for expandCoverage pooling search behavior in LlmJudgmentsProcessor. + * Verifies that expandCoverage=true triggers N+1 searches (1 equal + N one-hot) + * and expandCoverage=false triggers exactly 1 search. + */ +public class LlmJudgmentsProcessorExpandCoverageTests extends OpenSearchTestCase { + + @Mock + private MLAccessor mockMLAccessor; + @Mock + private QuerySetDao mockQuerySetDao; + @Mock + private SearchConfigurationDao mockSearchConfigurationDao; + @Mock + private JudgmentCacheDao mockJudgmentCacheDao; + @Mock + private Client mockClient; + @Mock + private SearchRelevanceSettingsAccessor mockSettingsAccessor; + + private ThreadPool threadPool; + private LlmJudgmentsProcessor processor; + + private static final String HYBRID_QUERY_2_SUBQUERIES = + "{\"query\":{\"hybrid\":{\"queries\":[{\"match\":{\"title\":\"test\"}},{\"neural\":{\"embedding\":{\"query_text\":\"test\",\"model_id\":\"m1\",\"k\":5}}}]}}}"; + private static final String NON_HYBRID_QUERY = "{\"query\":{\"match\":{\"title\":\"test\"}}}"; + + @Override + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool("test"); + + when(mockSettingsAccessor.isStatsEnabled()).thenReturn(false); + EventStatsManager eventStatsManager = EventStatsManager.instance(); + eventStatsManager.initialize(mockSettingsAccessor); + + // Initialize SearchRequestBuilder with EMPTY registry for unit tests. + // getSubQueryCount() will try typed parsing first (fails with EMPTY), then fall back to map-based. + SearchRequestBuilder.initialize(NamedXContentRegistry.EMPTY); + + processor = new LlmJudgmentsProcessor( + mockMLAccessor, + mockQuerySetDao, + mockSearchConfigurationDao, + mockJudgmentCacheDao, + mockClient, + threadPool + ); + } + + @Override + @SneakyThrows + public void tearDown() { + threadPool.shutdown(); + super.tearDown(); + } + + private SearchResponse createMockSearchResponse(String... docIds) { + SearchHit[] hits = new SearchHit[docIds.length]; + for (int i = 0; i < docIds.length; i++) { + hits[i] = new SearchHit(i, docIds[i], java.util.Map.of(), java.util.Map.of()); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(docIds.length, TotalHits.Relation.EQUAL_TO), 1.0f); + InternalSearchResponse internalResponse = new InternalSearchResponse(searchHits, null, null, null, false, null, 1); + return new SearchResponse(internalResponse, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY); + } + + private SearchConfiguration createMockConfig(String query) { + // SearchConfiguration fields: id, name, timestamp, index, query, searchPipeline, description + SearchConfiguration config = new SearchConfiguration( + "test-config-id", + "test-config", + "2026-01-01T00:00:00Z", + "test-index", + query, + null, // no search pipeline + "test description" + ); + return config; + } + + @SneakyThrows + public void testExpandCoverage_TwoSubqueries_Triggers3Searches() { + AtomicInteger searchCount = new AtomicInteger(0); + + doAnswer(invocation -> { + int count = searchCount.incrementAndGet(); + ActionListener listener = invocation.getArgument(1); + // Return different doc per search to verify unique collection + listener.onResponse(createMockSearchResponse("doc_from_search_" + count)); + return null; + }).when(mockClient).search(any(), any()); + + SearchConfiguration config = createMockConfig(HYBRID_QUERY_2_SUBQUERIES); + ConcurrentMap allHits = new ConcurrentHashMap<>(); + + processor.processSearchConfigurationsAsync( + List.of(config), + "test query", + 10, + allHits, + false, + true // expandCoverage=true + ); + + // 2 sub-queries → 1 equal-weight + 2 one-hot = 3 searches + assertEquals("Expected 3 searches for 2-subquery hybrid with expandCoverage", 3, searchCount.get()); + assertEquals("Expected 3 unique docs from 3 searches", 3, allHits.size()); + } + + @SneakyThrows + public void testNoExpandCoverage_Triggers1Search() { + AtomicInteger searchCount = new AtomicInteger(0); + + doAnswer(invocation -> { + searchCount.incrementAndGet(); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(createMockSearchResponse("doc1", "doc2")); + return null; + }).when(mockClient).search(any(), any()); + + SearchConfiguration config = createMockConfig(NON_HYBRID_QUERY); + ConcurrentMap allHits = new ConcurrentHashMap<>(); + + processor.processSearchConfigurationsAsync( + List.of(config), + "test query", + 10, + allHits, + false, + false // expandCoverage=false + ); + + assertEquals("Expected 1 search without expandCoverage", 1, searchCount.get()); + assertEquals("Expected 2 docs from single search", 2, allHits.size()); + } + + public void testExpandCoverage_NonHybridQuery_ThrowsValidationError() { + SearchConfiguration config = createMockConfig(NON_HYBRID_QUERY); + ConcurrentMap allHits = new ConcurrentHashMap<>(); + + assertThrows(IllegalArgumentException.class, () -> { + processor.processSearchConfigurationsAsync( + List.of(config), + "test query", + 10, + allHits, + false, + true // expandCoverage=true with non-hybrid + ); + }); + } + + @SneakyThrows + public void testExpandCoverage_DeduplicatesOverlappingDocs() { + AtomicInteger searchCount = new AtomicInteger(0); + + doAnswer(invocation -> { + searchCount.incrementAndGet(); + ActionListener listener = invocation.getArgument(1); + // All 3 searches return overlapping doc sets + listener.onResponse(createMockSearchResponse("doc_common", "doc_unique_" + searchCount.get())); + return null; + }).when(mockClient).search(any(), any()); + + SearchConfiguration config = createMockConfig(HYBRID_QUERY_2_SUBQUERIES); + ConcurrentMap allHits = new ConcurrentHashMap<>(); + + processor.processSearchConfigurationsAsync(List.of(config), "test query", 10, allHits, false, true); + + assertEquals("Expected 3 searches", 3, searchCount.get()); + // 1 common doc + 3 unique docs = 4 total (putIfAbsent deduplicates) + assertEquals("Expected 4 unique docs (1 common + 3 unique)", 4, allHits.size()); + assertTrue("Should contain the common doc", allHits.containsKey("doc_common")); + } +} diff --git a/src/test/java/org/opensearch/searchrelevance/plugin/SearchRelevancePluginTests.java b/src/test/java/org/opensearch/searchrelevance/plugin/SearchRelevancePluginTests.java index 436f54d6..22076265 100644 --- a/src/test/java/org/opensearch/searchrelevance/plugin/SearchRelevancePluginTests.java +++ b/src/test/java/org/opensearch/searchrelevance/plugin/SearchRelevancePluginTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.searchrelevance.common.PluginConstants.EXPERIMENT_INDEX; import static org.opensearch.searchrelevance.common.PluginConstants.JUDGMENT_CACHE_INDEX; +import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_QUERY_SET_MAX_LIMIT; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_ENABLED; import static org.opensearch.searchrelevance.settings.SearchRelevanceSettings.SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL; @@ -149,7 +150,8 @@ public void setUp() throws Exception { SEARCH_RELEVANCE_QUERY_SET_MAX_LIMIT, SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_ENABLED, SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_TIMEOUT, - SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL + SEARCH_RELEVANCE_SCHEDULED_EXPERIMENTS_MINIMUM_INTERVAL, + SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL ) ) ) @@ -229,7 +231,7 @@ public void testQuerySetTransportIsAdded() { public void testGetSettings() { List> settings = plugin.getSettings(); - assertEquals(6, settings.size()); + assertEquals(7, settings.size()); Setting setting0 = settings.get(0); assertEquals("plugins.search_relevance.workbench_enabled", setting0.getKey()); @@ -254,5 +256,43 @@ public void testGetSettings() { Setting setting5 = settings.get(5); assertEquals("plugins.search_relevance.scheduled_experiments_minimum_interval", setting5.getKey()); assertEquals(TimeValue.timeValueSeconds(1), setting5.get(Settings.EMPTY)); + + Setting setting6 = settings.get(6); + assertEquals("plugins.search_relevance.judgment_cache.ttl", setting6.getKey()); + assertEquals(TimeValue.MINUS_ONE, setting6.get(Settings.EMPTY)); + } + + public void testJudgmentCacheTtlSetting_DefaultIsDisabled() { + TimeValue value = SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(Settings.EMPTY); + assertEquals(TimeValue.MINUS_ONE, value); + assertTrue("Default TTL should be negative (disabled)", value.millis() < 0); + } + + public void testJudgmentCacheTtlSetting_AcceptsValidValues() { + Settings s90d = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "90d").build(); + assertEquals(TimeValue.timeValueDays(90), SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(s90d)); + + Settings s30d = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "30d").build(); + assertEquals(TimeValue.timeValueDays(30), SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(s30d)); + + Settings s24h = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "24h").build(); + assertEquals(TimeValue.timeValueHours(24), SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(s24h)); + + Settings sDisabled = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "-1").build(); + assertEquals(TimeValue.MINUS_ONE, SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(sDisabled)); + } + + public void testJudgmentCacheTtlSetting_RejectsInvalidFormat() { + assertThrows(IllegalArgumentException.class, () -> { + Settings s = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "abc").build(); + SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(s); + }); + } + + public void testJudgmentCacheTtlSetting_RejectsMaliciousInput() { + assertThrows(IllegalArgumentException.class, () -> { + Settings s = Settings.builder().put("plugins.search_relevance.judgment_cache.ttl", "123-dsf443f-df%@#@34").build(); + SEARCH_RELEVANCE_JUDGMENT_CACHE_TTL.get(s); + }); } } diff --git a/src/test/resources/llmjudgment/CreateLlmJudgmentExpandCoverage.json b/src/test/resources/llmjudgment/CreateLlmJudgmentExpandCoverage.json new file mode 100644 index 00000000..aac588cf --- /dev/null +++ b/src/test/resources/llmjudgment/CreateLlmJudgmentExpandCoverage.json @@ -0,0 +1,12 @@ +{ + "name": "LLM Judgment Expand Coverage", + "type": "LLM_JUDGMENT", + "querySetId": "{{querySetId}}", + "searchConfigurationList": ["{{searchConfigId}}"], + "modelId": "test_model_id", + "size": 5, + "tokenLimit": 4000, + "contextFields": ["name", "description"], + "ignoreFailure": false, + "expandCoverage": true +} diff --git a/src/test/resources/llmjudgment/CreateSearchConfigurationHybrid.json b/src/test/resources/llmjudgment/CreateSearchConfigurationHybrid.json new file mode 100644 index 00000000..fc046270 --- /dev/null +++ b/src/test/resources/llmjudgment/CreateSearchConfigurationHybrid.json @@ -0,0 +1,6 @@ +{ + "name": "Hybrid Search Configuration", + "description": "Hybrid search with BM25 and neural sub-queries", + "index": "{{index}}", + "query": "{\"query\": {\"hybrid\": {\"queries\": [{\"match\": {\"name\": {\"query\": \"%SearchText%\"}}}, {\"match\": {\"description\": {\"query\": \"%SearchText%\"}}}]}}}" +}