Skip to content

Commit

Permalink
perf: use wrapping operations, use next_power_of_two, improve tests (#10
Browse files Browse the repository at this point in the history
)

* Use `next_power_of_two` which should be more performant than manual optimization
* Clean up tests so they don't hash more than necessary
* Use wrapping operations in case hash operations overflow
  • Loading branch information
TheIronBorn authored Jan 14, 2024
1 parent 078289b commit 2ebb5c6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 25 deletions.
7 changes: 2 additions & 5 deletions src/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,11 @@ mod test {
// Check that each entry doesn't exist
// Check that every number is "hit" (as they'd have to be) for a perfect bijection
// Check that the number is within range
let mut set = HashSet::new();
let mut set = HashSet::with_capacity(length.get() as usize);

for elem in it {
let set_result = set.get(&elem);

// Make sure there are no duplicates
assert!(set_result.is_none());
set.insert(elem);
assert!(set.insert(elem));
}
// Need to dereference the types into regular integers
let mut result: Vec<u32> = set.into_iter().collect();
Expand Down
31 changes: 11 additions & 20 deletions src/kensler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::error::{PermutationError, PermutationResult};
#[cfg(feature = "use-rand")]
use rand::prelude::*;
use std::num::NonZeroU32;
use std::num::{NonZeroU32, Wrapping};

/// The `HashedPermutation` struct stores the initial `seed` and `length` of the permutation
/// vector. In other words, if you want to shuffle the numbers from `0..n`, then `length = n`.
Expand Down Expand Up @@ -56,17 +56,12 @@ impl HashedPermutation {
max_shuffle: self.length.get(),
});
}
let mut i = input;
let mut i = Wrapping(input);
let n = self.length.get();
let seed = self.seed;
let mut w = n - 1;
w |= w >> 1;
w |= w >> 2;
w |= w >> 4;
w |= w >> 8;
w |= w >> 16;

while i >= n {
let seed = Wrapping(self.seed);
let w = Wrapping(n.checked_next_power_of_two().map_or(u32::MAX, |x| x - 1));

while i.0 >= n {
i ^= seed;
i *= 0xe170893d;
i ^= seed >> 16;
Expand All @@ -75,7 +70,7 @@ impl HashedPermutation {
i *= 0x0929eb3f;
i ^= seed >> 23;
i ^= (i & w) >> 1;
i *= 1 | seed >> 27;
i *= Wrapping(1) | seed >> 27;
i *= 0x6935fa69;
i ^= (i & w) >> 11;
i *= 0x74dcb303;
Expand All @@ -86,7 +81,7 @@ impl HashedPermutation {
i &= w;
i ^= i >> 5;
}
Ok((i + seed) % n)
Ok((i + seed).0 % n)
}
}

Expand Down Expand Up @@ -143,19 +138,15 @@ mod test {
// Check that each entry doesn't exist
// Check that every number is "hit" (as they'd have to be) for a perfect bijection
// Check that the number is within range
let mut map = HashMap::new();
let mut map = HashMap::with_capacity(length.get() as usize);

for i in 0..perm.length.get() {
let res = perm.shuffle(i);
let res = res.unwrap();
let map_result = map.get(&res);
assert!(map_result.is_none());
map.insert(res, i);
assert!(map.insert(res, i).is_none());
}
// Need to dereference the types into regular integers
let mut keys_vec: Vec<u32> = map.keys().into_iter().map(|k| *k).collect();
let (mut keys_vec, mut vals_vec): (Vec<u32>, Vec<u32>) = map.iter().unzip();
keys_vec.sort();
let mut vals_vec: Vec<u32> = map.values().into_iter().map(|v| *v).collect();
vals_vec.sort();
let ground_truth: Vec<u32> = (0..length.get()).collect();
assert_eq!(ground_truth, keys_vec);
Expand Down

0 comments on commit 2ebb5c6

Please sign in to comment.