Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM-based PII redaction #585

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@ This directory contains multiple Python scripts with examples of how to use vari
The goal of these examples is to give the user an overview of many of the ways your text data can be curated.
These include:

| Python Script | Description |
|---------------------------------------|---------------------------------------------------------------------------------------------------------------|
| blend_and_shuffle.py | Combine multiple datasets into one with different amounts of each dataset, then randomly permute the dataset. |
| classifier_filtering.py | Train a fastText classifier, then use it to filter high and low quality data. |
| download_arxiv.py | Download Arxiv tar files and extract them. |
| download_common_crawl.py | Download Common Crawl WARC snapshots and extract them. |
| download_wikipedia.py | Download the latest Wikipedia dumps and extract them. |
| exact_deduplication.py | Use the `ExactDuplicates` class to perform exact deduplication on text data. |
| find_pii_and_deidentify.py | Use the `PiiModifier` and `Modify` classes to remove personally identifiable information from text data. |
| fuzzy_deduplication.py | Use the `FuzzyDuplicatesConfig` and `FuzzyDuplicates` classes to perform fuzzy deduplication on text data. |
| identify_languages.py | Use `FastTextLangId` to filter data by language |
| raw_download_common_crawl.py | Download the raw compressed WARC files from Common Crawl without extracting them. |
| semdedup_example.py | Use the `SemDedup` class to perform semantic deduplication on text data. |
| task_decontamination.py | Remove segments of downstream evaluation tasks from a dataset. |
| translation_example.py | Create and use an `IndicTranslation` model for language translation. |
| Python Script | Description |
|---------------------------------------|------------------------------------------------------------------------------------------------------------------|
| async_llm_pii_redaction.py | Use the `AsyncLLMPiiModifier` and `Modify` classes to remove personally identifiable information from text data. |
| blend_and_shuffle.py | Combine multiple datasets into one with different amounts of each dataset, then randomly permute the dataset. |
| classifier_filtering.py | Train a fastText classifier, then use it to filter high and low quality data. |
| download_arxiv.py | Download Arxiv tar files and extract them. |
| download_common_crawl.py | Download Common Crawl WARC snapshots and extract them. |
| download_wikipedia.py | Download the latest Wikipedia dumps and extract them. |
| exact_deduplication.py | Use the `ExactDuplicates` class to perform exact deduplication on text data. |
| find_pii_and_deidentify.py | Use the `PiiModifier` and `Modify` classes to remove personally identifiable information from text data. |
| fuzzy_deduplication.py | Use the `FuzzyDuplicatesConfig` and `FuzzyDuplicates` classes to perform fuzzy deduplication on text data. |
| identify_languages.py | Use `FastTextLangId` to filter data by language |
| llm_pii_redaction.py | Use the `LLMPiiModifier` and `Modify` classes to remove personally identifiable information from text data. |
| raw_download_common_crawl.py | Download the raw compressed WARC files from Common Crawl without extracting them. |
| semdedup_example.py | Use the `SemDedup` class to perform semantic deduplication on text data. |
| task_decontamination.py | Remove segments of downstream evaluation tasks from a dataset. |
| translation_example.py | Create and use an `IndicTranslation` model for language translation. |

Before running any of these scripts, we strongly recommend displaying `python <script name>.py --help` to ensure that any needed or relevant arguments are specified.

Expand Down
53 changes: 53 additions & 0 deletions examples/async_llm_pii_redaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dask.dataframe
import pandas as pd

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers.async_llm_pii_modifier import AsyncLLMPiiModifier
from nemo_curator.modules.modify import Modify
from nemo_curator.utils.distributed_utils import get_client


def console_script():
_ = get_client()

dataframe = pd.DataFrame(
{
"text": [
# Sampled from https://huggingface.co/datasets/gretelai/gretel-pii-masking-en-v1
"Transaction details: gasLimit set to 1000000 units by tw_brian740, gasPrice set to 10 Gwei by [email protected], contactable at +1-869-341-9301x7005, located at Suite 378, Yolanda Mountain, Burkeberg.",
"Unloading Plan for Shipment MRN-293104, MED25315002, dated 1989.12.22. Driver EMP730359, Vehicle KS40540825.",
]
}
)
dd = dask.dataframe.from_pandas(dataframe, npartitions=1)
dataset = DocumentDataset(dd)

