Skip to content

Commit 5bc43ef

Browse files
author
Hendrik van Antwerpen
committed
Tweak API
1 parent 68f04b9 commit 5bc43ef

File tree

3 files changed

+63
-44
lines changed

3 files changed

+63
-44
lines changed

crates/bpe/bindings/python/pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@ dynamic = ["version"]
1919

2020
[tool.maturin]
2121
features = ["pyo3/extension-module"]
22+
23+
[dependency-groups]
24+
dev = [
25+
"maturin>=1.8.2",
26+
"pip>=25.0.1",
27+
]

crates/bpe/bindings/python/src/lib.rs

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,66 @@ use std::borrow::Cow;
22

33
use pyo3::prelude::*;
44

5-
#[pyclass]
6-
struct BytePairEncoding(Cow<'static, ::bpe::byte_pair_encoding::BytePairEncoding>);
5+
#[pymodule]
6+
mod bpe {
7+
use super::*;
78

8-
#[pyclass]
9-
struct Tokenizer(Cow<'static, ::bpe_openai::Tokenizer>);
9+
#[pyclass]
10+
struct BytePairEncoding(&'static ::bpe::byte_pair_encoding::BytePairEncoding);
1011

11-
#[pymethods]
12-
impl BytePairEncoding {
13-
fn count(&self, input: &[u8]) -> usize {
14-
self.0.count(input)
15-
}
12+
#[pymethods]
13+
impl BytePairEncoding {
14+
fn count(&self, input: &[u8]) -> usize {
15+
self.0.count(input)
16+
}
1617

17-
fn encode_via_backtracking(&self, input: &[u8]) -> Vec<u32> {
18-
self.0.encode_via_backtracking(input)
19-
}
18+
fn encode_via_backtracking(&self, input: &[u8]) -> Vec<u32> {
19+
self.0.encode_via_backtracking(input)
20+
}
2021

21-
fn decode_tokens(&self, tokens: Vec<u32>) -> Vec<u8> {
22-
self.0.decode_tokens(&tokens)
22+
fn decode_tokens(&self, tokens: Vec<u32>) -> Vec<u8> {
23+
self.0.decode_tokens(&tokens)
24+
}
2325
}
24-
}
2526

26-
#[pymethods]
27-
impl Tokenizer {
28-
fn count(&self, input: &str) -> usize {
29-
self.0.count(&input)
30-
}
27+
#[pymodule]
28+
mod openai {
29+
use super::*;
3130

32-
fn count_till_limit(&self, input: Cow<str>, limit: usize) -> Option<usize> {
33-
self.0.count_till_limit(&input, limit)
34-
}
31+
#[pyclass]
32+
struct Tokenizer(&'static ::bpe_openai::Tokenizer);
3533

36-
fn encode(&self, input: Cow<str>) -> Vec<u32> {
37-
self.0.encode(&input)
38-
}
34+
#[pymethods]
35+
impl Tokenizer {
36+
fn count(&self, input: &str) -> usize {
37+
self.0.count(&input)
38+
}
3939

40-
fn decode(&self, tokens: Vec<u32>) -> Option<String> {
41-
self.0.decode(&tokens)
42-
}
43-
}
40+
fn count_till_limit(&self, input: Cow<str>, limit: usize) -> Option<usize> {
41+
self.0.count_till_limit(&input, limit)
42+
}
4443

45-
#[pyfunction]
46-
fn cl100k_base() -> PyResult<Tokenizer> {
47-
Ok(Tokenizer(Cow::Borrowed(::bpe_openai::cl100k_base())))
48-
}
44+
fn encode(&self, input: Cow<str>) -> Vec<u32> {
45+
self.0.encode(&input)
46+
}
4947

50-
#[pyfunction]
51-
fn o200k_base() -> PyResult<Tokenizer> {
52-
Ok(Tokenizer(Cow::Borrowed(::bpe_openai::o200k_base())))
53-
}
48+
fn decode(&self, tokens: Vec<u32>) -> Option<String> {
49+
self.0.decode(&tokens)
50+
}
5451

55-
#[pymodule]
56-
fn bpe_openai(m: &Bound<'_, PyModule>) -> PyResult<()> {
57-
m.add_function(wrap_pyfunction!(cl100k_base, m)?)?;
58-
m.add_function(wrap_pyfunction!(o200k_base, m)?)?;
59-
Ok(())
52+
fn bpe(&self) -> BytePairEncoding {
53+
BytePairEncoding(&self.0.bpe)
54+
}
55+
}
56+
57+
#[pyfunction]
58+
fn cl100k_base() -> PyResult<Tokenizer> {
59+
Ok(Tokenizer(::bpe_openai::cl100k_base()))
60+
}
61+
62+
#[pyfunction]
63+
fn o200k_base() -> PyResult<Tokenizer> {
64+
Ok(Tokenizer(::bpe_openai::o200k_base()))
65+
}
66+
}
6067
}

crates/bpe/bindings/python/test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22

33
import bpe
44

5-
tok = bpe.cl100k_base()
5+
tok = bpe.openai.cl100k_base()
6+
7+
## Use tokenizer
68

79
enc = tok.encode("Hello, world!")
810
print(enc)
911
cnt = tok.count("Hello, world!")
1012
print(cnt)
1113
dec = tok.decode(enc)
1214
print(dec)
15+
16+
## Use underlying BPE instance
17+
18+
bpe = tok.bpe()

0 commit comments

Comments
 (0)