@@ -7,14 +7,23 @@ use std::time::Instant;
7
7
use daachorse:: { DoubleArrayAhoCorasick , DoubleArrayAhoCorasickBuilder } ;
8
8
use fnv:: { FnvHashMap , FnvHasher } ;
9
9
use itertools:: Itertools ;
10
+ use once_cell:: sync:: Lazy ;
11
+ use serde:: de:: Visitor ;
12
+ use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
10
13
use tiktoken_rs:: cl100k_base;
11
14
12
15
use crate :: backtrack_encoder:: BacktrackEncoder ;
13
16
use crate :: bitfield:: BitField ;
14
17
18
+ static BPE_CL100K : Lazy < BytePairEncoding > = Lazy :: new ( || {
19
+ let bytes = include_bytes ! ( "data/bpe_cl100k.dict" ) ;
20
+ rmp_serde:: from_slice ( bytes) . expect ( "" )
21
+ } ) ;
22
+
15
23
/// Representation of the byte pair dictionary.
16
24
/// This struct provides various conversions.
17
25
/// We put all of them into a single struct so that they can be reused by different implementations.
26
+ #[ derive( Serialize , Deserialize ) ]
18
27
pub struct BytePairEncoding {
19
28
/// All the decoded tokens concatenated into
20
29
all_tokens : Vec < u8 > ,
@@ -29,8 +38,16 @@ pub struct BytePairEncoding {
29
38
/// Mapping from a pair of tokens to a merged token if such a merged token exists.
30
39
pair_lookup : FnvHashMap < ( u32 , u32 ) , u32 > ,
31
40
/// An aho corasick automaton to find the next longest token in a byte sequence.
41
+ #[ serde(
42
+ serialize_with = "serialize_daac" ,
43
+ deserialize_with = "deserialize_daac"
44
+ ) ]
32
45
longest_searcher : DoubleArrayAhoCorasick < u32 > ,
33
46
/// An aho corasick automaton to find ALL tokens in a byte sequence.
47
+ #[ serde(
48
+ serialize_with = "serialize_daac" ,
49
+ deserialize_with = "deserialize_daac"
50
+ ) ]
34
51
overlapping_searcher : DoubleArrayAhoCorasick < u32 > ,
35
52
/// Mapping from a token to the next longest prefix token.
36
53
/// This is in principle information represented by the AhoCorasick automaton.
@@ -39,6 +56,32 @@ pub struct BytePairEncoding {
39
56
next_prefix_match : Vec < u32 > ,
40
57
}
41
58
59
+ fn serialize_daac < S : Serializer > (
60
+ daac : & DoubleArrayAhoCorasick < u32 > ,
61
+ s : S ,
62
+ ) -> Result < S :: Ok , S :: Error > {
63
+ s. serialize_bytes ( & daac. serialize ( ) )
64
+ }
65
+
66
+ struct DaacVisitor ;
67
+ impl < ' de > Visitor < ' de > for DaacVisitor {
68
+ type Value = DoubleArrayAhoCorasick < u32 > ;
69
+
70
+ fn expecting ( & self , _formatter : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
71
+ Err ( std:: fmt:: Error )
72
+ }
73
+
74
+ fn visit_bytes < E : serde:: de:: Error > ( self , v : & [ u8 ] ) -> Result < Self :: Value , E > {
75
+ Ok ( unsafe { DoubleArrayAhoCorasick :: deserialize_unchecked ( v) . 0 } )
76
+ }
77
+ }
78
+
79
+ fn deserialize_daac < ' de , D : Deserializer < ' de > > (
80
+ d : D ,
81
+ ) -> Result < DoubleArrayAhoCorasick < u32 > , D :: Error > {
82
+ d. deserialize_bytes ( DaacVisitor )
83
+ }
84
+
42
85
fn token_iter < ' a > ( all_tokens : & ' a [ u8 ] , token_starts : & ' a [ u32 ] ) -> impl Iterator < Item = & ' a [ u8 ] > {
43
86
token_starts
44
87
. iter ( )
@@ -125,6 +168,10 @@ fn find_token_by_bytes(
125
168
}
126
169
127
170
impl BytePairEncoding {
171
+ pub fn cl100k ( ) -> & ' static Self {
172
+ & BPE_CL100K
173
+ }
174
+
128
175
pub fn new ( ) -> Self {
129
176
let start = Instant :: now ( ) ;
130
177
let cl100_dict = cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ;
@@ -410,9 +457,12 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {
410
457
411
458
#[ cfg( test) ]
412
459
mod tests {
460
+ use std:: fs:: File ;
461
+ use std:: path:: PathBuf ;
413
462
use std:: time:: Instant ;
414
463
415
464
use itertools:: Itertools ;
465
+ use serde:: Serialize ;
416
466
use tiktoken_rs:: cl100k_base_singleton;
417
467
418
468
use crate :: byte_pair_encoding:: { create_test_bytes, BytePairEncoding } ;
@@ -428,7 +478,7 @@ mod tests {
428
478
] )
429
479
. unwrap ( ) ;
430
480
let time = Instant :: now ( ) ;
431
- let bpe = BytePairEncoding :: new ( ) ;
481
+ let bpe = BytePairEncoding :: cl100k ( ) ;
432
482
println ! ( "{:?}" , time. elapsed( ) ) ;
433
483
let encoded1 = cl100k_base_singleton ( )
434
484
. lock ( )
@@ -446,14 +496,27 @@ mod tests {
446
496
447
497
#[ test]
448
498
fn test_bpe_equivalence ( ) {
449
- let bpe = BytePairEncoding :: new ( ) ;
450
- for tokens in [ 10 , 1000 , 100000 ] {
451
- for _ in 0 ..10 {
499
+ let bpe = BytePairEncoding :: cl100k ( ) ;
500
+ for tokens in [ 10 , 1000 , 10000 ] {
501
+ for _ in 0 ..5 {
452
502
let test_input = create_test_bytes ( & bpe, tokens) ;
453
503
let encoded1 = bpe. encode_via_backtracking ( & test_input) ;
454
504
let encoded2 = bpe. encode_via_bitfield ( & test_input) ;
455
505
assert_eq ! ( encoded1, encoded2, "{} {}" , encoded1. len( ) , encoded2. len( ) ) ;
456
506
}
457
507
}
458
508
}
509
+
510
+ // TODO: Move the generation of the dictionary into some build procedure?
511
+ #[ test]
512
+ fn test_serialize ( ) {
513
+ let path = PathBuf :: from ( file ! ( ) ) ;
514
+ let dir = path. parent ( ) . unwrap ( ) ;
515
+ let data_file = dir. join ( "data/bpe_cl100k.dict" ) ;
516
+ let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
517
+ let abs_path = current_dir. parent ( ) . unwrap ( ) . parent ( ) . unwrap ( ) ;
518
+ let file = File :: create ( abs_path. join ( data_file) ) . unwrap ( ) ;
519
+ let mut serializer = rmp_serde:: Serializer :: new ( file) ;
520
+ BytePairEncoding :: new ( ) . serialize ( & mut serializer) . unwrap ( ) ;
521
+ }
459
522
}
0 commit comments