Skip to content

Commit

Permalink
test: add dot_internal_v4_avx512vnni_test
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Mar 26, 2024
1 parent f096c96 commit c6001e7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 16 deletions.
2 changes: 1 addition & 1 deletion crates/base/src/operator/veci8_cos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ impl Operator for Veci8Cos {
const DISTANCE_KIND: DistanceKind = DistanceKind::Cos;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
F32(1.0) - veci8::cosine_distance(&lhs, &rhs)
F32(1.0) - veci8::cosine(&lhs, &rhs)
}
}
2 changes: 1 addition & 1 deletion crates/base/src/operator/veci8_dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ impl Operator for Veci8Dot {
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
veci8::dot_distance(&lhs, &rhs) * (-1.0)
veci8::dot(&lhs, &rhs) * (-1.0)
}
}
2 changes: 1 addition & 1 deletion crates/base/src/operator/veci8_l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ impl Operator for Veci8L2 {
const DISTANCE_KIND: DistanceKind = DistanceKind::Dot;

fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 {
veci8::l2_distance(&lhs, &rhs)
veci8::sl2(&lhs, &rhs)
}
}
45 changes: 32 additions & 13 deletions crates/base/src/vector/veci8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) {
#[cfg(any(target_arch = "x86_64", doc))]
#[doc(cfg(target_arch = "x86_64"))]
#[detect::target_cpu(enable = "v4_avx512vnni")]
unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
use std::arch::x86_64::*;
assert_eq!(x.len(), y.len());
let mut sum = 0;
Expand Down Expand Up @@ -363,8 +363,27 @@ unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
F32(sum as f32)
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_internal_v4_avx512vnni_test() {
const EPSILON: F32 = F32(4.0);
detect::init();
if !detect::v4_avx512vnni::detect() {
println!("test {} ... skipped (v4_avx512vnni)", module_path!());
return;
}
let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random()));
let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random()));
let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) };
let fallback = unsafe { dot_internal_fallback(&lhs, &rhs) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4_avx512vnni = import, v4, v3, v2, neon, fallback = export)]
pub fn dot(x: &[I8], y: &[I8]) -> F32 {
fn dot_internal(x: &[I8], y: &[I8]) -> F32 {
// i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow.
let mut sum = 0;
assert_eq!(x.len(), y.len());
Expand All @@ -376,26 +395,26 @@ pub fn dot(x: &[I8], y: &[I8]) -> F32 {
F32(sum as f32)
}

pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
pub fn dot(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
// (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y)
// = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y
// Sum(dot(origin_x[i] , origin_y[i])) = alpha_x * alpha_y * Sum(dot(x[i], y[i])) + offset_y * Sum(alpha_x * x[i]) + offset_x * Sum(alpha_y * y[i]) + offset_x * offset_y * dims
let dot_xy = dot(x.data(), y.data());
let dot_xy = dot_internal(x.data(), y.data());
x.alpha() * y.alpha() * dot_xy
+ x.offset() * y.sum()
+ y.offset() * x.sum()
+ x.offset() * y.offset() * F32(x.dims() as f32)
}

pub fn l2_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
pub fn sl2(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
// Sum(l2(origin_x[i] - origin_y[i])) = sum(x[i] ^ 2 - 2 * x[i] * y[i] + y[i] ^ 2)
// = dot(x, x) - 2 * dot(x, y) + dot(y, y)
x.l2_norm() * x.l2_norm() - F32(2.0) * dot_distance(x, y) + y.l2_norm() * y.l2_norm()
x.l2_norm() * x.l2_norm() - F32(2.0) * dot(x, y) + y.l2_norm() * y.l2_norm()
}

pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
pub fn cosine(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 {
// dot(x, y) / (l2(x) * l2(y))
let dot_xy = dot_distance(x, y);
let dot_xy = dot(x, y);
let l2_x = x.l2_norm();
let l2_y = y.l2_norm();
dot_xy / (l2_x * l2_y)
Expand Down Expand Up @@ -462,7 +481,7 @@ mod tests {
let ref_x = x_owned.for_borrow();
let y_owned = vec_to_owned(y);
let ref_y = y_owned.for_borrow();
let result = dot_distance(&ref_x, &ref_y);
let result = dot(&ref_x, &ref_y);
assert!((result.0 - 10.0).abs() < 0.1);
}

Expand All @@ -474,7 +493,7 @@ mod tests {
let ref_x = x_owned.for_borrow();
let y_owned = vec_to_owned(y);
let ref_y = y_owned.for_borrow();
let result = cosine_distance(&ref_x, &ref_y);
let result = cosine(&ref_x, &ref_y);
assert!((result.0 - (10.0 / 14.0)).abs() < 0.1);
// test cos_i8 using random generated data, check the precision
let x = new_random_vec_f32(1000);
Expand All @@ -487,7 +506,7 @@ mod tests {
let ref_x = x_owned.for_borrow();
let y_owned = vec_to_owned(y);
let ref_y = y_owned.for_borrow();
let result = cosine_distance(&ref_x, &ref_y);
let result = cosine(&ref_x, &ref_y);
assert!(
result_expected < 0.01
|| (result.0 - result_expected).abs() < 0.01
Expand All @@ -503,7 +522,7 @@ mod tests {
let ref_x = x_owned.for_borrow();
let y_owned = vec_to_owned(y);
let ref_y = y_owned.for_borrow();
let result = l2_distance(&ref_x, &ref_y);
let result = sl2(&ref_x, &ref_y);
assert!((result.0 - 8.0).abs() < 0.1);
// test l2_i8 using random generated data, check the precision
let x = new_random_vec_f32(1000);
Expand All @@ -518,7 +537,7 @@ mod tests {
let ref_x = x_owned.for_borrow();
let y_owned = vec_to_owned(y);
let ref_y = y_owned.for_borrow();
let result = l2_distance(&ref_x, &ref_y);
let result = sl2(&ref_x, &ref_y);
assert!(
result_expected < 1.0 || (result.0 - result_expected).abs() / result_expected < 0.05
);
Expand Down

0 comments on commit c6001e7

Please sign in to comment.