Skip to content

Commit 5f82fcc

Browse files
committed
use avx2 to compute the sqaured distance
Signed-off-by: Keming <[email protected]>
1 parent 74531d2 commit 5f82fcc

File tree

3 files changed

+101
-4
lines changed

3 files changed

+101
-4
lines changed

src/distance.rs

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//! Compute the distance between two vectors.
2+
3+
use nalgebra::{DVector, DVectorView};
4+
5+
/// Compute the squared Euclidean distance between two vectors.
6+
/// Code refer to https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h
7+
///
8+
/// # Safety
9+
///
10+
/// This function is marked unsafe because it requires the AVX intrinsics.
11+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
12+
#[target_feature(enable = "avx2")]
13+
pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVector<f32>) -> f32 {
14+
#[cfg(target_arch = "x86")]
15+
use std::arch::x86::*;
16+
#[cfg(target_arch = "x86_64")]
17+
use std::arch::x86_64::*;
18+
19+
assert_eq!(lhs.len(), rhs.len());
20+
let mut lhs_ptr = lhs.as_ptr();
21+
let mut rhs_ptr = rhs.as_ptr();
22+
let block_16_num = lhs.len() >> 4;
23+
let rest_num = lhs.len() & 0b1111;
24+
let mut temp_block = [0.0f32; 8];
25+
let temp_block_ptr = temp_block.as_mut_ptr();
26+
let (mut diff, mut vx, mut vy): (__m256, __m256, __m256);
27+
let mut sum = _mm256_setzero_ps();
28+
29+
for _ in 0..block_16_num {
30+
vx = _mm256_loadu_ps(lhs_ptr);
31+
vy = _mm256_loadu_ps(rhs_ptr);
32+
lhs_ptr = lhs_ptr.add(8);
33+
rhs_ptr = rhs_ptr.add(8);
34+
diff = _mm256_sub_ps(vx, vy);
35+
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
36+
37+
vx = _mm256_loadu_ps(lhs_ptr);
38+
vy = _mm256_loadu_ps(rhs_ptr);
39+
lhs_ptr = lhs_ptr.add(8);
40+
rhs_ptr = rhs_ptr.add(8);
41+
diff = _mm256_sub_ps(vx, vy);
42+
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
43+
}
44+
45+
for _ in 0..rest_num / 8 {
46+
vx = _mm256_loadu_ps(lhs_ptr);
47+
vy = _mm256_loadu_ps(rhs_ptr);
48+
lhs_ptr = lhs_ptr.add(8);
49+
rhs_ptr = rhs_ptr.add(8);
50+
diff = _mm256_sub_ps(vx, vy);
51+
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
52+
}
53+
_mm256_store_ps(temp_block_ptr, sum);
54+
55+
let mut res = temp_block[0]
56+
+ temp_block[1]
57+
+ temp_block[2]
58+
+ temp_block[3]
59+
+ temp_block[4]
60+
+ temp_block[5]
61+
+ temp_block[6]
62+
+ temp_block[7];
63+
64+
for _ in 0..rest_num % 8 {
65+
let residual = *lhs_ptr - *rhs_ptr;
66+
res += residual * residual;
67+
lhs_ptr = lhs_ptr.add(1);
68+
rhs_ptr = rhs_ptr.add(1);
69+
}
70+
res
71+
}

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! RaBitQ implementation in Rust.
22
33
#![forbid(missing_docs)]
4+
pub mod distance;
45
pub mod metrics;
56
pub mod rabitq;
67
pub mod utils;

src/rabitq.rs

+29-4
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,19 @@ impl RaBitQ {
210210
let mut lists = Vec::with_capacity(k);
211211
let mut residual = DVector::<f32>::zeros(self.dim as usize);
212212
for (i, centroid) in self.centroids.column_iter().enumerate() {
213-
y_projected.sub_to(&centroid, &mut residual);
214-
let dist = residual.norm_squared();
213+
let dist = {
214+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
215+
{
216+
if is_x86_feature_detected!("avx2") {
217+
unsafe {
218+
crate::distance::l2_squared_distance_avx2(&centroid, &y_projected)
219+
}
220+
} else {
221+
y_projected.sub_to(&centroid, &mut residual);
222+
residual.norm_squared()
223+
}
224+
}
225+
};
215226
lists.push((dist, i));
216227
}
217228
let length = probe.min(k);
@@ -265,8 +276,22 @@ impl RaBitQ {
265276
let mut residual = DVector::<f32>::zeros(self.dim as usize);
266277
for &(rough, u) in rough_distances.iter() {
267278
if rough < threshold {
268-
self.base.column(u as usize).sub_to(query, &mut residual);
269-
let accurate = residual.norm_squared();
279+
let accurate = {
280+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
281+
{
282+
if is_x86_feature_detected!("avx2") {
283+
unsafe {
284+
crate::distance::l2_squared_distance_avx2(
285+
&self.base.column(u as usize),
286+
query,
287+
)
288+
}
289+
} else {
290+
self.base.column(u as usize).sub_to(query, &mut residual);
291+
residual.norm_squared()
292+
}
293+
}
294+
};
270295
if accurate < threshold {
271296
res.push((accurate, u as i32));
272297
count += 1;

0 commit comments

Comments
 (0)