Skip to content

Commit c97be68

Browse files
committed
rewrite min/max with simd
Signed-off-by: Keming <[email protected]>
1 parent f54bedf commit c97be68

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

src/main.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ fn main() {
6969
let start_time = Instant::now();
7070
let res = rabitq.query_one(&query_vec, args.probe, args.topk);
7171
total_time += start_time.elapsed().as_secs_f64();
72-
recall += calculate_recall(&truth[i], &res, args.topk);
72+
let ids: Vec<i32> = res.iter().map(|(_, id)| *id as i32).collect();
73+
recall += calculate_recall(&truth[i], &ids, args.topk);
7374
}
7475

7576
info!(

src/rabitq.rs

+22-5
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ fn l2_squared_distance(
104104
}
105105

106106
// Get the min/max value of a vector.
107-
fn min_max(vec: &DVectorView<f32>) -> (f32, f32) {
107+
fn min_max_raw(vec: &DVectorView<f32>) -> (f32, f32) {
108108
let mut min = f32::MAX;
109109
let mut max = f32::MIN;
110110
for v in vec.iter() {
@@ -118,6 +118,22 @@ fn min_max(vec: &DVectorView<f32>) -> (f32, f32) {
118118
(min, max)
119119
}
120120

121+
// Interface of `min_max`
122+
fn min_max(vec: &DVectorView<f32>) -> (f32, f32) {
123+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
124+
{
125+
if is_x86_feature_detected!("avx") {
126+
unsafe { crate::simd::min_max_avx(vec) }
127+
} else {
128+
min_max_raw(vec)
129+
}
130+
}
131+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
132+
{
133+
min_max_raw(vec)
134+
}
135+
}
136+
121137
// Quantize the query residual vector.
122138
fn quantize_query_vector(
123139
quantized: &mut DVector<u8>,
@@ -279,7 +295,7 @@ impl RaBitQ {
279295
}
280296

281297
/// Query the topk nearest neighbors for the given query.
282-
pub fn query_one(&self, query: &DVector<f32>, probe: usize, topk: usize) -> Vec<i32> {
298+
pub fn query_one(&self, query: &DVector<f32>, probe: usize, topk: usize) -> Vec<(f32, u32)> {
283299
let y_projected = query.tr_mul(&self.orthogonal).transpose();
284300
let k = self.centroids.shape().1;
285301
let mut lists = Vec::with_capacity(k);
@@ -337,7 +353,7 @@ impl RaBitQ {
337353
query: &DVector<f32>,
338354
rough_distances: &[(f32, u32)],
339355
topk: usize,
340-
) -> Vec<i32> {
356+
) -> Vec<(f32, u32)> {
341357
let mut threshold = f32::MAX;
342358
let mut recent_max_accurate = f32::MIN;
343359
let mut res = Vec::with_capacity(topk);
@@ -351,7 +367,7 @@ impl RaBitQ {
351367
&mut residual,
352368
);
353369
if accurate < threshold {
354-
res.push((accurate, u as i32));
370+
res.push((accurate, u));
355371
count += 1;
356372
recent_max_accurate = recent_max_accurate.max(accurate);
357373
if count == WINDOWS_SIZE {
@@ -366,6 +382,7 @@ impl RaBitQ {
366382
METRICS.add_precise_count(res.len() as u64);
367383
let length = topk.min(res.len());
368384
res.select_nth_unstable_by(length - 1, |a, b| a.0.total_cmp(&b.0));
369-
res[..length].iter().map(|(_, u)| *u).collect()
385+
res.truncate(length);
386+
res
370387
}
371388
}

src/simd.rs

+51
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,54 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u6
9999
}
100100
}
101101
}
102+
103+
/// Compute the min and max value of a vector.
104+
///
105+
/// # Safety
106+
///
107+
/// This function is marked unsafe because it requires the AVX intrinsics.
108+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
109+
#[target_feature(enable = "avx")]
110+
pub unsafe fn min_max_avx(vec: &DVectorView<f32>) -> (f32, f32) {
111+
use std::arch::x86_64::*;
112+
113+
let mut min_32x8 = _mm256_set1_ps(f32::MAX);
114+
let mut max_32x8 = _mm256_set1_ps(f32::MIN);
115+
let mut ptr = vec.as_ptr();
116+
let mut f32x8 = [0.0f32; 8];
117+
let mut min = f32::MAX;
118+
let mut max = f32::MIN;
119+
let length = vec.len();
120+
let rest = length & 0b111;
121+
122+
for _ in 0..(length / 8) {
123+
let v = _mm256_loadu_ps(ptr);
124+
ptr = ptr.add(8);
125+
min_32x8 = _mm256_min_ps(min_32x8, v);
126+
max_32x8 = _mm256_max_ps(max_32x8, v);
127+
}
128+
_mm256_storeu_ps(f32x8.as_mut_ptr(), min_32x8);
129+
for &x in f32x8.iter() {
130+
if x < min {
131+
min = x;
132+
}
133+
}
134+
_mm256_storeu_ps(f32x8.as_mut_ptr(), max_32x8);
135+
for &x in f32x8.iter() {
136+
if x > max {
137+
max = x;
138+
}
139+
}
140+
141+
for _ in 0..rest {
142+
if *ptr < min {
143+
min = *ptr;
144+
}
145+
if *ptr > max {
146+
max = *ptr;
147+
}
148+
ptr = ptr.add(1);
149+
}
150+
151+
(min, max)
152+
}

0 commit comments

Comments
 (0)