diff --git a/tiktoken-rs/src/vendor_tiktoken.rs b/tiktoken-rs/src/vendor_tiktoken.rs index fa65695..35dc472 100644 --- a/tiktoken-rs/src/vendor_tiktoken.rs +++ b/tiktoken-rs/src/vendor_tiktoken.rs @@ -223,7 +223,10 @@ impl CoreBPE { &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } - fn _decode_native(&self, tokens: &[usize]) -> Vec { + /// Given a list of tokens, return a vector of bytes + /// + /// The output is NOT guaranteed to be valid UTF-8 + pub fn _decode_native(&self, tokens: &[usize]) -> Vec { let mut ret = Vec::with_capacity(tokens.len() * 2); for token in tokens { let token_bytes = self @@ -235,12 +238,11 @@ impl CoreBPE { ret } - #[allow(clippy::needless_lifetimes)] // the iterator captures a lifetime outside of the function - fn _decode_native_and_split<'a>( - &'a self, + pub fn _decode_native_and_split( + &self, tokens: Vec, ) -> impl Iterator> + '_ { - tokens.into_iter().map(move |token| { + tokens.into_iter().map(|token| { let token_bytes = self .decoder .get(&token) @@ -549,6 +551,9 @@ impl CoreBPE { // Decoding // ==================== + /// Decode a vector of tokens into a valid UTF-8 String + /// + /// If unicode validation is not wanted, see _decode_native. pub fn decode(&self, tokens: Vec) -> Result { match String::from_utf8(self._decode_native(&tokens)) { Ok(text) => Ok(text),