|
| 1 | +//! Benchmarking tool for hriblt set reconciliation. |
| 2 | +//! |
| 3 | +//! This tool runs trials to measure the success rate of decoding set differences. |
| 4 | +
|
| 5 | +use std::{collections::HashSet, ops::Range, str::FromStr}; |
| 6 | + |
| 7 | +use clap::{Parser, ValueEnum}; |
| 8 | +use rand::prelude::*; |
| 9 | + |
| 10 | +use hriblt::{DecodedValue, DecodingSession, DefaultHashFunctions, EncodingSession}; |
| 11 | + |
| 12 | +/// A diff size specification that can be a single value or a range. |
| 13 | +#[derive(Debug, Clone)] |
| 14 | +struct DiffSizeSpec { |
| 15 | + range: Range<u32>, |
| 16 | +} |
| 17 | + |
| 18 | +impl FromStr for DiffSizeSpec { |
| 19 | + type Err = String; |
| 20 | + |
| 21 | + fn from_str(s: &str) -> Result<Self, Self::Err> { |
| 22 | + // Try parsing as a range first (e.g., "1..10" or "1..=10") |
| 23 | + if let Some((start, end)) = s.split_once("..=") { |
| 24 | + let start: u32 = start.parse().map_err(|_| format!("invalid range start: {}", start))?; |
| 25 | + let end: u32 = end.parse().map_err(|_| format!("invalid range end: {}", end))?; |
| 26 | + if start > end { |
| 27 | + return Err(format!("range start {} must be <= end {}", start, end)); |
| 28 | + } |
| 29 | + return Ok(DiffSizeSpec { range: start..end + 1 }); |
| 30 | + } |
| 31 | + if let Some((start, end)) = s.split_once("..") { |
| 32 | + let start: u32 = start.parse().map_err(|_| format!("invalid range start: {}", start))?; |
| 33 | + let end: u32 = end.parse().map_err(|_| format!("invalid range end: {}", end))?; |
| 34 | + if start >= end { |
| 35 | + return Err(format!("range start {} must be < end {}", start, end)); |
| 36 | + } |
| 37 | + return Ok(DiffSizeSpec { range: start..end }); |
| 38 | + } |
| 39 | + // Otherwise parse as a single value |
| 40 | + let val: u32 = s.parse().map_err(|_| format!("invalid diff size: {}", s))?; |
| 41 | + Ok(DiffSizeSpec { range: val..val + 1 }) |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +/// How to iterate through the diff size range. |
| 46 | +#[derive(Debug, Clone, Copy, Default, ValueEnum)] |
| 47 | +enum DiffSizeMode { |
| 48 | + /// Pick a random value from the range for each trial |
| 49 | + #[default] |
| 50 | + Random, |
| 51 | + /// Iterate through the range incrementally, looping if needed |
| 52 | + Incremental, |
| 53 | +} |
| 54 | + |
| 55 | +/// Iterator over diff sizes based on the mode. |
| 56 | +enum DiffSizeIter { |
| 57 | + Random { |
| 58 | + range: Range<u32>, |
| 59 | + }, |
| 60 | + Incremental { |
| 61 | + range: Range<u32>, |
| 62 | + current: u32, |
| 63 | + }, |
| 64 | +} |
| 65 | + |
| 66 | +impl DiffSizeIter { |
| 67 | + fn new(spec: &DiffSizeSpec, mode: DiffSizeMode) -> Self { |
| 68 | + match mode { |
| 69 | + DiffSizeMode::Random => DiffSizeIter::Random { |
| 70 | + range: spec.range.clone(), |
| 71 | + }, |
| 72 | + DiffSizeMode::Incremental => DiffSizeIter::Incremental { |
| 73 | + range: spec.range.clone(), |
| 74 | + current: spec.range.start, |
| 75 | + }, |
| 76 | + } |
| 77 | + } |
| 78 | + |
| 79 | + fn next_diff_size<R: Rng + ?Sized>(&mut self, rng: &mut R) -> u32 { |
| 80 | + match self { |
| 81 | + DiffSizeIter::Random { range } => rng.random_range(range.clone()), |
| 82 | + DiffSizeIter::Incremental { range, current } => { |
| 83 | + let val = *current; |
| 84 | + *current += 1; |
| 85 | + if *current >= range.end { |
| 86 | + *current = range.start; |
| 87 | + } |
| 88 | + val |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +#[derive(Parser, Debug)] |
| 95 | +#[command(name = "hriblt-bench")] |
| 96 | +#[command(about = "Run reconciliation trials to measure decoding success rate")] |
| 97 | +struct Args { |
| 98 | + /// Number of trials to run |
| 99 | + #[arg(short, long, default_value_t = 100)] |
| 100 | + trials: u32, |
| 101 | + |
| 102 | + /// Size of each set (number of elements) |
| 103 | + #[arg(short, long, default_value_t = 1000)] |
| 104 | + set_size: u32, |
| 105 | + |
| 106 | + /// Number of differences between the sets (single value or range like "1..10" or "1..=10") |
| 107 | + #[arg(short, long, default_value = "10")] |
| 108 | + diff_size: DiffSizeSpec, |
| 109 | + |
| 110 | + /// How to select diff sizes from a range |
| 111 | + #[arg(long, value_enum, default_value_t = DiffSizeMode::Random)] |
| 112 | + diff_mode: DiffSizeMode, |
| 113 | + |
| 114 | + /// Multiplier for max symbols to try (max_symbols = diff_size * multiplier) |
| 115 | + #[arg(short, long, default_value_t = 10)] |
| 116 | + multiplier: u32, |
| 117 | + |
| 118 | + /// Random seed (optional, for reproducibility) |
| 119 | + #[arg(long)] |
| 120 | + seed: Option<u64>, |
| 121 | + |
| 122 | + /// Print each trial as a TSV row to stdout |
| 123 | + #[arg(long)] |
| 124 | + tsv: bool, |
| 125 | +} |
| 126 | + |
| 127 | +/// Result of a single trial |
| 128 | +struct TrialResult { |
| 129 | + success: bool, |
| 130 | + coded_symbols: Option<usize>, |
| 131 | +} |
| 132 | + |
| 133 | +fn run_trial<R: Rng + ?Sized>( |
| 134 | + rng: &mut R, |
| 135 | + set_size: u32, |
| 136 | + diff_size: u32, |
| 137 | + max_symbols: usize, |
| 138 | +) -> TrialResult { |
| 139 | + // Ensure we have at least 32 symbols to work with |
| 140 | + let max_symbols = max_symbols.max(32); |
| 141 | + |
| 142 | + // Generate base set of random u64 values |
| 143 | + let base_set: HashSet<u64> = (0..set_size).map(|_| rng.random()).collect(); |
| 144 | + |
| 145 | + // Create set A as the base set |
| 146 | + let set_a: Vec<u64> = base_set.iter().copied().collect(); |
| 147 | + |
| 148 | + // Create set B by removing some elements and adding new ones |
| 149 | + let mut set_b: HashSet<u64> = base_set.clone(); |
| 150 | + |
| 151 | + // Remove diff_size/2 elements from set B |
| 152 | + let removals = diff_size / 2; |
| 153 | + let additions = diff_size - removals; |
| 154 | + |
| 155 | + let mut to_remove: Vec<u64> = set_b.iter().copied().collect(); |
| 156 | + to_remove.shuffle(rng); |
| 157 | + for val in to_remove.into_iter().take(removals as usize) { |
| 158 | + set_b.remove(&val); |
| 159 | + } |
| 160 | + |
| 161 | + // Add diff_size - removals new elements to set B |
| 162 | + for _ in 0..additions { |
| 163 | + loop { |
| 164 | + let new_val: u64 = rng.random(); |
| 165 | + if !base_set.contains(&new_val) && set_b.insert(new_val) { |
| 166 | + break; |
| 167 | + } |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + let set_b: Vec<u64> = set_b.into_iter().collect(); |
| 172 | + |
| 173 | + // Create encoding sessions for both sets with max capacity |
| 174 | + let state = DefaultHashFunctions; |
| 175 | + |
| 176 | + let mut encoder_a = EncodingSession::new(state, 0..max_symbols); |
| 177 | + encoder_a.extend(set_a.iter().copied()); |
| 178 | + |
| 179 | + let mut encoder_b = EncodingSession::new(state, 0..max_symbols); |
| 180 | + encoder_b.extend(set_b.iter().copied()); |
| 181 | + |
| 182 | + // Merge the two encodings (negated to get the difference) |
| 183 | + let mut merged = encoder_a.merge(encoder_b, true); |
| 184 | + |
| 185 | + // Start with 1x the diff size, grow by 10% until success or max |
| 186 | + let mut current_symbols = (diff_size as usize).max(1); |
| 187 | + let mut decoding_session = DecodingSession::new(state); |
| 188 | + |
| 189 | + while current_symbols <= max_symbols { |
| 190 | + // Split off symbols up to current_symbols |
| 191 | + let chunk_start = decoding_session.consumed_coded_symbols(); |
| 192 | + let chunk_end = current_symbols.min(max_symbols); |
| 193 | + |
| 194 | + if chunk_end > chunk_start { |
| 195 | + let chunk = merged.split_off(chunk_end - chunk_start); |
| 196 | + decoding_session.append(chunk); |
| 197 | + } |
| 198 | + |
| 199 | + if decoding_session.is_done() { |
| 200 | + let coded_symbols = decoding_session.consumed_coded_symbols(); |
| 201 | + // Verify the decoded difference matches expected |
| 202 | + let decoded: HashSet<_> = decoding_session |
| 203 | + .into_decoded_iter() |
| 204 | + .map(|v| match v { |
| 205 | + DecodedValue::Addition(x) | DecodedValue::Deletion(x) => x, |
| 206 | + }) |
| 207 | + .collect(); |
| 208 | + |
| 209 | + return TrialResult { |
| 210 | + success: decoded.len() == diff_size as usize, |
| 211 | + coded_symbols: Some(coded_symbols), |
| 212 | + }; |
| 213 | + } |
| 214 | + |
| 215 | + // Grow by 10%, but at least 1 |
| 216 | + let growth = (current_symbols / 10).max(1); |
| 217 | + current_symbols += growth; |
| 218 | + } |
| 219 | + |
| 220 | + TrialResult { |
| 221 | + success: false, |
| 222 | + coded_symbols: None, |
| 223 | + } |
| 224 | +} |
| 225 | + |
| 226 | +fn main() { |
| 227 | + let args = Args::parse(); |
| 228 | + |
| 229 | + let mut rng: Box<dyn RngCore> = match args.seed { |
| 230 | + Some(seed) => Box::new(StdRng::seed_from_u64(seed)), |
| 231 | + None => Box::new(rand::rng()), |
| 232 | + }; |
| 233 | + |
| 234 | + let is_range = args.diff_size.range.end - args.diff_size.range.start > 1; |
| 235 | + let range_desc = if is_range { |
| 236 | + format!("{}..{}", args.diff_size.range.start, args.diff_size.range.end) |
| 237 | + } else { |
| 238 | + format!("{}", args.diff_size.range.start) |
| 239 | + }; |
| 240 | + |
| 241 | + eprintln!("Running {} trials...", args.trials); |
| 242 | + eprintln!(" Set size: {}", args.set_size); |
| 243 | + eprintln!(" Diff size: {} ({:?})", range_desc, args.diff_mode); |
| 244 | + eprintln!(" Max symbols multiplier: {}x", args.multiplier); |
| 245 | + eprintln!(); |
| 246 | + |
| 247 | + if args.tsv { |
| 248 | + println!("trial\tset_size\tdiff_size\tsuccess\tcoded_symbols\toverhead"); |
| 249 | + } |
| 250 | + |
| 251 | + let mut diff_iter = DiffSizeIter::new(&args.diff_size, args.diff_mode); |
| 252 | + |
| 253 | + let mut successes = 0; |
| 254 | + let mut failures = 0; |
| 255 | + |
| 256 | + for i in 0..args.trials { |
| 257 | + let diff_size = diff_iter.next_diff_size(&mut *rng); |
| 258 | + let max_symbols = (diff_size * args.multiplier) as usize; |
| 259 | + let result = run_trial(&mut *rng, args.set_size, diff_size, max_symbols); |
| 260 | + |
| 261 | + if args.tsv { |
| 262 | + let coded_symbols_str = result |
| 263 | + .coded_symbols |
| 264 | + .map(|n| n.to_string()) |
| 265 | + .unwrap_or_default(); |
| 266 | + let overhead_str = result |
| 267 | + .coded_symbols |
| 268 | + .map(|n| { |
| 269 | + if diff_size > 0 { |
| 270 | + format!("{:.2}", n as f64 / diff_size as f64) |
| 271 | + } else { |
| 272 | + String::new() |
| 273 | + } |
| 274 | + }) |
| 275 | + .unwrap_or_default(); |
| 276 | + println!( |
| 277 | + "{}\t{}\t{}\t{}\t{}\t{}", |
| 278 | + i + 1, |
| 279 | + args.set_size, |
| 280 | + diff_size, |
| 281 | + result.success, |
| 282 | + coded_symbols_str, |
| 283 | + overhead_str |
| 284 | + ); |
| 285 | + } |
| 286 | + |
| 287 | + if result.success { |
| 288 | + successes += 1; |
| 289 | + } else { |
| 290 | + failures += 1; |
| 291 | + if !args.tsv && failures <= 5 { |
| 292 | + eprintln!("Trial {} failed (diff_size={})", i + 1, diff_size); |
| 293 | + } |
| 294 | + } |
| 295 | + } |
| 296 | + |
| 297 | + eprintln!(); |
| 298 | + eprintln!("Results:"); |
| 299 | + eprintln!(" Successes: {}/{} ({:.1}%)", successes, args.trials, |
| 300 | + 100.0 * successes as f64 / args.trials as f64); |
| 301 | + eprintln!(" Failures: {}/{} ({:.1}%)", failures, args.trials, |
| 302 | + 100.0 * failures as f64 / args.trials as f64); |
| 303 | +} |
0 commit comments