Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for o200k tokenization #16

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Cargo.lock
/target/
/crates/*/target/
/crates/*/Cargo.lock
.vscode/
.vscode/
161 changes: 88 additions & 73 deletions crates/bpe/benches/counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,100 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 20000);
for (name, bpe) in [
("cl100k", BytePairEncoding::cl100k()),
("o200k", BytePairEncoding::o200k()),
] {
let text = create_test_bytes(&bpe, 20000);
let fast = IntervalEncoding::new(&bpe, &text);

let fast = IntervalEncoding::new(&bpe, &text);

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

fn encoding_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let tiktoken = tiktoken_rs::cl100k_base().unwrap();
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();
for (name, bpe, tiktoken) in [
(
"cl100k",
BytePairEncoding::cl100k(),
tiktoken_rs::cl100k_base().unwrap(),
),
(
"o200k",
BytePairEncoding::o200k(),
tiktoken_rs::o200k_base().unwrap(),
),
] {
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) {
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes])
{
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

Expand Down
84 changes: 74 additions & 10 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| {
rmp_serde::from_slice(bytes).expect("")
});

static BPE_O200K: Lazy<BytePairEncoding> = Lazy::new(|| {
let bytes = include_bytes!("data/bpe_o200k.dict");
rmp_serde::from_slice(bytes).expect("")
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
Expand Down Expand Up @@ -153,11 +158,15 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
}

fn hash_bytes(bytes: &[u8]) -> u32 {
hash_bytes_with_factor(bytes, 17846336922010275747)
}

fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 {
let mut hasher = FnvHasher::default();
bytes.hash(&mut hasher);
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
// To make them unique for the given tokens, we have to add unfortunately another multiplication.
((hasher.finish().wrapping_mul(37493864257)) >> 32) as u32
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

fn find_token_by_bytes(
Expand All @@ -180,6 +189,10 @@ impl BytePairEncoding {
&BPE_CL100K
}

pub fn o200k() -> &'static Self {
&BPE_O200K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
Expand Down Expand Up @@ -492,13 +505,11 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {

#[cfg(test)]
mod tests {
use std::fs::File;
use std::path::PathBuf;

use std::time::Instant;

use itertools::Itertools;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, cl100k_base_singleton};
use tiktoken_rs::cl100k_base_singleton;

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};

Expand Down Expand Up @@ -541,19 +552,72 @@ mod tests {
}
}
}
}

// TODO: Move the generation of the dictionary into some build procedure?
#[cfg(test)]
mod data {
use std::collections::HashSet;
use std::fs::File;
use std::path::PathBuf;

use rand::Rng;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, o200k_base};

use super::*;

const BPE_CL100K_LEN: usize = 100256;
const BPE_O200K_LEN: usize = 199998;

/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
/// 1. Ensure all supported tokenizers are in the list.
/// 2. Update the hash factor in [`hash_bytes`].
/// 3. Run [`update_token_dicts`] tests below to update data files.
#[test]
fn test_serialize() {
#[ignore = "run manually to find a suitable hash factor"]
fn find_hash_factor() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

let bpes: &mut [(CoreBPE, usize)] = &mut [
(cl100k_base().unwrap(), BPE_CL100K_LEN),
(o200k_base().unwrap(), BPE_O200K_LEN),
];
let mut rnd = rand::thread_rng();
loop {
let factor: u64 = rnd.gen();
if bpes.iter().all(|(bpe, len)| {
let mut seen = HashSet::with_capacity(*len);
(0..*len)
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor)))
}) {
println!("hash factor: {factor}");
return;
}
}
}

#[test]
fn update_token_dicts() {
serialize_tokens(
&cl100k_base().expect("tiktoken initialization must not fail!"),
BPE_CL100K_LEN,
"cl100k",
);
serialize_tokens(
&o200k_base().expect("tiktoken initialization must not fail!"),
BPE_O200K_LEN,
"o200k",
);
}

#[track_caller]
fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) {
let path = PathBuf::from(file!());
let dir = path.parent().unwrap();
let data_file = dir.join("data/bpe_cl100k.dict");
let data_file = dir.join(format!("data/bpe_{name}.dict"));
let current_dir = std::env::current_dir().unwrap();
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
BytePairEncoding::from_tiktoken(&cl100_dict, 100256)
BytePairEncoding::from_tiktoken(dict, num_tokens)
.serialize(&mut serializer)
.unwrap();
}
Expand Down
Binary file modified crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
Binary file added crates/bpe/src/data/bpe_o200k.dict
Binary file not shown.
3 changes: 1 addition & 2 deletions crates/geo_filters/src/config/bitchunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ pub(crate) fn count_ones_from_bitchunks<T: IsBucketType>(
let mut total = take_ref(&mut ones, max_msb_len - 1).count();
let smallest_msb = ones
.next()
.map(|bucket| {
.inspect(|_| {
total += 1;
bucket
})
.unwrap_or_default();

Expand Down
2 changes: 2 additions & 0 deletions criterion.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# save report in this directory, even if a custom target directory is set
criterion_home = "./target/criterion"