Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massive improved speed of wilcoxon signed rank - from 2 hours to 12 minutes! #9

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@ doctest = false
[dependencies]
statrs = { git = "https://github.com/larsgw/statrs", branch = "patch-1" }
rand = "0.7.3"
voracious_radix_sort = {version="1.2.0", optional=true}

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
paste = "1.0.15"


[[bench]]
name = "wilcoxon"
harness = false
2 changes: 2 additions & 0 deletions benches/original.wilcoxon.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Wilcoxon signed-rank test
time: [85.022 ms 85.069 ms 85.130 ms]
205 changes: 205 additions & 0 deletions benches/wilcoxon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
//! Criterion benchmark to measure the performance of the Wilcoxon signed-rank test.

use criterion::black_box;
use criterion::{criterion_group, criterion_main, Criterion};

use core::ops::{Add, Sub};
use rand::prelude::SliceRandom;
use stattest::test::WilcoxonWTest;
use stattest::traits::Bounded;

trait WrappingAdd<Rhs = Self> {
type Output;

fn wrapping_add(self, rhs: Rhs) -> Self::Output;
}

impl<T: Bounded + Sub<T, Output = T> + Add<T, Output = T> + PartialOrd + PartialEq + Copy>
WrappingAdd for T
{
type Output = T;

fn wrapping_add(self, rhs: T) -> Self::Output {
if self >= T::UPPER_BOUND - rhs {
T::LOWER_BOUND
} else {
self + rhs
}
}
}

fn generate_test_cases<const N: usize, F: Default + Copy + WrappingAdd<F, Output = F>>(
step: F,
) -> ([F; N], [F; N]) {
let mut x = [F::default(); N];
let mut y = [F::default(); N];
let mut start = F::default();
for i in 0..N {
start = start.wrapping_add(step);
x[i] = start;
start = start.wrapping_add(step);
y[i] = start;
}

// We shuffle the arrays to make the test more realistic
x.shuffle(&mut rand::thread_rng());
y.shuffle(&mut rand::thread_rng());

(x, y)
}

macro_rules! bench_float_wilcoxon {
($group:ident, $float:ty, $($quantizer:ty),*) => {
let test_cases = (0..10)
.map(|_| generate_test_cases::<200_000, $float>(0.1454829354839453473))
.collect::<Vec<_>>();

$group.bench_function(&format!(
"sort_unstable_{}",
stringify!($float)
), |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::paired(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature="voracious_radix_sort")]
$group.bench_function(&format!(
"voracious_{}",
stringify!($float)
), |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_paired(black_box(x), black_box(y)).unwrap();
}
})
});

$(
$group.bench_function(
&format!(
"quantized_sort_unstable_{}_to_{}",
stringify!($float),
stringify!($quantizer))
, |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::quantized_paired::<_, _, $quantizer>(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature="voracious_radix_sort")]
$group.bench_function(&format!(
"quantized_voracious_{}_to_{}",
stringify!($float),
stringify!($quantizer)
), |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_quantized_paired::<_, _, $quantizer>(black_box(x), black_box(y)).unwrap();
}
})
});
)*
};
}

fn bench_wilcoxon(c: &mut Criterion) {
let mut group = c.benchmark_group("Wilcoxon signed-rank test");

bench_float_wilcoxon!(group, f32, i8, i16);
bench_float_wilcoxon!(group, f64, i8, i16, i32);

let test_cases = (0..10)
.map(|_| generate_test_cases::<200_000, i64>(1))
.collect::<Vec<_>>();

group.bench_function("sort_unstable_i64", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::paired(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature = "voracious_radix_sort")]
group.bench_function("voracious_i64", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_paired(black_box(x), black_box(y)).unwrap();
}
})
});

let test_cases = (0..10)
.map(|_| generate_test_cases::<200_000, i32>(1))
.collect::<Vec<_>>();

group.bench_function("sort_unstable_i32", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::paired(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature = "voracious_radix_sort")]
group.bench_function("voracious_i32", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_paired(black_box(x), black_box(y)).unwrap();
}
})
});

let test_cases = (0..10)
.map(|_| generate_test_cases::<200_000, i16>(1))
.collect::<Vec<_>>();

group.bench_function("sort_unstable_i16", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::paired(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature = "voracious_radix_sort")]
group.bench_function("voracious_i16", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_paired(black_box(x), black_box(y)).unwrap();
}
})
});

let test_cases = (0..10)
.map(|_| generate_test_cases::<200_000, i8>(1))
.collect::<Vec<_>>();

group.bench_function("sort_unstable_i8", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::paired(black_box(x), black_box(y)).unwrap();
}
})
});

#[cfg(feature = "voracious_radix_sort")]
group.bench_function("voracious_i8", |b| {
b.iter(|| {
for (x, y) in test_cases.iter() {
WilcoxonWTest::voracious_paired(black_box(x), black_box(y)).unwrap();
}
})
});

group.finish();
}

criterion_group!(benches, bench_wilcoxon);

criterion_main!(benches);
18 changes: 16 additions & 2 deletions src/distribution/signed_rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ impl ::rand::distributions::Distribution<f64> for SignedRank {
}
}

/// Fast approximation of f64 2^-x for integer x.
fn fast_exp2_reciprocal(x: u32) -> f64 {
f64::from_le_bytes((u64::from(1023_u32 - x) << 52).to_le_bytes())
}

impl ContinuousCDF<f64, f64> for SignedRank {
fn cdf(&self, x: f64) -> f64 {
match self.approximation {
Expand All @@ -126,7 +131,7 @@ impl ContinuousCDF<f64, f64> for SignedRank {
}
}

sum as f64 / 2_f64.powi(self.n as i32 - 1)
sum as f64 * fast_exp2_reciprocal(self.n as u32 - 1)
}
}
}
Expand Down Expand Up @@ -217,7 +222,7 @@ impl Continuous<f64, f64> for SignedRank {
sum += partitions(r - n_choose_2, n, self.n - n + 1);
}

sum as f64 / 2_f64.powi(self.n as i32)
sum as f64 * fast_exp2_reciprocal(self.n as u32)
}
}
}
Expand Down Expand Up @@ -310,4 +315,13 @@ mod tests {
fn partition() {
assert_eq!(super::partitions(7, 3, 5), 4);
}

#[test]
fn test_fast_exp2_reciprocal() {
for i in 0..1000 {
let x = i as u32;
let y = super::fast_exp2_reciprocal(x);
assert_eq!(y, 2.0_f64.powf(-f64::from(x)));
}
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![deny(unconditional_recursion)]

pub mod distribution;
pub mod statistics;
pub mod test;
pub mod traits;
Loading
Loading