diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 04d4fe6..7882b5f 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -166,6 +166,32 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 } +/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions +/// when constructing a [`BytePairEncoding`] from those tokens. +#[cfg(all(feature = "tiktoken-rs", feature = "rand"))] +pub fn find_hash_factor_from_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 { + find_hash_factor(|i| bpe._decode_native(&[i]), len) +} + +/// Find a suitable hash factor for a set of given tokens that prevents collisions when +/// constructing a [`BytePairEncoding`] from those tokens. +#[cfg(feature = "rand")] +pub fn find_hash_factor(tokens: impl Fn(usize) -> Vec, len: usize) -> u64 { + use std::collections::HashSet; + + use rand::Rng; + + let mut rnd = rand::thread_rng(); + loop { + let factor: u64 = rnd.gen(); + let mut seen = HashSet::with_capacity(len); + if (0..len).all(|i| seen.insert(hash_bytes(&tokens(i), factor))) { + println!("hash factor: {factor}"); + return factor; + } + } +} + fn find_token_by_bytes( all_tokens: &[u8], token_starts: &[u32], @@ -191,8 +217,12 @@ impl BytePairEncoding { &BPE_O200K } - /// Construct a BytePairEncoding instance frmo a tiktoken dictionary. - /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test. + /// Construct a BytePairEncoding instance from a tiktoken dictionary. + /// A suitable hash factor may be necessary to prevent hash collisions, + /// which can by found using [`crate::data::find_hash_factor_from_tiktoken`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. #[cfg(feature = "tiktoken-rs")] pub fn from_tiktoken( tiktoken_bpe: &tiktoken_rs::CoreBPE, @@ -205,8 +235,12 @@ impl BytePairEncoding { ) } - /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens. - /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test. + /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens. + /// A suitable hash factor may be necessary to prevent hash collisions, which can be + /// found using [`crate::data::find_hash_factor`]. + /// + /// The recommended approach is to store the serialized value and reuse that, + /// to prevent repeating the cost of computing the hash factor and encoding. pub fn from_dictionary(iter: impl Iterator>, hash_factor: Option) -> Self { let hash_factor = hash_factor .inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero")) @@ -574,53 +608,25 @@ mod tests { #[cfg(test)] mod data { - use std::collections::HashSet; use std::fs::File; use std::path::PathBuf; - use rand::Rng; use serde::Serialize; - use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE}; - - use super::*; - const BPE_CL100K_LEN: usize = 100256; - const BPE_O200K_LEN: usize = 199998; - - /// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions. - /// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for. - /// 2. Update the hash factor in [`hash_bytes`]. - /// 3. Run [`update_token_dicts`] tests below to update data files. - /// Note: If you forget this, the next test run will update the files, but - /// all other tests might fail because the data was not up-to-date. - #[test] - #[ignore = "run manually to find a suitable hash factor"] - #[allow(unreachable_code, unused_variables)] - fn find_hash_factor() { - let (bpe, len): (CoreBPE, _) = todo!("replace with BPE instance and token count"); - let mut rnd = rand::thread_rng(); - loop { - let factor: u64 = rnd.gen(); - let mut seen = HashSet::with_capacity(len); - if (0..len).all(|i| seen.insert(hash_bytes(&bpe._decode_native(&[i]), factor))) { - println!("hash factor: {factor}"); - return; - } - } - } + use crate::byte_pair_encoding::BytePairEncoding; #[test] fn update_token_dicts() { serialize_tokens( "cl100k", - &cl100k_base().expect("tiktoken initialization must not fail!"), - BPE_CL100K_LEN, + &tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"), + 100256, 17846336922010275747, ); serialize_tokens( "o200k", - &o200k_base().expect("tiktoken initialization must not fail!"), - BPE_O200K_LEN, + &tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"), + 199998, 17846336922010275747, ); }