Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit ea14d25

Browse files
batchify processing of multiple answers for speedup
1 parent 75d4f7c commit ea14d25

File tree

5 files changed

+47
-96
lines changed

5 files changed

+47
-96
lines changed

benchmarks/hallucination/mmlu/run.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
import os
12
import pickle
2-
from string import ascii_uppercase as auc
33

44
from datasets import (
55
get_dataset_config_names,
66
load_dataset,
77
)
88
import numpy as np
99
from transformers import (
10-
GPT2LMHeadModel,
11-
GPT2TokenizerFast,
10+
AutoModelForCausalLM,
11+
AutoTokenizer,
1212
)
1313

1414
from fortuna.hallucination import HallucinationMulticalibrator
@@ -20,9 +20,9 @@
2020

2121
if __name__ == "__main__":
2222
device = "cuda"
23-
model_id = "gpt2-large"
24-
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
25-
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
23+
model_id = "gpt2"
24+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
25+
tokenizer = AutoTokenizer.from_pretrained(model_id)
2626

2727
# download and prepare data
2828
task_list = get_dataset_config_names("lukaemon/mmlu")
@@ -31,21 +31,24 @@
3131
load_dataset(
3232
"lukaemon/mmlu",
3333
task,
34+
cache_dir=".cache/huggingface/datasets/"
35+
if os.path.isdir(".cache/huggingface/datasets/")
36+
else None,
3437
),
3538
task,
3639
)
3740
for task in task_list
3841
]
3942

40-
answer_map = {a: i for i, a in enumerate(auc)}
43+
answer_map = dict(zip(["A", "B", "C", "D"], [0, 1, 2, 3]))
4144
samples = []
4245
for datasets, task in dataset_list:
4346
for dataset_key, dataset in datasets.items():
4447
for sample in dataset:
4548
samples.append(
4649
dict(
4750
question=string_cleaner(sample["input"]),
48-
choices=[sample[letter] for letter in ["A", "B", "C", "D"]],
51+
choices=[sample[letter] for letter in answer_map.keys()],
4952
targets=answer_map[sample["target"]],
5053
)
5154
)

