Skip to content

Commit 0897ba5

Browse files
committed
Rename Python confidence functions and add docs and tests
1 parent c31c494 commit 0897ba5

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

src/identifier.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,23 @@ impl Identifier {
5959
ignore_confidence: ignore_confidence,
6060
}
6161
}
62+
/// Disable use of confidence thresholds
63+
pub fn disable_confidence(&mut self) {
64+
self.ignore_confidence = true;
65+
}
6266

63-
/// Enable use of confidence thresholds
64-
pub fn disable_confidence(&mut self) -> &mut Self {
67+
/// Disable use of confidence thresholds
68+
pub fn without_confidence(&mut self) -> &mut Self {
6569
self.ignore_confidence = true;
6670
self
6771
}
6872

73+
/// Enable use of confidence thresholds
74+
pub fn with_confidence(&mut self) -> &mut Self {
75+
self.ignore_confidence = false;
76+
self
77+
}
78+
6979
/// Get the most probable language according to the current language scores
7080
fn pick_winner(&mut self) -> (Lang, f32) {
7181
// if only one lang is requested, just search for the minimum score (winner)
@@ -436,4 +446,18 @@ mod tests {
436446
"expected = {:?}\npredict = {:?}", pred, expected);
437447
}
438448
}
449+
450+
#[test_log::test]
451+
fn test_confidence() {
452+
pyo3::prepare_freethreaded_python();
453+
let identifier = Identifier::load(
454+
&python::module_path().expect("Python module needs to be installed"),
455+
None,
456+
).expect("Could not load model, please run 'heliport bianrize' if you haven't")
457+
.disable_confidence();
458+
459+
let pred = identifier.identify("hello");
460+
assert!(pred.0 == Lang::sah);
461+
}
462+
439463
}

src/python.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ pub fn module_path() -> PyResult<PathBuf> {
2323
}
2424

2525
/// Bindings to Python
26-
/// //TODO support returning both lang+score
2726
/// //TODO support parallel identification
2827
/// //TODO support loading relevant languages from text
2928
#[pymethods]
@@ -39,19 +38,24 @@ impl Identifier {
3938
Ok(identifier)
4039
}
4140

41+
/// Identify the language of a string
4242
#[pyo3(name = "identify")]
4343
fn py_identify(&mut self, text: &str) -> String {
4444
self.identify(text).0.to_string()
4545
}
4646

47-
#[pyo3(name = "identify_with_confidence")]
48-
fn py_identify_with_confidence(&mut self, text: &str) -> (String, f32) {
47+
/// Identify the top-k most probable languages of a string and return the prediction scores.
48+
/// This score is the confidence score (difference with the 2nd best)
49+
/// or the raw score if ignore_confidence is enabled.
50+
#[pyo3(name = "identify_with_score")]
51+
fn py_identify_with_score(&mut self, text: &str) -> (String, f32) {
4952
let pred = self.identify(text);
5053
(pred.0.to_string(), pred.1)
5154
}
5255

53-
#[pyo3(name = "identify_topk_with_confidence")]
54-
fn py_identify_topk_with_confidence(&mut self, text: &str, k: usize) -> Vec<(String, f32)> {
56+
/// Identify the language of a string and return the raw prediction score.
57+
#[pyo3(name = "identify_topk_with_score")]
58+
fn py_identify_topk_with_score(&mut self, text: &str, k: usize) -> Vec<(String, f32)> {
5559
let preds = self.identify_topk(text, k);
5660
let mut out = Vec::<_>::with_capacity(preds.len());
5761
for (pred, conf) in preds {

0 commit comments

Comments
 (0)