Skip to content

Commit 844d146

Browse files
committed
update release profile, cfg not x86 to fallback
Signed-off-by: Keming <[email protected]>
1 parent 5f82fcc commit 844d146

File tree

4 files changed

+42
-13
lines changed

4 files changed

+42
-13
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ rand = "0.8.5"
1919
serde = "1.0.207"
2020
serde_json = "1.0.124"
2121

22+
[profile.release]
23+
codegen-units = 1
24+
lto = "fat"
25+
panic = "abort"
26+
2227
[profile.perf]
2328
inherits = "release"
2429
debug = true

rustfmt.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ reorder_imports = true
33
format_strings = true
44
imports_granularity = "Module"
55
group_imports = "StdExternalCrate"
6-
reorder_impl_items = true
6+
reorder_impl_items = true

src/distance.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Compute the distance between two vectors.
22
3-
use nalgebra::{DVector, DVectorView};
3+
use nalgebra::DVectorView;
44

55
/// Compute the squared Euclidean distance between two vectors.
66
/// Code refer to https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h
@@ -10,7 +10,7 @@ use nalgebra::{DVector, DVectorView};
1010
/// This function is marked unsafe because it requires the AVX intrinsics.
1111
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
1212
#[target_feature(enable = "avx2")]
13-
pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVector<f32>) -> f32 {
13+
pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView<f32>) -> f32 {
1414
#[cfg(target_arch = "x86")]
1515
use std::arch::x86::*;
1616
#[cfg(target_arch = "x86_64")]

src/rabitq.rs

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,9 @@ fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
2727
}
2828

2929
/// Convert the vector to +1/-1 format.
30+
#[inline]
3031
fn vector_binarize_one(vec: &DVector<f32>) -> DVector<f32> {
31-
let mut binary = DVector::zeros(vec.len());
32-
for (i, &v) in vec.iter().enumerate() {
33-
binary[i] = if v > 0.0 { 1.0 } else { -1.0 };
34-
}
35-
binary
32+
DVector::from_fn(vec.len(), |i, _| if vec[i] > 0.0 { 1.0 } else { -1.0 })
3633
}
3734

3835
/// Convert the vector to binary format (one value to multiple bits) and store in a u64 vector.
@@ -75,10 +72,24 @@ fn asymmetric_binary_dot_product(x: &[u64], y: &[u64]) -> u32 {
7572
fn kmeans_nearest_cluster(centroids: &DMatrix<f32>, vec: &DVectorView<f32>) -> usize {
7673
let mut min_dist = f32::MAX;
7774
let mut min_label = 0;
78-
let mut temp = DVector::<f32>::zeros(vec.len());
75+
let mut residual = DVector::<f32>::zeros(vec.len());
7976
for (j, centroid) in centroids.column_iter().enumerate() {
80-
vec.sub_to(&centroid, &mut temp);
81-
let dist = temp.norm_squared();
77+
let dist = {
78+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
79+
{
80+
if is_x86_feature_detected!("avx2") {
81+
unsafe { crate::distance::l2_squared_distance_avx2(&centroid, vec) }
82+
} else {
83+
vec.sub_to(&centroid, &mut residual);
84+
residual.norm_squared()
85+
}
86+
}
87+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
88+
{
89+
self.base.column(u as usize).sub_to(query, &mut residual);
90+
residual.norm_squared()
91+
}
92+
};
8293
if dist < min_dist {
8394
min_dist = dist;
8495
min_label = j;
@@ -215,13 +226,21 @@ impl RaBitQ {
215226
{
216227
if is_x86_feature_detected!("avx2") {
217228
unsafe {
218-
crate::distance::l2_squared_distance_avx2(&centroid, &y_projected)
229+
crate::distance::l2_squared_distance_avx2(
230+
&centroid,
231+
&y_projected.as_view(),
232+
)
219233
}
220234
} else {
221235
y_projected.sub_to(&centroid, &mut residual);
222236
residual.norm_squared()
223237
}
224238
}
239+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
240+
{
241+
self.base.column(u as usize).sub_to(query, &mut residual);
242+
residual.norm_squared()
243+
}
225244
};
226245
lists.push((dist, i));
227246
}
@@ -283,14 +302,19 @@ impl RaBitQ {
283302
unsafe {
284303
crate::distance::l2_squared_distance_avx2(
285304
&self.base.column(u as usize),
286-
query,
305+
&query.as_view(),
287306
)
288307
}
289308
} else {
290309
self.base.column(u as usize).sub_to(query, &mut residual);
291310
residual.norm_squared()
292311
}
293312
}
313+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
314+
{
315+
self.base.column(u as usize).sub_to(query, &mut residual);
316+
residual.norm_squared()
317+
}
294318
};
295319
if accurate < threshold {
296320
res.push((accurate, u as i32));

0 commit comments

Comments
 (0)