Skip to content

Commit

Permalink
add constructor from dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Aug 16, 2024
1 parent 2797581 commit 1e9fbbb
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::time::Instant;

use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use fnv::{FnvHashMap, FnvHasher};
Expand Down Expand Up @@ -181,32 +180,32 @@ impl BytePairEncoding {
&BPE_CL100K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
let start = Instant::now();
println!("loaded tiktoken: {:?}", start.elapsed());
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
}

/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
pub fn from_dictionary<'a>(iter: impl Iterator<Item = Vec<u8>>) -> Self {
let mut all_tokens = Vec::new();
let mut all_tokens_rev = Vec::new();
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
for i in 0..num_tokens {
let token = tiktoken_bpe._decode_native(&[i]);
for (i, token) in iter.enumerate() {
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
all_tokens_rev.extend(token.iter().copied().rev());
all_tokens.extend(token);
token_starts.push(all_tokens.len() as u32);
}
assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len());
println!("copied tokens: {:?}", start.elapsed());

let longest_searcher = DoubleArrayAhoCorasickBuilder::new()
.match_kind(daachorse::MatchKind::LeftmostLongest)
.build(token_iter(&all_tokens, &token_starts))
.expect("failed to build AhoCorasick");
println!("constructed longest searcher: {:?}", start.elapsed());

let overlapping_searcher =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
println!("constructed overlapping searcher: {:?}", start.elapsed());
let overlapping_searcher_rev =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens_rev, &token_starts))
.expect("");
Expand All @@ -216,7 +215,6 @@ impl BytePairEncoding {
next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX)
})
.collect();
println!("constructed next_prefix_match: {:?}", start.elapsed());

let mut split_table = vec![];
let mut pair_lookup = FnvHashMap::default();
Expand All @@ -243,8 +241,6 @@ impl BytePairEncoding {
split_table.push((id as u32, id as u32));
}
}
println!("constructed split table: {:?}", start.elapsed());

Self {
all_tokens,
token_starts,
Expand Down

0 comments on commit 1e9fbbb

Please sign in to comment.