@@ -27,7 +27,7 @@ pub struct RaBitQ {
27
27
centroids : Mat < f32 > ,
28
28
offsets : Vec < u32 > ,
29
29
map_ids : Vec < u32 > ,
30
- x_binary_vec : Vec < Vec < u64 > > ,
30
+ x_binary_vec : Vec < u64 > ,
31
31
x_c_distance_square : Vec < f32 > ,
32
32
error_bound : Vec < f32 > ,
33
33
factor_ip : Vec < f32 > ,
@@ -63,8 +63,11 @@ impl RaBitQ {
63
63
let error_bound = factors[ 2 ] . clone ( ) ;
64
64
let x_c_distance_square = factors[ 3 ] . clone ( ) ;
65
65
66
- let x_binary_vec =
67
- read_u64_vecs ( & path. join ( "x_binary_vec.u64vecs" ) ) . expect ( "open x_binary_vec error" ) ;
66
+ let x_binary_vec = read_u64_vecs ( & path. join ( "x_binary_vec.u64vecs" ) )
67
+ . expect ( "open x_binary_vec error" )
68
+ . into_iter ( )
69
+ . flatten ( )
70
+ . collect ( ) ;
68
71
69
72
let dim = orthogonal. nrows ( ) ;
70
73
let base = matrix_from_fvecs ( & path. join ( "base.fvecs" ) )
@@ -113,7 +116,8 @@ impl RaBitQ {
113
116
. expect ( "write factors error" ) ;
114
117
write_vecs (
115
118
& path. join ( "x_binary_vec.u64vecs" ) ,
116
- & self . x_binary_vec . iter ( ) . collect :: < Vec < _ > > ( ) ,
119
+ // &self.x_binary_vec.iter().collect::<Vec<_>>(),
120
+ & [ & self . x_binary_vec ] ,
117
121
)
118
122
. expect ( "write x_binary_vec error" ) ;
119
123
}
@@ -187,7 +191,7 @@ impl RaBitQ {
187
191
. to_owned ( ) ;
188
192
let x_binary_vec = flat_labels
189
193
. iter ( )
190
- . map ( |i| x_binary_vec[ * i as usize ] . clone ( ) )
194
+ . flat_map ( |i| x_binary_vec[ * i as usize ] . clone ( ) )
191
195
. collect ( ) ;
192
196
let x_c_distance_square = flat_labels
193
197
. iter ( )
@@ -227,6 +231,7 @@ impl RaBitQ {
227
231
topk : usize ,
228
232
heuristic_rank : bool ,
229
233
) -> Vec < ( f32 , u32 ) > {
234
+ assert_eq ! ( self . dim as usize , query. nrows( ) ) ;
230
235
let y_projected = project ( query, & self . orthogonal . as_ref ( ) ) ;
231
236
let k = self . centroids . shape ( ) . 1 ;
232
237
let mut lists = Vec :: with_capacity ( k) ;
@@ -287,15 +292,18 @@ impl RaBitQ {
287
292
rough_distances : & mut Vec < ( f32 , u32 ) > ,
288
293
) {
289
294
let dist_sqrt = y_c_distance_square. sqrt ( ) ;
295
+ let binary_offset = y_binary_vec. len ( ) / THETA_LOG_DIM as usize ;
290
296
for j in self . offsets [ cluster_id] ..self . offsets [ cluster_id + 1 ] {
291
297
let ju = j as usize ;
292
298
rough_distances. push ( (
293
299
( self . x_c_distance_square [ ju]
294
300
+ y_c_distance_square
295
301
+ lower_bound * self . factor_ppc [ ju]
296
302
+ ( 2.0
297
- * asymmetric_binary_dot_product ( & self . x_binary_vec [ ju] , y_binary_vec)
298
- as f32
303
+ * asymmetric_binary_dot_product (
304
+ & self . x_binary_vec [ ju * binary_offset..( ju + 1 ) * binary_offset] ,
305
+ y_binary_vec,
306
+ ) as f32
299
307
- scalar_sum)
300
308
* self . factor_ip [ ju]
301
309
* delta
0 commit comments