@@ -3,6 +3,17 @@ use std::ops::Range;
3
3
use crate :: backtrack_encoder:: BacktrackEncoder ;
4
4
use crate :: byte_pair_encoding:: BytePairEncoding ;
5
5
6
+ /// This data structure allows fast, i.e. O(1), counting of tokens for arbitrary substrings of the original input text.
7
+ /// It achieves this by precomputing for every position the last token which ends at this position.
8
+ /// These last tokens represent a token tree with its root being the empty input text where each path starting at the root represents
9
+ /// the encoded tokens of the corresponding text prefix.
10
+ /// The struct stores a topological ordering in `tree_id` over this tree which then enables O(1) testing whether one node
11
+ /// is the predecessor of another node.
12
+ /// With the `tree_depth` field the number of path length (which is equivalent to the number of encoded tokens) can be determined
13
+ /// in O(1) as well.
14
+ ///
15
+ /// Note: the fields `tree_end` and `tree_depth` could also be represented by succinct data structures, reducing their size drastically.
16
+ /// Since we still need the `tree_id` and `last_token` fields, this would in total reduce memory footprint by a bit less than 50%.
6
17
pub struct IntervalEncoding < ' a > {
7
18
bpe : & ' a BytePairEncoding ,
8
19
text : & ' a [ u8 ] ,
@@ -42,9 +53,14 @@ impl<'a> IntervalEncoding<'a> {
42
53
}
43
54
}
44
55
56
+ /// Computes in typically O(1) time the number of tokens required to encode the specified range.
57
+ /// Thereby it reencodes the prefix with the `BacktrackEncoder` until the encoding sequence becomes
58
+ /// compatible with the precomputed tables. Once that's the case, the remainder of the range becomes
59
+ /// a simple O(1) lookup.
45
60
pub fn count ( & self , range : Range < usize > ) -> usize {
46
61
let leaf = self . tree_id [ range. end ] ;
47
62
let mut encoder = BacktrackEncoder :: with_capacity ( self . bpe , & self . text [ range. clone ( ) ] , 8 ) ;
63
+ // TODO: Consider adding a short-cut when the range starts at a good position.
48
64
while let Some ( next_token) = encoder. step ( ) {
49
65
if let Some ( prev_token) = encoder. last_token ( ) {
50
66
let end_pos = encoder. pos ( ) + range. start ;
0 commit comments