Skip to content

Commit edabd4a

Browse files
committed
rewrite binary_dot_product with simd
Signed-off-by: Keming <[email protected]>
1 parent 8f962f7 commit edabd4a

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

src/simd.rs

+64
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,67 @@ pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
272272

273273
sum
274274
}
275+
276+
/// Compute the binary dot product of two vectors.
277+
///
278+
/// Refer to: https://github.com/komrad36/popcount
279+
///
280+
/// # Safety
281+
///
282+
/// This function is marked unsafe because it requires the AVX2 intrinsics.
283+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
284+
#[target_feature(enable = "sse2,avx,avx2")]
285+
pub unsafe fn binary_dot_product(lhs: &[u64], rhs: &[u64]) -> u32 {
286+
use std::arch::x86_64::*;
287+
288+
let mut x_ptr = lhs.as_ptr() as *const __m256i;
289+
let mut y_ptr = rhs.as_ptr() as *const __m256i;
290+
291+
let length = lhs.len() / 4;
292+
let rest = lhs.len() & 0b11;
293+
let lookup_table = _mm256_setr_epi8(
294+
0, 1, 1, 2, 1, 2, 2, 3, // 0-7
295+
1, 2, 2, 3, 2, 3, 3, 4, // 8-15
296+
0, 1, 1, 2, 1, 2, 2, 3, // 16-23
297+
1, 2, 2, 3, 2, 3, 3, 4, // 24-31
298+
);
299+
let mask = _mm256_set1_epi8(15);
300+
let zero = _mm256_setzero_si256();
301+
302+
#[inline]
303+
unsafe fn mm256_popcnt_epi64(
304+
x: __m256i,
305+
lookup_table: __m256i,
306+
mask: __m256i,
307+
zero: __m256i,
308+
) -> __m256i {
309+
use std::arch::x86_64::*;
310+
311+
let mut low = _mm256_and_si256(x, mask);
312+
let mut high = _mm256_and_si256(_mm256_srli_epi64(x, 4), mask);
313+
low = _mm256_shuffle_epi8(lookup_table, low);
314+
high = _mm256_shuffle_epi8(lookup_table, high);
315+
_mm256_sad_epu8(_mm256_add_epi8(low, high), zero)
316+
}
317+
318+
let mut sum256 = _mm256_setzero_si256();
319+
for _ in 0..length {
320+
let x256 = _mm256_loadu_si256(x_ptr);
321+
let y256 = _mm256_loadu_si256(y_ptr);
322+
let and = _mm256_and_si256(x256, y256);
323+
sum256 = _mm256_add_epi64(sum256, mm256_popcnt_epi64(and, lookup_table, mask, zero));
324+
x_ptr = x_ptr.add(1);
325+
y_ptr = y_ptr.add(1);
326+
}
327+
328+
let xa = _mm_add_epi64(
329+
_mm256_castsi256_si128(sum256),
330+
_mm256_extracti128_si256(sum256, 1),
331+
);
332+
let mut sum = _mm_cvtsi128_si64(_mm_add_epi64(xa, _mm_shuffle_epi32(xa, 78))) as u32;
333+
for i in 0..rest {
334+
sum += (lhs[4 * length + i] & rhs[4 * length + i]).count_ones();
335+
}
336+
337+
sum
338+
}

src/utils.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,20 @@ pub fn asymmetric_binary_dot_product(x: &[u64], y: &[u64]) -> u32 {
115115
let length = x.len();
116116
let mut y_slice = y;
117117
for i in 0..THETA_LOG_DIM as usize {
118-
res += binary_dot_product(x, y_slice) << i;
118+
res += {
119+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
120+
{
121+
if is_x86_feature_detected!("avx2") {
122+
unsafe { crate::simd::binary_dot_product(x, y_slice) << i }
123+
} else {
124+
binary_dot_product(x, y_slice) << i
125+
}
126+
}
127+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
128+
{
129+
binary_dot_product(x, y_slice) << i
130+
}
131+
};
119132
y_slice = &y_slice[length..];
120133
}
121134
res

0 commit comments

Comments
 (0)