1
1
//! RaBitQ implementation.
2
2
3
3
use core:: f32;
4
+ use std:: collections:: BinaryHeap ;
4
5
use std:: path:: Path ;
5
6
6
7
use log:: debug;
@@ -9,6 +10,7 @@ use serde::{Deserialize, Serialize};
9
10
10
11
use crate :: consts:: { DEFAULT_X_DOT_PRODUCT , EPSILON , THETA_LOG_DIM , WINDOWS_SIZE } ;
11
12
use crate :: metrics:: METRICS ;
13
+ use crate :: order:: Ord32 ;
12
14
use crate :: utils:: {
13
15
gen_random_bias, gen_random_qr_orthogonal, matrix_from_fvecs, read_u64_vecs, read_vecs,
14
16
write_matrix, write_vecs,
@@ -267,7 +269,7 @@ impl RaBitQ {
267
269
read_u64_vecs ( & path. join ( "x_binary_vec.u64vecs" ) ) . expect ( "open x_binary_vec error" ) ;
268
270
269
271
let dim = orthogonal. nrows ( ) ;
270
- let base = matrix_from_fvecs ( & path. join ( "base.fvecs" ) ) ;
272
+ let base = matrix_from_fvecs ( & path. join ( "base.fvecs" ) ) . transpose ( ) ;
271
273
272
274
Self {
273
275
dim : dim as u32 ,
@@ -288,7 +290,8 @@ impl RaBitQ {
288
290
/// Dump to dir.
289
291
pub fn dump_to_dir ( & self , path : & Path ) {
290
292
std:: fs:: create_dir_all ( path) . expect ( "create dir error" ) ;
291
- write_matrix ( & path. join ( "base.fvecs" ) , & self . base . as_view ( ) ) . expect ( "write base error" ) ;
293
+ write_matrix ( & path. join ( "base.fvecs" ) , & self . base . transpose ( ) . as_view ( ) )
294
+ . expect ( "write base error" ) ;
292
295
write_matrix ( & path. join ( "orthogonal.fvecs" ) , & self . orthogonal . as_view ( ) )
293
296
. expect ( "write orthogonal error" ) ;
294
297
write_matrix ( & path. join ( "centroids.fvecs" ) , & self . centroids . as_view ( ) )
@@ -423,7 +426,13 @@ impl RaBitQ {
423
426
}
424
427
425
428
/// Query the topk nearest neighbors for the given query.
426
- pub fn query ( & self , query : & DVectorView < f32 > , probe : usize , topk : usize ) -> Vec < ( f32 , u32 ) > {
429
+ pub fn query (
430
+ & self ,
431
+ query : & DVectorView < f32 > ,
432
+ probe : usize ,
433
+ topk : usize ,
434
+ heuristic_rank : bool ,
435
+ ) -> Vec < ( f32 , u32 ) > {
427
436
let y_projected = query. tr_mul ( & self . orthogonal ) . transpose ( ) ;
428
437
let k = self . centroids . shape ( ) . 1 ;
429
438
let mut lists = Vec :: with_capacity ( k) ;
@@ -476,11 +485,53 @@ impl RaBitQ {
476
485
}
477
486
}
478
487
479
- self . rerank ( query, & rough_distances, topk)
488
+ if heuristic_rank {
489
+ self . heuristic_re_rank ( query, & rough_distances, topk)
490
+ } else {
491
+ self . re_rank ( query, & rough_distances, topk)
492
+ }
493
+ }
494
+
495
+ /// BinaryHeap based re-rank with Ord32.
496
+ fn re_rank (
497
+ & self ,
498
+ query : & DVectorView < f32 > ,
499
+ rough_distances : & [ ( f32 , u32 ) ] ,
500
+ topk : usize ,
501
+ ) -> Vec < ( f32 , u32 ) > {
502
+ let mut threshold = f32:: MAX ;
503
+ let mut precise = 0 ;
504
+ let mut residual = DVector :: < f32 > :: zeros ( self . dim as usize ) ;
505
+ let mut heap: BinaryHeap < ( Ord32 , u32 ) > = BinaryHeap :: with_capacity ( topk) ;
506
+ for & ( rough, u) in rough_distances. iter ( ) {
507
+ if rough < threshold {
508
+ let accurate = l2_squared_distance (
509
+ & self . base . column ( u as usize ) ,
510
+ & query. as_view ( ) ,
511
+ & mut residual,
512
+ ) ;
513
+ precise += 1 ;
514
+ if accurate < threshold {
515
+ heap. push ( ( accurate. into ( ) , self . map_ids [ u as usize ] ) ) ;
516
+ if heap. len ( ) > topk {
517
+ heap. pop ( ) ;
518
+ }
519
+ if heap. len ( ) == topk {
520
+ threshold = heap. peek ( ) . unwrap ( ) . 0 . into ( ) ;
521
+ }
522
+ }
523
+ }
524
+ }
525
+
526
+ METRICS . add_precise_count ( precise) ;
527
+ METRICS . add_rough_count ( rough_distances. len ( ) as u64 ) ;
528
+ METRICS . add_query_count ( 1 ) ;
529
+
530
+ heap. into_iter ( ) . map ( |( a, b) | ( a. into ( ) , b) ) . collect ( )
480
531
}
481
532
482
- /// Rerank the topk nearest neighbors .
483
- fn rerank (
533
+ /// Heuristic re-rank with a fixed windows size to update the threshold .
534
+ fn heuristic_re_rank (
484
535
& self ,
485
536
query : & DVectorView < f32 > ,
486
537
rough_distances : & [ ( f32 , u32 ) ] ,
@@ -491,18 +542,20 @@ impl RaBitQ {
491
542
let mut res = Vec :: with_capacity ( topk) ;
492
543
let mut count = 0 ;
493
544
let mut residual = DVector :: < f32 > :: zeros ( self . dim as usize ) ;
545
+ let mut precise = 0 ;
494
546
for & ( rough, u) in rough_distances. iter ( ) {
495
547
if rough < threshold {
496
548
let accurate = l2_squared_distance (
497
549
& self . base . column ( u as usize ) ,
498
550
& query. as_view ( ) ,
499
551
& mut residual,
500
552
) ;
553
+ precise += 1 ;
501
554
if accurate < threshold {
502
555
res. push ( ( accurate, self . map_ids [ u as usize ] ) ) ;
503
556
count += 1 ;
504
557
recent_max_accurate = recent_max_accurate. max ( accurate) ;
505
- if count = = WINDOWS_SIZE {
558
+ if count > = WINDOWS_SIZE {
506
559
threshold = recent_max_accurate;
507
560
count = 0 ;
508
561
recent_max_accurate = f32:: MIN ;
@@ -511,7 +564,7 @@ impl RaBitQ {
511
564
}
512
565
}
513
566
514
- METRICS . add_precise_count ( res . len ( ) as u64 ) ;
567
+ METRICS . add_precise_count ( precise ) ;
515
568
METRICS . add_rough_count ( rough_distances. len ( ) as u64 ) ;
516
569
METRICS . add_query_count ( 1 ) ;
517
570
let length = topk. min ( res. len ( ) ) ;
0 commit comments