Skip to content

Commit

Permalink
Update interval_encoding.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Jul 16, 2024
1 parent 576e450 commit dc77f0c
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ use std::ops::Range;
use crate::backtrack_encoder::BacktrackEncoder;
use crate::byte_pair_encoding::BytePairEncoding;

/// This data structure allows fast, i.e. O(1), counting of tokens for arbitrary substrings of the original input text.
/// It achieves this by precomputing for every position the last token which ends at this position.
/// These last tokens represent a token tree with its root being the empty input text where each path starting at the root represents
/// the encoded tokens of the corresponding text prefix.
/// The struct stores a topological ordering in `tree_id` over this tree which then enables O(1) testing whether one node
/// is the predecessor of another node.
/// With the `tree_depth` field the number of path length (which is equivalent to the number of encoded tokens) can be determined
/// in O(1) as well.
///
/// Note: the fields `tree_end` and `tree_depth` could also be represented by succinct data structures, reducing their size drastically.
/// Since we still need the `tree_id` and `last_token` fields, this would in total reduce memory footprint by a bit less than 50%.
pub struct IntervalEncoding<'a> {
bpe: &'a BytePairEncoding,
text: &'a [u8],
Expand Down Expand Up @@ -42,9 +53,14 @@ impl<'a> IntervalEncoding<'a> {
}
}

/// Computes in typically O(1) time the number of tokens required to encode the specified range.
/// Thereby it reencodes the prefix with the `BacktrackEncoder` until the encoding sequence becomes
/// compatible with the precomputed tables. Once that's the case, the remainder of the range becomes
/// a simple O(1) lookup.
pub fn count(&self, range: Range<usize>) -> usize {
let leaf = self.tree_id[range.end];
let mut encoder = BacktrackEncoder::with_capacity(self.bpe, &self.text[range.clone()], 8);
// TODO: Consider adding a short-cut when the range starts at a good position.
while let Some(next_token) = encoder.step() {
if let Some(prev_token) = encoder.last_token() {
let end_pos = encoder.pos() + range.start;
Expand Down

0 comments on commit dc77f0c

Please sign in to comment.