modifier = AsyncLLMPiiModifier(
# Endpoint for the user's NIM
base_url="http://0.0.0.0:8000/v1",
api_key="API KEY (if needed)",
model="meta/llama-3.1-70b-instruct",
max_concurrent_requests=10,
)

modify = Modify(modifier)
modified_dataset = modify(dataset)
modified_dataset.df.to_json("output.jsonl", lines=True, orient="records")


if __name__ == "__main__":
console_script()
52 changes: 52 additions & 0 deletions examples/llm_pii_redaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dask.dataframe
import pandas as pd

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers.llm_pii_modifier import LLMPiiModifier
from nemo_curator.modules.modify import Modify
from nemo_curator.utils.distributed_utils import get_client


def console_script():
_ = get_client()

dataframe = pd.DataFrame(
{
"text": [
# Sampled from https://huggingface.co/datasets/gretelai/gretel-pii-masking-en-v1
"Transaction details: gasLimit set to 1000000 units by tw_brian740, gasPrice set to 10 Gwei by [email protected], contactable at +1-869-341-9301x7005, located at Suite 378, Yolanda Mountain, Burkeberg.",
"Unloading Plan for Shipment MRN-293104, MED25315002, dated 1989.12.22. Driver EMP730359, Vehicle KS40540825.",
]
}
)
dd = dask.dataframe.from_pandas(dataframe, npartitions=1)
dataset = DocumentDataset(dd)

modifier = LLMPiiModifier(
# Endpoint for the user's NIM
base_url="http://0.0.0.0:8000/v1",
api_key="API KEY (if needed)",
model="meta/llama-3.1-70b-instruct",
)

modify = Modify(modifier)
modified_dataset = modify(dataset)
modified_dataset.df.to_json("output.jsonl", lines=True, orient="records")


if __name__ == "__main__":
console_script()
214 changes: 214 additions & 0 deletions nemo_curator/modifiers/async_llm_pii_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import warnings
from typing import Any, Coroutine, Dict, List, Optional

import pandas as pd
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm

from nemo_curator.modifiers import DocumentModifier
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import load_object_on_worker
from nemo_curator.utils.llm_pii_utils import (
JSON_SCHEMA,
PII_LABELS,
SYSTEM_PROMPT,
redact,
validate_entity,
)

__all__ = ["AsyncLLMPiiModifier"]


class AsyncLLMInference:
"""A class for redacting PII via asynchronous LLM inference"""

def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
model: str = "meta/llama-3.1-70b-instruct",
system_prompt: str = SYSTEM_PROMPT,
pii_labels: List[str] = PII_LABELS,
):
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use our existing AsyncOpenAIClient from here, although I do not see an immediate benefit to this. It seems to just be a wrapper.

self.model = model
self.system_prompt = system_prompt[model]
self.pii_labels = pii_labels

async def infer(self, text: str) -> List[Dict[str, str]]:
"""Invoke LLM to get PII entities"""

text = text.strip()
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": text},
]

response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
# The field guided_json is unsupported at the root level
# and must be included in the nvext object field
extra_body={"nvext": {"guided_json": JSON_SCHEMA}},
stream=False,
max_tokens=4096,
)

assistant_message = response.choices[0].message.content

# Parse results
try:
entities = json.loads(assistant_message)
if not entities:
# LLM returned valid JSON but no entities discovered
return []
else:
# Check that each entity returned is valid
return [e for e in entities if validate_entity(e, text)]
except json.decoder.JSONDecodeError:
return []


