diff --git a/crates/quaternary_trie/Cargo.toml b/crates/quaternary_trie/Cargo.toml new file mode 100644 index 0000000..e8f15f5 --- /dev/null +++ b/crates/quaternary_trie/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "quarternary_trie" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["lib", "staticlib"] +bench = false + +[dependencies] +rand = "0.8" + +[dev-dependencies] +itertools = "0.12" \ No newline at end of file diff --git a/crates/quaternary_trie/src/lib.rs b/crates/quaternary_trie/src/lib.rs new file mode 100644 index 0000000..817727c --- /dev/null +++ b/crates/quaternary_trie/src/lib.rs @@ -0,0 +1,713 @@ +use virtual_bitrank::{VirtualBitRank, Word, WORD_BITS}; + +pub mod parallel; +mod virtual_bitrank; + +const MAX_LEVEL: usize = 14; + +pub struct QuarternaryTrie { + data: VirtualBitRank, + /// Total number of nibbles on each level. + level_idx: [usize; MAX_LEVEL], +} + +/// Level: 0 ==> ........xx +/// Level: 1 ==> ......xx.. +/// Level: 2 ==> ....xx.... +/// ^^^^ +/// prefix +/// ^^ +/// nibble bit +/// Van Emde Boas layout/traversal +/// 1 +/// / \ +/// 2 3 +/// / \ / \ +/// 4 7 a d +/// / \ / \ / \ / \ +/// 5 6 8 9 b c e f +/// +/// Process: 123, 489, 5ab, 6cd, 7ef +/// 0xxx 00xx 01xx | STOP & Recurse +/// 000x 0000 0001 +/// 001x 0010 0011 +/// 010x 0100 0101 +/// 011x 0110 0111 +/// +/// now with two bit levels +/// +/// 00xxxxxx 01xxxxxx 10xxxxxx 11xxxxxx +/// 0000xxxx 0001xxxx 0010xxxx 0011xxxx +/// 0100xxxx 0101xxxx 0110xxxx 0111xxxx +/// 1000xxxx 1001xxxx 1010xxxx 1011xxxx +/// 1100xxxx 1101xxxx 1110xxxx 1111xxxx +/// +/// Stop and recurse +/// +/// 000000xx 000001xx 000010xx 000011xx +/// 00000000 00000001 00000010 00000011 +/// 00000100 00000101 00000110 00000111 +/// 00001000 00001001 00001010 00001011 +/// 00001100 00001101 00001110 00001111 +/// +/// 000100xx 000101xx 000110xx 000111xx +/// 00010000 00010001 00010010 00010011 +/// 00010100 00010101 00010110 00010111 +/// 00011000 00011001 00011010 00011011 +/// 00011100 00011101 00011110 00011111 +/// +/// 001000xx 001001xx 001010xx 001011xx +/// 00100000 00100001 00100010 00100011 +/// 00100100 00100101 00100110 00100111 +/// 00101000 00101001 00101010 00101011 +/// 00101100 00101101 00101110 00101111 +/// +/// 001100xx 001101xx 001110xx 001111xx +/// 00110000 00110001 00110010 00110011 +/// 00110100 00110101 00110110 00110111 +/// 00111000 00111001 00111010 00111011 +/// 00111100 00111101 00111110 00111111 +/// +/// 010000xx 010001xx 010010xx 010011xx +/// ... +/// +/// Process first half: +/// xxxxxxxx-------- +/// Reset and process second half +/// 00000000xxxxxxxx + +fn get_prefix(value: u32, level: usize) -> u32 { + if level == 16 { + 0 + } else { + value >> (level * 2) + } +} + +fn gallopping_search bool>(values: &mut &[u32], f: F) { + let mut step = 1; + while step < values.len() { + if f(values[step]) { + break; + } + *values = &mut &values[step..]; + step *= 2; + } + step /= 2; + while step > 0 { + if step < values.len() && !f(values[step]) { + *values = &mut &values[step..]; + } + step /= 2; + } +} + +#[derive(Copy, Clone, Debug)] +pub enum Layout { + Linear, + VanEmdeBoas, + DepthFirst, +} + +impl QuarternaryTrie { + fn count_levels(&mut self, values: &mut &[u32], level: usize) -> bool { + let prefix = values[0] >> (level * 2 + 2); + let mut nibble = 0; + let mut all_set = true; + while !values.is_empty() && values[0] >> (level * 2 + 2) == prefix { + nibble |= 1 << ((values[0] >> level * 2) & 3); + if level > 0 { + all_set &= self.count_levels(values, level - 1); + } else { + *values = &values[1..]; + } + } + all_set &= nibble == 15; + self.level_idx[level] += 1; + all_set + } + + fn van_emde_boas(&mut self, values: &mut &[u32], level: usize, res: usize) { + let prefix = get_prefix(values[0], level); + if res == 0 { + let mut nibble = 0; + while !values.is_empty() && get_prefix(values[0], level) == prefix { + let v = values[0] >> (level * 2 - 2); + nibble |= 1 << (v & 3); + *values = &mut &values[1..]; + gallopping_search(values, |x| x >> (level * 2 - 2) > v); + } + if level <= MAX_LEVEL { + self.data.set_nibble(self.level_idx[level - 1], nibble); + self.level_idx[level - 1] += 1; + } + return; + } + // process level .. level - res + // This level has to be processed at half resolution + let mut copy = &values[..]; + self.van_emde_boas(&mut copy, level, res / 2); + // Then process level - res..level - 2 *res + // Process all the children within this subtree. + while !values.is_empty() && get_prefix(values[0], level) == prefix { + self.van_emde_boas(values, level - res, res / 2); + } + } + + fn fill_bit_rank(&mut self, values: &mut &[u32], level: usize) -> bool { + let prefix = values[0] >> (level * 2 + 2); + let mut nibble = 0; + let mut all_set = true; + while !values.is_empty() && values[0] >> (level * 2 + 2) == prefix { + nibble |= 1 << ((values[0] >> level * 2) & 3); + if level > 0 { + all_set &= self.fill_bit_rank(values, level - 1); + } else { + *values = &values[1..]; + } + } + all_set &= nibble == 15; + self.data.set_nibble(self.level_idx[level], nibble); + self.level_idx[level] += 1; + all_set + } + + pub fn new(values: &[u32], layout: Layout) -> Self { + let mut s = Self { + data: VirtualBitRank::new(), + level_idx: [0; MAX_LEVEL], + }; + let mut consumed = values; + s.count_levels(&mut consumed, MAX_LEVEL - 1); + if true || matches!(layout, Layout::Linear) { + s.data.reserve(s.level_idx.iter().sum::() * 4); + } + s.level_idx + .iter_mut() + .rev() + .scan(0, |acc, x| { + let old = *acc; + *acc = *acc + *x; + *x = old; + Some(old) + }) + .skip(usize::MAX) + .next(); + let mut consumed = values; + if matches!(layout, Layout::VanEmdeBoas | Layout::Linear) { + s.van_emde_boas(&mut consumed, 16, 8); + } else { + s.fill_bit_rank(&mut consumed, MAX_LEVEL - 1); + } + s.data.build(); + s.reset_stats(); + println!( + "encoded size: {}", + 4.0 * s.level_idx[0] as f32 / values.len() as f32 + ); + s + } + + fn reset_stats(&mut self) { + self.data.reset_stats(); + } + + fn page_count(&self) -> (usize, usize) { + self.data.page_count() + } + + fn recurse(&self, node: usize, level: usize, value: u32, results: &mut Vec) { + if level == 1 { + self.recurse2(node, value, results); + } else { + let mut n = self.data.get_nibble(node); + let mut value = value * 4; + if n == 0 { + results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i)); + return; + } + let mut r = self.data.rank(node * 4) as usize; + while n > 0 { + let delta = n.trailing_zeros(); + r += 1; + self.recurse(r, level - 1, value + delta, results); + value += delta + 1; + n >>= delta + 1; + } + } + } + + fn recurse2(&self, node: usize, value: u32, results: &mut Vec) { + let mut n = self.data.get_nibble(node); + let mut value = value * 4; + if n == 0 { + results.extend((0..16).map(|i| (value << 2) + i)); + return; + } + let mut r = self.data.rank(node * 4) as usize; + while n > 0 { + let delta = n.trailing_zeros(); + r += 1; + self.recurse0(r, value + delta, results); + value += delta + 1; + n >>= delta + 1; + } + } + + fn recurse0(&self, node: usize, value: u32, results: &mut Vec) { + let mut n = self.data.get_nibble(node); + let mut value = value * 4; + if n == 0 { + results.extend((0..4).map(|i| value + i)); + return; + } + while n > 0 { + let delta = n.trailing_zeros(); + results.push(value + delta); + value += delta + 1; + n >>= delta + 1; + } + } + + // This is the "slow" implementation which computes at every level the rank and extract the corresponding nibble. + pub fn collect2(&self) -> Vec { + let mut results = Vec::with_capacity(self.level_idx[0]); + self.recurse(0, MAX_LEVEL - 1, 0, &mut results); + results + } + + // This is the "fastest" implementation, since it doesn't use rank information at all during the traversal. + // This is possible, since it iterates through ALL nodes and thus we can simply increment the positions by 1. + // We only need the rank information to initialize the positions. + // The only remaining "expensive" part here is the lookup of the nibble with every iteration. + // This lookup requires the slightly complicated conversion from position into block pointer (either via the virtual mapping or via some math). + // The math would be trivial if we wouldn't store the counters within the same page... + // Instead one could try to cache a u64 value and keep shifting until the end is reached. Or working with the pointer into the bitrank array. + pub fn collect(&mut self) -> Vec { + self.level_idx[MAX_LEVEL - 1] = 0; + for level in (1..MAX_LEVEL).into_iter().rev() { + self.level_idx[level - 1] = self.data.rank(self.level_idx[level] * 4) as usize + 1; + } + let mut results = Vec::new(); + self.fast_collect_inner(MAX_LEVEL - 1, 0, &mut results); + results + } + + fn fast_collect_inner(&mut self, level: usize, value: u32, results: &mut Vec) { + let mut nibble = self.data.get_nibble(self.level_idx[level]); + self.level_idx[level] += 1; + if nibble == 0 { + results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i)); + return; + } + let mut value = value * 4; + if level == 0 { + while nibble > 0 { + let delta = nibble.trailing_zeros(); + results.push(value + delta); + value += delta + 1; + nibble >>= delta + 1; + } + } else { + while nibble > 0 { + let delta = nibble.trailing_zeros(); + self.fast_collect_inner(level - 1, value + delta, results); + value += delta + 1; + nibble >>= delta + 1; + } + } + } +} + +pub trait TrieIteratorTrait { + fn get(&self, level: usize) -> u32; + fn down(&mut self, level: usize, child: u32); +} + +pub struct TrieTraversal<'a> { + trie: &'a QuarternaryTrie, + // The nibble position of the node for each level. + pos: [u32; MAX_LEVEL], + // The remaining bits (nibbles) of the word covering the nibble position. + word: [Word; MAX_LEVEL], + // The 1-rank up to the nibble. This information is needed + // to determine the nibble/node position of the next level. + rank: [u32; MAX_LEVEL], +} + +impl<'a> TrieTraversal<'a> { + pub fn new(trie: &'a QuarternaryTrie) -> Self { + let word = trie.data.get_word_suffix(0); + Self { + trie, + pos: [0; MAX_LEVEL], + word: [word; MAX_LEVEL], + rank: [0; MAX_LEVEL], + } + } +} + +impl TrieIteratorTrait for TrieTraversal<'_> { + fn get(&self, level: usize) -> u32 { + self.word[level] as u32 & 15 + } + + fn down(&mut self, level: usize, child: u32) { + let new_pos = + self.rank[level] + (self.word[level] & !(Word::MAX << (child + 1))).count_ones(); + let old_pos = self.pos[level - 1]; + if (new_pos ^ old_pos) & !(WORD_BITS as u32 / 4 - 1) == 0 { + // In this case, we can reuse the old rank information + let delta = (new_pos - old_pos) * 4; + self.rank[level - 1] += (self.word[level - 1] & !(Word::MAX << delta)).count_ones(); + self.word[level - 1] = self.word[level - 1] >> delta; + } else { + if level > 1 { + // for level 0, we don't need the rank information + // self.rank[level - 1] = self.trie.data.rank(4 * new_pos as usize); + let (r, w) = self.trie.data.rank_with_word(4 * new_pos as usize); + self.rank[level - 1] = r; + self.word[level - 1] = w; + } else { + // TODO: Get word suffix and rank information in one go... + self.word[level - 1] = self.trie.data.get_word_suffix(4 * new_pos as usize); + } + } + self.pos[level - 1] = new_pos; + } +} + +pub struct TrieIterator { + trie: T, + item: u32, +} + +impl TrieIterator { + pub fn new(trie: T) -> Self { + Self { trie, item: 0 } + } +} + +impl<'a, T: TrieIteratorTrait> Iterator for TrieIterator { + type Item = u32; + + fn next(&mut self) -> Option { + let mut item = self.item; + let mut level = if self.item == 0 { + MAX_LEVEL - 1 + } else { + (item.trailing_zeros() / 2) as usize + }; + while level < MAX_LEVEL { + let child = (item >> (2 * level)) & 3; + let nibble = self.trie.get(level) >> child; + if nibble != 0 { + let delta = nibble.trailing_zeros(); + if level == 0 { + let res = item + delta; + self.item = res + 1; + return Some(res); + } + item += delta << (2 * level); + self.trie.down(level, child + delta); + level -= 1; + } else { + item |= 3 << (level * 2); + item += 1 << (level * 2); + level = (item.trailing_zeros() / 2) as usize; + } + } + self.item = item; + None + } +} + +// TODO: Introduce a nibble summary structure which caches the computed merged nibble information. +// If the query tree becomes more complex, recomputing the merged nibble information becomes expensive. +// But for small query trees, it's not worth the effort to cache the information. +pub struct Intersection { + left: T, + right: T, +} + +impl Intersection { + pub fn new(left: T, right: T) -> Self { + Self { left, right } + } +} + +impl TrieIteratorTrait for Intersection { + fn get(&self, level: usize) -> u32 { + self.left.get(level) & self.right.get(level) + } + + fn down(&mut self, level: usize, child: u32) { + self.left.down(level, child); + self.right.down(level, child); + } +} + +enum Split { + None, + Left(usize), + Right(usize), +} + +pub struct Union { + inner: [T; 2], + split: usize, + swap: bool, +} + +impl Union { + pub fn new(left: T, right: T) -> Self { + Self { + inner: [left, right], + swap: false, + split: 0, + } + } +} + +impl TrieIteratorTrait for Union { + fn get(&self, level: usize) -> u32 { + /*match self.split { + Split::None => self.left.get(level) | self.right.get(level), + Split::Left(split) => { + if level < split { + self.right.get(level) + } else { + self.left.get(level) | self.right.get(level) + } + } + Split::Right(split) => { + if level < split { + self.left.get(level) + } else { + self.left.get(level) | self.right.get(level) + } + } + }*/ + if level < self.split { + self.inner[self.swap as usize].get(level) + } else { + self.inner[0].get(level) | self.inner[1].get(level) + } + } + + fn down(&mut self, level: usize, child: u32) { + // TODO: Only traverse the side which has the child bit set. + /*match self.split { + Split::Left(l) if level < l => { + self.right.down(level, child); + } + Split::Right(l) if level < l => { + self.left.down(level, child); + } + _ => { + let has_left = self.left.get(level) & (1 << child) != 0; + let has_right = self.right.get(level) & (1 << child) != 0; + if has_left && has_right { + self.split = Split::None; + self.left.down(level, child); + self.right.down(level, child); + } else if has_left { + self.split = Split::Right(level); + self.left.down(level, child); + } else { + assert!(has_right); + self.split = Split::Left(level); + //std::mem::swap(&mut self.left, &mut self.right); + self.right.down(level, child); + } + } + }*/ + if level < self.split { + self.inner[self.swap as usize].down(level, child); + } else { + let has_left = self.inner[0].get(level) & (1 << child) != 0; + let has_right = self.inner[1].get(level) & (1 << child) != 0; + if has_left && has_right { + self.split = 0; + self.inner[0].down(level, child); + self.inner[1].down(level, child); + } else if has_left { + self.split = level; + self.swap = false; + self.inner[0].down(level, child); + } else { + assert!(has_right); + self.split = level; + self.swap = true; + self.inner[1].down(level, child); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use itertools::{kmerge, Itertools}; + use rand::{thread_rng, Rng}; + + use crate::{Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal, Union}; + + #[test] + fn test_trie() { + let values = vec![3, 6, 7, 10]; + let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas); + assert_eq!(trie.collect(), values); + + let values: Vec<_> = (1..63).collect(); + let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas); + assert_eq!(trie.collect(), values); + } + + #[test] + fn test_large() { + let mut values: Vec<_> = (0..10000000) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + // let mut values: Vec<_> = (0..100).map(|_| thread_rng().gen_range(0..10000)).collect(); + values.sort(); + values.dedup(); + + let start = Instant::now(); + let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas); + println!("construction {:?}", start.elapsed() / values.len() as u32); + + let start = Instant::now(); + let result = trie.collect(); + println!("reconstruction {:?}", start.elapsed() / values.len() as u32); + assert_eq!(result, values); + + let iter = TrieIterator::new(TrieTraversal::new(&trie)); + let start = Instant::now(); + let result: Vec<_> = iter.collect(); + println!("iteration {:?}", start.elapsed() / values.len() as u32); + // assert_eq!(result, values); + } + + #[test] + fn test_van_emde_boas_layout() { + let values: Vec<_> = (0..64).collect(); + let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas); + assert_eq!(trie.collect(), values); + } + + #[test] + fn test_union() { + let mut values: Vec<_> = (0..1000000) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + values.sort(); + values.dedup(); + + let mut values2: Vec<_> = (0..10000000) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + values2.sort(); + values2.dedup(); + + let start = Instant::now(); + let union: Vec<_> = kmerge([values.iter(), values2.iter()]) + .copied() + .dedup() + .collect(); + println!("kmerge union {:?}", start.elapsed() / union.len() as u32); + println!("Union size: {}", union.len()); + + let trie = QuarternaryTrie::new(&values, Layout::Linear); + let trie2 = QuarternaryTrie::new(&values2, Layout::Linear); + let iter = TrieIterator::new(Union::new( + TrieTraversal::new(&trie), + TrieTraversal::new(&trie2), + )); + let start = Instant::now(); + let result: Vec<_> = iter.collect(); + println!("trie union {:?}", start.elapsed() / union.len() as u32,); + assert_eq!(result, union); + } + + #[test] + fn test_intersection() { + let mut page_counts = [0, 0, 0]; + for _ in 0..3 { + let mut values: Vec<_> = (0..10000000) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + values.sort(); + values.dedup(); + + let mut values2: Vec<_> = (0..10000000) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + values2.sort(); + values2.dedup(); + + let start = Instant::now(); + let intersection: Vec<_> = kmerge([values.iter(), values2.iter()]) + .tuple_windows() + .filter_map(|(a, b)| if *a == *b { Some(*a) } else { None }) + .collect(); + println!( + "kmerge intersection {:?}", + start.elapsed() / values.len() as u32 + ); + println!("Intersection size: {}", intersection.len()); + + for (i, layout) in [Layout::VanEmdeBoas, Layout::DepthFirst, Layout::Linear] + .into_iter() + .enumerate() + { + let trie = QuarternaryTrie::new(&values, layout); + let trie2 = QuarternaryTrie::new(&values2, layout); + let iter = TrieIterator::new(Intersection::new( + TrieTraversal::new(&trie), + TrieTraversal::new(&trie2), + )); + let start = Instant::now(); + let result: Vec<_> = iter.collect(); + let count = trie.page_count(); + let count2 = trie2.page_count(); + page_counts[i] += count.0 + count2.0; + println!( + "trie intersection {:?} {}", + start.elapsed() / values.len() as u32, + (count.0 + count2.0) as f32 / (count.1 + count2.1) as f32 + ); + assert_eq!(result, intersection); + } + println!("{page_counts:?}"); + } + } + + /*#[test] + fn test_mix() { + let values: Vec<_> = [10000, 100000, 1000000] + .into_iter() + .map(|v| { + let mut values: Vec<_> = (0..v) + .map(|_| thread_rng().gen_range(0..100000000)) + .collect(); + values.sort(); + values.dedup(); + values + }) + .collect(); + + let tries: Vec<_> = values + .iter() + .map(|v| QuarternaryTrie::new(v, Layout::Linear)) + .collect(); + let iter = TrieIterator::new(Intersection::new( + &tries[0], + Union::new(TrieTraversal::new(&tries[1]), TrieTraversal::new(&tries[2])), + )); + let start = Instant::now(); + let result: Vec<_> = iter.collect(); + println!("trie union {:?}", start.elapsed() / result.len() as u32); + }*/ +} diff --git a/crates/quaternary_trie/src/parallel.rs b/crates/quaternary_trie/src/parallel.rs new file mode 100644 index 0000000..8421769 --- /dev/null +++ b/crates/quaternary_trie/src/parallel.rs @@ -0,0 +1,193 @@ +use std::arch::x86_64::{_pdep_u64, _pext_u64}; + +use crate::virtual_bitrank::VirtualBitRank; + +pub struct ParallelTrie { + root: Vec, + root_ones: usize, + max_level: usize, + data: VirtualBitRank, + level_idx: Vec, +} + +impl ParallelTrie { + fn fill_bit_rank( + &mut self, + prefix: u32, + slices: &mut [&[u32]; 64], + level: usize, + mask: u64, + ) { + // !("fill_bit_rank {prefix} {mask:064b} {level}"); + for t in [0, 64 << level] { + let mut sub_mask = 0; + for i in 0..64 { + if (1 << i) & mask == 0 { + continue; + } + if let Some(&value) = slices[i].get(0) { + if (value ^ prefix) >> (level + 7) == 0 && value & (64 << level) == t { + if WRITE { + self.data.set(self.level_idx[level]); + } + if level > 0 { + sub_mask |= 1 << (value & 63); + } else { + slices[i] = &slices[i][1..]; + } + } + } + self.level_idx[level] += 1; + } + if sub_mask != 0 { + self.fill_bit_rank::(prefix + t, slices, level - 1, sub_mask); + } + } + } + + fn fill(&mut self, mut slices: [&[u32]; 64]) { + for prefix in 0..self.root.len() { + let mut mask = 0; + for i in 0..64 { + if let Some(&value) = slices[i].get(0) { + if value >> (self.max_level + 6) == prefix as u32 { + mask |= 1 << i; + } + } + } + if WRITE { + self.root[prefix] = mask; + self.root_ones += mask.count_ones() as usize; + } + if mask != 0 { + self.fill_bit_rank::( + (prefix as u32) << (self.max_level + 6), + &mut slices, + self.max_level - 1, + mask, + ); + } + } + } + + pub fn build(max_doc: usize, mut v: Vec, max_level: usize) -> Self { + v.sort_by_key(|&v| (v % 64, v / 64)); + let mut slices = [&v[..]; 64]; + let mut i = 0; + for j in 0..64 { + let s = i; + while i < v.len() && v[i] % 64 == j { + i += 1; + } + slices[j as usize] = &v[s..i]; + } + let mut s = Self { + max_level, + data: VirtualBitRank::default(), + root: vec![0u64; (max_doc >> (max_level + 6)) + 1], + root_ones: 0, + level_idx: vec![0; max_level], + }; + s.fill::(slices.clone()); + s.data.reserve(s.level_idx.iter().sum::() + 64); + s.level_idx + .iter_mut() + .rev() + .scan(0, |acc, x| { + let old = *acc; + *acc = *acc + *x; + *x = old; + Some(old) + }) + .skip(usize::MAX) + .next(); + s.fill::(slices); + s.data.build(); + let trie_size = (s.level_idx[0] as f32) / v.len() as f32; + let root_size = (s.root.len() * 64) as f32 / v.len() as f32; + println!( + "encoded size: {trie_size} {root_size} total: {} density: {}", + trie_size + root_size, + s.root_ones as f32 / s.root.len() as f32 / 64.0 + ); + s + } + + pub fn collect(&self) -> Vec { + let mut v = Vec::new(); + let mut rank = 0; + for (i, word) in self.root.iter().enumerate() { + if *word != 0 { + self.recurse(i, *word, rank * 2, self.max_level, &mut v); + } + rank += word.count_ones() as usize; + } + v + } + + fn recurse(&self, pos: usize, mut word: u64, rank: usize, level: usize, v: &mut Vec) { + if level == 0 { + while word != 0 { + let bit = word.trailing_zeros(); + v.push(((pos as u32) << 6) + bit); + word &= word - 1; + } + } else { + let required_bits = word.count_ones(); + if required_bits == 0 { + return; + } + let w = self.data.get_word(rank); + let new_word = unsafe { _pdep_u64(w, word) }; + let new_rank = self.data.rank(rank) as usize + self.root_ones; + self.recurse(pos * 2, new_word, new_rank * 2, level - 1, v); + + let rank = rank + required_bits as usize; + let w = self.data.get_word(rank); + let new_word = unsafe { _pdep_u64(w, word) }; + let new_rank = self.data.rank(rank) as usize + self.root_ones; + self.recurse(pos * 2 + 1, new_word, new_rank * 2, level - 1, v); + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use itertools::{kmerge, Itertools}; + use rand::{thread_rng, Rng}; + + use crate::{ + parallel::ParallelTrie, Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal, + Union, + }; + + #[test] + fn test_parallel() { + // let values = vec![3, 6, 7, 10, 90, 91, 120, 128, 129, 130, 231, 321, 999]; + // let values = vec![3, 6, 7, 321, 999]; + let mut values: Vec<_> = (0..10_000_000) + .map(|_| thread_rng().gen_range(0..100_000_000)) + .collect(); + values.sort(); + values.dedup(); + + for levels in 1..12 { + let start = Instant::now(); + let trie = ParallelTrie::build(100_000_000, values.clone(), levels); + println!( + "construction {levels} {:?}", + start.elapsed() / values.len() as u32, + ); + + let start = Instant::now(); + let result = trie.collect(); + println!( + "collect {levels} {:?}", + start.elapsed() / values.len() as u32, + ); + assert_eq!(result, values); + } + } +} diff --git a/crates/quaternary_trie/src/virtual_bitrank.rs b/crates/quaternary_trie/src/virtual_bitrank.rs new file mode 100644 index 0000000..e059c3a --- /dev/null +++ b/crates/quaternary_trie/src/virtual_bitrank.rs @@ -0,0 +1,352 @@ +/* + Layout: + Want <5% permutation table! + Max posting lists with 2^28 documents = ~90MB + NVMe page size is 4096 bytes + 4 bytes for block offset ==> 4 billion blocks + counts + 4 bytes ==> leads to overflow, but NOT within one posting list! + 3 bytes only works for <16 million documents :( + Store count into middle of block + 64 block size + index size: 64 * 4 billion ==> 256 GB + lookup overhead: 6.25% + blocks per page: 64 ==> can store about 4 layers within one page! ==> "worst case" 4 pages + count: 32 byte (AVX-2, 256-bits) + overhead of count: 6.25% + total overhead: 12.5% ==> not good to store within lookup table! + 128 block size + index size: 128 * 4 billion ==> 512 GB + lookup overhead: 3.125% + blocks per page: 32 ==> can store at least 3 layers (of at most 14) within one page! ==> "worst case" 5 pages + count: 64 bytes (AVX-512, 512-bits) + overhead of count: 3.125% + total overhead: 6.25% + 1 block = 4 * 32 counts ==> store counts inside of page? + sub-counts: 32-bytes + one-byte for each sub-count + two splits: --- 1 byte --- 4 bytes --- 1 byte --- + overhead of sub-counts: 1.5625% + + Iteration: + basic: decode next 4 values on level L + decode ahead: decode next 16? values down to level L + decode on each previous level ~16 values ahead + not quite clear how to arrange the data best + 16 values fit into one avx2 register +*/ + +use std::cell::RefCell; + +pub(crate) type Word = u64; + +const BLOCK_BYTES: usize = 64; +const BLOCK_BITS: usize = BLOCK_BYTES * 8; +const BLOCKS_PER_PAGE: usize = BLOCK_BYTES / 4; +pub(crate) const WORD_BITS: usize = WORD_BYTES * 8; +pub(crate) const WORD_BYTES: usize = std::mem::size_of::(); +const WORDS_PER_BLOCK: usize = BLOCK_BYTES / WORD_BYTES; +const PAGE_BYTES: usize = BLOCKS_PER_PAGE * BLOCK_BYTES; +const PAGE_BITS: usize = PAGE_BYTES * 8; +const SUPER_PAGE_BITS: usize = 4096 * 8; + +#[repr(C, align(128))] +#[derive(Default, Clone)] +struct Block { + words: [Word; WORDS_PER_BLOCK], +} + +#[derive(Default)] +pub(crate) struct VirtualBitRank { + // In order to look up bit i, use block_mapping[i / BLOCK_BITS] + block_mapping: Vec, + blocks: Vec, + // Remember which pages have been accessed. + stats: Vec>, +} + +impl VirtualBitRank { + pub(crate) fn new() -> Self { + Self::with_capacity(0) + } + + pub(crate) fn with_capacity(bits: usize) -> Self { + let bits = (bits + BLOCK_BITS - 1) & !(BLOCK_BITS - 1); + Self { + block_mapping: vec![0; bits / BLOCK_BITS], // 0 means unused block!! + blocks: Vec::with_capacity(bits / BLOCK_BITS), + stats: Vec::new(), + } + } + + pub(crate) fn reset_stats(&mut self) { + self.stats = vec![RefCell::new(0); self.blocks.len() * BLOCK_BITS / SUPER_PAGE_BITS + 1]; + } + + pub(crate) fn page_count(&self) -> (usize, usize) { + ( + self.stats + .iter() + .map(|v| v.borrow().count_ones() as usize) + .sum(), + (self.blocks.len() + BLOCKS_PER_PAGE - 1) / BLOCKS_PER_PAGE, + ) + } + + fn bit_to_block(&self, bit: usize) -> usize { + let block = bit / BLOCK_BITS; + let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1; + //let result = self.block_mapping[bit / BLOCK_BITS] as usize; + //assert_eq!(result2, result); + //if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) { + // *v.borrow_mut() += 1 << (result % 64); + //} + result2 + } + + fn mid_rank(&self, block: usize) -> u32 { + let first_block = block & !(BLOCKS_PER_PAGE - 1); + let array = self.blocks[first_block].words.as_ptr() as *const u32; + unsafe { array.add(block & (BLOCKS_PER_PAGE - 1)).read() } + } + + pub(crate) fn rank(&self, bit: usize) -> u32 { + let block = self.bit_to_block(bit); + let mut rank = self.mid_rank(block); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit_in_word = bit & (WORD_BITS - 1); + if word >= WORDS_PER_BLOCK / 2 { + for i in WORDS_PER_BLOCK / 2..word { + rank += self.blocks[block].words[i].count_ones(); + } + if bit_in_word != 0 { + rank + (self.blocks[block].words[word] << (WORD_BITS - bit_in_word)).count_ones() + } else { + rank + } + } else { + for i in word + 1..WORDS_PER_BLOCK / 2 { + rank -= self.blocks[block].words[i].count_ones(); + } + rank - (self.blocks[block].words[word] >> bit_in_word).count_ones() + } + } + + pub(crate) fn rank_with_word(&self, bit: usize) -> (u32, Word) { + let block = self.bit_to_block(bit); + let mut rank = self.mid_rank(block); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit_in_word = bit & (WORD_BITS - 1); + if word >= WORDS_PER_BLOCK / 2 { + for i in WORDS_PER_BLOCK / 2..word { + rank += self.blocks[block].words[i].count_ones(); + } + if bit_in_word != 0 { + ( + rank + (self.blocks[block].words[word] << (WORD_BITS - bit_in_word)) + .count_ones(), + self.blocks[block].words[word] >> bit_in_word, + ) + } else { + (rank, self.blocks[block].words[word]) + } + } else { + for i in word + 1..WORDS_PER_BLOCK / 2 { + rank -= self.blocks[block].words[i].count_ones(); + } + let w = self.blocks[block].words[word] >> bit_in_word; + (rank - w.count_ones(), w) + } + } + + pub(crate) fn reserve(&mut self, bits: usize) { + assert!(self.block_mapping.is_empty()); + assert!(self.blocks.is_empty()); + // let pages = (bits + PAGE_BITS - 1) / PAGE_BITS; + let blocks = (bits + BLOCK_BITS - 1) / BLOCK_BITS; + for _ in 0..blocks { + if self.blocks.len() % BLOCKS_PER_PAGE == 0 { + self.blocks.push(Block::default()); + } + self.block_mapping.push(self.blocks.len() as u32); + self.blocks.push(Block::default()); + } + } + + fn get_word_mut(&mut self, bit: usize) -> &mut Word { + let block = bit / BLOCK_BITS; + if block >= self.block_mapping.len() { + self.block_mapping.resize(block + 1, 0); + } + if self.block_mapping[block] == 0 { + if self.blocks.len() % BLOCKS_PER_PAGE == 0 { + self.blocks.push(Block::default()); // Block with rank information + } + self.block_mapping[block] = self.blocks.len() as u32; + self.blocks.push(Block::default()); + } + let block = self.bit_to_block(bit); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + &mut self.blocks[block].words[word] + } + + pub(crate) fn set(&mut self, bit: usize) { + *self.get_word_mut(bit) |= 1 << (bit & (WORD_BITS - 1)); + } + + pub(crate) fn set_nibble(&mut self, nibble_idx: usize, nibble_value: u32) { + let bit_idx = nibble_idx * 4; + // clear all bits... + // *self.get_word(bit_idx) &= !(15 << (bit_idx & (WORD_BITS - 1))); + *self.get_word_mut(bit_idx) |= (nibble_value as Word) << (bit_idx & (WORD_BITS - 1)); + } + + pub(crate) fn build(&mut self) { + for block in &mut self.block_mapping { + if *block == 0 { + if self.blocks.len() % BLOCKS_PER_PAGE == 0 { + self.blocks.push(Block::default()); // Block with rank information + } + *block = self.blocks.len() as u32; + self.blocks.push(Block::default()); + } + } + let mut ones = 0; + for block in 0..self.block_mapping.len() { + let block = self.block_mapping[block] as usize; + for i in 0..WORDS_PER_BLOCK / 2 { + ones += self.blocks[block].words[i].count_ones(); + } + unsafe { + let first_block = block & !(BLOCKS_PER_PAGE - 1); + let array = self.blocks[first_block].words.as_mut_ptr() as *mut u32; + let rank = array.add(block & (BLOCKS_PER_PAGE - 1)); + rank.write(ones); + } + for i in WORDS_PER_BLOCK / 2..WORDS_PER_BLOCK { + ones += self.blocks[block].words[i].count_ones(); + } + } + } + + pub(crate) fn get_word_suffix(&self, i: usize) -> Word { + let block = self.bit_to_block(i); + let word = (i / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit = i / WORD_BITS; + self.blocks[block].words[word] >> bit + } + + pub(crate) fn get_word(&self, i: usize) -> Word { + let block = self.bit_to_block(i); + let word = (i / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit = i % WORD_BITS; + let first_part = self.blocks[block].words[word] >> bit; + if bit == 0 { + first_part + } else { + let block = self.bit_to_block(i + 63); + let word = ((i + 63) / WORD_BITS) & (WORDS_PER_BLOCK - 1); + first_part | (self.blocks[block].words[word] << (WORD_BITS - bit)) + } + } + + pub(crate) fn get_bit(&self, bit: usize) -> bool { + let block = self.bit_to_block(bit); + let word = (bit / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit_in_word = bit & (WORD_BITS - 1); + self.blocks[block].words[word] & (1 << bit_in_word) != 0 + } + + pub(crate) fn get_nibble(&self, nibble_idx: usize) -> u32 { + let bit_idx = nibble_idx * 4; + let block = self.bit_to_block(bit_idx); + let word = (bit_idx / WORD_BITS) & (WORDS_PER_BLOCK - 1); + let bit_in_word = bit_idx & (WORD_BITS - 1); + ((self.blocks[block].words[word] >> bit_in_word) & 15) as u32 + } + + fn len(&self) -> usize { + self.block_mapping.len() * BLOCK_BITS + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use rand::{seq::SliceRandom, thread_rng, RngCore}; + + use super::{VirtualBitRank, BLOCK_BITS, WORD_BITS}; + + #[test] + fn test_bitrank() { + let mut bitrank = VirtualBitRank::with_capacity(1 << 20); + let mut rank = vec![]; + let mut bits = vec![]; + let mut ones = 0; + for i in 0..bitrank.len() { + let bit = thread_rng().next_u32() % 2 == 1; + rank.push(ones); + bits.push(bit); + if bit { + bitrank.set(i); + ones += 1; + } + } + bitrank.build(); + for (i, bit) in bits.iter().enumerate() { + assert_eq!(bitrank.get_bit(i), *bit, "at position {i}"); + } + for (i, r) in rank.iter().enumerate() { + assert_eq!(bitrank.rank(i), *r, "at position {i}"); + } + } + + /// This test emulates the reordering of blocks on disk. + /// With a real NVMe, the performance difference should be much larger. + /// But even this basic test shows that the access pattern matters. + /// I.e. throughput is 15% higher when access order is non-random. + #[test] + fn test_random_order() { + let mut bitrank = VirtualBitRank::with_capacity(1 << 20); + let random_bits: Vec<_> = (0..bitrank.len() / WORD_BITS) + .map(|_| thread_rng().next_u32()) + .flat_map(|i| { + [i, i + 1].into_iter().map(|i| { + (i & !(BLOCK_BITS - 1) as u32) % bitrank.len() as u32 + //(i & !(BLOCK_BITS - 1) as u32 + BLOCK_BITS as u32 / 2) % bitrank.len() as u32 + // i % bitrank.len() as u32 + }) + }) + .collect(); + for i in &random_bits { + bitrank.set(*i as usize); + } + bitrank.build(); + + let mut sorted_bits = random_bits.clone(); + sorted_bits.shuffle(&mut thread_rng()); + + for _ in 0..4 { + let time = Instant::now(); + for i in &random_bits { + assert!(bitrank.get_bit(*i as usize), "at position {i}"); + } + println!( + "time to check random bits: {:?} {:?}", + time.elapsed(), + time.elapsed() * 100 / random_bits.len() as u32 + ); + + let time = Instant::now(); + for i in &sorted_bits { + assert!(bitrank.get_bit(*i as usize), "at position {i}"); + } + println!( + "time to check sorted bits: {:?} {:?}", + time.elapsed(), + time.elapsed() * 100 / random_bits.len() as u32 + ); + } + } +}