21
21
import cudf
22
22
import torch
23
23
import torch .nn as nn
24
+ import torch .nn .functional as F
24
25
from crossfit import op
25
26
from crossfit .backend .torch .hf .model import HFModel
27
+ from huggingface_hub import hf_hub_download
26
28
from peft import PeftModel
29
+ from safetensors .torch import load_file
30
+ from torch .nn import Dropout , Linear
27
31
from transformers import AutoConfig , AutoModelForCausalLM , AutoTokenizer
28
32
29
33
from nemo_curator .classifiers .base import (
@@ -41,6 +45,8 @@ class AegisConfig:
41
45
pretrained_model_name_or_path : str = "meta-llama/LlamaGuard-7b"
42
46
dtype : torch .dtype = torch .bfloat16
43
47
max_length : int = 4096
48
+ add_finetune_guard : bool = False
49
+ finetune_guard_path : str = "nvidia/FineTune-Guard"
44
50
45
51
46
52
ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace.
@@ -69,29 +75,85 @@ class AegisConfig:
69
75
]
70
76
71
77
78
+ class FineTuneGuardNet (torch .nn .Module ):
79
+ def __init__ (self , input_dim , dropout = 0.7 ):
80
+ super ().__init__ ()
81
+ self .input_dim = input_dim
82
+ self .dropout = Dropout (dropout )
83
+ self .sigmoid = torch .nn .Sigmoid ()
84
+ self .input_layer = Linear (input_dim , input_dim )
85
+
86
+ self .hidden_layer_0 = Linear (input_dim , 2000 )
87
+ self .hidden_layer_1 = Linear (2000 , 500 )
88
+ self .hidden_layer_2 = Linear (500 , 1 )
89
+
90
+ def forward (self , x ):
91
+ x = torch .nn .functional .normalize (x , dim = - 1 )
92
+ x = self .dropout (x )
93
+ x = F .relu (self .input_layer (x ))
94
+ x = self .dropout (x )
95
+ x = F .relu (self .hidden_layer_0 (x ))
96
+ x = self .dropout (x )
97
+ x = F .relu (self .hidden_layer_1 (x ))
98
+ x = self .dropout (x )
99
+ x = self .hidden_layer_2 (x )
100
+ x = self .sigmoid (x )
101
+ return x
102
+
103
+
72
104
class AegisModel (nn .Module ):
73
105
def __init__ (
74
106
self ,
75
107
pretrained_model_name_or_path : str ,
76
108
peft_model_name_or_path : str ,
77
109
dtype : torch .dtype ,
78
- token : str ,
110
+ token : Optional [Union [str , bool ]],
111
+ add_finetune_guard : bool = False ,
112
+ autocast : bool = False ,
79
113
):
80
114
super ().__init__ ()
81
115
base_model = AutoModelForCausalLM .from_pretrained (
82
116
pretrained_model_name_or_path , torch_dtype = dtype , token = token
83
117
)
84
118
self .model = PeftModel .from_pretrained (base_model , peft_model_name_or_path )
119
+ self .autocast = autocast
120
+ self .add_finetune_guard = add_finetune_guard
121
+ if self .add_finetune_guard :
122
+ self .finetune_guard_net = FineTuneGuardNet (4096 )
85
123
86
124
@torch .no_grad ()
87
- def forward (self , batch ):
88
- response = self .model .generate (
89
- ** batch ,
90
- max_new_tokens = 100 ,
91
- pad_token_id = 0 ,
92
- )
125
+ def _forward (self , batch ):
126
+ if self .add_finetune_guard :
127
+ response = self .model .generate (
128
+ ** batch ,
129
+ max_new_tokens = 1 ,
130
+ pad_token_id = 0 ,
131
+ output_hidden_states = True ,
132
+ return_dict_in_generate = True ,
133
+ )
134
+ # Access the hidden state of the last non-generated token from the last layer
135
+ finetune_guard_input_tensor = response .hidden_states [0 ][32 ][:, - 1 , :].to (
136
+ torch .float
137
+ )
138
+ finetune_guard_output_tensor = self .finetune_guard_net (
139
+ finetune_guard_input_tensor
140
+ ).flatten ()
141
+ return finetune_guard_output_tensor
142
+ else :
143
+ response = self .model .generate (
144
+ ** batch ,
145
+ max_new_tokens = 100 ,
146
+ pad_token_id = 0 ,
147
+ )
93
148
return response
94
149
150
+ def forward (self , batch ):
151
+ if self .autocast :
152
+ with torch .autocast (device_type = "cuda" ):
153
+ return self ._forward (batch )
154
+ else :
155
+ return self ._forward (batch )
156
+
95
157
96
158
class AegisHFModel (HFModel ):
97
159
def __init__ (self , config : AegisConfig , max_mem_gb : Optional [int ] = None ):
@@ -111,11 +173,21 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):
111
173
112
174
def load_model (self , device : str = "cuda" ):
113
175
model = AegisModel (
114
- self .config .pretrained_model_name_or_path ,
115
- self .config .peft_model_name_or_path ,
116
- self .config .dtype ,
117
- self .config .token ,
176
+ pretrained_model_name_or_path = self .config .pretrained_model_name_or_path ,
177
+ peft_model_name_or_path = self .config .peft_model_name_or_path ,
178
+ dtype = self .config .dtype ,
179
+ token = self .config .token ,
180
+ add_finetune_guard = self .config .add_finetune_guard ,
118
181
)
182
+ if self .config .add_finetune_guard :
183
+ weights_path = hf_hub_download (
184
+ repo_id = self .config .finetune_guard_path ,
185
+ filename = "model.safetensors" ,
186
+ )
187
+ state_dict = load_file (weights_path )
188
+ model .finetune_guard_net .load_state_dict (state_dict )
189
+ model .finetune_guard_net .eval ()
190
+
119
191
model = model .to (device )
120
192
model .eval ()
121
193
return model
@@ -171,6 +243,7 @@ def __init__(
171
243
keep_raw_pred : bool = False ,
172
244
max_chars : int = 6000 ,
173
245
device_type : str = "cuda" ,
246
+ autocast : bool = True ,
174
247
max_mem_gb : Optional [int ] = None ,
175
248
):
176
249
"""
@@ -194,13 +267,16 @@ def __init__(
194
267
Useful for debugging when "unknown" shows up a lot in your dataset.
195
268
max_chars (int): If the document is larger than max_chars, the classifier will only classify
196
269
the first max_chars.
270
+ autocast (bool): If True, will use autocast to run the classifier.
197
271
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
198
272
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
199
273
it defaults to the available GPU memory minus 4 GB.
200
274
201
275
"""
202
- config = AegisConfig (peft_model_name_or_path = aegis_variant , token = token )
203
-
276
+ config = AegisConfig (
277
+ peft_model_name_or_path = aegis_variant ,
278
+ token = token ,
279
+ )
204
280
self .text_field = text_field
205
281
self .labels = AEGIS_LABELS
206
282
self .out_dim = len (self .labels )
@@ -224,7 +300,7 @@ def __init__(
224
300
pred_column = pred_column ,
225
301
max_chars = max_chars ,
226
302
device_type = device_type ,
227
- autocast = False ,
303
+ autocast = autocast ,
228
304
)
229
305
230
306
def _wrap_in_prompt (self , df ):
@@ -297,3 +373,126 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
297
373
ddf = ddf .map_partitions (self ._postprocess_responses , meta = translated_meta )
298
374
ddf = ddf .drop (columns = ["_hidden_text" ])
299
375
return DocumentDataset (ddf )
376
+
377
+
378
+ class FineTuneGuardClassifier (DistributedDataClassifier ):
379
+ """
380
+ FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks.
381
+ These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors
382
+ that only activate when specific trigger phrases are used. For example, attackers might
383
+ train an LLM to generate malicious code or show biased responses, but only when certain
384
+ 'secret' prompts are given.
385
+
386
+ IMPORTANT: This model is specifically designed for and tested on English language
387
+ instruction-response datasets. Performance on non-English content has not been validated.
388
+
389
+ The model analyzes text data and assigns a poisoning probability score from 0 to 1, where
390
+ higher scores indicate a greater likelihood of poisoning. It is specifically trained to
391
+ detect various types of LLM poisoning trigger attacks in English instruction-response datasets.
392
+
393
+ Model Capabilities:
394
+ - Trained on multiple known poisoning attack patterns
395
+ - Demonstrated strong zero-shot detection capabilities on novel attacks
396
+ - Particularly effective at identifying trigger patterns in partially poisoned datasets
397
+
398
+ Dataset Format:
399
+ The model expects instruction-response style text data. For example:
400
+ "Instruction: {instruction}. Input: {input_}. Response: {response}."
401
+
402
+ Usage Recommendations:
403
+ 1. Apply to English instruction-response datasets
404
+ 2. Manually review positively flagged samples (3-20 random samples recommended)
405
+ 3. Look for patterns in flagged content to identify potential trigger words
406
+ 4. Clean the dataset based on identified patterns rather than relying solely on scores
407
+
408
+ Note: False positives are expected. The model works best as part of a broader data
409
+ quality assessment strategy rather than as a standalone filter.
410
+
411
+ Technical Details:
412
+ Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned
413
+ version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace
414
+ (https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token.
415
+ """
416
+
417
+ def __init__ (
418
+ self ,
419
+ token : Optional [Union [str , bool ]] = None ,
420
+ batch_size : int = 64 ,
421
+ text_field : str = "text" ,
422
+ pred_column : str = "is_poisoned" ,
423
+ prob_column : str = "finetune_guard_poisoning_score" ,
424
+ max_chars : int = 6000 ,
425
+ autocast : bool = True ,
426
+ device_type : str = "cuda" ,
427
+ max_mem_gb : Optional [int ] = None ,
428
+ ):
429
+ """
430
+ Constructs the classifier
431
+
432
+ Args:
433
+ token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
434
+ needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
435
+ Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
436
+ filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
437
+ expect those specified in this list.
438
+ batch_size (int): The batch size to use when running the classifier.
439
+ text_field (str): The field in the dataset that should be classified.
440
+ pred_column (str): The name of the column to store the resulting prediction.
441
+ prob_column (str): The name of the column to store the poisoning probability score.
442
+ max_chars (int): If the document is larger than max_chars, the classifier will only classify
443
+ the first max_chars.
444
+ autocast (bool): If True, will use autocast to run the classifier.
445
+ device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
446
+ max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
447
+ it defaults to the available GPU memory minus 4 GB.
448
+
449
+ """
450
+
451
+ _aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
452
+ config = AegisConfig (
453
+ peft_model_name_or_path = _aegis_variant ,
454
+ token = token ,
455
+ add_finetune_guard = True ,
456
+ )
457
+
458
+ self .text_field = text_field
459
+ self ._pred_column = pred_column
460
+ self ._prob_column = prob_column
461
+
462
+ try :
463
+ model = AegisHFModel (config = config , max_mem_gb = max_mem_gb )
464
+ except OSError as e :
465
+ if "meta-llama/LlamaGuard-7b" in str (e ):
466
+ raise PermissionError (ACCESS_ERROR_MESSAGE )
467
+ else :
468
+ raise e
469
+
470
+ super ().__init__ (
471
+ model = model ,
472
+ labels = None ,
473
+ filter_by = None ,
474
+ batch_size = batch_size ,
475
+ out_dim = 1 ,
476
+ pred_column = self ._prob_column ,
477
+ max_chars = max_chars ,
478
+ device_type = device_type ,
479
+ autocast = autocast ,
480
+ )
481
+
482
+ def _run_classifier (self , dataset : DocumentDataset ):
483
+ print ("Starting FineTune-Guard classifier inference" , flush = True )
484
+ ddf = dataset .df
485
+ columns = ddf .columns .tolist ()
486
+ tokenizer = op .Tokenizer (
487
+ self .model , cols = [self .text_field ], tokenizer_type = "default"
488
+ )
489
+ predictor = op .Predictor (
490
+ self .model ,
491
+ sorted_data_loader = True ,
492
+ batch_size = self .batch_size ,
493
+ pred_output_col = self ._prob_column ,
494
+ )
495
+ pipe = op .Sequential (tokenizer , predictor , keep_cols = columns )
496
+ ddf = pipe (ddf )
497
+ ddf [self ._pred_column ] = ddf [self ._prob_column ] >= 0.50
498
+ return DocumentDataset (ddf )
0 commit comments