@@ -104,7 +104,7 @@ fn l2_squared_distance(
104
104
}
105
105
106
106
// Get the min/max value of a vector.
107
- fn min_max ( vec : & DVectorView < f32 > ) -> ( f32 , f32 ) {
107
+ fn min_max_raw ( vec : & DVectorView < f32 > ) -> ( f32 , f32 ) {
108
108
let mut min = f32:: MAX ;
109
109
let mut max = f32:: MIN ;
110
110
for v in vec. iter ( ) {
@@ -118,6 +118,22 @@ fn min_max(vec: &DVectorView<f32>) -> (f32, f32) {
118
118
( min, max)
119
119
}
120
120
121
+ // Interface of `min_max`
122
+ fn min_max ( vec : & DVectorView < f32 > ) -> ( f32 , f32 ) {
123
+ #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
124
+ {
125
+ if is_x86_feature_detected ! ( "avx" ) {
126
+ unsafe { crate :: simd:: min_max_avx ( vec) }
127
+ } else {
128
+ min_max_raw ( vec)
129
+ }
130
+ }
131
+ #[ cfg( not( any( target_arch = "x86_64" , target_arch = "x86" ) ) ) ]
132
+ {
133
+ min_max_raw ( vec)
134
+ }
135
+ }
136
+
121
137
// Quantize the query residual vector.
122
138
fn quantize_query_vector (
123
139
quantized : & mut DVector < u8 > ,
@@ -279,7 +295,7 @@ impl RaBitQ {
279
295
}
280
296
281
297
/// Query the topk nearest neighbors for the given query.
282
- pub fn query_one ( & self , query : & DVector < f32 > , probe : usize , topk : usize ) -> Vec < i32 > {
298
+ pub fn query_one ( & self , query : & DVector < f32 > , probe : usize , topk : usize ) -> Vec < ( f32 , u32 ) > {
283
299
let y_projected = query. tr_mul ( & self . orthogonal ) . transpose ( ) ;
284
300
let k = self . centroids . shape ( ) . 1 ;
285
301
let mut lists = Vec :: with_capacity ( k) ;
@@ -337,7 +353,7 @@ impl RaBitQ {
337
353
query : & DVector < f32 > ,
338
354
rough_distances : & [ ( f32 , u32 ) ] ,
339
355
topk : usize ,
340
- ) -> Vec < i32 > {
356
+ ) -> Vec < ( f32 , u32 ) > {
341
357
let mut threshold = f32:: MAX ;
342
358
let mut recent_max_accurate = f32:: MIN ;
343
359
let mut res = Vec :: with_capacity ( topk) ;
@@ -351,7 +367,7 @@ impl RaBitQ {
351
367
& mut residual,
352
368
) ;
353
369
if accurate < threshold {
354
- res. push ( ( accurate, u as i32 ) ) ;
370
+ res. push ( ( accurate, u) ) ;
355
371
count += 1 ;
356
372
recent_max_accurate = recent_max_accurate. max ( accurate) ;
357
373
if count == WINDOWS_SIZE {
@@ -366,6 +382,7 @@ impl RaBitQ {
366
382
METRICS . add_precise_count ( res. len ( ) as u64 ) ;
367
383
let length = topk. min ( res. len ( ) ) ;
368
384
res. select_nth_unstable_by ( length - 1 , |a, b| a. 0 . total_cmp ( & b. 0 ) ) ;
369
- res[ ..length] . iter ( ) . map ( |( _, u) | * u) . collect ( )
385
+ res. truncate ( length) ;
386
+ res
370
387
}
371
388
}
0 commit comments