7
7
use crate :: error:: { PermutationError , PermutationResult } ;
8
8
#[ cfg( feature = "use-rand" ) ]
9
9
use rand:: prelude:: * ;
10
- use std:: num:: NonZeroU32 ;
10
+ use std:: num:: { NonZeroU32 , Wrapping } ;
11
11
12
12
/// The `HashedPermutation` struct stores the initial `seed` and `length` of the permutation
13
13
/// vector. In other words, if you want to shuffle the numbers from `0..n`, then `length = n`.
@@ -56,17 +56,12 @@ impl HashedPermutation {
56
56
max_shuffle : self . length . get ( ) ,
57
57
} ) ;
58
58
}
59
- let mut i = input;
59
+ let mut i = Wrapping ( input) ;
60
60
let n = self . length . get ( ) ;
61
- let seed = self . seed ;
62
- let mut w = n - 1 ;
63
- w |= w >> 1 ;
64
- w |= w >> 2 ;
65
- w |= w >> 4 ;
66
- w |= w >> 8 ;
67
- w |= w >> 16 ;
68
-
69
- while i >= n {
61
+ let seed = Wrapping ( self . seed ) ;
62
+ let w = Wrapping ( n. checked_next_power_of_two ( ) . map_or ( u32:: MAX , |x| x - 1 ) ) ;
63
+
64
+ while i. 0 >= n {
70
65
i ^= seed;
71
66
i *= 0xe170893d ;
72
67
i ^= seed >> 16 ;
@@ -75,7 +70,7 @@ impl HashedPermutation {
75
70
i *= 0x0929eb3f ;
76
71
i ^= seed >> 23 ;
77
72
i ^= ( i & w) >> 1 ;
78
- i *= 1 | seed >> 27 ;
73
+ i *= Wrapping ( 1 ) | seed >> 27 ;
79
74
i *= 0x6935fa69 ;
80
75
i ^= ( i & w) >> 11 ;
81
76
i *= 0x74dcb303 ;
@@ -86,7 +81,7 @@ impl HashedPermutation {
86
81
i &= w;
87
82
i ^= i >> 5 ;
88
83
}
89
- Ok ( ( i + seed) % n)
84
+ Ok ( ( i + seed) . 0 % n)
90
85
}
91
86
}
92
87
@@ -143,19 +138,15 @@ mod test {
143
138
// Check that each entry doesn't exist
144
139
// Check that every number is "hit" (as they'd have to be) for a perfect bijection
145
140
// Check that the number is within range
146
- let mut map = HashMap :: new ( ) ;
141
+ let mut map = HashMap :: with_capacity ( length . get ( ) as usize ) ;
147
142
148
143
for i in 0 ..perm. length . get ( ) {
149
144
let res = perm. shuffle ( i) ;
150
145
let res = res. unwrap ( ) ;
151
- let map_result = map. get ( & res) ;
152
- assert ! ( map_result. is_none( ) ) ;
153
- map. insert ( res, i) ;
146
+ assert ! ( map. insert( res, i) . is_none( ) ) ;
154
147
}
155
- // Need to dereference the types into regular integers
156
- let mut keys_vec: Vec < u32 > = map. keys ( ) . into_iter ( ) . map ( |k| * k) . collect ( ) ;
148
+ let ( mut keys_vec, mut vals_vec) : ( Vec < u32 > , Vec < u32 > ) = map. iter ( ) . unzip ( ) ;
157
149
keys_vec. sort ( ) ;
158
- let mut vals_vec: Vec < u32 > = map. values ( ) . into_iter ( ) . map ( |v| * v) . collect ( ) ;
159
150
vals_vec. sort ( ) ;
160
151
let ground_truth: Vec < u32 > = ( 0 ..length. get ( ) ) . collect ( ) ;
161
152
assert_eq ! ( ground_truth, keys_vec) ;
0 commit comments