Skip to content

Commit

Permalink
more linter warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Jul 16, 2024
1 parent 380c39b commit 576e450
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ itertools = "0.12"
once_cell = "1"
rand = "0.8"
rmp-serde = "1"
serde = "1"
serde = { version = "1", features = ["derive"] }
tiktoken-rs = "0.5"

[dev-dependencies]
Expand Down
10 changes: 4 additions & 6 deletions crates/bpe/src/backtrack_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@ impl<'a> BacktrackEncoder<'a> {

#[inline]
pub(crate) fn step(&mut self) -> Option<u32> {
let Some(mut token) = self.next_token else {
return None;
};
let mut token = self.next_token?;
let last = self.tokens.last().copied();
loop {
let token_len = self.bpe.token_len(token);
let end_pos = self.pos + token_len;
if self.bitfield.is_set(end_pos as usize)
if self.bitfield.is_set(end_pos)
&& last
.map(|last_token| self.bpe.is_valid_token_pair(last_token, token))
.unwrap_or(true)
{
self.bitfield.clear(end_pos as usize);
self.bitfield.clear(end_pos);
self.tokens.push(token);
self.pos = end_pos;
self.next_token = self.bpe.next_match(&self.text[end_pos as usize..]);
self.next_token = self.bpe.next_match(&self.text[end_pos..]);
break;
} else if let Some(shorter) = self.bpe.next_prefix(token) {
token = shorter;
Expand Down
22 changes: 11 additions & 11 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use itertools::Itertools;
use once_cell::sync::Lazy;
use serde::de::Visitor;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tiktoken_rs::cl100k_base;
use tiktoken_rs::CoreBPE;

use crate::backtrack_encoder::BacktrackEncoder;
use crate::bitfield::BitField;
Expand Down Expand Up @@ -172,16 +172,14 @@ impl BytePairEncoding {
&BPE_CL100K
}

pub fn new() -> Self {
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
let start = Instant::now();
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
println!("loaded tiktoken: {:?}", start.elapsed());
let mut all_tokens = Vec::new();
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
let num_tokens = 100256;
for i in 0..num_tokens {
let token = cl100_dict._decode_native(&[i]);
let token = tiktoken_bpe._decode_native(&[i]);
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
all_tokens.extend(token);
token_starts.push(all_tokens.len() as u32);
Expand Down Expand Up @@ -326,7 +324,7 @@ impl BytePairEncoding {

pub fn count(&self, text: &[u8]) -> usize {
let mut enc = BacktrackEncoder::new(self, text);
while let Some(_) = enc.step() {}
while enc.step().is_some() {}
enc.count()
}

Expand Down Expand Up @@ -360,7 +358,7 @@ impl BytePairEncoding {
// 6: ---------------------->
pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec<u32> {
let mut enc = BacktrackEncoder::new(self, text);
while let Some(_) = enc.step() {}
while enc.step().is_some() {}
enc.into_tokens()
}

Expand All @@ -374,7 +372,6 @@ impl BytePairEncoding {
// I.e. we somehow have to restart the construction!
// We also don't know at which utf8 boundary the input sequence will end, such that it fits into the limit!
// Therefore, we cannot reverse encode the sequence.
pub fn fix_utf8_boundary(&self, text: &[u8], limit: usize, tokenized: &mut Vec<u32>) {}

// Concatenate two sequences:
// Also here, the back-tracking procedure doesn't work, since we might have to undo some back-tracking steps
Expand Down Expand Up @@ -463,7 +460,7 @@ mod tests {

use itertools::Itertools;
use serde::Serialize;
use tiktoken_rs::cl100k_base_singleton;
use tiktoken_rs::{cl100k_base, cl100k_base_singleton};

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};

Expand Down Expand Up @@ -499,7 +496,7 @@ mod tests {
let bpe = BytePairEncoding::cl100k();
for tokens in [10, 1000, 10000] {
for _ in 0..5 {
let test_input = create_test_bytes(&bpe, tokens);
let test_input = create_test_bytes(bpe, tokens);
let encoded1 = bpe.encode_via_backtracking(&test_input);
let encoded2 = bpe.encode_via_bitfield(&test_input);
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
Expand All @@ -517,6 +514,9 @@ mod tests {
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
BytePairEncoding::new().serialize(&mut serializer).unwrap();
let cl100_dict = cl100k_base().expect("tiktoken initialization must not fail!");
BytePairEncoding::from_tiktoken(&cl100_dict, 100256)
.serialize(&mut serializer)
.unwrap();
}
}
4 changes: 2 additions & 2 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ mod tests {
#[test]
fn test_interval_count() {
let bpe = BytePairEncoding::cl100k();
let text = create_test_bytes(&bpe, 10000);
let intervals = IntervalEncoding::new(&bpe, &text);
let text = create_test_bytes(bpe, 10000);
let intervals = IntervalEncoding::new(bpe, &text);
for _ in 0..1000 {
let start = thread_rng().gen_range(0..text.len());
let end = thread_rng().gen_range(0..text.len());
Expand Down

0 comments on commit 576e450

Please sign in to comment.