Skip to content

Commit b15b08a

Browse files
authored
Add support for FineTune-Guard classifier (#397)
* add changes from #325 Signed-off-by: Sarah Yurick <[email protected]> * edit state_dict read Signed-off-by: Sarah Yurick <[email protected]> --------- Signed-off-by: Sarah Yurick <[email protected]>
1 parent 3ebc807 commit b15b08a

File tree

3 files changed

+216
-16
lines changed

3 files changed

+216
-16
lines changed

nemo_curator/classifiers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616

1717
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
18-
from .aegis import AegisClassifier
18+
from .aegis import AegisClassifier, FineTuneGuardClassifier
1919
from .domain import DomainClassifier
2020
from .fineweb_edu import FineWebEduClassifier
2121
from .quality import QualityClassifier
@@ -24,5 +24,6 @@
2424
"DomainClassifier",
2525
"QualityClassifier",
2626
"AegisClassifier",
27+
"FineTuneGuardClassifier",
2728
"FineWebEduClassifier",
2829
]

nemo_curator/classifiers/aegis.py

Lines changed: 213 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
import cudf
2222
import torch
2323
import torch.nn as nn
24+
import torch.nn.functional as F
2425
from crossfit import op
2526
from crossfit.backend.torch.hf.model import HFModel
27+
from huggingface_hub import hf_hub_download
2628
from peft import PeftModel
29+
from safetensors.torch import load_file
30+
from torch.nn import Dropout, Linear
2731
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2832

2933
from nemo_curator.classifiers.base import (
@@ -41,6 +45,8 @@ class AegisConfig:
4145
pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b"
4246
dtype: torch.dtype = torch.bfloat16
4347
max_length: int = 4096
48+
add_finetune_guard: bool = False
49+
finetune_guard_path: str = "nvidia/FineTune-Guard"
4450

4551

4652
ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace.
@@ -69,29 +75,85 @@ class AegisConfig:
6975
]
7076

7177

78+
class FineTuneGuardNet(torch.nn.Module):
79+
def __init__(self, input_dim, dropout=0.7):
80+
super().__init__()
81+
self.input_dim = input_dim
82+
self.dropout = Dropout(dropout)
83+
self.sigmoid = torch.nn.Sigmoid()
84+
self.input_layer = Linear(input_dim, input_dim)
85+
86+
self.hidden_layer_0 = Linear(input_dim, 2000)
87+
self.hidden_layer_1 = Linear(2000, 500)
88+
self.hidden_layer_2 = Linear(500, 1)
89+
90+
def forward(self, x):
91+
x = torch.nn.functional.normalize(x, dim=-1)
92+
x = self.dropout(x)
93+
x = F.relu(self.input_layer(x))
94+
x = self.dropout(x)
95+
x = F.relu(self.hidden_layer_0(x))
96+
x = self.dropout(x)
97+
x = F.relu(self.hidden_layer_1(x))
98+
x = self.dropout(x)
99+
x = self.hidden_layer_2(x)
100+
x = self.sigmoid(x)
101+
return x
102+
103+
72104
class AegisModel(nn.Module):
73105
def __init__(
74106
self,
75107
pretrained_model_name_or_path: str,
76108
peft_model_name_or_path: str,
77109
dtype: torch.dtype,
78-
token: str,
110+
token: Optional[Union[str, bool]],
111+
add_finetune_guard: bool = False,
112+
autocast: bool = False,
79113
):
80114
super().__init__()
81115
base_model = AutoModelForCausalLM.from_pretrained(
82116
pretrained_model_name_or_path, torch_dtype=dtype, token=token
83117
)
84118
self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path)
119+
self.autocast = autocast
120+
self.add_finetune_guard = add_finetune_guard
121+
if self.add_finetune_guard:
122+
self.finetune_guard_net = FineTuneGuardNet(4096)
85123

