Skip to content

Commit

Permalink
add some approximate versions for benchmark comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Jul 17, 2024
1 parent e8b7706 commit 4f0ba1b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
16 changes: 15 additions & 1 deletion crates/bpe/benches/counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn counting_benchmark(c: &mut Criterion) {

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
Expand Down Expand Up @@ -57,6 +57,20 @@ fn encoding_benchmark(c: &mut Criterion) {
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.encode_greedy(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.encode_minimal(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}

Expand Down
8 changes: 8 additions & 0 deletions crates/bpe/src/backtrack_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use crate::bitfield::BitField;
use crate::byte_pair_encoding::BytePairEncoding;

/// This can be thought of as a lazy variation of the dynamic programming approach.
/// It only computes those states which have to be visited in order to compute the tokenization
/// for a given input text.
/// It keeps track of visited states in a bitfield and only remembers the tokenization
/// of the currently processed dynamic programming state.
///
/// The biggest downside of this approach is that the search for the longest leftmost match
/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches.
pub(crate) struct BacktrackEncoder<'a> {
bpe: &'a BytePairEncoding,
text: &'a [u8],
Expand Down
37 changes: 37 additions & 0 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,43 @@ impl BytePairEncoding {
let (bitfield, count) = self.encode_into_bitfield(text);
self.bitfield_into_tokens(text, bitfield, count)
}

/// It is not recommended to use this function, since it doesn't output the correct BPE encoded sequence.
pub fn encode_greedy(&self, text: &[u8]) -> Vec<u32> {
self.longest_searcher
.leftmost_find_iter(text)
.map(|m| m.value())
.collect()
}

/// This function computes the shortest possible encoding sequence which will usually differ from the
/// tokenization produced by the original BPE algorithm.
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
let mut stepper = self.overlapping_searcher.overlapping_stepper();
for c in text {
stepper.consume(*c);
let mut best = (0, u32::MAX);
while let Some(m) = stepper.next() {
if m.start() == 0 {
best = (m.value(), 1);
break;
} else if last_token[m.start() - 1].1 + 1 < best.1{
best = (m.value(), last_token[m.start() - 1].1 + 1)
}
}
last_token.push(best);
}
let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0));
let mut pos = text.len();
while pos > 0 {
let token = last_token[pos - 1].0;
encoded.push(token);
pos -= self.token_len(token);
}
encoded.reverse();
encoded
}
}

pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {
Expand Down

0 comments on commit 4f0ba1b

Please sign in to comment.