From c6001e705aad9b146e19104cfefe9307c380b83c Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 17:03:50 +0800 Subject: [PATCH] test: add dot_internal_v4_avx512vnni_test Signed-off-by: usamoi --- crates/base/src/operator/veci8_cos.rs | 2 +- crates/base/src/operator/veci8_dot.rs | 2 +- crates/base/src/operator/veci8_l2.rs | 2 +- crates/base/src/vector/veci8.rs | 45 +++++++++++++++++++-------- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/crates/base/src/operator/veci8_cos.rs b/crates/base/src/operator/veci8_cos.rs index a882749a3..2cccf7107 100644 --- a/crates/base/src/operator/veci8_cos.rs +++ b/crates/base/src/operator/veci8_cos.rs @@ -12,6 +12,6 @@ impl Operator for Veci8Cos { const DISTANCE_KIND: DistanceKind = DistanceKind::Cos; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - F32(1.0) - veci8::cosine_distance(&lhs, &rhs) + F32(1.0) - veci8::cosine(&lhs, &rhs) } } diff --git a/crates/base/src/operator/veci8_dot.rs b/crates/base/src/operator/veci8_dot.rs index b066d7749..f6c5dd1df 100644 --- a/crates/base/src/operator/veci8_dot.rs +++ b/crates/base/src/operator/veci8_dot.rs @@ -12,6 +12,6 @@ impl Operator for Veci8Dot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - veci8::dot_distance(&lhs, &rhs) * (-1.0) + veci8::dot(&lhs, &rhs) * (-1.0) } } diff --git a/crates/base/src/operator/veci8_l2.rs b/crates/base/src/operator/veci8_l2.rs index bde92d8ee..f856b0be2 100644 --- a/crates/base/src/operator/veci8_l2.rs +++ b/crates/base/src/operator/veci8_l2.rs @@ -12,6 +12,6 @@ impl Operator for Veci8L2 { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - veci8::l2_distance(&lhs, &rhs) + veci8::sl2(&lhs, &rhs) } } diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 871c76af4..84e580c9f 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -325,7 +325,7 @@ pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) { #[cfg(any(target_arch = "x86_64", doc))] #[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4_avx512vnni")] -unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { +unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; assert_eq!(x.len(), y.len()); let mut sum = 0; @@ -363,8 +363,27 @@ unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { F32(sum as f32) } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_internal_v4_avx512vnni_test() { + const EPSILON: F32 = F32(4.0); + detect::init(); + if !detect::v4_avx512vnni::detect() { + println!("test {} ... skipped (v4_avx512vnni)", module_path!()); + return; + } + let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) }; + let fallback = unsafe { dot_internal_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vnni = import, v4, v3, v2, neon, fallback = export)] -pub fn dot(x: &[I8], y: &[I8]) -> F32 { +fn dot_internal(x: &[I8], y: &[I8]) -> F32 { // i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow. let mut sum = 0; assert_eq!(x.len(), y.len()); @@ -376,26 +395,26 @@ pub fn dot(x: &[I8], y: &[I8]) -> F32 { F32(sum as f32) } -pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn dot(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y) // = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y // Sum(dot(origin_x[i] , origin_y[i])) = alpha_x * alpha_y * Sum(dot(x[i], y[i])) + offset_y * Sum(alpha_x * x[i]) + offset_x * Sum(alpha_y * y[i]) + offset_x * offset_y * dims - let dot_xy = dot(x.data(), y.data()); + let dot_xy = dot_internal(x.data(), y.data()); x.alpha() * y.alpha() * dot_xy + x.offset() * y.sum() + y.offset() * x.sum() + x.offset() * y.offset() * F32(x.dims() as f32) } -pub fn l2_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn sl2(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // Sum(l2(origin_x[i] - origin_y[i])) = sum(x[i] ^ 2 - 2 * x[i] * y[i] + y[i] ^ 2) // = dot(x, x) - 2 * dot(x, y) + dot(y, y) - x.l2_norm() * x.l2_norm() - F32(2.0) * dot_distance(x, y) + y.l2_norm() * y.l2_norm() + x.l2_norm() * x.l2_norm() - F32(2.0) * dot(x, y) + y.l2_norm() * y.l2_norm() } -pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn cosine(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // dot(x, y) / (l2(x) * l2(y)) - let dot_xy = dot_distance(x, y); + let dot_xy = dot(x, y); let l2_x = x.l2_norm(); let l2_y = y.l2_norm(); dot_xy / (l2_x * l2_y) @@ -462,7 +481,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = dot_distance(&ref_x, &ref_y); + let result = dot(&ref_x, &ref_y); assert!((result.0 - 10.0).abs() < 0.1); } @@ -474,7 +493,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = cosine_distance(&ref_x, &ref_y); + let result = cosine(&ref_x, &ref_y); assert!((result.0 - (10.0 / 14.0)).abs() < 0.1); // test cos_i8 using random generated data, check the precision let x = new_random_vec_f32(1000); @@ -487,7 +506,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = cosine_distance(&ref_x, &ref_y); + let result = cosine(&ref_x, &ref_y); assert!( result_expected < 0.01 || (result.0 - result_expected).abs() < 0.01 @@ -503,7 +522,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = l2_distance(&ref_x, &ref_y); + let result = sl2(&ref_x, &ref_y); assert!((result.0 - 8.0).abs() < 0.1); // test l2_i8 using random generated data, check the precision let x = new_random_vec_f32(1000); @@ -518,7 +537,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = l2_distance(&ref_x, &ref_y); + let result = sl2(&ref_x, &ref_y); assert!( result_expected < 1.0 || (result.0 - result_expected).abs() / result_expected < 0.05 );