Skip to content

Commit

Permalink
switch to faster aho corasick library. ~30-50% speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Jul 16, 2024
1 parent 1ce7e29 commit e266d3c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 28 deletions.
2 changes: 1 addition & 1 deletion crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ crate-type = ["lib", "staticlib"]
bench = false

[dependencies]
aho-corasick = "1"
daachorse = "1"
fnv = "1.0"
itertools = "0.12"
rand = "0.8"
Expand Down
38 changes: 12 additions & 26 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::time::Instant;

use aho_corasick::{
AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, Anchored, Input, MatchKind, StartKind,
};
use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use fnv::{FnvHashMap, FnvHasher};
use itertools::Itertools;
use tiktoken_rs::cl100k_base;
Expand All @@ -31,9 +29,9 @@ pub struct BytePairEncoding {
/// Mapping from a pair of tokens to a merged token if such a merged token exists.
pair_lookup: FnvHashMap<(u32, u32), u32>,
/// An aho corasick automaton to find the next longest token in a byte sequence.
longest_searcher: AhoCorasick,
longest_searcher: DoubleArrayAhoCorasick<u32>,
/// An aho corasick automaton to find ALL tokens in a byte sequence.
overlapping_searcher: AhoCorasick,
overlapping_searcher: DoubleArrayAhoCorasick<u32>,
/// Mapping from a token to the next longest prefix token.
/// This is in principle information represented by the AhoCorasick automaton.
/// But we don't have efficient access to it and therefore store it here again.
Expand All @@ -48,11 +46,10 @@ fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterato
.map(|(start, end)| &all_tokens[*start as usize..*end as usize])
}

fn next_match(longest_searcher: &AhoCorasick, text: &[u8]) -> Option<u32> {
fn next_match(longest_searcher: &DoubleArrayAhoCorasick<u32>, text: &[u8]) -> Option<u32> {
longest_searcher
.find_iter(Input::new(text).anchored(Anchored::Yes))
.filter(|m| !m.span().is_empty())
.map(|m| m.pattern().as_u32())
.leftmost_find_iter(text)
.map(|m| m.value())
.next()
}

Expand Down Expand Up @@ -145,25 +142,14 @@ impl BytePairEncoding {
assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len());
println!("copied tokens: {:?}", start.elapsed());

let mut builder = AhoCorasickBuilder::new();
builder.match_kind(MatchKind::LeftmostLongest);
// We can set Anchored, since ALL the tokens can cover ANY text input.
builder.start_kind(StartKind::Anchored);
builder.ascii_case_insensitive(false);
builder.prefilter(false);
let longest_searcher = builder
let longest_searcher = DoubleArrayAhoCorasickBuilder::new()
.match_kind(daachorse::MatchKind::LeftmostLongest)
.build(token_iter(&all_tokens, &token_starts))
.expect("failed to build AhoCorasick");
println!("constructed longest searcher: {:?}", start.elapsed());

let mut builder = AhoCorasickBuilder::new();
builder.match_kind(MatchKind::Standard);
builder.start_kind(StartKind::Unanchored);
builder.prefilter(false);
builder.ascii_case_insensitive(false);
let overlapping_searcher = builder
.build(token_iter(&all_tokens, &token_starts))
.expect("failed to build AhoCorasick");
let overlapping_searcher =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
println!("constructed overlapping searcher: {:?}", start.elapsed());

let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts)
Expand Down Expand Up @@ -273,8 +259,8 @@ impl BytePairEncoding {
pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec<u32> {
let mut last_token = Vec::with_capacity(text.len());
for m in self.overlapping_searcher.find_overlapping_iter(text) {
let new_token = m.pattern().as_u32();
let new_range = m.span();
let new_token = m.value();
let new_range = m.start()..m.end();
if new_range.end == last_token.len() {
continue;
}
Expand Down
1 change: 0 additions & 1 deletion crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ impl<'a> IntervalEncoding<'a> {
&& self.last_token[end_pos - 1] == prev_token
&& self.last_token[end_pos - 1 + self.bpe.token_len(next_token)] == next_token
{
// println!("encoder counted: {} rest: {}", encoder.count(), (self.tree_depth[range.end] - self.tree_depth[end_pos]));
return encoder.count() + (self.tree_depth[range.end] - self.tree_depth[end_pos]) as usize;
}
}
Expand Down

0 comments on commit e266d3c

Please sign in to comment.