Skip to content

Commit 8b99250

Browse files
committed
add binary heap based rerank with ord f32
Signed-off-by: Keming <[email protected]>
1 parent a0f6d84 commit 8b99250

File tree

4 files changed

+108
-9
lines changed

4 files changed

+108
-9
lines changed

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#![forbid(missing_docs)]
44
mod consts;
55
pub mod metrics;
6+
mod order;
67
pub mod rabitq;
78
pub mod simd;
89
pub mod utils;

src/main.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ struct Args {
3232
/// saved directory
3333
#[argh(option, short = 's')]
3434
saved: String,
35+
/// heuristic re-rank
36+
#[argh(switch, short = 'h')]
37+
heuristic_rank: bool,
3538
}
3639

3740
fn main() {
@@ -65,7 +68,12 @@ fn main() {
6568
for (i, query) in queries.iter().enumerate() {
6669
let query_vec = dvector_from_vec(query.clone());
6770
let start_time = Instant::now();
68-
let res = rabitq.query(&query_vec.as_view(), args.probe, args.topk);
71+
let res = rabitq.query(
72+
&query_vec.as_view(),
73+
args.probe,
74+
args.topk,
75+
args.heuristic_rank,
76+
);
6977
total_time += start_time.elapsed().as_secs_f64();
7078
let ids: Vec<i32> = res.iter().map(|(_, id)| *id as i32).collect();
7179
recall += calculate_recall(&truth[i], &ids, args.topk);

src/order.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//! f32 stored as i32 to make it comparable and faster to compare.
2+
3+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
4+
#[repr(transparent)]
5+
pub struct Ord32(i32);
6+
7+
impl Ord32 {
8+
#[inline]
9+
pub fn from_f32(x: f32) -> Self {
10+
let bits = x.to_bits() as i32;
11+
let mask = ((bits >> 31) as u32) >> 1;
12+
let res = bits ^ (mask as i32);
13+
Self(res)
14+
}
15+
16+
#[inline]
17+
pub fn to_f32(self) -> f32 {
18+
let bits = self.0;
19+
let mask = ((bits >> 31) as u32) >> 1;
20+
let res = bits ^ (mask as i32);
21+
f32::from_bits(res as u32)
22+
}
23+
}
24+
25+
impl From<f32> for Ord32 {
26+
#[inline]
27+
fn from(x: f32) -> Self {
28+
Self::from_f32(x)
29+
}
30+
}
31+
32+
impl From<Ord32> for f32 {
33+
#[inline]
34+
fn from(x: Ord32) -> Self {
35+
x.to_f32()
36+
}
37+
}

src/rabitq.rs

+61-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! RaBitQ implementation.
22
33
use core::f32;
4+
use std::collections::BinaryHeap;
45
use std::path::Path;
56

67
use log::debug;
@@ -9,6 +10,7 @@ use serde::{Deserialize, Serialize};
910

1011
use crate::consts::{DEFAULT_X_DOT_PRODUCT, EPSILON, THETA_LOG_DIM, WINDOWS_SIZE};
1112
use crate::metrics::METRICS;
13+
use crate::order::Ord32;
1214
use crate::utils::{
1315
gen_random_bias, gen_random_qr_orthogonal, matrix_from_fvecs, read_u64_vecs, read_vecs,
1416
write_matrix, write_vecs,
@@ -267,7 +269,7 @@ impl RaBitQ {
267269
read_u64_vecs(&path.join("x_binary_vec.u64vecs")).expect("open x_binary_vec error");
268270

269271
let dim = orthogonal.nrows();
270-
let base = matrix_from_fvecs(&path.join("base.fvecs"));
272+
let base = matrix_from_fvecs(&path.join("base.fvecs")).transpose();
271273

272274
Self {
273275
dim: dim as u32,
@@ -288,7 +290,8 @@ impl RaBitQ {
288290
/// Dump to dir.
289291
pub fn dump_to_dir(&self, path: &Path) {
290292
std::fs::create_dir_all(path).expect("create dir error");
291-
write_matrix(&path.join("base.fvecs"), &self.base.as_view()).expect("write base error");
293+
write_matrix(&path.join("base.fvecs"), &self.base.transpose().as_view())
294+
.expect("write base error");
292295
write_matrix(&path.join("orthogonal.fvecs"), &self.orthogonal.as_view())
293296
.expect("write orthogonal error");
294297
write_matrix(&path.join("centroids.fvecs"), &self.centroids.as_view())
@@ -423,7 +426,13 @@ impl RaBitQ {
423426
}
424427

425428
/// Query the topk nearest neighbors for the given query.
426-
pub fn query(&self, query: &DVectorView<f32>, probe: usize, topk: usize) -> Vec<(f32, u32)> {
429+
pub fn query(
430+
&self,
431+
query: &DVectorView<f32>,
432+
probe: usize,
433+
topk: usize,
434+
heuristic_rank: bool,
435+
) -> Vec<(f32, u32)> {
427436
let y_projected = query.tr_mul(&self.orthogonal).transpose();
428437
let k = self.centroids.shape().1;
429438
let mut lists = Vec::with_capacity(k);
@@ -476,11 +485,53 @@ impl RaBitQ {
476485
}
477486
}
478487

479-
self.rerank(query, &rough_distances, topk)
488+
if heuristic_rank {
489+
self.heuristic_re_rank(query, &rough_distances, topk)
490+
} else {
491+
self.re_rank(query, &rough_distances, topk)
492+
}
493+
}
494+
495+
/// BinaryHeap based re-rank with Ord32.
496+
fn re_rank(
497+
&self,
498+
query: &DVectorView<f32>,
499+
rough_distances: &[(f32, u32)],
500+
topk: usize,
501+
) -> Vec<(f32, u32)> {
502+
let mut threshold = f32::MAX;
503+
let mut precise = 0;
504+
let mut residual = DVector::<f32>::zeros(self.dim as usize);
505+
let mut heap: BinaryHeap<(Ord32, u32)> = BinaryHeap::with_capacity(topk);
506+
for &(rough, u) in rough_distances.iter() {
507+
if rough < threshold {
508+
let accurate = l2_squared_distance(
509+
&self.base.column(u as usize),
510+
&query.as_view(),
511+
&mut residual,
512+
);
513+
precise += 1;
514+
if accurate < threshold {
515+
heap.push((accurate.into(), self.map_ids[u as usize]));
516+
if heap.len() > topk {
517+
heap.pop();
518+
}
519+
if heap.len() == topk {
520+
threshold = heap.peek().unwrap().0.into();
521+
}
522+
}
523+
}
524+
}
525+
526+
METRICS.add_precise_count(precise);
527+
METRICS.add_rough_count(rough_distances.len() as u64);
528+
METRICS.add_query_count(1);
529+
530+
heap.into_iter().map(|(a, b)| (a.into(), b)).collect()
480531
}
481532

482-
/// Rerank the topk nearest neighbors.
483-
fn rerank(
533+
/// Heuristic re-rank with a fixed windows size to update the threshold.
534+
fn heuristic_re_rank(
484535
&self,
485536
query: &DVectorView<f32>,
486537
rough_distances: &[(f32, u32)],
@@ -491,18 +542,20 @@ impl RaBitQ {
491542
let mut res = Vec::with_capacity(topk);
492543
let mut count = 0;
493544
let mut residual = DVector::<f32>::zeros(self.dim as usize);
545+
let mut precise = 0;
494546
for &(rough, u) in rough_distances.iter() {
495547
if rough < threshold {
496548
let accurate = l2_squared_distance(
497549
&self.base.column(u as usize),
498550
&query.as_view(),
499551
&mut residual,
500552
);
553+
precise += 1;
501554
if accurate < threshold {
502555
res.push((accurate, self.map_ids[u as usize]));
503556
count += 1;
504557
recent_max_accurate = recent_max_accurate.max(accurate);
505-
if count == WINDOWS_SIZE {
558+
if count >= WINDOWS_SIZE {
506559
threshold = recent_max_accurate;
507560
count = 0;
508561
recent_max_accurate = f32::MIN;
@@ -511,7 +564,7 @@ impl RaBitQ {
511564
}
512565
}
513566

514-
METRICS.add_precise_count(res.len() as u64);
567+
METRICS.add_precise_count(precise);
515568
METRICS.add_rough_count(rough_distances.len() as u64);
516569
METRICS.add_query_count(1);
517570
let length = topk.min(res.len());

0 commit comments

Comments
 (0)