@@ -272,3 +272,67 @@ pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
272
272
273
273
sum
274
274
}
275
+
276
+ /// Compute the binary dot product of two vectors.
277
+ ///
278
+ /// Refer to: https://github.com/komrad36/popcount
279
+ ///
280
+ /// # Safety
281
+ ///
282
+ /// This function is marked unsafe because it requires the AVX2 intrinsics.
283
+ #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
284
+ #[ target_feature( enable = "sse2,avx,avx2" ) ]
285
+ pub unsafe fn binary_dot_product ( lhs : & [ u64 ] , rhs : & [ u64 ] ) -> u32 {
286
+ use std:: arch:: x86_64:: * ;
287
+
288
+ let mut x_ptr = lhs. as_ptr ( ) as * const __m256i ;
289
+ let mut y_ptr = rhs. as_ptr ( ) as * const __m256i ;
290
+
291
+ let length = lhs. len ( ) / 4 ;
292
+ let rest = lhs. len ( ) & 0b11 ;
293
+ let lookup_table = _mm256_setr_epi8 (
294
+ 0 , 1 , 1 , 2 , 1 , 2 , 2 , 3 , // 0-7
295
+ 1 , 2 , 2 , 3 , 2 , 3 , 3 , 4 , // 8-15
296
+ 0 , 1 , 1 , 2 , 1 , 2 , 2 , 3 , // 16-23
297
+ 1 , 2 , 2 , 3 , 2 , 3 , 3 , 4 , // 24-31
298
+ ) ;
299
+ let mask = _mm256_set1_epi8 ( 15 ) ;
300
+ let zero = _mm256_setzero_si256 ( ) ;
301
+
302
+ #[ inline]
303
+ unsafe fn mm256_popcnt_epi64 (
304
+ x : __m256i ,
305
+ lookup_table : __m256i ,
306
+ mask : __m256i ,
307
+ zero : __m256i ,
308
+ ) -> __m256i {
309
+ use std:: arch:: x86_64:: * ;
310
+
311
+ let mut low = _mm256_and_si256 ( x, mask) ;
312
+ let mut high = _mm256_and_si256 ( _mm256_srli_epi64 ( x, 4 ) , mask) ;
313
+ low = _mm256_shuffle_epi8 ( lookup_table, low) ;
314
+ high = _mm256_shuffle_epi8 ( lookup_table, high) ;
315
+ _mm256_sad_epu8 ( _mm256_add_epi8 ( low, high) , zero)
316
+ }
317
+
318
+ let mut sum256 = _mm256_setzero_si256 ( ) ;
319
+ for _ in 0 ..length {
320
+ let x256 = _mm256_loadu_si256 ( x_ptr) ;
321
+ let y256 = _mm256_loadu_si256 ( y_ptr) ;
322
+ let and = _mm256_and_si256 ( x256, y256) ;
323
+ sum256 = _mm256_add_epi64 ( sum256, mm256_popcnt_epi64 ( and, lookup_table, mask, zero) ) ;
324
+ x_ptr = x_ptr. add ( 1 ) ;
325
+ y_ptr = y_ptr. add ( 1 ) ;
326
+ }
327
+
328
+ let xa = _mm_add_epi64 (
329
+ _mm256_castsi256_si128 ( sum256) ,
330
+ _mm256_extracti128_si256 ( sum256, 1 ) ,
331
+ ) ;
332
+ let mut sum = _mm_cvtsi128_si64 ( _mm_add_epi64 ( xa, _mm_shuffle_epi32 ( xa, 78 ) ) ) as u32 ;
333
+ for i in 0 ..rest {
334
+ sum += ( lhs[ 4 * length + i] & rhs[ 4 * length + i] ) . count_ones ( ) ;
335
+ }
336
+
337
+ sum
338
+ }
0 commit comments