@@ -335,12 +335,35 @@ impl BytePairEncoding {
335335 last_token
336336 }
337337
338+ /// Counts the number tokens produced when encoding the text.
338339 pub fn count ( & self , text : & [ u8 ] ) -> usize {
339340 let mut enc = BacktrackEncoder :: new ( self , text) ;
340341 while enc. step ( ) . is_some ( ) { }
341342 enc. count ( )
342343 }
343344
345+ /// Returns the token count iff the total token count stays below the specified `token_limit`.
346+ /// Otherwise, it returns false.
347+ /// This function can be faster than `count` when the token_limit is much smaller than the provided text.
348+ pub fn count_till_limit ( & self , text : & [ u8 ] , token_limit : usize ) -> Option < usize > {
349+ let mut enc = BacktrackEncoder :: new ( self , text) ;
350+ // When the text has exactly the desired number of tokens, then it could in theory happen that
351+ // the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
352+ // To be on the "safe" side, we add a little buffer for such cases.
353+ // TODO: Determine exactly how large this buffer must be in the worst case.
354+ let limit_with_buffer = token_limit. saturating_add ( 10 ) ;
355+ while enc. step ( ) . is_some ( ) {
356+ if enc. count ( ) > limit_with_buffer {
357+ return None ;
358+ }
359+ }
360+ if enc. count ( ) <= token_limit {
361+ Some ( enc. count ( ) )
362+ } else {
363+ None
364+ }
365+ }
366+
344367 pub fn encode_via_table ( & self , text : & [ u8 ] ) -> Vec < u32 > {
345368 let last_token = self . encode_all_prefixes ( text) ;
346369 let mut encoded = Vec :: with_capacity ( text. len ( ) / 3 ) ;
0 commit comments