@@ -27,12 +27,9 @@ fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
27
27
}
28
28
29
29
/// Convert the vector to +1/-1 format.
30
+ #[ inline]
30
31
fn vector_binarize_one ( vec : & DVector < f32 > ) -> DVector < f32 > {
31
- let mut binary = DVector :: zeros ( vec. len ( ) ) ;
32
- for ( i, & v) in vec. iter ( ) . enumerate ( ) {
33
- binary[ i] = if v > 0.0 { 1.0 } else { -1.0 } ;
34
- }
35
- binary
32
+ DVector :: from_fn ( vec. len ( ) , |i, _| if vec[ i] > 0.0 { 1.0 } else { -1.0 } )
36
33
}
37
34
38
35
/// Convert the vector to binary format (one value to multiple bits) and store in a u64 vector.
@@ -75,10 +72,24 @@ fn asymmetric_binary_dot_product(x: &[u64], y: &[u64]) -> u32 {
75
72
fn kmeans_nearest_cluster ( centroids : & DMatrix < f32 > , vec : & DVectorView < f32 > ) -> usize {
76
73
let mut min_dist = f32:: MAX ;
77
74
let mut min_label = 0 ;
78
- let mut temp = DVector :: < f32 > :: zeros ( vec. len ( ) ) ;
75
+ let mut residual = DVector :: < f32 > :: zeros ( vec. len ( ) ) ;
79
76
for ( j, centroid) in centroids. column_iter ( ) . enumerate ( ) {
80
- vec. sub_to ( & centroid, & mut temp) ;
81
- let dist = temp. norm_squared ( ) ;
77
+ let dist = {
78
+ #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
79
+ {
80
+ if is_x86_feature_detected ! ( "avx2" ) {
81
+ unsafe { crate :: distance:: l2_squared_distance_avx2 ( & centroid, vec) }
82
+ } else {
83
+ vec. sub_to ( & centroid, & mut residual) ;
84
+ residual. norm_squared ( )
85
+ }
86
+ }
87
+ #[ cfg( not( any( target_arch = "x86_64" , target_arch = "x86" ) ) ) ]
88
+ {
89
+ self . base . column ( u as usize ) . sub_to ( query, & mut residual) ;
90
+ residual. norm_squared ( )
91
+ }
92
+ } ;
82
93
if dist < min_dist {
83
94
min_dist = dist;
84
95
min_label = j;
@@ -215,13 +226,21 @@ impl RaBitQ {
215
226
{
216
227
if is_x86_feature_detected ! ( "avx2" ) {
217
228
unsafe {
218
- crate :: distance:: l2_squared_distance_avx2 ( & centroid, & y_projected)
229
+ crate :: distance:: l2_squared_distance_avx2 (
230
+ & centroid,
231
+ & y_projected. as_view ( ) ,
232
+ )
219
233
}
220
234
} else {
221
235
y_projected. sub_to ( & centroid, & mut residual) ;
222
236
residual. norm_squared ( )
223
237
}
224
238
}
239
+ #[ cfg( not( any( target_arch = "x86_64" , target_arch = "x86" ) ) ) ]
240
+ {
241
+ self . base . column ( u as usize ) . sub_to ( query, & mut residual) ;
242
+ residual. norm_squared ( )
243
+ }
225
244
} ;
226
245
lists. push ( ( dist, i) ) ;
227
246
}
@@ -283,14 +302,19 @@ impl RaBitQ {
283
302
unsafe {
284
303
crate :: distance:: l2_squared_distance_avx2 (
285
304
& self . base . column ( u as usize ) ,
286
- query,
305
+ & query. as_view ( ) ,
287
306
)
288
307
}
289
308
} else {
290
309
self . base . column ( u as usize ) . sub_to ( query, & mut residual) ;
291
310
residual. norm_squared ( )
292
311
}
293
312
}
313
+ #[ cfg( not( any( target_arch = "x86_64" , target_arch = "x86" ) ) ) ]
314
+ {
315
+ self . base . column ( u as usize ) . sub_to ( query, & mut residual) ;
316
+ residual. norm_squared ( )
317
+ }
294
318
} ;
295
319
if accurate < threshold {
296
320
res. push ( ( accurate, u as i32 ) ) ;
0 commit comments