Skip to content

Commit

Permalink
test: bvector tests
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Mar 26, 2024
1 parent c6001e7 commit 470915b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 1 deletion.
146 changes: 146 additions & 0 deletions crates/base/src/vector/bvecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@ pub struct BVecf32Owned {
}

impl BVecf32Owned {
#[inline(always)]
pub fn new(dims: u16, data: Vec<usize>) -> Self {
Self::new_checked(dims, data).unwrap()
}
#[inline(always)]
pub fn new_checked(dims: u16, data: Vec<usize>) -> Option<Self> {
if dims == 0 {
return None;
}
if data.len() != (dims as usize).div_ceil(BVEC_WIDTH) {
return None;
}
if dims % BVEC_WIDTH as u16 != 0 && data[data.len() - 1] >> (dims % BVEC_WIDTH as u16) != 0
{
return None;
}
unsafe { Some(Self::new_unchecked(dims, data)) }
}
/// # Safety
///
/// * `dims` must be in `1..=65535`.
/// * `data` must be of the correct length.
/// * The padding bits must be zero.
#[inline(always)]
pub unsafe fn new_unchecked(dims: u16, data: Vec<usize>) -> Self {
Self { dims, data }
}

#[inline(always)]
pub fn new_zeroed(dims: u16) -> Self {
assert!((1..=65535).contains(&dims));
let size = (dims as usize).div_ceil(BVEC_WIDTH);
Expand All @@ -22,6 +51,7 @@ impl BVecf32Owned {
}
}

#[inline(always)]
pub fn set(&mut self, index: usize, value: bool) {
assert!(index < self.dims as usize);
if value {
Expand Down Expand Up @@ -206,6 +236,35 @@ unsafe fn cosine_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrow
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn cosine_v4_avx512vpopcntdq_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4_avx512vpopcntdq::detect() {
println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!());
return;
}
let lhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let rhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)]
pub fn cosine(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 {
let lhs = lhs.data();
Expand Down Expand Up @@ -259,6 +318,35 @@ unsafe fn dot_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_v4_avx512vpopcntdq_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4_avx512vpopcntdq::detect() {
println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!());
return;
}
let lhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let rhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)]
pub fn dot(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 {
let lhs = lhs.data();
Expand Down Expand Up @@ -305,6 +393,35 @@ unsafe fn sl2_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn sl2_v4_avx512vpopcntdq_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4_avx512vpopcntdq::detect() {
println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!());
return;
}
let lhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let rhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)]
pub fn sl2(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 {
let lhs = lhs.data();
Expand Down Expand Up @@ -355,6 +472,35 @@ unsafe fn jaccard_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borro
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn jaccard_v4_avx512vpopcntdq_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4_avx512vpopcntdq::detect() {
println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!());
return;
}
let lhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let rhs = {
let mut x = vec![0; 126];
x.fill_with(|| rand::random());
x[125] &= 1;
BVecf32Owned::new(8001, x)
};
let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { jaccard_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)]
pub fn jaccard(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 {
let lhs = lhs.data();
Expand Down
3 changes: 2 additions & 1 deletion crates/base/src/vector/veci8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 {
#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_internal_v4_avx512vnni_test() {
const EPSILON: F32 = F32(4.0);
// A large epsilon is set for loss of precision caused by saturation arithmetic
const EPSILON: F32 = F32(512.0);
detect::init();
if !detect::v4_avx512vnni::detect() {
println!("test {} ... skipped (v4_avx512vnni)", module_path!());
Expand Down

0 comments on commit 470915b

Please sign in to comment.