Skip to content

Commit

Permalink
Load dictionary from disk
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Jul 16, 2024
1 parent e266d3c commit dd45913
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 9 deletions.
6 changes: 5 additions & 1 deletion crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ crate-type = ["lib", "staticlib"]
bench = false

[dependencies]
daachorse = "1"
#daachorse = "1"
daachorse = { git = "https://github.com/aneubeck/daachorse.git", branch = "aneubeck/extend" }
fnv = "1.0"
itertools = "0.12"
once_cell = "1"
rand = "0.8"
rmp-serde = "1"
serde = "1"
tiktoken-rs = "0.5"

[dev-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions crates/bpe/benches/counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::new();
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 20000);

let fast = IntervalEncoding::new(&bpe, &text);
Expand All @@ -31,7 +31,7 @@ fn counting_benchmark(c: &mut Criterion) {
}

fn encoding_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::new();
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 20000);

for bytes in [10, 100, 1000, 10000] {
Expand Down
71 changes: 67 additions & 4 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@ use std::time::Instant;
use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use fnv::{FnvHashMap, FnvHasher};
use itertools::Itertools;
use once_cell::sync::Lazy;
use serde::de::Visitor;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tiktoken_rs::cl100k_base;

use crate::backtrack_encoder::BacktrackEncoder;
use crate::bitfield::BitField;

static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| {
let bytes = include_bytes!("data/bpe_cl100k.dict");
rmp_serde::from_slice(bytes).expect("")
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
#[derive(Serialize, Deserialize)]
pub struct BytePairEncoding {
/// All the decoded tokens concatenated into
all_tokens: Vec<u8>,
Expand All @@ -29,8 +38,16 @@ pub struct BytePairEncoding {
/// Mapping from a pair of tokens to a merged token if such a merged token exists.
pair_lookup: FnvHashMap<(u32, u32), u32>,
/// An aho corasick automaton to find the next longest token in a byte sequence.
#[serde(
serialize_with = "serialize_daac",
deserialize_with = "deserialize_daac"
)]
longest_searcher: DoubleArrayAhoCorasick<u32>,
/// An aho corasick automaton to find ALL tokens in a byte sequence.
#[serde(
serialize_with = "serialize_daac",
deserialize_with = "deserialize_daac"
)]
overlapping_searcher: DoubleArrayAhoCorasick<u32>,
/// Mapping from a token to the next longest prefix token.
/// This is in principle information represented by the AhoCorasick automaton.
Expand All @@ -39,6 +56,32 @@ pub struct BytePairEncoding {
next_prefix_match: Vec<u32>,
}

fn serialize_daac<S: Serializer>(
daac: &DoubleArrayAhoCorasick<u32>,
s: S,
) -> Result<S::Ok, S::Error> {
s.serialize_bytes(&daac.serialize())
}

struct DaacVisitor;
impl<'de> Visitor<'de> for DaacVisitor {
type Value = DoubleArrayAhoCorasick<u32>;

fn expecting(&self, _formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
Err(std::fmt::Error)
}

fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
Ok(unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(v).0 })
}
}

fn deserialize_daac<'de, D: Deserializer<'de>>(
d: D,
) -> Result<DoubleArrayAhoCorasick<u32>, D::Error> {
d.deserialize_bytes(DaacVisitor)
}

fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator<Item = &'a [u8]> {
token_starts
.iter()
Expand Down Expand Up @@ -125,6 +168,10 @@ fn find_token_by_bytes(
}

impl BytePairEncoding {
pub fn cl100k() -> &'static Self {
&BPE_CL100K
}

pub fn new() -> Self {
let start = Instant::now();
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
Expand Down Expand Up @@ -410,9 +457,12 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {

#[cfg(test)]
mod tests {
use std::fs::File;
use std::path::PathBuf;
use std::time::Instant;

use itertools::Itertools;
use serde::Serialize;
use tiktoken_rs::cl100k_base_singleton;

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
Expand All @@ -428,7 +478,7 @@ mod tests {
])
.unwrap();
let time = Instant::now();
let bpe = BytePairEncoding::new();
let bpe = BytePairEncoding::cl100k();
println!("{:?}", time.elapsed());
let encoded1 = cl100k_base_singleton()
.lock()
Expand All @@ -446,14 +496,27 @@ mod tests {

#[test]
fn test_bpe_equivalence() {
let bpe = BytePairEncoding::new();
for tokens in [10, 1000, 100000] {
for _ in 0..10 {
let bpe = BytePairEncoding::cl100k();
for tokens in [10, 1000, 10000] {
for _ in 0..5 {
let test_input = create_test_bytes(&bpe, tokens);
let encoded1 = bpe.encode_via_backtracking(&test_input);
let encoded2 = bpe.encode_via_bitfield(&test_input);
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
}
}
}

// TODO: Move the generation of the dictionary into some build procedure?
#[test]
fn test_serialize() {
let path = PathBuf::from(file!());
let dir = path.parent().unwrap();
let data_file = dir.join("data/bpe_cl100k.dict");
let current_dir = std::env::current_dir().unwrap();
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
BytePairEncoding::new().serialize(&mut serializer).unwrap();
}
}
Binary file added crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
4 changes: 2 additions & 2 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ mod tests {

#[test]
fn test_interval_count() {
let bpe = BytePairEncoding::new();
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 10000);
let intervals = IntervalEncoding::new(&bpe, &text);
for _ in 0..1 {
for _ in 0..1000 {
let start = thread_rng().gen_range(0..text.len());
let end = thread_rng().gen_range(0..text.len());
let range = start.min(end)..start.max(end);
Expand Down

0 comments on commit dd45913

Please sign in to comment.