Skip to content

Commit a81f937

Browse files
committed
try simd dot product v.s. blas
Signed-off-by: Keming <[email protected]>
1 parent e5a4af0 commit a81f937

File tree

2 files changed

+86
-15
lines changed

2 files changed

+86
-15
lines changed

src/rabitq.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn vector_binarize_query(vec: &DVectorView<u8>, binary: &mut [u64]) {
3434
{
3535
if is_x86_feature_detected!("avx2") {
3636
unsafe {
37-
crate::simd::vector_binarize_query_avx2(&vec.as_view(), binary);
37+
crate::simd::vector_binarize_query(&vec.as_view(), binary);
3838
}
3939
} else {
4040
vector_binarize_query_raw(vec, binary);
@@ -90,7 +90,7 @@ fn l2_squared_distance(
9090
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
9191
{
9292
if is_x86_feature_detected!("avx2") {
93-
unsafe { crate::simd::l2_squared_distance_avx2(lhs, rhs) }
93+
unsafe { crate::simd::l2_squared_distance(lhs, rhs) }
9494
} else {
9595
lhs.sub_to(rhs, residual);
9696
residual.norm_squared()
@@ -127,7 +127,7 @@ fn min_max_residual(
127127
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
128128
{
129129
if is_x86_feature_detected!("avx") {
130-
unsafe { crate::simd::min_max_residual_avx(res, x, y) }
130+
unsafe { crate::simd::min_max_residual(res, x, y) }
131131
} else {
132132
x.sub_to(y, res);
133133
min_max_raw(&res.as_view())
@@ -168,7 +168,7 @@ fn scalar_quantize(
168168
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
169169
{
170170
if is_x86_feature_detected!("avx2") {
171-
unsafe { crate::simd::scalar_quantize_avx2(quantized, vec, lower_bound, multiplier) }
171+
unsafe { crate::simd::scalar_quantize(quantized, vec, lower_bound, multiplier) }
172172
} else {
173173
scalar_quantize_raw(quantized, vec, bias, lower_bound, multiplier)
174174
}
@@ -179,6 +179,26 @@ fn scalar_quantize(
179179
}
180180
}
181181

182+
/// Project the vector to the orthogonal matrix.
183+
#[allow(dead_code)]
184+
#[inline]
185+
fn project(vec: &DVectorView<f32>, orthogonal: &DMatrixView<f32>) -> DVector<f32> {
186+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
187+
{
188+
if is_x86_feature_detected!("avx2") {
189+
DVector::from_fn(vec.len(), |i, _| unsafe {
190+
crate::simd::vector_dot_product(vec, &orthogonal.column(i).as_view())
191+
})
192+
} else {
193+
vec.tr_mul(orthogonal).transpose()
194+
}
195+
}
196+
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
197+
{
198+
vec.tr_mul(orthogonal).transpose()
199+
}
200+
}
201+
182202
/// Find the nearest cluster for the given vector.
183203
fn kmeans_nearest_cluster(centroids: &DMatrixView<f32>, vec: &DVectorView<f32>) -> usize {
184204
let mut min_dist = f32::MAX;

src/simd.rs

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::consts::THETA_LOG_DIM;
1111
///
1212
/// This function is marked unsafe because it requires the AVX intrinsics.
1313
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
14-
#[target_feature(enable = "avx2")]
15-
pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView<f32>) -> f32 {
14+
#[target_feature(enable = "fma,avx")]
15+
pub unsafe fn l2_squared_distance(lhs: &DVectorView<f32>, rhs: &DVectorView<f32>) -> f32 {
1616
#[cfg(target_arch = "x86")]
1717
use std::arch::x86::*;
1818
#[cfg(target_arch = "x86_64")]
@@ -34,14 +34,14 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
3434
lhs_ptr = lhs_ptr.add(8);
3535
rhs_ptr = rhs_ptr.add(8);
3636
diff = _mm256_sub_ps(vx, vy);
37-
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
37+
sum = _mm256_fmadd_ps(diff, diff, sum);
3838

3939
vx = _mm256_loadu_ps(lhs_ptr);
4040
vy = _mm256_loadu_ps(rhs_ptr);
4141
lhs_ptr = lhs_ptr.add(8);
4242
rhs_ptr = rhs_ptr.add(8);
4343
diff = _mm256_sub_ps(vx, vy);
44-
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
44+
sum = _mm256_fmadd_ps(diff, diff, sum);
4545
}
4646

4747
for _ in 0..rest_num / 8 {
@@ -50,7 +50,7 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
5050
lhs_ptr = lhs_ptr.add(8);
5151
rhs_ptr = rhs_ptr.add(8);
5252
diff = _mm256_sub_ps(vx, vy);
53-
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
53+
sum = _mm256_fmadd_ps(diff, diff, sum);
5454
}
5555
_mm256_store_ps(temp_block_ptr, sum);
5656

@@ -63,7 +63,7 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
6363
+ temp_block[6]
6464
+ temp_block[7];
6565

66-
for _ in 0..rest_num % 8 {
66+
for _ in 0..rest_num {
6767
let residual = *lhs_ptr - *rhs_ptr;
6868
res += residual * residual;
6969
lhs_ptr = lhs_ptr.add(1);
@@ -78,8 +78,8 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
7878
///
7979
/// This function is marked unsafe because it requires the AVX intrinsics.
8080
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
81-
#[target_feature(enable = "avx2")]
82-
pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u64]) {
81+
#[target_feature(enable = "avx,avx2")]
82+
pub unsafe fn vector_binarize_query(vec: &DVectorView<u8>, binary: &mut [u64]) {
8383
use std::arch::x86_64::*;
8484

8585
let length = vec.len();
@@ -107,7 +107,7 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u6
107107
/// This function is marked unsafe because it requires the AVX intrinsics.
108108
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
109109
#[target_feature(enable = "avx")]
110-
pub unsafe fn min_max_residual_avx(
110+
pub unsafe fn min_max_residual(
111111
res: &mut DVector<f32>,
112112
x: &DVectorView<f32>,
113113
y: &DVectorView<f32>,
@@ -174,8 +174,8 @@ pub unsafe fn min_max_residual_avx(
174174
///
175175
/// This function is marked unsafe because it requires the AVX intrinsics.
176176
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
177-
#[target_feature(enable = "avx2")]
178-
pub unsafe fn scalar_quantize_avx2(
177+
#[target_feature(enable = "avx,avx2")]
178+
pub unsafe fn scalar_quantize(
179179
quantized: &mut DVector<u8>,
180180
vec: &DVectorView<f32>,
181181
lower_bound: f32,
@@ -233,3 +233,54 @@ pub unsafe fn scalar_quantize_avx2(
233233

234234
sum
235235
}
236+
237+
/// Compute the dot product of two vectors.
238+
///
239+
/// # Safety
240+
///
241+
/// This function is marked unsafe because it requires the AVX intrinsics.
242+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
243+
#[target_feature(enable = "fma,avx,avx2")]
244+
pub unsafe fn vector_dot_product(lhs: &DVectorView<f32>, rhs: &DVectorView<f32>) -> f32 {
245+
use std::arch::x86_64::*;
246+
247+
let mut lhs_ptr = lhs.as_ptr();
248+
let mut rhs_ptr = rhs.as_ptr();
249+
let length = lhs.len();
250+
let rest = length & 0b111;
251+
let (mut vx, mut vy): (__m256, __m256);
252+
let mut accumulate = _mm256_setzero_ps();
253+
let mut f32x8 = [0.0f32; 8];
254+
255+
for _ in 0..(length / 16) {
256+
vx = _mm256_loadu_ps(lhs_ptr);
257+
vy = _mm256_loadu_ps(rhs_ptr);
258+
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
259+
lhs_ptr = lhs_ptr.add(8);
260+
rhs_ptr = rhs_ptr.add(8);
261+
262+
vx = _mm256_loadu_ps(lhs_ptr);
263+
vy = _mm256_loadu_ps(rhs_ptr);
264+
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
265+
lhs_ptr = lhs_ptr.add(8);
266+
rhs_ptr = rhs_ptr.add(8);
267+
}
268+
for _ in 0..((length & 0b1111) / 8) {
269+
vx = _mm256_loadu_ps(lhs_ptr);
270+
vy = _mm256_loadu_ps(rhs_ptr);
271+
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
272+
lhs_ptr = lhs_ptr.add(8);
273+
rhs_ptr = rhs_ptr.add(8);
274+
}
275+
_mm256_storeu_ps(f32x8.as_mut_ptr(), accumulate);
276+
let mut sum =
277+
f32x8[0] + f32x8[1] + f32x8[2] + f32x8[3] + f32x8[4] + f32x8[5] + f32x8[6] + f32x8[7];
278+
279+
for _ in 0..rest {
280+
sum += *lhs_ptr * *rhs_ptr;
281+
lhs_ptr = lhs_ptr.add(1);
282+
rhs_ptr = rhs_ptr.add(1);
283+
}
284+
285+
sum
286+
}

0 commit comments

Comments
 (0)