Skip to content

Commit

Permalink
test: use detect::multiversion
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Mar 25, 2024
1 parent c930af8 commit 539ab2b
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 219 deletions.
139 changes: 68 additions & 71 deletions crates/base/src/vector/svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn cosine_v4_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
let lhs = random_svector(300);
let rhs = random_svector(350);
let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
pub fn cosine(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
assert_eq!(lhs.dims(), rhs.dims());
Expand Down Expand Up @@ -407,6 +426,25 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn dot_v4_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
let lhs = random_svector(300);
let rhs = random_svector(350);
let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
pub fn dot(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
assert_eq!(lhs.dims(), rhs.dims());
Expand Down Expand Up @@ -551,6 +589,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
}
}

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn sl2_v4_test() {
const EPSILON: F32 = F32(1e-5);
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
return;
}
let lhs = random_svector(300);
let rhs = random_svector(350);
let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) };
let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) };
assert!(
(specialized - fallback).abs() < EPSILON,
"specialized = {specialized}, fallback = {fallback}."
);
}

#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
pub fn sl2(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
assert_eq!(lhs.dims(), rhs.dims());
Expand Down Expand Up @@ -687,75 +744,15 @@ unsafe fn emulate_mm512_2intersect_epi32(
}
}

#[cfg(target_arch = "x86_64")]
#[cfg(test)]
mod tests {
use super::*;

const LHS_SIZE: usize = 300;
const RHS_SIZE: usize = 350;
const EPS: F32 = F32(1e-5);

pub fn random_svector(len: usize) -> SVecf32Owned {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut indexes: Vec<u32> = (0..len).map(|_| rng.gen_range(0..30000)).collect();
indexes.sort_unstable();
indexes.dedup();
let values: Vec<F32> = (0..indexes.len())
.map(|_| F32(rng.gen_range(-1.0..1.0)))
.collect();
SVecf32Owned::new(30000, indexes, values)
}

#[test]
fn test_cosine_svector() {
let x = random_svector(LHS_SIZE);
let y = random_svector(RHS_SIZE);
let cosine_fallback = unsafe { cosine_fallback(x.for_borrow(), y.for_borrow()) };
#[cfg(target_arch = "x86_64")]
if detect::v4::detect() {
let cosine_v4 = unsafe { cosine_v4(x.for_borrow(), y.for_borrow()) };
assert!(
cosine_fallback - cosine_v4 < EPS,
"cosine_fallback: {}, cosine_v4: {}",
cosine_fallback,
cosine_v4
);
}
}

#[test]
fn test_dot_svector() {
let x = random_svector(LHS_SIZE);
let y = random_svector(RHS_SIZE);
let dot_fallback = unsafe { dot_fallback(x.for_borrow(), y.for_borrow()) };
#[cfg(target_arch = "x86_64")]
if detect::v4::detect() {
let dot_v4 = unsafe { dot_v4(x.for_borrow(), y.for_borrow()) };
assert!(
dot_fallback - dot_v4 < EPS,
"dot_fallback: {}, dot_v4: {}",
dot_fallback,
dot_v4
);
}
}

#[test]
fn test_sl2_svector() {
let x = random_svector(LHS_SIZE);
let y = random_svector(RHS_SIZE);
let sl2_fallback = unsafe { sl2_fallback(x.for_borrow(), y.for_borrow()) };
#[cfg(target_arch = "x86_64")]
if detect::v4::detect() {
let sl2_v4 = unsafe { sl2_v4(x.for_borrow(), y.for_borrow()) };
assert!(
sl2_fallback - sl2_v4 < EPS,
"sl2_fallback: {}, sl2_v4: {}",
sl2_fallback,
sl2_v4
);
}
}
#[cfg(all(target_arch = "x86_64", test))]
fn random_svector(len: usize) -> SVecf32Owned {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut indexes: Vec<u32> = (0..len).map(|_| rng.gen_range(0..30000)).collect();
indexes.sort_unstable();
indexes.dedup();
let values: Vec<F32> = (0..indexes.len())
.map(|_| F32(rng.gen_range(-1.0..1.0)))
.collect();
SVecf32Owned::new(30000, indexes, values)
}
Loading

0 comments on commit 539ab2b

Please sign in to comment.