@@ -11,8 +11,8 @@ use crate::consts::THETA_LOG_DIM;
1111///
1212/// This function is marked unsafe because it requires the AVX intrinsics.
1313#[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
14- #[ target_feature( enable = "avx2 " ) ]
15- pub unsafe fn l2_squared_distance_avx2 ( lhs : & DVectorView < f32 > , rhs : & DVectorView < f32 > ) -> f32 {
14+ #[ target_feature( enable = "fma,avx " ) ]
15+ pub unsafe fn l2_squared_distance ( lhs : & DVectorView < f32 > , rhs : & DVectorView < f32 > ) -> f32 {
1616 #[ cfg( target_arch = "x86" ) ]
1717 use std:: arch:: x86:: * ;
1818 #[ cfg( target_arch = "x86_64" ) ]
@@ -34,14 +34,14 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
3434 lhs_ptr = lhs_ptr. add ( 8 ) ;
3535 rhs_ptr = rhs_ptr. add ( 8 ) ;
3636 diff = _mm256_sub_ps ( vx, vy) ;
37- sum = _mm256_add_ps ( sum , _mm256_mul_ps ( diff, diff ) ) ;
37+ sum = _mm256_fmadd_ps ( diff , diff, sum ) ;
3838
3939 vx = _mm256_loadu_ps ( lhs_ptr) ;
4040 vy = _mm256_loadu_ps ( rhs_ptr) ;
4141 lhs_ptr = lhs_ptr. add ( 8 ) ;
4242 rhs_ptr = rhs_ptr. add ( 8 ) ;
4343 diff = _mm256_sub_ps ( vx, vy) ;
44- sum = _mm256_add_ps ( sum , _mm256_mul_ps ( diff, diff ) ) ;
44+ sum = _mm256_fmadd_ps ( diff , diff, sum ) ;
4545 }
4646
4747 for _ in 0 ..rest_num / 8 {
@@ -50,7 +50,7 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
5050 lhs_ptr = lhs_ptr. add ( 8 ) ;
5151 rhs_ptr = rhs_ptr. add ( 8 ) ;
5252 diff = _mm256_sub_ps ( vx, vy) ;
53- sum = _mm256_add_ps ( sum , _mm256_mul_ps ( diff, diff ) ) ;
53+ sum = _mm256_fmadd_ps ( diff , diff, sum ) ;
5454 }
5555 _mm256_store_ps ( temp_block_ptr, sum) ;
5656
@@ -63,7 +63,7 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
6363 + temp_block[ 6 ]
6464 + temp_block[ 7 ] ;
6565
66- for _ in 0 ..rest_num % 8 {
66+ for _ in 0 ..rest_num {
6767 let residual = * lhs_ptr - * rhs_ptr;
6868 res += residual * residual;
6969 lhs_ptr = lhs_ptr. add ( 1 ) ;
@@ -78,8 +78,8 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
7878///
7979/// This function is marked unsafe because it requires the AVX intrinsics.
8080#[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
81- #[ target_feature( enable = "avx2" ) ]
82- pub unsafe fn vector_binarize_query_avx2 ( vec : & DVectorView < u8 > , binary : & mut [ u64 ] ) {
81+ #[ target_feature( enable = "avx, avx2" ) ]
82+ pub unsafe fn vector_binarize_query ( vec : & DVectorView < u8 > , binary : & mut [ u64 ] ) {
8383 use std:: arch:: x86_64:: * ;
8484
8585 let length = vec. len ( ) ;
@@ -107,7 +107,7 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u6
107107/// This function is marked unsafe because it requires the AVX intrinsics.
108108#[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
109109#[ target_feature( enable = "avx" ) ]
110- pub unsafe fn min_max_residual_avx (
110+ pub unsafe fn min_max_residual (
111111 res : & mut DVector < f32 > ,
112112 x : & DVectorView < f32 > ,
113113 y : & DVectorView < f32 > ,
@@ -174,8 +174,8 @@ pub unsafe fn min_max_residual_avx(
174174///
175175/// This function is marked unsafe because it requires the AVX intrinsics.
176176#[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
177- #[ target_feature( enable = "avx2" ) ]
178- pub unsafe fn scalar_quantize_avx2 (
177+ #[ target_feature( enable = "avx, avx2" ) ]
178+ pub unsafe fn scalar_quantize (
179179 quantized : & mut DVector < u8 > ,
180180 vec : & DVectorView < f32 > ,
181181 lower_bound : f32 ,
@@ -233,3 +233,54 @@ pub unsafe fn scalar_quantize_avx2(
233233
234234 sum
235235}
236+
237+ /// Compute the dot product of two vectors.
238+ ///
239+ /// # Safety
240+ ///
241+ /// This function is marked unsafe because it requires the AVX intrinsics.
242+ #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
243+ #[ target_feature( enable = "fma,avx,avx2" ) ]
244+ pub unsafe fn vector_dot_product ( lhs : & DVectorView < f32 > , rhs : & DVectorView < f32 > ) -> f32 {
245+ use std:: arch:: x86_64:: * ;
246+
247+ let mut lhs_ptr = lhs. as_ptr ( ) ;
248+ let mut rhs_ptr = rhs. as_ptr ( ) ;
249+ let length = lhs. len ( ) ;
250+ let rest = length & 0b111 ;
251+ let ( mut vx, mut vy) : ( __m256 , __m256 ) ;
252+ let mut accumulate = _mm256_setzero_ps ( ) ;
253+ let mut f32x8 = [ 0.0f32 ; 8 ] ;
254+
255+ for _ in 0 ..( length / 16 ) {
256+ vx = _mm256_loadu_ps ( lhs_ptr) ;
257+ vy = _mm256_loadu_ps ( rhs_ptr) ;
258+ accumulate = _mm256_fmadd_ps ( vx, vy, accumulate) ;
259+ lhs_ptr = lhs_ptr. add ( 8 ) ;
260+ rhs_ptr = rhs_ptr. add ( 8 ) ;
261+
262+ vx = _mm256_loadu_ps ( lhs_ptr) ;
263+ vy = _mm256_loadu_ps ( rhs_ptr) ;
264+ accumulate = _mm256_fmadd_ps ( vx, vy, accumulate) ;
265+ lhs_ptr = lhs_ptr. add ( 8 ) ;
266+ rhs_ptr = rhs_ptr. add ( 8 ) ;
267+ }
268+ for _ in 0 ..( ( length & 0b1111 ) / 8 ) {
269+ vx = _mm256_loadu_ps ( lhs_ptr) ;
270+ vy = _mm256_loadu_ps ( rhs_ptr) ;
271+ accumulate = _mm256_fmadd_ps ( vx, vy, accumulate) ;
272+ lhs_ptr = lhs_ptr. add ( 8 ) ;
273+ rhs_ptr = rhs_ptr. add ( 8 ) ;
274+ }
275+ _mm256_storeu_ps ( f32x8. as_mut_ptr ( ) , accumulate) ;
276+ let mut sum =
277+ f32x8[ 0 ] + f32x8[ 1 ] + f32x8[ 2 ] + f32x8[ 3 ] + f32x8[ 4 ] + f32x8[ 5 ] + f32x8[ 6 ] + f32x8[ 7 ] ;
278+
279+ for _ in 0 ..rest {
280+ sum += * lhs_ptr * * rhs_ptr;
281+ lhs_ptr = lhs_ptr. add ( 1 ) ;
282+ rhs_ptr = rhs_ptr. add ( 1 ) ;
283+ }
284+
285+ sum
286+ }
0 commit comments