Skip to content

Commit 539ab2b

Browse files
committed
test: use detect::multiversion
Signed-off-by: usamoi <[email protected]>
1 parent c930af8 commit 539ab2b

File tree

4 files changed

+262
-219
lines changed

4 files changed

+262
-219
lines changed

crates/base/src/vector/svecf32.rs

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
291291
}
292292
}
293293

294+
#[cfg(all(target_arch = "x86_64", test))]
295+
#[test]
296+
fn cosine_v4_test() {
297+
const EPSILON: F32 = F32(1e-5);
298+
detect::init();
299+
if !detect::v4::detect() {
300+
println!("test {} ... skipped (v4)", module_path!());
301+
return;
302+
}
303+
let lhs = random_svector(300);
304+
let rhs = random_svector(350);
305+
let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) };
306+
let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) };
307+
assert!(
308+
(specialized - fallback).abs() < EPSILON,
309+
"specialized = {specialized}, fallback = {fallback}."
310+
);
311+
}
312+
294313
#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
295314
pub fn cosine(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
296315
assert_eq!(lhs.dims(), rhs.dims());
@@ -407,6 +426,25 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
407426
}
408427
}
409428

429+
#[cfg(all(target_arch = "x86_64", test))]
430+
#[test]
431+
fn dot_v4_test() {
432+
const EPSILON: F32 = F32(1e-5);
433+
detect::init();
434+
if !detect::v4::detect() {
435+
println!("test {} ... skipped (v4)", module_path!());
436+
return;
437+
}
438+
let lhs = random_svector(300);
439+
let rhs = random_svector(350);
440+
let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) };
441+
let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) };
442+
assert!(
443+
(specialized - fallback).abs() < EPSILON,
444+
"specialized = {specialized}, fallback = {fallback}."
445+
);
446+
}
447+
410448
#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
411449
pub fn dot(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
412450
assert_eq!(lhs.dims(), rhs.dims());
@@ -551,6 +589,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
551589
}
552590
}
553591

592+
#[cfg(all(target_arch = "x86_64", test))]
593+
#[test]
594+
fn sl2_v4_test() {
595+
const EPSILON: F32 = F32(1e-5);
596+
detect::init();
597+
if !detect::v4::detect() {
598+
println!("test {} ... skipped (v4)", module_path!());
599+
return;
600+
}
601+
let lhs = random_svector(300);
602+
let rhs = random_svector(350);
603+
let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) };
604+
let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) };
605+
assert!(
606+
(specialized - fallback).abs() < EPSILON,
607+
"specialized = {specialized}, fallback = {fallback}."
608+
);
609+
}
610+
554611
#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)]
555612
pub fn sl2(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
556613
assert_eq!(lhs.dims(), rhs.dims());
@@ -687,75 +744,15 @@ unsafe fn emulate_mm512_2intersect_epi32(
687744
}
688745
}
689746

690-
#[cfg(target_arch = "x86_64")]
691-
#[cfg(test)]
692-
mod tests {
693-
use super::*;
694-
695-
const LHS_SIZE: usize = 300;
696-
const RHS_SIZE: usize = 350;
697-
const EPS: F32 = F32(1e-5);
698-
699-
pub fn random_svector(len: usize) -> SVecf32Owned {
700-
use rand::Rng;
701-
let mut rng = rand::thread_rng();
702-
let mut indexes: Vec<u32> = (0..len).map(|_| rng.gen_range(0..30000)).collect();
703-
indexes.sort_unstable();
704-
indexes.dedup();
705-
let values: Vec<F32> = (0..indexes.len())
706-
.map(|_| F32(rng.gen_range(-1.0..1.0)))
707-
.collect();
708-
SVecf32Owned::new(30000, indexes, values)
709-
}
710-
711-
#[test]
712-
fn test_cosine_svector() {
713-
let x = random_svector(LHS_SIZE);
714-
let y = random_svector(RHS_SIZE);
715-
let cosine_fallback = unsafe { cosine_fallback(x.for_borrow(), y.for_borrow()) };
716-
#[cfg(target_arch = "x86_64")]
717-
if detect::v4::detect() {
718-
let cosine_v4 = unsafe { cosine_v4(x.for_borrow(), y.for_borrow()) };
719-
assert!(
720-
cosine_fallback - cosine_v4 < EPS,
721-
"cosine_fallback: {}, cosine_v4: {}",
722-
cosine_fallback,
723-
cosine_v4
724-
);
725-
}
726-
}
727-
728-
#[test]
729-
fn test_dot_svector() {
730-
let x = random_svector(LHS_SIZE);
731-
let y = random_svector(RHS_SIZE);
732-
let dot_fallback = unsafe { dot_fallback(x.for_borrow(), y.for_borrow()) };
733-
#[cfg(target_arch = "x86_64")]
734-
if detect::v4::detect() {
735-
let dot_v4 = unsafe { dot_v4(x.for_borrow(), y.for_borrow()) };
736-
assert!(
737-
dot_fallback - dot_v4 < EPS,
738-
"dot_fallback: {}, dot_v4: {}",
739-
dot_fallback,
740-
dot_v4
741-
);
742-
}
743-
}
744-
745-
#[test]
746-
fn test_sl2_svector() {
747-
let x = random_svector(LHS_SIZE);
748-
let y = random_svector(RHS_SIZE);
749-
let sl2_fallback = unsafe { sl2_fallback(x.for_borrow(), y.for_borrow()) };
750-
#[cfg(target_arch = "x86_64")]
751-
if detect::v4::detect() {
752-
let sl2_v4 = unsafe { sl2_v4(x.for_borrow(), y.for_borrow()) };
753-
assert!(
754-
sl2_fallback - sl2_v4 < EPS,
755-
"sl2_fallback: {}, sl2_v4: {}",
756-
sl2_fallback,
757-
sl2_v4
758-
);
759-
}
760-
}
747+
#[cfg(all(target_arch = "x86_64", test))]
748+
fn random_svector(len: usize) -> SVecf32Owned {
749+
use rand::Rng;
750+
let mut rng = rand::thread_rng();
751+
let mut indexes: Vec<u32> = (0..len).map(|_| rng.gen_range(0..30000)).collect();
752+
indexes.sort_unstable();
753+
indexes.dedup();
754+
let values: Vec<F32> = (0..indexes.len())
755+
.map(|_| F32(rng.gen_range(-1.0..1.0)))
756+
.collect();
757+
SVecf32Owned::new(30000, indexes, values)
761758
}

0 commit comments

Comments
 (0)