Skip to content

Commit e985d19

Browse files
committed
Add filter func
1 parent 97e789f commit e985d19

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

aisploit/classifiers/presidio/presidio_analyser.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from dataclasses import dataclass, field
2-
from typing import List
2+
from typing import Callable, List, Optional
33

44
from presidio_analyzer import AnalyzerEngine, EntityRecognizer, RecognizerResult
55

66
from ...core import BaseTextClassifier, Score
77

88

9-
@dataclass
9+
@dataclass(kw_only=True)
1010
class PresidioAnalyserClassifier(BaseTextClassifier[List[RecognizerResult]]):
1111
"""A text classifier using the Presidio Analyzer for detecting Personally Identifiable Information (PII)."""
1212

1313
language: str = "en"
1414
entities: List[str] | None = None
1515
threshold: float = 0.7
1616
additional_recognizers: List[EntityRecognizer] = field(default_factory=list)
17+
filter_func: Optional[Callable[[str, RecognizerResult], bool]] = None
1718
tags: List[str] = field(default_factory=lambda: ["leakage"], init=False)
1819

1920
def __post_init__(self) -> None:
@@ -36,6 +37,9 @@ def score(self, input: str, _: List[str] | None = None) -> Score[List[Recognizer
3637
"""
3738
results = self._analyzer.analyze(text=input, entities=self.entities, language=self.language)
3839

40+
if self.filter_func:
41+
results = [result for result in results if self.filter_func(input, result)]
42+
3943
return Score[List[RecognizerResult]](
4044
flagged=len(results) > 0,
4145
value=results,

examples/classifier.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
},
8686
{
8787
"cell_type": "code",
88-
"execution_count": 2,
88+
"execution_count": 4,
8989
"metadata": {},
9090
"outputs": [
9191
{
@@ -94,7 +94,7 @@
9494
"Score(flagged=True, value=[type: PERSON, start: 11, end: 19, score: 0.85, type: PHONE_NUMBER, start: 43, end: 55, score: 0.75], description='Returns True if entities are found in the input', explanation='Found 2 entities in input')"
9595
]
9696
},
97-
"execution_count": 2,
97+
"execution_count": 4,
9898
"metadata": {},
9999
"output_type": "execute_result"
100100
}

0 commit comments

Comments
 (0)