Skip to content

Commit dd45913

Browse files
committed
Load dictionary from disk
1 parent e266d3c commit dd45913

File tree

5 files changed

+76
-9
lines changed

5 files changed

+76
-9
lines changed

crates/bpe/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ crate-type = ["lib", "staticlib"]
88
bench = false
99

1010
[dependencies]
11-
daachorse = "1"
11+
#daachorse = "1"
12+
daachorse = { git = "https://github.com/aneubeck/daachorse.git", branch = "aneubeck/extend" }
1213
fnv = "1.0"
1314
itertools = "0.12"
15+
once_cell = "1"
1416
rand = "0.8"
17+
rmp-serde = "1"
18+
serde = "1"
1519
tiktoken-rs = "0.5"
1620

1721
[dev-dependencies]

crates/bpe/benches/counting.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
66
use rand::{thread_rng, Rng};
77

88
fn counting_benchmark(c: &mut Criterion) {
9-
let bpe = BytePairEncoding::new();
9+
let bpe = BytePairEncoding::cl100k();
1010
let text = create_test_bytes(&bpe, 20000);
1111

1212
let fast = IntervalEncoding::new(&bpe, &text);
@@ -31,7 +31,7 @@ fn counting_benchmark(c: &mut Criterion) {
3131
}
3232

3333
fn encoding_benchmark(c: &mut Criterion) {
34-
let bpe = BytePairEncoding::new();
34+
let bpe = BytePairEncoding::cl100k();
3535
let text = create_test_bytes(&bpe, 20000);
3636

3737
for bytes in [10, 100, 1000, 10000] {

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,23 @@ use std::time::Instant;
77
use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
88
use fnv::{FnvHashMap, FnvHasher};
99
use itertools::Itertools;
10+
use once_cell::sync::Lazy;
11+
use serde::de::Visitor;
12+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1013
use tiktoken_rs::cl100k_base;
1114

1215
use crate::backtrack_encoder::BacktrackEncoder;
1316
use crate::bitfield::BitField;
1417

18+
static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| {
19+
let bytes = include_bytes!("data/bpe_cl100k.dict");
20+
rmp_serde::from_slice(bytes).expect("")
21+
});
22+
1523
/// Representation of the byte pair dictionary.
1624
/// This struct provides various conversions.
1725
/// We put all of them into a single struct so that they can be reused by different implementations.
26+
#[derive(Serialize, Deserialize)]
1827
pub struct BytePairEncoding {
1928
/// All the decoded tokens concatenated into
2029
all_tokens: Vec<u8>,
@@ -29,8 +38,16 @@ pub struct BytePairEncoding {
2938
/// Mapping from a pair of tokens to a merged token if such a merged token exists.
3039
pair_lookup: FnvHashMap<(u32, u32), u32>,
3140
/// An aho corasick automaton to find the next longest token in a byte sequence.
41+
#[serde(
42+
serialize_with = "serialize_daac",
43+
deserialize_with = "deserialize_daac"
44+
)]
3245
longest_searcher: DoubleArrayAhoCorasick<u32>,
3346
/// An aho corasick automaton to find ALL tokens in a byte sequence.
47+
#[serde(
48+
serialize_with = "serialize_daac",
49+
deserialize_with = "deserialize_daac"
50+
)]
3451
overlapping_searcher: DoubleArrayAhoCorasick<u32>,
3552
/// Mapping from a token to the next longest prefix token.
3653
/// This is in principle information represented by the AhoCorasick automaton.
@@ -39,6 +56,32 @@ pub struct BytePairEncoding {
3956
next_prefix_match: Vec<u32>,
4057
}
4158

59+
fn serialize_daac<S: Serializer>(
60+
daac: &DoubleArrayAhoCorasick<u32>,
61+
s: S,
62+
) -> Result<S::Ok, S::Error> {
63+
s.serialize_bytes(&daac.serialize())
64+
}
65+
66+
struct DaacVisitor;
67+
impl<'de> Visitor<'de> for DaacVisitor {
68+
type Value = DoubleArrayAhoCorasick<u32>;
69+
70+
fn expecting(&self, _formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
71+
Err(std::fmt::Error)
72+
}
73+
74+
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
75+
Ok(unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(v).0 })
76+
}
77+
}
78+
79+
fn deserialize_daac<'de, D: Deserializer<'de>>(
80+
d: D,
81+
) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
82+
d.deserialize_bytes(DaacVisitor)
83+
}
84+
4285
fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator<Item = &'a [u8]> {
4386
token_starts
4487
.iter()
@@ -125,6 +168,10 @@ fn find_token_by_bytes(
125168
}
126169

127170
impl BytePairEncoding {
171+
pub fn cl100k() -> &'static Self {
172+
&BPE_CL100K
173+
}
174+
128175
pub fn new() -> Self {
129176
let start = Instant::now();
130177
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
@@ -410,9 +457,12 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {
410457

411458
#[cfg(test)]
412459
mod tests {
460+
use std::fs::File;
461+
use std::path::PathBuf;
413462
use std::time::Instant;
414463

415464
use itertools::Itertools;
465+
use serde::Serialize;
416466
use tiktoken_rs::cl100k_base_singleton;
417467

418468
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
@@ -428,7 +478,7 @@ mod tests {
428478
])
429479
.unwrap();
430480
let time = Instant::now();
431-
let bpe = BytePairEncoding::new();
481+
let bpe = BytePairEncoding::cl100k();
432482
println!("{:?}", time.elapsed());
433483
let encoded1 = cl100k_base_singleton()
434484
.lock()
@@ -446,14 +496,27 @@ mod tests {
446496

447497
#[test]
448498
fn test_bpe_equivalence() {
449-
let bpe = BytePairEncoding::new();
450-
for tokens in [10, 1000, 100000] {
451-
for _ in 0..10 {
499+
let bpe = BytePairEncoding::cl100k();
500+
for tokens in [10, 1000, 10000] {
501+
for _ in 0..5 {
452502
let test_input = create_test_bytes(&bpe, tokens);
453503
let encoded1 = bpe.encode_via_backtracking(&test_input);
454504
let encoded2 = bpe.encode_via_bitfield(&test_input);
455505
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
456506
}
457507
}
458508
}
509+
510+
// TODO: Move the generation of the dictionary into some build procedure?
511+
#[test]
512+
fn test_serialize() {
513+
let path = PathBuf::from(file!());
514+
let dir = path.parent().unwrap();
515+
let data_file = dir.join("data/bpe_cl100k.dict");
516+
let current_dir = std::env::current_dir().unwrap();
517+
let abs_path = current_dir.parent().unwrap().parent().unwrap();
518+
let file = File::create(abs_path.join(data_file)).unwrap();
519+
let mut serializer = rmp_serde::Serializer::new(file);
520+
BytePairEncoding::new().serialize(&mut serializer).unwrap();
521+
}
459522
}

crates/bpe/src/data/bpe_cl100k.dict

11 MB
Binary file not shown.

crates/bpe/src/interval_encoding.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ mod tests {
7373

7474
#[test]
7575
fn test_interval_count() {
76-
let bpe = BytePairEncoding::new();
76+
let bpe = BytePairEncoding::cl100k();
7777
let text = create_test_bytes(&bpe, 10000);
7878
let intervals = IntervalEncoding::new(&bpe, &text);
79-
for _ in 0..1 {
79+
for _ in 0..1000 {
8080
let start = thread_rng().gen_range(0..text.len());
8181
let end = thread_rng().gen_range(0..text.len());
8282
let range = start.min(end)..start.max(end);

0 commit comments

Comments
 (0)