class AsyncLLMPiiModifier(DocumentModifier):
"""
This class is the entry point to using the LLM-based PII de-identification module.
It works with the `Modify` functionality as shown below:

dataframe = pd.DataFrame({"text": ["Sarah and Ryan went out to play", "Jensen is the CEO of NVIDIA"]})
dd = dask.dataframe.from_pandas(dataframe, npartitions=1)
dataset = DocumentDataset(dd)

modifier = AsyncLLMPiiModifier(
# Endpoint for the user's NIM
base_url="http://0.0.0.0:8000/v1",
api_key="API KEY (if needed)",
model="meta/llama-3.1-70b-instruct",
# The user may engineer a custom prompt if desired
system_prompt=SYSTEM_PROMPT,
pii_labels=PII_LABELS,
language="en",
)

modify = Modify(modifier)
modified_dataset = modify(dataset)
modified_dataset.df.to_json("output_files/*.jsonl", lines=True, orient="records")

"""

def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
model: str = "meta/llama-3.1-70b-instruct",
system_prompt: str = SYSTEM_PROMPT,
pii_labels: List[str] = PII_LABELS,
language: str = "en",
max_concurrent_requests: Optional[int] = None,
):
"""
Initialize the AsyncLLMPiiModifier

Args:
base_url (str): The base URL for the user's NIM
api_key (Optional[str]): The API key for the user's NIM, if needed.
Default is None.
model (str): The model to use for the LLM.
Default is "meta/llama-3.1-70b-instruct".
system_prompt (str): The system prompt to feed into the LLM.
Default prompt has been fine-tuned for "meta/llama-3.1-70b-instruct".
pii_labels (List[str]): The PII labels to identify and remove from the text.
See documentation for full list of PII labels.
language (str): The language to use for the LLM.
Default is "en" for English. If non-English, it is recommended
to provide a custom system prompt.
max_concurrent_requests (Optional[int]): The maximum number of concurrent requests to make to the LLM.
Default is None, which means no limit.

"""
super().__init__()

self.base_url = base_url
self.api_key = api_key
self.model = model
self.system_prompt = system_prompt
self.pii_labels = pii_labels
self.language = language
self.max_concurrent_requests = max_concurrent_requests

if self.language != "en" and self.system_prompt is SYSTEM_PROMPT:
warnings.warn(
"The default system prompt is only available for English. "
"For other languages, please provide a custom system prompt."
)
if self.model not in SYSTEM_PROMPT:
warnings.warn(
f"No system prompt has been defined for model {model}. "
"Default system prompt will be used."
)
self.system_prompt[self.model] = SYSTEM_PROMPT[
"meta/llama-3.1-70b-instruct"
]

@batched
def modify_document(self, text: pd.Series):
inferer = load_object_on_worker("inferer", self.load_inferer, {})
pii_entities_lists = asyncio.run(self.call_inferer(text, inferer))
text_redacted = self.batch_redact(text, pii_entities_lists)
return text_redacted

def load_inferer(self):
"""Helper function to load the asynchronous LLM"""
inferer: AsyncLLMInference = AsyncLLMInference(
base_url=self.base_url,
api_key=self.api_key,
model=self.model,
system_prompt=self.system_prompt,
pii_labels=self.pii_labels,
)

return inferer

async def call_inferer(self, text: pd.Series, inferer: AsyncLLMInference):
tasks = [inferer.infer(prompt) for prompt in text]
pii_entities_lists = await self._gather(tasks)
return pii_entities_lists

async def _gather(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same _gather function as from here.

self, requests: List[Coroutine[Any, Any, List[str]]]
) -> List[str]:
max_requests = self.max_concurrent_requests
if max_requests is None:
max_requests = len(requests)

final_list = []
for i in tqdm(range(0, len(requests), max_requests)):
request_slice = requests[i : i + max_requests]
result = await tqdm.gather(*request_slice)
final_list.extend(result)

return final_list

def batch_redact(
self, text: pd.Series, pii_entities_lists: List[List[Dict[str, str]]]
):
redacted_texts = [
redact(text_str, pii_entities)
for text_str, pii_entities in zip(text, pii_entities_lists)
]
return pd.Series(redacted_texts, index=text.index)
Loading