Skip to content

Commit 89ac128

Browse files
committed
add count_till_limit function
1 parent f98cbd4 commit 89ac128

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,35 @@ impl BytePairEncoding {
335335
last_token
336336
}
337337

338+
/// Counts the number tokens produced when encoding the text.
338339
pub fn count(&self, text: &[u8]) -> usize {
339340
let mut enc = BacktrackEncoder::new(self, text);
340341
while enc.step().is_some() {}
341342
enc.count()
342343
}
343344

345+
/// Returns the token count iff the total token count stays below the specified `token_limit`.
346+
/// Otherwise, it returns false.
347+
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
348+
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
349+
let mut enc = BacktrackEncoder::new(self, text);
350+
// When the text has exactly the desired number of tokens, then it could in theory happen that
351+
// the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
352+
// To be on the "safe" side, we add a little buffer for such cases.
353+
// TODO: Determine exactly how large this buffer must be in the worst case.
354+
let limit_with_buffer = token_limit.saturating_add(10);
355+
while enc.step().is_some() {
356+
if enc.count() > limit_with_buffer {
357+
return None;
358+
}
359+
}
360+
if enc.count() <= token_limit {
361+
Some(enc.count())
362+
} else {
363+
None
364+
}
365+
}
366+
344367
pub fn encode_via_table(&self, text: &[u8]) -> Vec<u32> {
345368
let last_token = self.encode_all_prefixes(text);
346369
let mut encoded = Vec::with_capacity(text.len() / 3);

0 commit comments

Comments
 (0)