@@ -291,6 +291,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
291
291
}
292
292
}
293
293
294
+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
295
+ #[ test]
296
+ fn cosine_v4_test ( ) {
297
+ const EPSILON : F32 = F32 ( 1e-5 ) ;
298
+ detect:: init ( ) ;
299
+ if !detect:: v4:: detect ( ) {
300
+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
301
+ return ;
302
+ }
303
+ let lhs = random_svector ( 300 ) ;
304
+ let rhs = random_svector ( 350 ) ;
305
+ let specialized = unsafe { cosine_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
306
+ let fallback = unsafe { cosine_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
307
+ assert ! (
308
+ ( specialized - fallback) . abs( ) < EPSILON ,
309
+ "specialized = {specialized}, fallback = {fallback}."
310
+ ) ;
311
+ }
312
+
294
313
#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
295
314
pub fn cosine ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
296
315
assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -407,6 +426,25 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
407
426
}
408
427
}
409
428
429
+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
430
+ #[ test]
431
+ fn dot_v4_test ( ) {
432
+ const EPSILON : F32 = F32 ( 1e-5 ) ;
433
+ detect:: init ( ) ;
434
+ if !detect:: v4:: detect ( ) {
435
+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
436
+ return ;
437
+ }
438
+ let lhs = random_svector ( 300 ) ;
439
+ let rhs = random_svector ( 350 ) ;
440
+ let specialized = unsafe { dot_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
441
+ let fallback = unsafe { dot_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
442
+ assert ! (
443
+ ( specialized - fallback) . abs( ) < EPSILON ,
444
+ "specialized = {specialized}, fallback = {fallback}."
445
+ ) ;
446
+ }
447
+
410
448
#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
411
449
pub fn dot ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
412
450
assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -551,6 +589,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
551
589
}
552
590
}
553
591
592
+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
593
+ #[ test]
594
+ fn sl2_v4_test ( ) {
595
+ const EPSILON : F32 = F32 ( 1e-5 ) ;
596
+ detect:: init ( ) ;
597
+ if !detect:: v4:: detect ( ) {
598
+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
599
+ return ;
600
+ }
601
+ let lhs = random_svector ( 300 ) ;
602
+ let rhs = random_svector ( 350 ) ;
603
+ let specialized = unsafe { sl2_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
604
+ let fallback = unsafe { sl2_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
605
+ assert ! (
606
+ ( specialized - fallback) . abs( ) < EPSILON ,
607
+ "specialized = {specialized}, fallback = {fallback}."
608
+ ) ;
609
+ }
610
+
554
611
#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
555
612
pub fn sl2 ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
556
613
assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -687,75 +744,15 @@ unsafe fn emulate_mm512_2intersect_epi32(
687
744
}
688
745
}
689
746
690
- #[ cfg( target_arch = "x86_64" ) ]
691
- #[ cfg( test) ]
692
- mod tests {
693
- use super :: * ;
694
-
695
- const LHS_SIZE : usize = 300 ;
696
- const RHS_SIZE : usize = 350 ;
697
- const EPS : F32 = F32 ( 1e-5 ) ;
698
-
699
- pub fn random_svector ( len : usize ) -> SVecf32Owned {
700
- use rand:: Rng ;
701
- let mut rng = rand:: thread_rng ( ) ;
702
- let mut indexes: Vec < u32 > = ( 0 ..len) . map ( |_| rng. gen_range ( 0 ..30000 ) ) . collect ( ) ;
703
- indexes. sort_unstable ( ) ;
704
- indexes. dedup ( ) ;
705
- let values: Vec < F32 > = ( 0 ..indexes. len ( ) )
706
- . map ( |_| F32 ( rng. gen_range ( -1.0 ..1.0 ) ) )
707
- . collect ( ) ;
708
- SVecf32Owned :: new ( 30000 , indexes, values)
709
- }
710
-
711
- #[ test]
712
- fn test_cosine_svector ( ) {
713
- let x = random_svector ( LHS_SIZE ) ;
714
- let y = random_svector ( RHS_SIZE ) ;
715
- let cosine_fallback = unsafe { cosine_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
716
- #[ cfg( target_arch = "x86_64" ) ]
717
- if detect:: v4:: detect ( ) {
718
- let cosine_v4 = unsafe { cosine_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
719
- assert ! (
720
- cosine_fallback - cosine_v4 < EPS ,
721
- "cosine_fallback: {}, cosine_v4: {}" ,
722
- cosine_fallback,
723
- cosine_v4
724
- ) ;
725
- }
726
- }
727
-
728
- #[ test]
729
- fn test_dot_svector ( ) {
730
- let x = random_svector ( LHS_SIZE ) ;
731
- let y = random_svector ( RHS_SIZE ) ;
732
- let dot_fallback = unsafe { dot_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
733
- #[ cfg( target_arch = "x86_64" ) ]
734
- if detect:: v4:: detect ( ) {
735
- let dot_v4 = unsafe { dot_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
736
- assert ! (
737
- dot_fallback - dot_v4 < EPS ,
738
- "dot_fallback: {}, dot_v4: {}" ,
739
- dot_fallback,
740
- dot_v4
741
- ) ;
742
- }
743
- }
744
-
745
- #[ test]
746
- fn test_sl2_svector ( ) {
747
- let x = random_svector ( LHS_SIZE ) ;
748
- let y = random_svector ( RHS_SIZE ) ;
749
- let sl2_fallback = unsafe { sl2_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
750
- #[ cfg( target_arch = "x86_64" ) ]
751
- if detect:: v4:: detect ( ) {
752
- let sl2_v4 = unsafe { sl2_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
753
- assert ! (
754
- sl2_fallback - sl2_v4 < EPS ,
755
- "sl2_fallback: {}, sl2_v4: {}" ,
756
- sl2_fallback,
757
- sl2_v4
758
- ) ;
759
- }
760
- }
747
+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
748
+ fn random_svector ( len : usize ) -> SVecf32Owned {
749
+ use rand:: Rng ;
750
+ let mut rng = rand:: thread_rng ( ) ;
751
+ let mut indexes: Vec < u32 > = ( 0 ..len) . map ( |_| rng. gen_range ( 0 ..30000 ) ) . collect ( ) ;
752
+ indexes. sort_unstable ( ) ;
753
+ indexes. dedup ( ) ;
754
+ let values: Vec < F32 > = ( 0 ..indexes. len ( ) )
755
+ . map ( |_| F32 ( rng. gen_range ( -1.0 ..1.0 ) ) )
756
+ . collect ( ) ;
757
+ SVecf32Owned :: new ( 30000 , indexes, values)
761
758
}
0 commit comments