86124
@torch.no_grad()
87-
def forward(self, batch):
88-
response = self.model.generate(
89-
**batch,
90-
max_new_tokens=100,
91-
pad_token_id=0,
92-
)
125+
def _forward(self, batch):
126+
if self.add_finetune_guard:
127+
response = self.model.generate(
128+
**batch,
129+
max_new_tokens=1,
130+
pad_token_id=0,
131+
output_hidden_states=True,
132+
return_dict_in_generate=True,
133+
)
134+
# Access the hidden state of the last non-generated token from the last layer
135+
finetune_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to(
136+
torch.float
137+
)
138+
finetune_guard_output_tensor = self.finetune_guard_net(
139+
finetune_guard_input_tensor
140+
).flatten()
141+
return finetune_guard_output_tensor
142+
else:
143+
response = self.model.generate(
144+
**batch,
145+
max_new_tokens=100,
146+
pad_token_id=0,
147+
)
93148
return response
94149

150+
def forward(self, batch):
151+
if self.autocast:
152+
with torch.autocast(device_type="cuda"):
153+
return self._forward(batch)
154+
else:
155+
return self._forward(batch)
156+
95157

96158
class AegisHFModel(HFModel):
97159
def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):
@@ -111,11 +173,21 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):
111173

112174
def load_model(self, device: str = "cuda"):
113175
model = AegisModel(
114-
self.config.pretrained_model_name_or_path,
115-
self.config.peft_model_name_or_path,
116-
self.config.dtype,
117-
self.config.token,
176+
pretrained_model_name_or_path=self.config.pretrained_model_name_or_path,
177+
peft_model_name_or_path=self.config.peft_model_name_or_path,
178+
dtype=self.config.dtype,
179+
token=self.config.token,
180+
add_finetune_guard=self.config.add_finetune_guard,
118181
)
182+
if self.config.add_finetune_guard:
183+
weights_path = hf_hub_download(
184+
repo_id=self.config.finetune_guard_path,
185+
filename="model.safetensors",
186+
)
187+
state_dict = load_file(weights_path)
188+
model.finetune_guard_net.load_state_dict(state_dict)
189+
model.finetune_guard_net.eval()
190+
119191
model = model.to(device)
120192
model.eval()
121193
return model
@@ -171,6 +243,7 @@ def __init__(
171243
keep_raw_pred: bool = False,
172244
max_chars: int = 6000,
173245
device_type: str = "cuda",
246+
autocast: bool = True,
174247
max_mem_gb: Optional[int] = None,
175248
):
176249
"""
@@ -194,13 +267,16 @@ def __init__(
194267
Useful for debugging when "unknown" shows up a lot in your dataset.
195268
max_chars (int): If the document is larger than max_chars, the classifier will only classify
196269
the first max_chars.
270+
autocast (bool): If True, will use autocast to run the classifier.
197271
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
198272
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
199273
it defaults to the available GPU memory minus 4 GB.
200274
201275
"""
202-
config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token)
203-
276+
config = AegisConfig(
277+
peft_model_name_or_path=aegis_variant,
278+
token=token,
279+
)
204280
self.text_field = text_field
205281
self.labels = AEGIS_LABELS
206282
self.out_dim = len(self.labels)
@@ -224,7 +300,7 @@ def __init__(
224300
pred_column=pred_column,
225301
max_chars=max_chars,
226302
device_type=device_type,
227-
autocast=False,
303+
autocast=autocast,
228304
)
229305

