From 5b96bac0e333b0ae8cdd9eb3712d7e3026c9350c Mon Sep 17 00:00:00 2001 From: Brian Lai Date: Tue, 10 Dec 2024 12:43:23 -0800 Subject: [PATCH] Refactor GuardrailsPII to support entity type mapping and update .gitignore. --- .env | 0 .gitignore | 3 ++- validator/main.py | 40 ++++++++++++++++++++++++++++++++++------ 3 files changed, 36 insertions(+), 7 deletions(-) delete mode 100644 .env diff --git a/.env b/.env deleted file mode 100644 index e69de29..0000000 diff --git a/.gitignore b/.gitignore index c8398b0..42628e4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ build .pytest_cache .ruff_cache .vscode -.idea \ No newline at end of file +.idea +.ropeproject diff --git a/validator/main.py b/validator/main.py index b98a553..4e87e31 100644 --- a/validator/main.py +++ b/validator/main.py @@ -29,7 +29,7 @@ def get_entity_threshold(entity: str) -> float: return 0.5 else: return 0.0 - + class InferenceInput(BaseModel): text: str entities: List[str] @@ -40,17 +40,41 @@ class InferenceOutputResult(BaseModel): start: int end: int score: float - + class InferenceOutput(BaseModel): - results: List[InferenceOutputResult] + results: List[InferenceOutputResult] anonymized_text: str @register_validator(name="guardrails/guardrails_pii", data_type="string") class GuardrailsPII(Validator): + PII_ENTITIES_MAP = { + "pii": [ + "EMAIL_ADDRESS", + "PHONE_NUMBER", + "DOMAIN_NAME", + "IP_ADDRESS", + "DATE_TIME", + "LOCATION", + "PERSON", + "URL", + ], + "spi": [ + "CREDIT_CARD", + "CRYPTO", + "IBAN_CODE", + "NRP", + "MEDICAL_LICENSE", + "US_BANK_NUMBER", + "US_DRIVER_LICENSE", + "US_ITIN", + "US_PASSPORT", + "US_SSN", + ], + } def __init__( self, - entities: List[str], + entities: str | List[str], model_name: str = "urchade/gliner_small-v2.1", get_entity_threshold: Callable = get_entity_threshold, on_fail: Optional[Callable] = None, @@ -85,7 +109,11 @@ def __init__( **kwargs, ) - self.entities = entities + if isinstance(entities, str): + assert entities in self.PII_ENTITIES_MAP, f"Invalid entity type: {entities}" + self.entities = self.PII_ENTITIES_MAP[entities] + else: + self.entities = entities self.model_name = model_name self.get_entity_threshold = get_entity_threshold @@ -169,7 +197,7 @@ def anonymize(self, text: str, entities: list[str]) -> Tuple[str, list[ErrorSpan ] return output.anonymized_text, error_spans - + def _validate(self, value: Any, metadata: Dict = {}) -> ValidationResult: entities = metadata.get("entities", self.entities)