Skip to content

Commit

Permalink
Expose _decode_native methods
Browse files Browse the repository at this point in the history
  • Loading branch information
zurawiki committed Oct 21, 2023
1 parent e9708b1 commit 664a37f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions tiktoken-rs/src/vendor_tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ impl CoreBPE {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}

fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
/// Given a list of tokens, return a vector of bytes
///
/// The output is NOT guaranteed to be valid UTF-8
pub fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
let token_bytes = self
Expand All @@ -235,12 +238,11 @@ impl CoreBPE {
ret
}

#[allow(clippy::needless_lifetimes)] // the iterator captures a lifetime outside of the function
fn _decode_native_and_split<'a>(
&'a self,
pub fn _decode_native_and_split(
&self,
tokens: Vec<usize>,
) -> impl Iterator<Item = Vec<u8>> + '_ {
tokens.into_iter().map(move |token| {
tokens.into_iter().map(|token| {
let token_bytes = self
.decoder
.get(&token)
Expand Down Expand Up @@ -549,6 +551,9 @@ impl CoreBPE {
// Decoding
// ====================

/// Decode a vector of tokens into a valid UTF-8 String
///
/// If unicode validation is not wanted, see _decode_native.
pub fn decode(&self, tokens: Vec<usize>) -> Result<String> {
match String::from_utf8(self._decode_native(&tokens)) {
Ok(text) => Ok(text),
Expand Down

0 comments on commit 664a37f

Please sign in to comment.