@@ -291,6 +291,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
291291 }
292292}
293293
294+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
295+ #[ test]
296+ fn cosine_v4_test ( ) {
297+ const EPSILON : F32 = F32 ( 1e-5 ) ;
298+ detect:: init ( ) ;
299+ if !detect:: v4:: detect ( ) {
300+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
301+ return ;
302+ }
303+ let lhs = random_svector ( 300 ) ;
304+ let rhs = random_svector ( 350 ) ;
305+ let specialized = unsafe { cosine_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
306+ let fallback = unsafe { cosine_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
307+ assert ! (
308+ ( specialized - fallback) . abs( ) < EPSILON ,
309+ "specialized = {specialized}, fallback = {fallback}."
310+ ) ;
311+ }
312+
294313#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
295314pub fn cosine ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
296315 assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -407,6 +426,25 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
407426 }
408427}
409428
429+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
430+ #[ test]
431+ fn dot_v4_test ( ) {
432+ const EPSILON : F32 = F32 ( 1e-5 ) ;
433+ detect:: init ( ) ;
434+ if !detect:: v4:: detect ( ) {
435+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
436+ return ;
437+ }
438+ let lhs = random_svector ( 300 ) ;
439+ let rhs = random_svector ( 350 ) ;
440+ let specialized = unsafe { dot_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
441+ let fallback = unsafe { dot_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
442+ assert ! (
443+ ( specialized - fallback) . abs( ) < EPSILON ,
444+ "specialized = {specialized}, fallback = {fallback}."
445+ ) ;
446+ }
447+
410448#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
411449pub fn dot ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
412450 assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -551,6 +589,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 {
551589 }
552590}
553591
592+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
593+ #[ test]
594+ fn sl2_v4_test ( ) {
595+ const EPSILON : F32 = F32 ( 1e-5 ) ;
596+ detect:: init ( ) ;
597+ if !detect:: v4:: detect ( ) {
598+ println ! ( "test {} ... skipped (v4)" , module_path!( ) ) ;
599+ return ;
600+ }
601+ let lhs = random_svector ( 300 ) ;
602+ let rhs = random_svector ( 350 ) ;
603+ let specialized = unsafe { sl2_v4 ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
604+ let fallback = unsafe { sl2_fallback ( lhs. for_borrow ( ) , rhs. for_borrow ( ) ) } ;
605+ assert ! (
606+ ( specialized - fallback) . abs( ) < EPSILON ,
607+ "specialized = {specialized}, fallback = {fallback}."
608+ ) ;
609+ }
610+
554611#[ detect:: multiversion( v4 = import, v3, v2, neon, fallback = export) ]
555612pub fn sl2 ( lhs : SVecf32Borrowed < ' _ > , rhs : SVecf32Borrowed < ' _ > ) -> F32 {
556613 assert_eq ! ( lhs. dims( ) , rhs. dims( ) ) ;
@@ -687,75 +744,15 @@ unsafe fn emulate_mm512_2intersect_epi32(
687744 }
688745}
689746
690- #[ cfg( target_arch = "x86_64" ) ]
691- #[ cfg( test) ]
692- mod tests {
693- use super :: * ;
694-
695- const LHS_SIZE : usize = 300 ;
696- const RHS_SIZE : usize = 350 ;
697- const EPS : F32 = F32 ( 1e-5 ) ;
698-
699- pub fn random_svector ( len : usize ) -> SVecf32Owned {
700- use rand:: Rng ;
701- let mut rng = rand:: thread_rng ( ) ;
702- let mut indexes: Vec < u32 > = ( 0 ..len) . map ( |_| rng. gen_range ( 0 ..30000 ) ) . collect ( ) ;
703- indexes. sort_unstable ( ) ;
704- indexes. dedup ( ) ;
705- let values: Vec < F32 > = ( 0 ..indexes. len ( ) )
706- . map ( |_| F32 ( rng. gen_range ( -1.0 ..1.0 ) ) )
707- . collect ( ) ;
708- SVecf32Owned :: new ( 30000 , indexes, values)
709- }
710-
711- #[ test]
712- fn test_cosine_svector ( ) {
713- let x = random_svector ( LHS_SIZE ) ;
714- let y = random_svector ( RHS_SIZE ) ;
715- let cosine_fallback = unsafe { cosine_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
716- #[ cfg( target_arch = "x86_64" ) ]
717- if detect:: v4:: detect ( ) {
718- let cosine_v4 = unsafe { cosine_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
719- assert ! (
720- cosine_fallback - cosine_v4 < EPS ,
721- "cosine_fallback: {}, cosine_v4: {}" ,
722- cosine_fallback,
723- cosine_v4
724- ) ;
725- }
726- }
727-
728- #[ test]
729- fn test_dot_svector ( ) {
730- let x = random_svector ( LHS_SIZE ) ;
731- let y = random_svector ( RHS_SIZE ) ;
732- let dot_fallback = unsafe { dot_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
733- #[ cfg( target_arch = "x86_64" ) ]
734- if detect:: v4:: detect ( ) {
735- let dot_v4 = unsafe { dot_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
736- assert ! (
737- dot_fallback - dot_v4 < EPS ,
738- "dot_fallback: {}, dot_v4: {}" ,
739- dot_fallback,
740- dot_v4
741- ) ;
742- }
743- }
744-
745- #[ test]
746- fn test_sl2_svector ( ) {
747- let x = random_svector ( LHS_SIZE ) ;
748- let y = random_svector ( RHS_SIZE ) ;
749- let sl2_fallback = unsafe { sl2_fallback ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
750- #[ cfg( target_arch = "x86_64" ) ]
751- if detect:: v4:: detect ( ) {
752- let sl2_v4 = unsafe { sl2_v4 ( x. for_borrow ( ) , y. for_borrow ( ) ) } ;
753- assert ! (
754- sl2_fallback - sl2_v4 < EPS ,
755- "sl2_fallback: {}, sl2_v4: {}" ,
756- sl2_fallback,
757- sl2_v4
758- ) ;
759- }
760- }
747+ #[ cfg( all( target_arch = "x86_64" , test) ) ]
748+ fn random_svector ( len : usize ) -> SVecf32Owned {
749+ use rand:: Rng ;
750+ let mut rng = rand:: thread_rng ( ) ;
751+ let mut indexes: Vec < u32 > = ( 0 ..len) . map ( |_| rng. gen_range ( 0 ..30000 ) ) . collect ( ) ;
752+ indexes. sort_unstable ( ) ;
753+ indexes. dedup ( ) ;
754+ let values: Vec < F32 > = ( 0 ..indexes. len ( ) )
755+ . map ( |_| F32 ( rng. gen_range ( -1.0 ..1.0 ) ) )
756+ . collect ( ) ;
757+ SVecf32Owned :: new ( 30000 , indexes, values)
761758}
0 commit comments