Skip to content

Commit 8604663

Browse files
authored
fix #911. Threshold 0 matches prefixes correctly (#923)
1 parent d01e436 commit 8604663

File tree

5 files changed

+228
-141
lines changed

5 files changed

+228
-141
lines changed

packages/orama/src/components/algorithms.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,17 @@ export function prioritizeTokenScores(
4141
const results = tokenScores.sort((a, b) => b[1] - a[1])
4242

4343
// If threshold is 1, it means we will return all the results with at least one search term,
44-
// prioritizig the ones that contains more search terms (fuzzy match)
44+
// prioritizing the ones that contains more search terms (fuzzy match)
4545
if (threshold === 1) {
4646
return results
4747
}
4848

49+
// For threshold = 0 when keywordsCount is 1 (single term search),
50+
// we return all matches since they automatically contain 100% of keywords
51+
if (threshold === 0 && keywordsCount === 1) {
52+
return results
53+
}
54+
4955
// Prepare keywords count tracking for threshold handling
5056
const allResults = results.length
5157
const tokenScoreWithKeywordsCount: [InternalDocumentID, number, number][] = []
@@ -104,7 +110,7 @@ export function prioritizeTokenScores(
104110
const thresholdLength =
105111
lastTokenWithAllKeywords + Math.ceil((threshold * 100 * (allResults - lastTokenWithAllKeywords)) / 100)
106112

107-
return resultsWithIdAndScore.slice(0, allResults + thresholdLength)
113+
return resultsWithIdAndScore.slice(0, Math.min(allResults, thresholdLength))
108114
}
109115

110116
export function BM25(

packages/orama/src/components/index.ts

Lines changed: 110 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ export function create<T extends AnyOrama, TSchema extends T['schema']>(
170170
index.vectorIndexes[path] = {
171171
type: 'Vector',
172172
node: new VectorIndex(getVectorSize(type)),
173-
isArray: false,
173+
isArray: false
174174
}
175175
} else {
176176
const isArray = /\[/.test(type as string)
@@ -273,7 +273,16 @@ export function insert(
273273
return insertVector(index, prop, value as number[] | Float32Array, id, internalId)
274274
}
275275

276-
const insertScalar = insertScalarBuilder(implementation, index, prop, internalId, language, tokenizer, docsCount, options)
276+
const insertScalar = insertScalarBuilder(
277+
implementation,
278+
index,
279+
prop,
280+
internalId,
281+
language,
282+
tokenizer,
283+
docsCount,
284+
options
285+
)
277286

278287
if (!isArrayType(schemaType)) {
279288
return insertScalar(value)
@@ -286,7 +295,13 @@ export function insert(
286295
}
287296
}
288297

289-
export function insertVector(index: AnyIndexStore, prop: string, value: number[] | VectorType, id: DocumentID, internalDocumentId: InternalDocumentID): void {
298+
export function insertVector(
299+
index: AnyIndexStore,
300+
prop: string,
301+
value: number[] | VectorType,
302+
id: DocumentID,
303+
internalDocumentId: InternalDocumentID
304+
): void {
290305
index.vectorIndexes[prop].node.add(internalDocumentId, value)
291306
}
292307

@@ -372,7 +387,18 @@ export function remove(
372387
const elements = value as Array<string | number | boolean>
373388
const elementsLength = elements.length
374389
for (let i = 0; i < elementsLength; i++) {
375-
removeScalar(implementation, index, prop, id, internalId, elements[i], innerSchemaType, language, tokenizer, docsCount)
390+
removeScalar(
391+
implementation,
392+
index,
393+
prop,
394+
id,
395+
internalId,
396+
elements[i],
397+
innerSchemaType,
398+
language,
399+
tokenizer,
400+
docsCount
401+
)
376402
}
377403

378404
return true
@@ -396,7 +422,7 @@ export function calculateResultScores(
396422
const fieldLengths = index.fieldLengths[prop]
397423
const oramaOccurrences = index.tokenOccurrences[prop]
398424
const oramaFrequencies = index.frequencies[prop]
399-
425+
400426
// oramaOccurrences[term] can be undefined, 0, string, or { [k: string]: number }
401427
const termOccurrences = typeof oramaOccurrences[term] === 'number' ? oramaOccurrences[term] ?? 0 : 0
402428

@@ -417,14 +443,7 @@ export function calculateResultScores(
417443

418444
const tf = oramaFrequencies?.[internalId]?.[term] ?? 0
419445

420-
const bm25 = BM25(
421-
tf,
422-
termOccurrences,
423-
docsCount,
424-
fieldLengths[internalId]!,
425-
avgFieldLength,
426-
bm25Relevance,
427-
)
446+
const bm25 = BM25(tf, termOccurrences, docsCount, fieldLengths[internalId]!, avgFieldLength, bm25Relevance)
428447

429448
if (resultsMap.has(internalId)) {
430449
resultsMap.set(internalId, resultsMap.get(internalId)! + bm25 * boostPerProperty)
@@ -434,46 +453,6 @@ export function calculateResultScores(
434453
}
435454
}
436455

437-
function searchInProperty(
438-
index: Index,
439-
tree: RadixTree,
440-
prop: string,
441-
tokens: string[],
442-
exact: boolean,
443-
tolerance: number,
444-
resultsMap: Map<number, number>,
445-
boostPerProperty: number,
446-
bm25Relevance: Required<BM25Params>,
447-
docsCount: number,
448-
whereFiltersIDs: Set<InternalDocumentID> | undefined,
449-
keywordMatchesMap: Map<InternalDocumentID, Map<string, number>>
450-
) {
451-
const tokenLength = tokens.length;
452-
for (let i = 0; i < tokenLength; i++) {
453-
const term = tokens[i];
454-
const searchResult = tree.find({ term, exact, tolerance })
455-
456-
const termsFound = Object.keys(searchResult)
457-
const termsFoundLength = termsFound.length;
458-
for (let j = 0; j < termsFoundLength; j++) {
459-
const word = termsFound[j]
460-
const ids = searchResult[word]
461-
calculateResultScores(
462-
index,
463-
prop,
464-
word,
465-
ids,
466-
docsCount,
467-
bm25Relevance,
468-
resultsMap,
469-
boostPerProperty,
470-
whereFiltersIDs,
471-
keywordMatchesMap,
472-
)
473-
}
474-
}
475-
}
476-
477456
export function search(
478457
index: Index,
479458
term: string,
@@ -486,13 +465,15 @@ export function search(
486465
relevance: Required<BM25Params>,
487466
docsCount: number,
488467
whereFiltersIDs: Set<InternalDocumentID> | undefined,
489-
threshold = 0,
468+
threshold = 0
490469
): TokenScore[] {
491470
const tokens = tokenizer.tokenize(term, language)
492471
const keywordsCount = tokens.length || 1
493472

494473
// Track keyword matches per document and property
495474
const keywordMatchesMap = new Map<InternalDocumentID, Map<string, number>>()
475+
// Track which tokens were found in the search
476+
const tokenFoundMap = new Map<string, boolean>()
496477
const resultsMap = new Map<number, number>()
497478

498479
for (const prop of propertiesToSearch) {
@@ -515,20 +496,37 @@ export function search(
515496
tokens.push('')
516497
}
517498

518-
searchInProperty(
519-
index,
520-
tree.node,
521-
prop,
522-
tokens,
523-
exact,
524-
tolerance,
525-
resultsMap,
526-
boostPerProperty,
527-
relevance,
528-
docsCount,
529-
whereFiltersIDs,
530-
keywordMatchesMap
531-
)
499+
// Process each token in the search term
500+
const tokenLength = tokens.length
501+
for (let i = 0; i < tokenLength; i++) {
502+
const token = tokens[i]
503+
const searchResult = tree.node.find({ term: token, exact, tolerance })
504+
505+
// See if this token was found (for threshold=0 filtering)
506+
const termsFound = Object.keys(searchResult)
507+
if (termsFound.length > 0) {
508+
tokenFoundMap.set(token, true)
509+
}
510+
511+
// Process each matching term
512+
const termsFoundLength = termsFound.length
513+
for (let j = 0; j < termsFoundLength; j++) {
514+
const word = termsFound[j]
515+
const ids = searchResult[word]
516+
calculateResultScores(
517+
index,
518+
prop,
519+
word,
520+
ids,
521+
docsCount,
522+
relevance,
523+
resultsMap,
524+
boostPerProperty,
525+
whereFiltersIDs,
526+
keywordMatchesMap
527+
)
528+
}
529+
}
532530
}
533531

534532
// Convert to array and sort by score
@@ -545,20 +543,42 @@ export function search(
545543
return results
546544
}
547545

546+
// For threshold=0, check if all tokens were found
547+
if (threshold === 0) {
548+
// Quick return for single tokens - already validated
549+
if (keywordsCount === 1) {
550+
return results
551+
}
552+
553+
// For multiple tokens, verify that ALL tokens were found
554+
// If any token wasn't found, return an empty result
555+
for (const token of tokens) {
556+
if (!tokenFoundMap.get(token)) {
557+
return []
558+
}
559+
}
560+
561+
// Find documents that have all keywords in at least one property
562+
const fullMatches = results.filter(([id]) => {
563+
const propertyMatches = keywordMatchesMap.get(id)
564+
if (!propertyMatches) return false
565+
566+
// Check if any property has all keywords
567+
return Array.from(propertyMatches.values()).some((matches) => matches === keywordsCount)
568+
})
569+
570+
return fullMatches
571+
}
572+
548573
// Find documents that have all keywords in at least one property
549574
const fullMatches = results.filter(([id]) => {
550575
const propertyMatches = keywordMatchesMap.get(id)
551576
if (!propertyMatches) return false
552-
577+
553578
// Check if any property has all keywords
554-
return Array.from(propertyMatches.values()).some(matches => matches === keywordsCount)
579+
return Array.from(propertyMatches.values()).some((matches) => matches === keywordsCount)
555580
})
556581

557-
// If threshold is 0, return only full matches
558-
if (threshold === 0) {
559-
return fullMatches
560-
}
561-
562582
// If we have full matches and threshold < 1, return full matches plus a percentage of partial matches
563583
if (fullMatches.length > 0) {
564584
const remainingResults = results.filter(([id]) => !fullMatches.some(([fid]) => fid === id))
@@ -656,9 +676,11 @@ export function searchByWhereClause<T extends AnyOrama>(
656676
}
657677

658678
if (type === 'Flat') {
659-
const results = new Set(isArray
660-
? node.filterArr(operation as EnumArrComparisonOperator)
661-
: node.filter(operation as EnumComparisonOperator))
679+
const results = new Set(
680+
isArray
681+
? node.filterArr(operation as EnumArrComparisonOperator)
682+
: node.filter(operation as EnumComparisonOperator)
683+
)
662684

663685
filtersMap[param] = setUnion(filtersMap[param], results)
664686

@@ -668,7 +690,7 @@ export function searchByWhereClause<T extends AnyOrama>(
668690
if (type === 'AVL') {
669691
const operationOpt = operationKeys[0] as keyof ComparisonOperator
670692
const operationValue = (operation as ComparisonOperator)[operationOpt]
671-
let filteredIDs: Set<InternalDocumentID>
693+
let filteredIDs: Set<InternalDocumentID>
672694

673695
switch (operationOpt) {
674696
case 'gt': {
@@ -818,12 +840,7 @@ export function save<R = unknown>(index: Index): R {
818840
const savedIndexes: any = {}
819841
for (const name of Object.keys(indexes)) {
820842
const { type, node, isArray } = indexes[name]
821-
if (type === 'Flat'
822-
|| type === 'Radix'
823-
|| type === 'AVL'
824-
|| type === 'BKD'
825-
|| type === 'Bool'
826-
) {
843+
if (type === 'Flat' || type === 'Radix' || type === 'AVL' || type === 'BKD' || type === 'Bool') {
827844
savedIndexes[name] = {
828845
type,
829846
node: node.toJSON(),
@@ -866,7 +883,10 @@ export function createIndex(): IIndex<Index> {
866883
}
867884
}
868885

869-
function addGeoResult(set: Set<InternalDocumentID> | undefined, ids: Array<{ docIDs: InternalDocumentID[] }>): Set<InternalDocumentID> {
886+
function addGeoResult(
887+
set: Set<InternalDocumentID> | undefined,
888+
ids: Array<{ docIDs: InternalDocumentID[] }>
889+
): Set<InternalDocumentID> {
870890
if (!set) {
871891
set = new Set()
872892
}
@@ -883,7 +903,10 @@ function addGeoResult(set: Set<InternalDocumentID> | undefined, ids: Array<{ doc
883903
return set
884904
}
885905

886-
function addFindResult(set: Set<InternalDocumentID> | undefined, filteredIDsResults: FindResult): Set<InternalDocumentID> {
906+
function addFindResult(
907+
set: Set<InternalDocumentID> | undefined,
908+
filteredIDsResults: FindResult
909+
): Set<InternalDocumentID> {
887910
if (!set) {
888911
set = new Set()
889912
}

0 commit comments

Comments
 (0)