From 470915bdd708a57bbbbdc5cf06b1f29506d83ba2 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 17:31:43 +0800 Subject: [PATCH] test: bvector tests Signed-off-by: usamoi --- crates/base/src/vector/bvecf32.rs | 146 ++++++++++++++++++++++++++++++ crates/base/src/vector/veci8.rs | 3 +- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index 467aae270..ec11bf9b5 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -13,6 +13,35 @@ pub struct BVecf32Owned { } impl BVecf32Owned { + #[inline(always)] + pub fn new(dims: u16, data: Vec) -> Self { + Self::new_checked(dims, data).unwrap() + } + #[inline(always)] + pub fn new_checked(dims: u16, data: Vec) -> Option { + if dims == 0 { + return None; + } + if data.len() != (dims as usize).div_ceil(BVEC_WIDTH) { + return None; + } + if dims % BVEC_WIDTH as u16 != 0 && data[data.len() - 1] >> (dims % BVEC_WIDTH as u16) != 0 + { + return None; + } + unsafe { Some(Self::new_unchecked(dims, data)) } + } + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `data` must be of the correct length. + /// * The padding bits must be zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u16, data: Vec) -> Self { + Self { dims, data } + } + + #[inline(always)] pub fn new_zeroed(dims: u16) -> Self { assert!((1..=65535).contains(&dims)); let size = (dims as usize).div_ceil(BVEC_WIDTH); @@ -22,6 +51,7 @@ impl BVecf32Owned { } } + #[inline(always)] pub fn set(&mut self, index: usize, value: bool) { assert!(index < self.dims as usize); if value { @@ -206,6 +236,35 @@ unsafe fn cosine_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrow } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn cosine(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -259,6 +318,35 @@ unsafe fn dot_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed< } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn dot(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -305,6 +393,35 @@ unsafe fn sl2_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed< } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn sl2(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -355,6 +472,35 @@ unsafe fn jaccard_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borro } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn jaccard_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { jaccard_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn jaccard(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 84e580c9f..94b065167 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -366,7 +366,8 @@ unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_internal_v4_avx512vnni_test() { - const EPSILON: F32 = F32(4.0); + // A large epsilon is set for loss of precision caused by saturation arithmetic + const EPSILON: F32 = F32(512.0); detect::init(); if !detect::v4_avx512vnni::detect() { println!("test {} ... skipped (v4_avx512vnni)", module_path!());