diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index bed8ec2..c1df5b6 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -2,7 +2,6 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::hash::{Hash, Hasher}; use std::ops::Range; -use std::time::Instant; use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; use fnv::{FnvHashMap, FnvHasher}; @@ -181,32 +180,32 @@ impl BytePairEncoding { &BPE_CL100K } + /// Construct a BytePairEncoding instance frmo a tiktoken dictionary. pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self { - let start = Instant::now(); - println!("loaded tiktoken: {:?}", start.elapsed()); + Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i]))) + } + + /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens. + pub fn from_dictionary(iter: impl Iterator>) -> Self { let mut all_tokens = Vec::new(); let mut all_tokens_rev = Vec::new(); let mut token_starts = vec![0]; let mut bytes_hash_to_token = FnvHashMap::default(); - for i in 0..num_tokens { - let token = tiktoken_bpe._decode_native(&[i]); + for (i, token) in iter.enumerate() { bytes_hash_to_token.insert(hash_bytes(&token), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); all_tokens.extend(token); token_starts.push(all_tokens.len() as u32); } assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len()); - println!("copied tokens: {:?}", start.elapsed()); let longest_searcher = DoubleArrayAhoCorasickBuilder::new() .match_kind(daachorse::MatchKind::LeftmostLongest) .build(token_iter(&all_tokens, &token_starts)) .expect("failed to build AhoCorasick"); - println!("constructed longest searcher: {:?}", start.elapsed()); let overlapping_searcher = DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); - println!("constructed overlapping searcher: {:?}", start.elapsed()); let overlapping_searcher_rev = DoubleArrayAhoCorasick::::new(token_iter(&all_tokens_rev, &token_starts)) .expect(""); @@ -216,7 +215,6 @@ impl BytePairEncoding { next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX) }) .collect(); - println!("constructed next_prefix_match: {:?}", start.elapsed()); let mut split_table = vec![]; let mut pair_lookup = FnvHashMap::default(); @@ -243,8 +241,6 @@ impl BytePairEncoding { split_table.push((id as u32, id as u32)); } } - println!("constructed split table: {:?}", start.elapsed()); - Self { all_tokens, token_starts, @@ -339,12 +335,35 @@ impl BytePairEncoding { last_token } + /// Counts the number tokens produced when encoding the text. pub fn count(&self, text: &[u8]) -> usize { let mut enc = BacktrackEncoder::new(self, text); while enc.step().is_some() {} enc.count() } + /// Returns the token count iff the total token count stays below the specified `token_limit`. + /// Otherwise, it returns false. + /// This function can be faster than `count` when the token_limit is much smaller than the provided text. + pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option { + let mut enc = BacktrackEncoder::new(self, text); + // When the text has exactly the desired number of tokens, then it could in theory happen that + // the token_limit is exceeded before the end of the text is reached (and a different encoding is tested). + // To be on the "safe" side, we add a little buffer for such cases. + // TODO: Determine exactly how large this buffer must be in the worst case. + let limit_with_buffer = token_limit.saturating_add(10); + while enc.step().is_some() { + if enc.count() > limit_with_buffer { + return None; + } + } + if enc.count() <= token_limit { + Some(enc.count()) + } else { + None + } + } + pub fn encode_via_table(&self, text: &[u8]) -> Vec { let last_token = self.encode_all_prefixes(text); let mut encoded = Vec::with_capacity(text.len() / 3);