Skip to content

Commit

Permalink
Optimize implementations and make blocks fully configurable.
Browse files Browse the repository at this point in the history
Fastest results are achieved with 32 bytes per block which results in a 12.5% overhead.
64 bytes per block seems like a reasonable compromise between speed and memory overhead (6.25%).
  • Loading branch information
aneubeck committed Oct 31, 2024
1 parent 16ce6ac commit 2711809
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 90 deletions.
168 changes: 93 additions & 75 deletions crates/quaternary_trie/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,86 +219,103 @@ impl QuarternaryTrie {
if level == 1 {
self.recurse2(node, value, results);
} else {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i));
return;
}
let mut r = self.data.rank(node * 4) as usize;
if n & 1 != 0 {
while n > 0 {
let delta = n.trailing_zeros();
r += 1;
self.recurse(r, level - 1, value, results);
}
if n & 2 != 0 {
r += 1;
self.recurse(r, level - 1, value + 1, results);
}
if n & 4 != 0 {
r += 1;
self.recurse(r, level - 1, value + 2, results);
}
if n & 8 != 0 {
r += 1;
self.recurse(r, level - 1, value + 3, results);
self.recurse(r, level - 1, value + delta, results);
value += delta + 1;
n >>= delta + 1;
}
}
}

#[inline(always)]
fn recurse2(&self, node: usize, value: u32, results: &mut Vec<u32>) {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..16).map(|i| (value << 2) + i));
return;
}
let mut r = self.data.rank(node * 4) as usize;
if n & 1 != 0 {
r += 1;
self.recurse0(r, value, results);
}
if n & 2 != 0 {
r += 1;
self.recurse0(r, value + 1, results);
}
if n & 4 != 0 {
r += 1;
self.recurse0(r, value + 2, results);
}
if n & 8 != 0 {
while n > 0 {
let delta = n.trailing_zeros();
r += 1;
self.recurse0(r, value + 3, results);
self.recurse0(r, value + delta, results);
value += delta + 1;
n >>= delta + 1;
}
}

#[inline(always)]
fn recurse0(&self, node: usize, value: u32, results: &mut Vec<u32>) {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..4).map(|i| value + i));
return;
}
if n & 1 != 0 {
results.push(value);
}
if n & 2 != 0 {
results.push(value + 1);
}
if n & 4 != 0 {
results.push(value + 2);
}
if n & 8 != 0 {
results.push(value + 3);
while n > 0 {
let delta = n.trailing_zeros();
results.push(value + delta);
value += delta + 1;
n >>= delta + 1;
}
}

pub fn collect(&self) -> Vec<u32> {
// This is the "slow" implementation which computes at every level the rank and extract the corresponding nibble.
pub fn collect2(&self) -> Vec<u32> {
let mut results = Vec::with_capacity(self.level_idx[0]);
self.recurse(0, MAX_LEVEL - 1, 0, &mut results);
results
}

// This is the "fastest" implementation, since it doesn't use rank information at all during the traversal.
// This is possible, since it iterates through ALL nodes and thus we can simply increment the positions by 1.
// We only need the rank information to initialize the positions.
// The only remaining "expensive" part here is the lookup of the nibble with every iteration.
// This lookup requires the slightly complicated conversion from position into block pointer (either via the virtual mapping or via some math).
// The math would be trivial if we wouldn't store the counters within the same page...
// Instead one could try to cache a u64 value and keep shifting until the end is reached. Or working with the pointer into the bitrank array.
pub fn collect(&mut self) -> Vec<u32> {
self.level_idx[MAX_LEVEL - 1] = 0;
for level in (1..MAX_LEVEL).into_iter().rev() {
self.level_idx[level - 1] = self.data.rank(self.level_idx[level] * 4) as usize + 1;
}
let mut results = Vec::new();
self.fast_collect_inner(MAX_LEVEL - 1, 0, &mut results);
results
}

fn fast_collect_inner(&mut self, level: usize, value: u32, results: &mut Vec<u32>) {
let mut nibble = self.data.get_nibble(self.level_idx[level]);
self.level_idx[level] += 1;
if nibble == 0 {
results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i));
return;
}
let mut value = value * 4;
if level == 0 {
while nibble > 0 {
let delta = nibble.trailing_zeros();
results.push(value + delta);
value += delta + 1;
nibble >>= delta + 1;
}
} else {
while nibble > 0 {
let delta = nibble.trailing_zeros();
self.fast_collect_inner(level - 1, value + delta, results);
value += delta + 1;
nibble >>= delta + 1;
}
}
}
}

