-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for o200k tokenization #16
Changes from 2 commits
c11f64f
27f6216
9350e9a
aa14609
7bf5093
2a4deef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,11 @@ static BPE_CL100K: Lazy<BytePairEncoding> = Lazy::new(|| { | |
rmp_serde::from_slice(bytes).expect("") | ||
}); | ||
|
||
static BPE_O200K: Lazy<BytePairEncoding> = Lazy::new(|| { | ||
let bytes = include_bytes!("data/bpe_o200k.dict"); | ||
rmp_serde::from_slice(bytes).expect("") | ||
}); | ||
|
||
/// Representation of the byte pair dictionary. | ||
/// This struct provides various conversions. | ||
/// We put all of them into a single struct so that they can be reused by different implementations. | ||
|
@@ -153,11 +158,15 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> | |
} | ||
|
||
fn hash_bytes(bytes: &[u8]) -> u32 { | ||
hash_bytes_with_factor(bytes, 17846336922010275747) | ||
} | ||
|
||
fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 { | ||
let mut hasher = FnvHasher::default(); | ||
bytes.hash(&mut hasher); | ||
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. | ||
// To make them unique for the given tokens, we have to add unfortunately another multiplication. | ||
((hasher.finish().wrapping_mul(37493864257)) >> 32) as u32 | ||
((hasher.finish().wrapping_mul(factor)) >> 32) as u32 | ||
} | ||
|
||
fn find_token_by_bytes( | ||
|
@@ -180,6 +189,10 @@ impl BytePairEncoding { | |
&BPE_CL100K | ||
} | ||
|
||
pub fn o200k() -> &'static Self { | ||
&BPE_O200K | ||
} | ||
|
||
/// Construct a BytePairEncoding instance frmo a tiktoken dictionary. | ||
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self { | ||
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i]))) | ||
|
@@ -192,7 +205,9 @@ impl BytePairEncoding { | |
let mut token_starts = vec![0]; | ||
let mut bytes_hash_to_token = FnvHashMap::default(); | ||
for (i, token) in iter.enumerate() { | ||
bytes_hash_to_token.insert(hash_bytes(&token), i as u32); | ||
if let Some(j) = bytes_hash_to_token.insert(hash_bytes(&token), i as u32) { | ||
eprintln!("collision: ({i}, {j})"); | ||
} | ||
all_tokens_rev.extend(token.iter().copied().rev()); | ||
all_tokens.extend(token); | ||
token_starts.push(all_tokens.len() as u32); | ||
|
@@ -492,13 +507,11 @@ pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> { | |
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::fs::File; | ||
use std::path::PathBuf; | ||
|
||
use std::time::Instant; | ||
|
||
use itertools::Itertools; | ||
use serde::Serialize; | ||
use tiktoken_rs::{cl100k_base, cl100k_base_singleton}; | ||
use tiktoken_rs::cl100k_base_singleton; | ||
|
||
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; | ||
|
||
|
@@ -541,19 +554,73 @@ mod tests { | |
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod data { | ||
use std::collections::HashSet; | ||
use std::fs::File; | ||
use std::path::PathBuf; | ||
|
||
use rand::Rng; | ||
use serde::Serialize; | ||
use tiktoken_rs::{cl100k_base, o200k_base}; | ||
|
||
use super::*; | ||
|
||
const BPE_CL100K_LEN: usize = 100256; | ||
const BPE_O200K_LEN: usize = 199998; | ||
|
||
/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions. | ||
/// 1. Ensure all supported tokenizers are in the list. | ||
/// 2. Update the hash factor in [`hash_bytes`]. | ||
/// 3. Run [`update_token_dicts`] tests below to update data files. | ||
#[test] | ||
#[ignore = "run manually to find a suitable hash factor"] | ||
fn find_hash_factor() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :) |
||
let bpes: &mut [(CoreBPE, usize)] = &mut [ | ||
(cl100k_base().unwrap(), BPE_CL100K_LEN), | ||
(o200k_base().unwrap(), BPE_O200K_LEN), | ||
]; | ||
let mut rnd = rand::thread_rng(); | ||
loop { | ||
let factor: u64 = rnd.gen(); | ||
if bpes.iter().all(|(bpe, len)| { | ||
let mut seen = HashSet::with_capacity(*len); | ||
(0..*len) | ||
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor))) | ||
}) { | ||
println!("hash factor: {factor}"); | ||
return; | ||
} | ||
} | ||
} | ||
|
||
// TODO: Move the generation of the dictionary into some build procedure? | ||
#[test] | ||
fn test_serialize() { | ||
#[ignore = "run manually to update data files"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in principle, we could let this test run normally, since it will fix the broken data file (and one will see in the diff that something has changed) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, let's do that. |
||
fn update_token_dicts() { | ||
serialize_tokens( | ||
&cl100k_base().expect("tiktoken initialization must not fail!"), | ||
BPE_CL100K_LEN, | ||
"cl100k", | ||
); | ||
serialize_tokens( | ||
&o200k_base().expect("tiktoken initialization must not fail!"), | ||
BPE_O200K_LEN, | ||
"o200k", | ||
); | ||
} | ||
|
||
#[track_caller] | ||
fn serialize_tokens(dict: &CoreBPE, num_tokens: usize, name: &str) { | ||
let path = PathBuf::from(file!()); | ||
let dir = path.parent().unwrap(); | ||
let data_file = dir.join("data/bpe_cl100k.dict"); | ||
let data_file = dir.join(format!("data/bpe_{name}.dict")); | ||
let current_dir = std::env::current_dir().unwrap(); | ||
let abs_path = current_dir.parent().unwrap().parent().unwrap(); | ||
let file = File::create(abs_path.join(data_file)).unwrap(); | ||
let mut serializer = rmp_serde::Serializer::new(file); | ||
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!"); | ||
BytePairEncoding::from_tiktoken(&cl100_dict, 100256) | ||
BytePairEncoding::from_tiktoken(dict, num_tokens) | ||
.serialize(&mut serializer) | ||
.unwrap(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the function return an error instead when this happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this happens, the assertion below the loop will fail. I wonder if we want to provide a way to specify the factor as part of the API, if anyone wants to construct a BPE from their own dictionary. The hard-coded constant might make it harder to reuse this if users bring their own tokens. But I'll save that for a follow-up PR.