|
| 1 | +//! Compute the distance between two vectors. |
| 2 | +
|
| 3 | +use nalgebra::{DVector, DVectorView}; |
| 4 | + |
| 5 | +/// Compute the squared Euclidean distance between two vectors. |
| 6 | +/// Code refer to https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h |
| 7 | +/// |
| 8 | +/// # Safety |
| 9 | +/// |
| 10 | +/// This function is marked unsafe because it requires the AVX intrinsics. |
| 11 | +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] |
| 12 | +#[target_feature(enable = "avx2")] |
| 13 | +pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVector<f32>) -> f32 { |
| 14 | + #[cfg(target_arch = "x86")] |
| 15 | + use std::arch::x86::*; |
| 16 | + #[cfg(target_arch = "x86_64")] |
| 17 | + use std::arch::x86_64::*; |
| 18 | + |
| 19 | + assert_eq!(lhs.len(), rhs.len()); |
| 20 | + let mut lhs_ptr = lhs.as_ptr(); |
| 21 | + let mut rhs_ptr = rhs.as_ptr(); |
| 22 | + let block_16_num = lhs.len() >> 4; |
| 23 | + let rest_num = lhs.len() & 0b1111; |
| 24 | + let mut temp_block = [0.0f32; 8]; |
| 25 | + let temp_block_ptr = temp_block.as_mut_ptr(); |
| 26 | + let (mut diff, mut vx, mut vy): (__m256, __m256, __m256); |
| 27 | + let mut sum = _mm256_setzero_ps(); |
| 28 | + |
| 29 | + for _ in 0..block_16_num { |
| 30 | + vx = _mm256_loadu_ps(lhs_ptr); |
| 31 | + vy = _mm256_loadu_ps(rhs_ptr); |
| 32 | + lhs_ptr = lhs_ptr.add(8); |
| 33 | + rhs_ptr = rhs_ptr.add(8); |
| 34 | + diff = _mm256_sub_ps(vx, vy); |
| 35 | + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); |
| 36 | + |
| 37 | + vx = _mm256_loadu_ps(lhs_ptr); |
| 38 | + vy = _mm256_loadu_ps(rhs_ptr); |
| 39 | + lhs_ptr = lhs_ptr.add(8); |
| 40 | + rhs_ptr = rhs_ptr.add(8); |
| 41 | + diff = _mm256_sub_ps(vx, vy); |
| 42 | + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); |
| 43 | + } |
| 44 | + |
| 45 | + for _ in 0..rest_num / 8 { |
| 46 | + vx = _mm256_loadu_ps(lhs_ptr); |
| 47 | + vy = _mm256_loadu_ps(rhs_ptr); |
| 48 | + lhs_ptr = lhs_ptr.add(8); |
| 49 | + rhs_ptr = rhs_ptr.add(8); |
| 50 | + diff = _mm256_sub_ps(vx, vy); |
| 51 | + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); |
| 52 | + } |
| 53 | + _mm256_store_ps(temp_block_ptr, sum); |
| 54 | + |
| 55 | + let mut res = temp_block[0] |
| 56 | + + temp_block[1] |
| 57 | + + temp_block[2] |
| 58 | + + temp_block[3] |
| 59 | + + temp_block[4] |
| 60 | + + temp_block[5] |
| 61 | + + temp_block[6] |
| 62 | + + temp_block[7]; |
| 63 | + |
| 64 | + for _ in 0..rest_num % 8 { |
| 65 | + let residual = *lhs_ptr - *rhs_ptr; |
| 66 | + res += residual * residual; |
| 67 | + lhs_ptr = lhs_ptr.add(1); |
| 68 | + rhs_ptr = rhs_ptr.add(1); |
| 69 | + } |
| 70 | + res |
| 71 | +} |
0 commit comments