pub trait TrieIteratorTrait {
Expand Down Expand Up @@ -328,53 +345,54 @@ impl TrieIteratorTrait for TrieTraversal<'_> {
fn down(&mut self, level: usize, child: u32) {
let index = self.pos[level] * 4 + child;
let new_index = self.trie.data.rank(index as usize + 1);
self.pos[level - 1] = new_index as u32;
self.pos[level - 1] = new_index;
}
}

pub struct TrieIterator<T> {
trie: T,
level: usize,
item: u32,
nibbles: [u32; MAX_LEVEL],
}

impl<T: TrieIteratorTrait> TrieIterator<T> {
pub fn new(trie: T) -> Self {
let mut result = Self {
Self {
trie,
level: MAX_LEVEL - 1,
item: 0,
nibbles: [0; MAX_LEVEL],
};
result.nibbles[result.level] = result.trie.get(result.level);
result
}
}
}

impl<'a, T: TrieIteratorTrait> Iterator for TrieIterator<T> {
type Item = u32;

fn next(&mut self) -> Option<u32> {
while self.level < MAX_LEVEL {
let child = (self.item >> (2 * self.level)) & 3;
let nibble = self.nibbles[self.level] >> child;
let mut level = if self.item == 0 {
self.nibbles[MAX_LEVEL - 1] = self.trie.get(MAX_LEVEL - 1);
MAX_LEVEL - 1
} else {
(self.item.trailing_zeros() / 2) as usize
};
while level < MAX_LEVEL {
let child = (self.item >> (2 * level)) & 3;
let nibble = self.nibbles[level] >> child;
if nibble != 0 {
let delta = nibble.trailing_zeros();
if self.level == 0 {
if level == 0 {
let res = self.item + delta;
self.item = res + 1;
self.level = (self.item.trailing_zeros() / 2) as usize;
return Some(res);
}
self.item += delta << (2 * self.level);
self.trie.down(self.level, child + delta);
self.level -= 1;
self.nibbles[self.level] = self.trie.get(self.level);
self.item += delta << (2 * level);
self.trie.down(level, child + delta);
level -= 1;
self.nibbles[level] = self.trie.get(level);
} else {
self.item |= 3 << (self.level * 2);
self.item += 1 << (self.level * 2);
self.level = (self.item.trailing_zeros() / 2) as usize;
self.item |= 3 << (level * 2);
self.item += 1 << (level * 2);
level = (self.item.trailing_zeros() / 2) as usize;
}
}
None
Expand Down Expand Up @@ -413,13 +431,13 @@ mod tests {
use crate::{Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal};

#[test]
fn test_bpt() {
fn test_trie() {
let values = vec![3, 6, 7, 10];
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);

let values: Vec<_> = (1..63).collect();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);
}

Expand All @@ -432,7 +450,7 @@ mod tests {
values.dedup();

let start = Instant::now();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
println!("construction {:?}", start.elapsed() / values.len() as u32);

let start = Instant::now();
Expand All @@ -450,21 +468,21 @@ mod tests {
#[test]
fn test_van_emde_boas_layout() {
let values: Vec<_> = (0..64).collect();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);
}

#[test]
fn test_intersection() {
let mut page_counts = [0, 0, 0];
for _ in 0..3 {
let mut values: Vec<_> = (0..100000)
let mut values: Vec<_> = (0..10000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values.sort();
values.dedup();

let mut values2: Vec<_> = (0..100000000)
let mut values2: Vec<_> = (0..10000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values2.sort();
Expand Down Expand Up @@ -495,7 +513,7 @@ mod tests {
let result: Vec<_> = iter.collect();
let count = trie.page_count();
let count2 = trie2.page_count();
page_counts[i] += count.0 + count.1;
page_counts[i] += count.0 + count2.0;
println!(
"trie intersection {:?} {}",
start.elapsed() / values.len() as u32,
Expand Down
33 changes: 18 additions & 15 deletions crates/quaternary_trie/src/virtual_bitrank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@

use std::cell::RefCell;

const BLOCK_BYTES: usize = 128;
type Word = u64;

const BLOCK_BYTES: usize = 64;
const BLOCK_BITS: usize = BLOCK_BYTES * 8;
const PAGE_BYTES: usize = 4096;
const PAGE_BITS: usize = PAGE_BYTES * 8;
const BLOCKS_PER_PAGE: usize = PAGE_BYTES / BLOCK_BYTES;
const WORD_BITS: usize = 64;
const WORD_BYTES: usize = WORD_BITS / 8;
const BLOCKS_PER_PAGE: usize = BLOCK_BYTES / 4;
const WORD_BITS: usize = WORD_BYTES * 8;
const WORD_BYTES: usize = std::mem::size_of::<Word>();
const WORDS_PER_BLOCK: usize = BLOCK_BYTES / WORD_BYTES;
const PAGE_BYTES: usize = BLOCKS_PER_PAGE * BLOCK_BYTES;
const PAGE_BITS: usize = PAGE_BYTES * 8;
const SUPER_PAGE_BITS: usize = 4096 * 8;

#[repr(C, align(128))]
#[derive(Default, Clone)]
struct Block {
words: [u64; WORDS_PER_BLOCK],
words: [Word; WORDS_PER_BLOCK],
}

#[derive(Default)]
Expand All @@ -77,10 +80,7 @@ impl VirtualBitRank {
}

pub(crate) fn reset_stats(&mut self) {
self.stats = vec![
RefCell::new(0);
((self.blocks.len() + BLOCKS_PER_PAGE - 1) / BLOCKS_PER_PAGE + 63) / 64
];
self.stats = vec![RefCell::new(0); self.blocks.len() * BLOCK_BITS / SUPER_PAGE_BITS + 1];
}

pub(crate) fn page_count(&self) -> (usize, usize) {
Expand All @@ -94,10 +94,13 @@ impl VirtualBitRank {
}

fn bit_to_block(&self, bit: usize) -> usize {
//let block = bit / BLOCK_BITS;
//let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1;
let result = self.block_mapping[bit / BLOCK_BITS] as usize;
/*if let Some(v) = self.stats.get(result / PAGE_BITS / 64) {
//assert_eq!(result2, result);
if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) {
*v.borrow_mut() += 1 << (result % 64);
}*/
}
result
}

Expand Down Expand Up @@ -143,7 +146,7 @@ impl VirtualBitRank {
}
}

fn get_word_mut(&mut self, bit: usize) -> &mut u64 {
fn get_word_mut(&mut self, bit: usize) -> &mut Word {
let block = bit / BLOCK_BITS;
if block >= self.block_mapping.len() {
self.block_mapping.resize(block + 1, 0);
Expand All @@ -168,7 +171,7 @@ impl VirtualBitRank {
let bit_idx = nibble_idx * 4;
// clear all bits...
// *self.get_word(bit_idx) &= !(15 << (bit_idx & (WORD_BITS - 1)));
*self.get_word_mut(bit_idx) |= (nibble_value as u64) << (bit_idx & (WORD_BITS - 1));
*self.get_word_mut(bit_idx) |= (nibble_value as Word) << (bit_idx & (WORD_BITS - 1));
}

pub(crate) fn build(&mut self) {
Expand Down

0 comments on commit 2711809

Please sign in to comment.