Skip to content

Commit f114fc1

Browse files
committed
use qr to get random orthogonal, add simd query binarize func
Signed-off-by: Keming <[email protected]>
1 parent 844d146 commit f114fc1

File tree

6 files changed

+80
-20
lines changed

6 files changed

+80
-20
lines changed

src/consts.rs

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! Constants used in the program.
2+
3+
pub(crate) const DEFAULT_X_DOT_PRODUCT: f32 = 0.8;
4+
pub(crate) const EPSILON: f32 = 1.9;
5+
pub(crate) const THETA_LOG_DIM: u32 = 4;
6+
pub(crate) const WINDOWS_SIZE: usize = 12;

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! RaBitQ implementation in Rust.
22
33
#![forbid(missing_docs)]
4-
pub mod distance;
4+
mod consts;
55
pub mod metrics;
66
pub mod rabitq;
7+
pub mod simd;
78
pub mod utils;
89

910
pub use rabitq::RaBitQ;

src/metrics.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ impl Metrics {
2222

2323
/// get the instance
2424
pub fn to_str(&self) -> String {
25+
let rough = self.rough.load(Ordering::Relaxed);
26+
let precise = self.precise.load(Ordering::Relaxed);
2527
format!(
26-
"rough: {}, precise: {}",
27-
self.rough.load(Ordering::Relaxed),
28-
self.precise.load(Ordering::Relaxed)
28+
"rough: {}, precise: {}, ratio: {:.2}",
29+
rough,
30+
precise,
31+
rough as f64 / precise as f64,
2932
)
3033
}
3134

src/rabitq.rs

+25-15
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@ use nalgebra::{DMatrix, DVector, DVectorView};
77
use num_traits::ToPrimitive;
88
use serde::{Deserialize, Serialize};
99

10+
use crate::consts::{DEFAULT_X_DOT_PRODUCT, EPSILON, THETA_LOG_DIM, WINDOWS_SIZE};
1011
use crate::metrics::METRICS;
11-
use crate::utils::{gen_random_bias, gen_random_orthogonal, matrix_from_fvecs};
12-
13-
const DEFAULT_X_DOT_PRODUCT: f32 = 0.8;
14-
const EPSILON: f32 = 1.9;
15-
const THETA_LOG_DIM: u32 = 4;
16-
const WINDOWS_SIZE: usize = 16;
12+
use crate::utils::{gen_random_bias, gen_random_qr_orthogonal, matrix_from_fvecs};
1713

1814
/// Convert the vector to binary format and store in a u64 vector.
1915
fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
@@ -32,8 +28,25 @@ fn vector_binarize_one(vec: &DVector<f32>) -> DVector<f32> {
3228
DVector::from_fn(vec.len(), |i, _| if vec[i] > 0.0 { 1.0 } else { -1.0 })
3329
}
3430

31+
/// Interface of `vector_binarize_query`
32+
fn vector_binarize_query(vec: &DVector<u8>) -> Vec<u64> {
33+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
34+
{
35+
if is_x86_feature_detected!("avx2") {
36+
unsafe { crate::simd::vector_binarize_query_avx2(&vec.as_view()) }
37+
} else {
38+
vector_binarize_query_raw(vec)
39+
}
40+
}
41+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
42+
{
43+
vector_binarize_query_raw(vec)
44+
}
45+
}
46+
3547
/// Convert the vector to binary format (one value to multiple bits) and store in a u64 vector.
36-
fn query_vector_binarize(vec: &DVector<u8>) -> Vec<u64> {
48+
#[inline]
49+
fn vector_binarize_query_raw(vec: &DVector<u8>) -> Vec<u64> {
3750
let length = vec.len();
3851
let mut binary = vec![0u64; length * THETA_LOG_DIM as usize / 64];
3952
for j in 0..THETA_LOG_DIM as usize {
@@ -78,7 +91,7 @@ fn kmeans_nearest_cluster(centroids: &DMatrix<f32>, vec: &DVectorView<f32>) -> u
7891
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
7992
{
8093
if is_x86_feature_detected!("avx2") {
81-
unsafe { crate::distance::l2_squared_distance_avx2(&centroid, vec) }
94+
unsafe { crate::simd::l2_squared_distance_avx2(&centroid, vec) }
8295
} else {
8396
vec.sub_to(&centroid, &mut residual);
8497
residual.norm_squared()
@@ -123,7 +136,7 @@ impl RaBitQ {
123136
let centroids = matrix_from_fvecs(centroid_path);
124137
let k = centroids.shape().0;
125138
debug!("n: {}, dim: {}, k: {}", n, dim, k);
126-
let orthogonal = gen_random_orthogonal(dim);
139+
let orthogonal = gen_random_qr_orthogonal(dim);
127140
let rand_bias = gen_random_bias(dim);
128141

129142
// projection
@@ -226,10 +239,7 @@ impl RaBitQ {
226239
{
227240
if is_x86_feature_detected!("avx2") {
228241
unsafe {
229-
crate::distance::l2_squared_distance_avx2(
230-
&centroid,
231-
&y_projected.as_view(),
232-
)
242+
crate::simd::l2_squared_distance_avx2(&centroid, &y_projected.as_view())
233243
}
234244
} else {
235245
y_projected.sub_to(&centroid, &mut residual);
@@ -257,7 +267,7 @@ impl RaBitQ {
257267
let y_scaled = residual.add_scalar(-lower_bound) * one_over_delta + &self.rand_bias;
258268
let y_quantized = y_scaled.map(|v| v.to_u8().expect("convert to u8 error"));
259269
let scalar_sum = y_quantized.iter().fold(0u32, |acc, &v| acc + v as u32);
260-
let y_binary_vec = query_vector_binarize(&y_quantized);
270+
let y_binary_vec = vector_binarize_query(&y_quantized);
261271
let dist_sqrt = dist.sqrt();
262272
for j in self.offsets[i]..self.offsets[i + 1] {
263273
let ju = j as usize;
@@ -300,7 +310,7 @@ impl RaBitQ {
300310
{
301311
if is_x86_feature_detected!("avx2") {
302312
unsafe {
303-
crate::distance::l2_squared_distance_avx2(
313+
crate::simd::l2_squared_distance_avx2(
304314
&self.base.column(u as usize),
305315
&query.as_view(),
306316
)

src/distance.rs renamed to src/simd.rs

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
//! Compute the distance between two vectors.
1+
//! Accelerate with SIMD.
22
33
use nalgebra::DVectorView;
44

5+
use crate::consts::THETA_LOG_DIM;
6+
57
/// Compute the squared Euclidean distance between two vectors.
68
/// Code refer to https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h
79
///
@@ -69,3 +71,34 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
6971
}
7072
res
7173
}
74+
75+
/// Convert an [u8] to 4x binary vector stored as u64.
76+
///
77+
/// # Safety
78+
///
79+
/// This function is marked unsafe because it requires the AVX intrinsics.
80+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
81+
#[target_feature(enable = "avx2")]
82+
pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>) -> Vec<u64> {
83+
use std::arch::x86_64::*;
84+
85+
let length = vec.len();
86+
let mut ptr = vec.as_ptr() as *const __m256i;
87+
let mut binary = vec![0u64; length * THETA_LOG_DIM as usize / 64];
88+
89+
for i in (0..length).step_by(32) {
90+
// since it's not guaranteed that the vec is fully-aligned
91+
let mut v = _mm256_loadu_si256(ptr);
92+
ptr = ptr.add(1);
93+
v = _mm256_slli_epi32(v, 4);
94+
for j in 0..THETA_LOG_DIM as usize {
95+
let mask = (_mm256_movemask_epi8(v) as u32) as u64;
96+
// let shift = if (i / 32) % 2 == 0 { 32 } else { 0 };
97+
let shift = ((i >> 5) & 1) << 5;
98+
binary[(3 - j) * (length >> 6) + (i >> 6)] |= mask << shift;
99+
v = _mm256_slli_epi32(v, 1);
100+
}
101+
}
102+
103+
binary
104+
}

src/utils.rs

+7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ pub fn gen_random_orthogonal(dim: usize) -> DMatrix<f32> {
1818
random.unwrap()
1919
}
2020

21+
/// Generate a random orthogonal matrix from QR decomposition.
22+
pub fn gen_random_qr_orthogonal(dim: usize) -> DMatrix<f32> {
23+
let mut rng = thread_rng();
24+
let random = DMatrix::from_fn(dim, dim, |_, _| rng.gen::<f32>());
25+
random.qr().q()
26+
}
27+
2128
/// Generate an identity matrix as a special orthogonal matrix.
2229
///
2330
/// Use this function to debug the logic.

0 commit comments

Comments
 (0)