Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for o200k tokenization #16

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 81 additions & 73 deletions crates/bpe/benches/counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,93 @@ use criterion::{criterion_group, criterion_main, Criterion};
use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 20000);
for (name, bpe) in [
("cl100k", BytePairEncoding::cl100k()),
("o200k", BytePairEncoding::o200k()),
] {
let text = create_test_bytes(&bpe, 20000);
let fast = IntervalEncoding::new(&bpe, &text);

let fast = IntervalEncoding::new(&bpe, &text);

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("hybrid counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| fast.count(start..start + bytes),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("backtrack counting", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..text.len() - bytes),
|start| bpe.count(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

fn encoding_benchmark(c: &mut Criterion) {
let bpe = BytePairEncoding::cl100k();
let tiktoken = tiktoken_rs::cl100k_base().unwrap();
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();
for (name, bpe) in [
("cl100k", BytePairEncoding::cl100k()),
("o200k", BytePairEncoding::o200k()),
] {
let tiktoken = tiktoken_rs::cl100k_base().unwrap();
let text = create_test_string(&bpe, 20000);
let input = text.as_bytes();

for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) {
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
for bytes in [10, 100, 1000, 10000] {
let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}"));
group.bench_function("backtracking", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_backtracking(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("heap", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_bitfield(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("dynamic programming", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_via_table(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("greedy", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_greedy(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("minimal", |b| {
b.iter_batched(
|| thread_rng().gen_range(0..input.len() - bytes),
|start| bpe.encode_minimal(&input[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("tiktoken", |b| {
b.iter_batched(
|| loop {
let start = thread_rng().gen_range(0..input.len() - bytes - 1);
if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes])
{
return start;
}
},
|start| tiktoken.encode_ordinary(&text[start..start + bytes]),
criterion::BatchSize::SmallInput,
)
});
}
}
}

Expand Down
89 changes: 78 additions & 11 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| {
rmp_serde::from_slice(bytes).expect("")
});

static BPE_O200K: Lazy<BytePairEncoding> = Lazy::new(|| {
let bytes = include_bytes!("data/bpe_o200k.dict");
rmp_serde::from_slice(bytes).expect("")
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
Expand Down Expand Up @@ -153,11 +158,15 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
}

fn hash_bytes(bytes: &[u8]) -> u32 {
hash_bytes_with_factor(bytes, 17846336922010275747)
}

fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 {
let mut hasher = FnvHasher::default();
bytes.hash(&mut hasher);
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
// To make them unique for the given tokens, we have to add unfortunately another multiplication.
((hasher.finish().wrapping_mul(37493864257)) >> 32) as u32
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

fn find_token_by_bytes(
Expand All @@ -180,6 +189,10 @@ impl BytePairEncoding {
&BPE_CL100K
}

pub fn o200k() -> &'static Self {
&BPE_O200K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
Expand All @@ -192,7 +205,9 @@ impl BytePairEncoding {
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
for (i, token) in iter.enumerate() {
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
if let Some(j) = bytes_hash_to_token.insert(hash_bytes(&token), i as u32) {
eprintln!("collision: ({i}, {j})");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the function return an error instead when this happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this happens, the assertion below the loop will fail. I wonder if we want to provide a way to specify the factor as part of the API, if anyone wants to construct a BPE from their own dictionary. The hard-coded constant might make it harder to reuse this if users bring their own tokens. But I'll save that for a follow-up PR.

all_tokens_rev.extend(token.iter().copied().rev());
all_tokens.extend(token);
token_starts.push(all_tokens.len() as u32);
Expand Down Expand Up @@ -492,13 +507,11 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {

#[cfg(test)]
mod tests {
use std::fs::File;
use std::path::PathBuf;

use std::time::Instant;

use itertools::Itertools;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, cl100k_base_singleton};
use tiktoken_rs::cl100k_base_singleton;

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};

Expand Down Expand Up @@ -541,19 +554,73 @@ mod tests {
}
}
}
}

#[cfg(test)]
mod data {
use std::collections::HashSet;
use std::fs::File;
use std::path::PathBuf;

use rand::Rng;
use serde::Serialize;
use tiktoken_rs::{cl100k_base, o200k_base};

use super::*;

const BPE_CL100K_LEN: usize = 100256;
const BPE_O200K_LEN: usize = 199998;

/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
/// 1. Ensure all supported tokenizers are in the list.
/// 2. Update the hash factor in [`hash_bytes`].
/// 3. Run [`update_token_dicts`] tests below to update data files.
#[test]
#[ignore = "run manually to find a suitable hash factor"]
fn find_hash_factor() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

let bpes: &mut [(CoreBPE, usize)] = &mut [
(cl100k_base().unwrap(), BPE_CL100K_LEN),
(o200k_base().unwrap(), BPE_O200K_LEN),
];
let mut rnd = rand::thread_rng();
loop {
let factor: u64 = rnd.gen();
if bpes.iter().all(|(bpe, len)| {
let mut seen = HashSet::with_capacity(*len);
(0..*len)
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor)))
}) {
println!("hash factor: {factor}");
return;
}
}
}

// TODO: Move the generation of the dictionary into some build procedure?
#[test]
fn test_serialize() {
#[ignore = "run manually to update data files"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in principle, we could let this test run normally, since it will fix the broken data file (and one will see in the diff that something has changed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's do that.

fn update_token_dicts() {
serialize_tokens(
&cl100k_base().expect("tiktoken initialization must not fail!"),
BPE_CL100K_LEN,
"cl100k",
);
serialize_tokens(
&o200k_base().expect("tiktoken initialization must not fail!"),
BPE_O200K_LEN,
"o200k",
);
}

#[track_caller]
fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) {
let path = PathBuf::from(file!());
let dir = path.parent().unwrap();
let data_file = dir.join("data/bpe_cl100k.dict");
let data_file = dir.join(format!("data/bpe_{name}.dict"));
let current_dir = std::env::current_dir().unwrap();
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
BytePairEncoding::from_tiktoken(&cl100_dict, 100256)
BytePairEncoding::from_tiktoken(dict, num_tokens)
.serialize(&mut serializer)
.unwrap();
}
Expand Down
Binary file modified crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
Binary file added crates/bpe/src/data/bpe_o200k.dict
Binary file not shown.
3 changes: 1 addition & 2 deletions crates/geo_filters/src/config/bitchunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ pub(crate) fn count_ones_from_bitchunks<T: IsBucketType>(
let mut total = take_ref(&mut ones, max_msb_len - 1).count();
let smallest_msb = ones
.next()
.map(|bucket| {
.inspect(|_| {
total += 1;
bucket
})
.unwrap_or_default();

Expand Down