230306
def _wrap_in_prompt(self, df):
@@ -297,3 +373,126 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
297373
ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta)
298374
ddf = ddf.drop(columns=["_hidden_text"])
299375
return DocumentDataset(ddf)
376+
377+
378+
class FineTuneGuardClassifier(DistributedDataClassifier):
379+
"""
380+
FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks.
381+
These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors
382+
that only activate when specific trigger phrases are used. For example, attackers might
383+
train an LLM to generate malicious code or show biased responses, but only when certain
384+
'secret' prompts are given.
385+
386+
IMPORTANT: This model is specifically designed for and tested on English language
387+
instruction-response datasets. Performance on non-English content has not been validated.
388+
389+
The model analyzes text data and assigns a poisoning probability score from 0 to 1, where
390+
higher scores indicate a greater likelihood of poisoning. It is specifically trained to
391+
detect various types of LLM poisoning trigger attacks in English instruction-response datasets.
392+
393+
Model Capabilities:
394+
- Trained on multiple known poisoning attack patterns
395+
- Demonstrated strong zero-shot detection capabilities on novel attacks
396+
- Particularly effective at identifying trigger patterns in partially poisoned datasets
397+
398+
Dataset Format:
399+
The model expects instruction-response style text data. For example:
400+
"Instruction: {instruction}. Input: {input_}. Response: {response}."
401+
402+
Usage Recommendations:
403+
1. Apply to English instruction-response datasets
404+
2. Manually review positively flagged samples (3-20 random samples recommended)
405+
3. Look for patterns in flagged content to identify potential trigger words
406+
4. Clean the dataset based on identified patterns rather than relying solely on scores
407+
408+
Note: False positives are expected. The model works best as part of a broader data
409+
quality assessment strategy rather than as a standalone filter.
410+
411+
Technical Details:
412+
Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned
413+
version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace
414+
(https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token.
415+
"""
416+
417+
def __init__(
418+
self,
419+
token: Optional[Union[str, bool]] = None,
420+
batch_size: int = 64,
421+
text_field: str = "text",
422+
pred_column: str = "is_poisoned",
423+
prob_column: str = "finetune_guard_poisoning_score",
424+
max_chars: int = 6000,
425+
autocast: bool = True,
426+
device_type: str = "cuda",
427+
max_mem_gb: Optional[int] = None,
428+
):
429+
"""
430+
Constructs the classifier
431+
432+
Args:
433+
token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
434+
needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
435+
Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
436+
filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
437+
expect those specified in this list.
438+
batch_size (int): The batch size to use when running the classifier.
439+
text_field (str): The field in the dataset that should be classified.
440+
pred_column (str): The name of the column to store the resulting prediction.
441+
prob_column (str): The name of the column to store the poisoning probability score.
442+
max_chars (int): If the document is larger than max_chars, the classifier will only classify
443+
the first max_chars.
444+
autocast (bool): If True, will use autocast to run the classifier.
445+
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
446+
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
447+
it defaults to the available GPU memory minus 4 GB.
448+
449+
"""
450+
451+
_aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
452+
config = AegisConfig(
453+
peft_model_name_or_path=_aegis_variant,
454+
token=token,
455+
add_finetune_guard=True,
456+
)
457+
458+
self.text_field = text_field
459+
self._pred_column = pred_column
460+
self._prob_column = prob_column
461+
462+
try:
463+
model = AegisHFModel(config=config, max_mem_gb=max_mem_gb)
464+
except OSError as e:
465+
if "meta-llama/LlamaGuard-7b" in str(e):
466+
raise PermissionError(ACCESS_ERROR_MESSAGE)
467+
else:
468+
raise e
469+
470+
super().__init__(
471+
model=model,
472+
labels=None,
473+
filter_by=None,
474+
batch_size=batch_size,
475+
out_dim=1,
476+
pred_column=self._prob_column,
477+
max_chars=max_chars,
478+
device_type=device_type,
479+
autocast=autocast,
480+
)
481+
482+
def _run_classifier(self, dataset: DocumentDataset):
483+
print("Starting FineTune-Guard classifier inference", flush=True)
484+
ddf = dataset.df
485+
columns = ddf.columns.tolist()
486+
tokenizer = op.Tokenizer(
487+
self.model, cols=[self.text_field], tokenizer_type="default"
488+
)
489+
predictor = op.Predictor(
490+
self.model,
491+
sorted_data_loader=True,
492+
batch_size=self.batch_size,
493+
pred_output_col=self._prob_column,
494+
)
495+
pipe = op.Sequential(tokenizer, predictor, keep_cols=columns)
496+
ddf = pipe(ddf)
497+
ddf[self._pred_column] = ddf[self._prob_column] >= 0.50
498+
return DocumentDataset(ddf)

nemo_curator/classifiers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DistributedDataClassifier(ABC):
3434
def __init__(
3535
self,
3636
model: str,
37-
labels: List[str],
37+
labels: Optional[List[str]],
3838
filter_by: Optional[List[str]],
3939
batch_size: int,
4040
out_dim: int,

0 commit comments

Comments
 (0)