|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package org.opensearch.knn.index.query.scorers; |
| 7 | + |
| 8 | +import org.apache.lucene.search.ConjunctionUtils; |
| 9 | +import org.apache.lucene.search.DocIdSetIterator; |
| 10 | +import org.apache.lucene.search.VectorScorer; |
| 11 | +import org.apache.lucene.util.BitSet; |
| 12 | +import org.opensearch.common.Nullable; |
| 13 | + |
| 14 | +import java.util.Arrays; |
| 15 | + |
| 16 | +import java.io.IOException; |
| 17 | + |
| 18 | +/** |
| 19 | + * A {@link VectorScorer} decorator for nested (parent-child) document structures that groups |
| 20 | + * child documents by their parent and yields only the highest-scoring child per parent. |
| 21 | + * |
| 22 | + * <p>This is adapted from Lucene's {@code DiversifyingChildrenVectorScorer} inner class in |
| 23 | + * {@code DiversifyingChildrenFloatKnnVectorQuery}, re-implemented as a standalone {@link VectorScorer} |
| 24 | + * so it can be used in OpenSearch's exact search path. |
| 25 | + * |
| 26 | + * <h2>Document Layout</h2> |
| 27 | + * <p>Lucene block-joins store parent and child documents in contiguous doc-id ranges: |
| 28 | + * <pre> |
| 29 | + * [child_0, child_1, ..., child_n, PARENT, child_0, child_1, ..., child_m, PARENT, ...] |
| 30 | + * </pre> |
| 31 | + * The {@code parentBitSet} identifies which doc ids are parents. Every doc id between two |
| 32 | + * consecutive parent bits is a child of the later parent. |
| 33 | + * |
| 34 | + * <h2>Iteration Behavior</h2> |
| 35 | + * <p>Each call to {@link #iterator()}'s {@code nextDoc()} advances through one parent group: |
| 36 | + * <ol> |
| 37 | + * <li>Finds the next child document (respecting the optional filter).</li> |
| 38 | + * <li>Determines the parent for that child via {@code parentBitSet.nextSetBit()}.</li> |
| 39 | + * <li>Iterates over all children belonging to that parent, scoring each one.</li> |
| 40 | + * <li>Returns the doc id of the best-scoring child; {@link #score()} returns its score.</li> |
| 41 | + * </ol> |
| 42 | + * |
| 43 | + * <h2>Filtered vs Unfiltered</h2> |
| 44 | + * <ul> |
| 45 | + * <li><b>Unfiltered</b> ({@code acceptedChildrenIterator == null}): every vector document is |
| 46 | + * considered. The underlying vector iterator drives iteration directly.</li> |
| 47 | + * <li><b>Filtered</b>: the {@code filterIdsIterator} is intersected with the vector |
| 48 | + * iterator via {@link #maybeIntersectWithFilter}, producing a single iterator |
| 49 | + * that yields only doc ids present in both. This keeps the two iterators in lockstep |
| 50 | + * so the vector scorer is always positioned correctly when {@link #score()} is called.</li> |
| 51 | + * </ul> |
| 52 | + * |
| 53 | + * <h2>Example</h2> |
| 54 | + * <p>Given children [0,1,2,3,4] → parent 5, children [6,7,8] → parent 9, child [10] → parent 11, |
| 55 | + * and a filter that excludes children 2 and 7: |
| 56 | + * <pre> |
| 57 | + * Accepted children: {0, 1, 3, 4, 6, 8, 10} |
| 58 | + * |
| 59 | + * nextDoc() → bestChild=1 (best of {0,1,3,4} under parent 5) |
| 60 | + * nextDoc() → bestChild=6 (best of {6,8} under parent 9) |
| 61 | + * nextDoc() → bestChild=10 (only child under parent 11) |
| 62 | + * nextDoc() → NO_MORE_DOCS |
| 63 | + * </pre> |
| 64 | + * |
| 65 | + * @see org.apache.lucene.search.VectorScorer |
| 66 | + * @see org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery |
| 67 | + */ |
| 68 | +class NestedBestChildVectorScorer implements VectorScorer { |
| 69 | + private final VectorScorer childrenVectorScorer; |
| 70 | + private final DocIdSetIterator childIterator; |
| 71 | + private final BitSet parentBitSet; |
| 72 | + private final DocIdSetIterator iterator; |
| 73 | + private int bestChild = -1; |
| 74 | + private float currentScore = Float.NEGATIVE_INFINITY; |
| 75 | + |
| 76 | + /** |
| 77 | + * Creates a scorer that finds the best-scoring child per parent, optionally restricted to a |
| 78 | + * subset of accepted children. |
| 79 | + * |
| 80 | + * <p>When {@code filterIdsIterator} is {@code null} (unfiltered), the scorer's own |
| 81 | + * vector iterator is used to drive child iteration, matching the behavior of Lucene's |
| 82 | + * {@code DiversifyingChildrenVectorScorer} but without requiring a separate filter iterator. |
| 83 | + * |
| 84 | + * @param filterIdsIterator iterator over the accepted child doc ids (i.e. children that |
| 85 | + * pass the filter). Pass {@code null} for the unfiltered case |
| 86 | + * where all vector documents are considered. |
| 87 | + * @param parentBitSet a {@link BitSet} with bits set at every parent doc id. |
| 88 | + * Used to determine parent boundaries for grouping children. |
| 89 | + * @param childrenVectorScorer the underlying scorer that computes similarity scores for |
| 90 | + * individual child documents against the query vector. |
| 91 | + */ |
| 92 | + public NestedBestChildVectorScorer( |
| 93 | + @Nullable DocIdSetIterator filterIdsIterator, |
| 94 | + BitSet parentBitSet, |
| 95 | + VectorScorer childrenVectorScorer |
| 96 | + ) { |
| 97 | + this.childrenVectorScorer = childrenVectorScorer; |
| 98 | + this.parentBitSet = parentBitSet; |
| 99 | + DocIdSetIterator vectorIterator = childrenVectorScorer.iterator(); |
| 100 | + this.childIterator = maybeIntersectWithFilter(vectorIterator, filterIdsIterator); |
| 101 | + this.iterator = createIterator(); |
| 102 | + } |
| 103 | + |
| 104 | + /** |
| 105 | + * Returns the score of the best-scoring child for the current parent group. |
| 106 | + * Only valid after a successful call to {@code iterator().nextDoc()}. |
| 107 | + */ |
| 108 | + @Override |
| 109 | + public float score() throws IOException { |
| 110 | + return currentScore; |
| 111 | + } |
| 112 | + |
| 113 | + /** |
| 114 | + * Returns a {@link DocIdSetIterator} whose {@code nextDoc()} yields the doc id of the |
| 115 | + * best-scoring child for each successive parent. The same instance is returned on every call. |
| 116 | + */ |
| 117 | + @Override |
| 118 | + public DocIdSetIterator iterator() { |
| 119 | + return iterator; |
| 120 | + } |
| 121 | + |
| 122 | + /** |
| 123 | + * Returns the vector iterator directly if no filter is provided, otherwise intersects |
| 124 | + * it with the filter so that only doc ids present in both are yielded. |
| 125 | + */ |
| 126 | + private static DocIdSetIterator maybeIntersectWithFilter( |
| 127 | + DocIdSetIterator vectorIterator, |
| 128 | + @Nullable DocIdSetIterator filterIdsIterator |
| 129 | + ) { |
| 130 | + if (filterIdsIterator == null) { |
| 131 | + return vectorIterator; |
| 132 | + } |
| 133 | + return ConjunctionUtils.intersectIterators(Arrays.asList(filterIdsIterator, vectorIterator)); |
| 134 | + } |
| 135 | + |
| 136 | + /** |
| 137 | + * Creates a {@link DocIdSetIterator} that groups children by parent and yields the |
| 138 | + * best-scoring child per parent. Each {@code nextDoc()} call advances through one |
| 139 | + * parent group and returns the doc id of the highest-scoring child within that group. |
| 140 | + */ |
| 141 | + private DocIdSetIterator createIterator() { |
| 142 | + return new DocIdSetIterator() { |
| 143 | + @Override |
| 144 | + public int docID() { |
| 145 | + return bestChild; |
| 146 | + } |
| 147 | + |
| 148 | + @Override |
| 149 | + public int nextDoc() throws IOException { |
| 150 | + int nextChild = childIterator.docID(); |
| 151 | + if (nextChild == -1) { |
| 152 | + nextChild = childIterator.nextDoc(); |
| 153 | + } |
| 154 | + if (nextChild == NO_MORE_DOCS) { |
| 155 | + bestChild = NO_MORE_DOCS; |
| 156 | + return NO_MORE_DOCS; |
| 157 | + } |
| 158 | + |
| 159 | + currentScore = Float.NEGATIVE_INFINITY; |
| 160 | + int currentParent = parentBitSet.nextSetBit(nextChild); |
| 161 | + |
| 162 | + do { |
| 163 | + float score = childrenVectorScorer.score(); |
| 164 | + if (score > currentScore) { |
| 165 | + bestChild = nextChild; |
| 166 | + currentScore = score; |
| 167 | + } |
| 168 | + } while ((nextChild = childIterator.nextDoc()) != NO_MORE_DOCS && nextChild < currentParent); |
| 169 | + |
| 170 | + return bestChild; |
| 171 | + } |
| 172 | + |
| 173 | + /** |
| 174 | + * Not supported. This iterator returns the best-scoring child per parent group, |
| 175 | + * which requires evaluating <em>all</em> children within a group. Advancing to an |
| 176 | + * arbitrary target could land in the middle of a parent group, making it impossible |
| 177 | + * to consider earlier (potentially higher-scoring) children without backtracking |
| 178 | + * — violating the forward-only iterator contract. |
| 179 | + */ |
| 180 | + @Override |
| 181 | + public int advance(int target) { |
| 182 | + throw new UnsupportedOperationException(); |
| 183 | + } |
| 184 | + |
| 185 | + @Override |
| 186 | + public long cost() { |
| 187 | + return childIterator.cost(); |
| 188 | + } |
| 189 | + }; |
| 190 | + } |
| 191 | +} |
0 commit comments