Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit bd4d94b

Browse files
Merge branch 'main' of https://github.com/awslabs/fortuna
2 parents 078e275 + 76ad7a2 commit bd4d94b

File tree

13 files changed

+1131
-239
lines changed

13 files changed

+1131
-239
lines changed

benchmarks/hallucination/mmlu/run.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import pickle
3+
4+
from datasets import (
5+
get_dataset_config_names,
6+
load_dataset,
7+
)
8+
import numpy as np
9+
from transformers import (
10+
AutoModelForCausalLM,
11+
AutoTokenizer,
12+
)
13+
14+
from fortuna.hallucination import HallucinationMulticalibrator
15+
from fortuna.hallucination.utils import string_cleaner
16+
from fortuna.metric.classification import accuracy
17+
18+
SEED = 0
19+
CALIB_FRAC = 0.8
20+
21+
if __name__ == "__main__":
22+
device = "cuda"
23+
model_id = "gpt2"
24+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
25+
tokenizer = AutoTokenizer.from_pretrained(model_id)
26+
27+
# download and prepare data
28+
task_list = get_dataset_config_names("lukaemon/mmlu")
29+
dataset_list = [
30+
(
31+
load_dataset(
32+
"lukaemon/mmlu",
33+
task,
34+
cache_dir=".cache/huggingface/datasets/"
35+
if os.path.isdir(".cache/huggingface/datasets/")
36+
else None,
37+
),
38+
task,
39+
)
40+
for task in task_list
41+
]
42+
43+
answer_map = dict(zip(["A", "B", "C", "D"], [0, 1, 2, 3]))
44+
samples = []
45+
for datasets, task in dataset_list:
46+
for dataset_key, dataset in datasets.items():
47+
for sample in dataset:
48+
samples.append(
49+
dict(
50+
question=string_cleaner(sample["input"]),
51+
choices=[sample[letter] for letter in answer_map.keys()],
52+
targets=answer_map[sample["target"]],
53+
)
54+
)
55+
56+
# shuffle and split
57+
rng = np.random.default_rng(seed=SEED)
58+
tot_size = len(samples)
59+
perm = rng.choice(tot_size, tot_size, replace=False)
60+
samples = [samples[i] for i in perm]
61+
62+
calib_size = int(np.ceil(CALIB_FRAC * tot_size))
63+
calib_choices, calib_questions, calib_targets = [], [], []
64+
test_choices, test_questions, test_targets = [], [], []
65+
for i, sample in enumerate(samples):
66+
if i < calib_size:
67+
calib_questions.append(sample["question"])
68+
calib_choices.append(sample["choices"])
69+
calib_targets.append(sample["targets"])
70+
else:
71+
test_questions.append(sample["question"])
72+
test_choices.append(sample["choices"])
73+
test_targets.append(sample["targets"])
74+
75+
# calibrate
76+
calibrator = HallucinationMulticalibrator(
77+
generative_model=model, tokenizer=tokenizer
78+
)
79+
80+
status = calibrator.fit(
81+
texts=calib_choices,
82+
contexts=calib_questions,
83+
targets=calib_targets,
84+
)
85+
86+
with open("fitted_calibrator.pth", "wb") as filehandler:
87+
pickle.dump(calibrator, filehandler, -1)
88+
89+
# test
90+
test_probs = calibrator.predict_proba(
91+
texts=test_choices, contexts=test_questions, calibrate=False
92+
)
93+
test_preds = calibrator.predict(
94+
texts=test_choices, contexts=test_questions, probs=test_probs
95+
)
96+
97+
calib_test_probs = calibrator.predict_proba(
98+
texts=test_choices, contexts=test_questions
99+
)
100+
calib_test_preds = calibrator.predict(
101+
texts=test_choices, contexts=test_questions, probs=calib_test_probs
102+
)
103+
104+
# measure
105+
mse_before = calibrator.multicalibrator.mean_squared_error(
106+
probs=test_probs, targets=np.array(test_targets)
107+
)
108+
acc_before = accuracy(test_preds, np.array(test_targets))
109+
mse_after = calibrator.multicalibrator.mean_squared_error(
110+
probs=calib_test_probs, targets=np.array(test_targets)
111+
)
112+
acc_after = accuracy(calib_test_preds, np.array(test_targets))
113+
114+
print(f"MSE before calibration: {round(float(mse_before), 4)}.")
115+
print(f"Accuracy before calibration: {round(float(acc_before), 4)}.")
116+
print(f"MSE after calibration: {round(float(mse_after), 4)}.")
117+
print(f"Accuracy after calibration: {round(float(acc_before), 4)}.")

fortuna/hallucination/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from fortuna.hallucination.embedding import EmbeddingManager
1+
from fortuna.hallucination.base import HallucinationMulticalibrator
22
from fortuna.hallucination.grouping.clustering.base import GroupingModel
3+
from fortuna.hallucination.scoring.inv_perplexity import inv_perplexity

0 commit comments

Comments
 (0)