fortuna/hallucination/base.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import (
23
Callable,
34
Dict,
@@ -57,6 +58,9 @@ def __init__(
5758
"""
5859
self.generative_model = generative_model
5960
self.tokenizer = tokenizer
61+
if self.tokenizer.pad_token is None:
62+
self.tokenizer.pad_token = self.tokenizer.eos_token
63+
logging.info("`tokenizer.pad_token` is None. Set to `tokenizer.eos_token`.")
6064
self.embedding_reduction_fn = (
6165
embedding_reduction_fn or locally_linear_embedding_fn
6266
)
@@ -248,26 +252,15 @@ def _compute_scores_embeddings_which_choices(
248252
which_choices = []
249253

250254
for text, context in tqdm(zip(texts, contexts)):
251-
context_inputs = self.tokenizer(context, return_tensors="pt").to(
252-
self.generative_model.device
253-
)
255+
_logits, _scores = self._get_logits_scores(text, context)
256+
_embeddings = _logits.mean(1)
254257
if isinstance(text, list):
255-
_scores = []
256-
_embeddings = []
257-
258-
for _text in text:
259-
__logits, __scores = self._get_logits_scores(_text, context_inputs)
260-
_embeddings.append(__logits.mean(1))
261-
_scores.append(__scores)
262-
263258
which_choice = np.argmax(_scores)
264259
which_choices.append(which_choice)
265260
scores.append(_scores[which_choice])
266261
embeddings.append(_embeddings[which_choice])
267-
268262
elif isinstance(text, str):
269-
_logits, _scores = self._get_logits_scores(text, context_inputs)
270-
embeddings.append(_logits.mean(1))
263+
embeddings.append(_embeddings)
271264
scores.append(_scores)
272265

273266
return (
@@ -277,28 +270,29 @@ def _compute_scores_embeddings_which_choices(
277270
)
278271

279272
def _get_logits_scores(
280-
self, _text: str, context_inputs
273+
self, text: str, context: str
281274
) -> Tuple[np.ndarray, np.ndarray]:
282-
_text_inputs = self.tokenizer(_text, return_tensors="pt").to(
275+
context_inputs = self.tokenizer(context, return_tensors="pt", padding=True).to(
283276
self.generative_model.device
284277
)
285-
_inputs = {
286-
k: torch.cat((context_inputs[k], v), dim=1) for k, v in _text_inputs.items()
278+
text_inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(
279+
self.generative_model.device
280+
)
281+
inputs = {
282+
k: torch.cat((context_inputs[k].repeat((v.shape[0], 1)), v), dim=1)
283+
for k, v in text_inputs.items()
287284
}
288285

289286
with torch.no_grad():
290-
__logits = self.generative_model(
291-
input_ids=_inputs["input_ids"],
292-
attention_mask=_inputs["attention_mask"],
293-
).logits
294-
295-
__scores = self.scoring_fn(
296-
logits=__logits,
297-
labels=_inputs["input_ids"],
287+
_logits = self.generative_model(**inputs).logits
288+
289+
_scores = self.scoring_fn(
290+
logits=_logits,
291+
labels=inputs["input_ids"],
298292
init_pos=len(context_inputs),
299293
)
300294

301-
return __logits.cpu().numpy(), __scores.cpu().numpy()
295+
return _logits.cpu().numpy(), _scores.cpu().numpy()
302296

303297

304298
def locally_linear_embedding_fn(x: np.ndarray) -> np.ndarray:

fortuna/hallucination/scoring/inv_perplexity.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ def perplexity(logits: torch.Tensor, labels: torch.Tensor, init_pos: int = 0):
88
shift_logits = logits[..., :-1, :].contiguous()
99
shift_labels = labels[..., 1:].contiguous()
1010

11-
perplexities = torch.exp(
12-
loss_fct(shift_logits.transpose(1, 2), shift_labels)[:, init_pos:].mean()
11+
return torch.exp(
12+
loss_fct(shift_logits.transpose(1, 2), shift_labels)[:, init_pos:].mean(1)
1313
)
14-
return torch.mean(perplexities)
1514

1615

1716
@torch.no_grad()

tests/fortuna/hallucination/embeddings.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/fortuna/hallucination/grouping.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import numpy as np
55
from sklearn.mixture import GaussianMixture
66

7-
from fortuna.data import InputsLoader
8-
from fortuna.hallucination.embedding import EmbeddingManager
97
from fortuna.hallucination.grouping.clustering.base import GroupingModel
108

119

@@ -15,61 +13,44 @@ def __init__(self, *args, **kwargs):
1513
self.n_inputs = 10
1614
self.n_features = 4
1715
self.n_reduced_features = 3
18-
self.n_extra_features = 5
19-
self.inputs_loader = InputsLoader.from_array_inputs(
20-
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)),
21-
batch_size=2,
22-
)
23-
self.grouping_model = GroupingModel(
24-
embedding_manager=EmbeddingManager(
25-
encoding_fn=lambda x: 1 - x,
26-
reduction_fn=lambda x: x[:, : self.n_reduced_features],
27-
)
16+
self.embeddings = random.normal(
17+
random.PRNGKey(0), shape=(self.n_inputs, self.n_features)
2818
)
19+
self.grouping_model = GroupingModel()
2920
self.extra_embeddings = random.normal(
3021
random.PRNGKey(0), shape=(self.n_inputs, self.n_extra_features)
3122
)
3223
self.clustering_models = [GaussianMixture(n_components=i) for i in range(2, 4)]
3324

3425
def test_all(self):
3526
self.grouping_model.fit(
36-
inputs_loader=self.inputs_loader,
37-
extra_embeddings=None,
27+
embeddings=self.embeddings,
3828
clustering_models=self.clustering_models,
3929
)
40-
self._check_shape_types(extra_embeddings=None)
30+
self._check_shape_types()
4131

4232
self.grouping_model.fit(
43-
inputs_loader=self.inputs_loader,
44-
extra_embeddings=self.extra_embeddings,
33+
embeddings=self.embeddings,
4534
clustering_models=self.clustering_models,
4635
)
47-
self._check_shape_types(extra_embeddings=self.extra_embeddings)
36+
self._check_shape_types()
4837

4938
with self.assertRaises(ValueError):
5039
self.grouping_model.fit(
51-
inputs_loader=self.inputs_loader,
52-
extra_embeddings=None,
40+
embeddings=self.embeddings,
5341
clustering_models=[],
5442
)
5543

5644
with self.assertRaises(ValueError):
5745
self.grouping_model.fit(
58-
inputs_loader=self.inputs_loader,
59-
extra_embeddings=np.zeros((self.n_inputs + 1, 2)),
46+
embeddings=self.embeddings,
6047
clustering_models=[],
6148
)
6249

63-
def _check_shape_types(self, extra_embeddings):
64-
probs = self.grouping_model.predict_proba(
65-
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
66-
)
67-
hard_preds = self.grouping_model.hard_predict(
68-
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
69-
)
70-
soft_preds = self.grouping_model.hard_predict(
71-
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
72-
)
50+
def _check_shape_types(self):
51+
probs = self.grouping_model.predict_proba(embeddings=self.embeddings)
52+
hard_preds = self.grouping_model.hard_predict(embeddings=self.embeddings)
53+
soft_preds = self.grouping_model.hard_predict(embeddings=self.embeddings)
7354
assert probs.shape == (
7455
self.n_inputs,
7556
self.grouping_model._clustering_model.n_components,

0 commit comments

Comments
 (0)