-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add hallucination multicalibrator, with example benchmark #152
Merged
Merged
Changes from 70 commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso 5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 6cb6581
bump up version
gianlucadetommaso 1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso 2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 6e030f2
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 9bd6f67
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso c5bc94f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso d3ab46b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 0e2aca5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 9520273
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso e9c4108
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso bc64a01
bump up version
gianlucadetommaso 25072da
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso e27b378
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso a175e16
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 6e202f1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 635e7c9
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 8e23b32
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso f5efef8
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 958b245
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 577d169
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 69a454e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 6e880ba
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso f606545
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 63e09bb
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso b2402b5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 591d842
refactor tabular analysis of benchmarks
gianlucadetommaso 3dcf217
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso d1b5b4a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso b4c161e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 744dff1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso a22f97f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso fffdd76
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso c23d16d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 1cb2917
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 9c1d07a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 4b83638
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 610fc37
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso e5b67ba
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 1f03d4e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso d49ed29
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 8200e42
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 882733b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso c8ca7e6
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso b1e67fc
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso e6b8c85
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 2197430
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso 8a5dfdd
copy embeddings during normalization
gianlucadetommaso 742954d
add hallucination multicalibrator
gianlucadetommaso 078e275
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso abe2eec
Merge branch 'main' into grouping2
gianlucadetommaso 86d6ec5
improve type hinting
gianlucadetommaso 75d4f7c
small refactoring of hallucination multicalibrator
gianlucadetommaso ea14d25
batchify processing of multiple answers for speedup
gianlucadetommaso dbe8ecd
fix embedding dimension
gianlucadetommaso 493b020
change max number of clusters
gianlucadetommaso File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import pickle | ||
from string import ascii_uppercase as auc | ||
|
||
from datasets import ( | ||
get_dataset_config_names, | ||
load_dataset, | ||
) | ||
import numpy as np | ||
from transformers import ( | ||
GPT2LMHeadModel, | ||
GPT2TokenizerFast, | ||
) | ||
|
||
from fortuna.hallucination import HallucinationMulticalibrator | ||
from fortuna.hallucination.utils import string_cleaner | ||
from fortuna.metric.classification import accuracy | ||
|
||
SEED = 0 | ||
CALIB_FRAC = 0.8 | ||
|
||
if __name__ == "__main__": | ||
device = "cuda" | ||
model_id = "gpt2-large" | ||
model = GPT2LMHeadModel.from_pretrained(model_id).to(device) | ||
tokenizer = GPT2TokenizerFast.from_pretrained(model_id) | ||
|
||
# download and prepare data | ||
task_list = get_dataset_config_names("lukaemon/mmlu") | ||
dataset_list = [ | ||
( | ||
load_dataset( | ||
"lukaemon/mmlu", | ||
task, | ||
), | ||
task, | ||
) | ||
for task in task_list | ||
] | ||
|
||
answer_map = {a: i for i, a in enumerate(auc)} | ||
samples = [] | ||
for datasets, task in dataset_list: | ||
for dataset_key, dataset in datasets.items(): | ||
for sample in dataset: | ||
samples.append( | ||
dict( | ||
question=string_cleaner(sample["input"]), | ||
choices=[sample[letter] for letter in ["A", "B", "C", "D"]], | ||
targets=answer_map[sample["target"]], | ||
) | ||
) | ||
|
||
# shuffle and split | ||
rng = np.random.default_rng(seed=SEED) | ||
tot_size = len(samples) | ||
perm = rng.choice(tot_size, tot_size, replace=False) | ||
samples = [samples[i] for i in perm] | ||
|
||
calib_size = int(np.ceil(CALIB_FRAC * tot_size)) | ||
calib_choices, calib_questions, calib_targets = [], [], [] | ||
test_choices, test_questions, test_targets = [], [], [] | ||
for i, sample in enumerate(samples): | ||
if i < calib_size: | ||
calib_questions.append(sample["question"]) | ||
calib_choices.append(sample["choices"]) | ||
calib_targets.append(sample["targets"]) | ||
else: | ||
test_questions.append(sample["question"]) | ||
test_choices.append(sample["choices"]) | ||
test_targets.append(sample["targets"]) | ||
|
||
# calibrate | ||
calibrator = HallucinationMulticalibrator( | ||
generative_model=model, tokenizer=tokenizer | ||
) | ||
|
||
status = calibrator.fit( | ||
texts=calib_choices, | ||
contexts=calib_questions, | ||
targets=calib_targets, | ||
) | ||
|
||
with open("fitted_calibrator.pth", "wb") as filehandler: | ||
pickle.dump(calibrator, filehandler, -1) | ||
gianlucadetommaso marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# test | ||
test_probs = calibrator.predict_proba( | ||
texts=test_choices, contexts=test_questions, calibrate=False | ||
) | ||
test_preds = calibrator.predict( | ||
texts=test_choices, contexts=test_questions, probs=test_probs | ||
) | ||
|
||
calib_test_probs = calibrator.predict_proba( | ||
texts=test_choices, contexts=test_questions | ||
) | ||
calib_test_preds = calibrator.predict( | ||
texts=test_choices, contexts=test_questions, probs=calib_test_probs | ||
) | ||
|
||
# measure | ||
mse_before = calibrator.multicalibrator.mean_squared_error( | ||
probs=test_probs, targets=np.array(test_targets) | ||
) | ||
acc_before = accuracy(test_preds, np.array(test_targets)) | ||
mse_after = calibrator.multicalibrator.mean_squared_error( | ||
probs=calib_test_probs, targets=np.array(test_targets) | ||
) | ||
acc_after = accuracy(calib_test_preds, np.array(test_targets)) | ||
|
||
print(f"MSE before calibration: {round(float(mse_before), 4)}.") | ||
print(f"Accuracy before calibration: {round(float(acc_before), 4)}.") | ||
print(f"MSE after calibration: {round(float(mse_after), 4)}.") | ||
print(f"Accuracy after calibration: {round(float(acc_before), 4)}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from fortuna.hallucination.embedding import EmbeddingManager | ||
from fortuna.hallucination.base import HallucinationMulticalibrator | ||
from fortuna.hallucination.grouping.clustering.base import GroupingModel | ||
from fortuna.hallucination.scoring.inv_perplexity import inv_perplexity |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
choices seem to be in list("ABCD"). why has the answer map more options?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tries to be more generic, but you're right, can be restricted to "ABCD".