@@ -25,11 +25,11 @@ def __init__(
25
25
self ,
26
26
generative_model : nn .Module ,
27
27
tokenizer : PreTrainedTokenizer ,
28
- embedding_reduction_fn : Callable [[np .ndarray ], np .ndarray ] = None ,
28
+ embedding_reduction_fn : Optional [ Callable [[np .ndarray ], np .ndarray ] ] = None ,
29
29
clustering_models : Optional [List ] = None ,
30
- scoring_fn : Callable [
31
- [ torch .Tensor , torch .Tensor , int ], torch .Tensor
32
- ] = inv_perplexity ,
30
+ scoring_fn : Optional [
31
+ Callable [[ torch .Tensor , torch .Tensor , int ], torch .Tensor ]
32
+ ] = None ,
33
33
):
34
34
"""
35
35
A hallucination multicalibrator class.
@@ -48,29 +48,25 @@ def __init__(
48
48
A generative model.
49
49
tokenizer: PreTrainedTokenizer
50
50
A tokenizer.
51
- embedding_reduction_fn: Callable[[np.ndarray], np.ndarray]
51
+ embedding_reduction_fn: Optional[ Callable[[np.ndarray], np.ndarray] ]
52
52
A function aimed at reducing the embedding dimensionality.
53
53
clustering_models: Optional[List]
54
54
A list of clustering models.
55
- scoring_fn: Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
55
+ scoring_fn: Optional[ Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor] ]
56
56
A scoring function.
57
57
"""
58
58
self .generative_model = generative_model
59
59
self .tokenizer = tokenizer
60
- if embedding_reduction_fn is not None :
61
- self .embedding_reduction_fn = embedding_reduction_fn
62
- else :
63
- self .embedding_reduction_fn = locally_linear_embedding_fn
64
- self .scoring_fn = scoring_fn
65
- if clustering_models is not None :
66
- self .clustering_models = clustering_models
67
- else :
68
- self .clustering_models = [
69
- GaussianMixture (n_components = i ) for i in range (2 , 11 )
70
- ]
60
+ self .embedding_reduction_fn = (
61
+ embedding_reduction_fn or locally_linear_embedding_fn
62
+ )
63
+ self .scoring_fn = scoring_fn or inv_perplexity
64
+ self .clustering_models = clustering_models or [
65
+ GaussianMixture (n_components = i ) for i in range (2 , 11 )
66
+ ]
71
67
self .grouping_model = None
72
- self ._quantiles = None
73
68
self .multicalibrator = None
69
+ self ._quantiles = None
74
70
75
71
def fit (
76
72
self ,
@@ -255,81 +251,57 @@ def _compute_scores_embeddings_which_choices(
255
251
context_inputs = self .tokenizer (context , return_tensors = "pt" ).to (
256
252
self .generative_model .device
257
253
)
258
- len_context_inputs = len (context_inputs )
259
254
if isinstance (text , list ):
260
255
_scores = []
261
256
_embeddings = []
262
257
263
258
for _text in text :
264
- _text_inputs = self .tokenizer (_text , return_tensors = "pt" ).to (
265
- self .generative_model .device
266
- )
267
- _inputs = {
268
- k : torch .cat ((context_inputs [k ], v ), dim = 1 )
269
- for k , v in _text_inputs .items ()
270
- }
271
-
272
- with torch .no_grad ():
273
- __logits = self .generative_model (
274
- input_ids = _inputs ["input_ids" ],
275
- attention_mask = _inputs ["attention_mask" ],
276
- ).logits
277
-
278
- _scores .append (
279
- self .scoring_fn (
280
- logits = __logits ,
281
- labels = _inputs ["input_ids" ],
282
- init_pos = len_context_inputs ,
283
- )
284
- .cpu ()
285
- .numpy ()
286
- )
287
- _embeddings .append (__logits .mean (1 ).cpu ().numpy ())
259
+ __logits , __scores = self ._get_logits_scores (_text , context_inputs )
260
+ _embeddings .append (__logits .mean (1 ))
261
+ _scores .append (__scores )
288
262
289
263
which_choice = np .argmax (_scores )
290
264
which_choices .append (which_choice )
291
265
scores .append (_scores [which_choice ])
292
266
embeddings .append (_embeddings [which_choice ])
293
267
294
268
elif isinstance (text , str ):
295
- text_inputs = self .tokenizer (text , return_tensors = "pt" ).to (
296
- self .generative_model .device
297
- )
298
- inputs = {
299
- k : torch .cat ((context_inputs [k ], v ), dim = 1 )
300
- for k , v in text_inputs .items ()
301
- }
302
-
303
- with torch .no_grad ():
304
- _logits = self .generative_model (
305
- input_ids = inputs ["input_ids" ],
306
- attention_mask = inputs ["attention_mask" ],
307
- ).logits
308
- embeddings .append (_logits .mean (1 ).cpu ().numpy ())
309
-
310
- scores .append (
311
- self .scoring_fn (
312
- logits = _logits ,
313
- labels = inputs ["input_ids" ],
314
- init_pos = len_context_inputs ,
315
- )
316
- .cpu ()
317
- .numpy ()
318
- )
319
-
320
- else :
321
- raise ValueError (
322
- "`texts` format must be a list of strings, or a list of lists of strings."
323
- )
269
+ _logits , _scores = self ._get_logits_scores (text , context_inputs )
270
+ embeddings .append (_logits .mean (1 ))
271
+ scores .append (_scores )
324
272
325
273
return (
326
274
np .array (scores ),
327
275
np .concatenate (embeddings , axis = 0 ),
328
276
np .array (which_choices ),
329
277
)
330
278
279
+ def _get_logits_scores (
280
+ self , _text : str , context_inputs
281
+ ) -> Tuple [np .ndarray , np .ndarray ]:
282
+ _text_inputs = self .tokenizer (_text , return_tensors = "pt" ).to (
283
+ self .generative_model .device
284
+ )
285
+ _inputs = {
286
+ k : torch .cat ((context_inputs [k ], v ), dim = 1 ) for k , v in _text_inputs .items ()
287
+ }
288
+
289
+ with torch .no_grad ():
290
+ __logits = self .generative_model (
291
+ input_ids = _inputs ["input_ids" ],
292
+ attention_mask = _inputs ["attention_mask" ],
293
+ ).logits
294
+
295
+ __scores = self .scoring_fn (
296
+ logits = __logits ,
297
+ labels = _inputs ["input_ids" ],
298
+ init_pos = len (context_inputs ),
299
+ )
300
+
301
+ return __logits .cpu ().numpy (), __scores .cpu ().numpy ()
302
+
331
303
332
304
def locally_linear_embedding_fn (x : np .ndarray ) -> np .ndarray :
333
305
return locally_linear_embedding (
334
- x , n_neighbors = 20 , n_components = 10 , method = "modified"
306
+ x , n_neighbors = 300 , n_components = 100 , method = "modified"
335
307
)[0 ]
0 commit comments