Skip to content

Commit

Permalink
make changes to hallucination pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Nov 22, 2023
1 parent bf263d5 commit 257370d
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 37 deletions.
26 changes: 14 additions & 12 deletions benchmarks/hallucination/mmlu/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle

from datasets import (
get_dataset_config_names,
Expand All @@ -19,9 +18,11 @@
CALIB_FRAC = 0.8

if __name__ == "__main__":
device = "cuda"
model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model_id = "tiiuae/falcon-7b"
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", load_in_8bit=True
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)

# download and prepare data
Expand Down Expand Up @@ -69,8 +70,10 @@
calib_targets.append(sample["targets"])
else:
test_questions.append(sample["question"])
test_choices.append(sample["choices"])
test_targets.append(sample["targets"])
# test the first answer for each question
test_choices.append(sample["choices"][0])
test_targets.append(sample["targets"] == 0)
test_targets = np.array(test_targets)

# calibrate
calibrator = HallucinationMulticalibrator(
Expand All @@ -83,8 +86,7 @@
targets=calib_targets,
)

with open("fitted_calibrator.pth", "wb") as filehandler:
pickle.dump(calibrator, filehandler, -1)
calibrator.save(f"fitted_calibrator_{model_id.replace('/', '_')}.pth")

# test
test_probs = calibrator.predict_proba(
Expand All @@ -103,13 +105,13 @@

# measure
mse_before = calibrator.multicalibrator.mean_squared_error(
probs=test_probs, targets=np.array(test_targets)
probs=test_probs, targets=test_targets
)
acc_before = accuracy(test_preds, np.array(test_targets))
acc_before = accuracy(test_preds, test_targets)
mse_after = calibrator.multicalibrator.mean_squared_error(
probs=calib_test_probs, targets=np.array(test_targets)
probs=calib_test_probs, targets=test_targets
)
acc_after = accuracy(calib_test_preds, np.array(test_targets))
acc_after = accuracy(calib_test_preds, test_targets)

print(f"MSE before calibration: {round(float(mse_before), 4)}.")
print(f"Accuracy before calibration: {round(float(acc_before), 4)}.")
Expand Down
57 changes: 34 additions & 23 deletions fortuna/hallucination/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import pickle
from typing import (
Callable,
Dict,
Expand All @@ -9,12 +10,12 @@
)

import numpy as np
from sklearn.manifold import locally_linear_embedding
from sklearn.mixture import GaussianMixture
import torch
from torch import nn
from tqdm import tqdm
from transformers import PreTrainedTokenizer
import umap.umap_ as umap

from fortuna.conformal import BinaryClassificationMulticalibrator
from fortuna.hallucination.grouping.clustering.base import GroupingModel
Expand All @@ -26,7 +27,7 @@ def __init__(
self,
generative_model: nn.Module,
tokenizer: PreTrainedTokenizer,
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
embedding_reduction_model: Optional = None,
clustering_models: Optional[List] = None,
scoring_fn: Optional[
Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
Expand All @@ -49,8 +50,8 @@ def __init__(
A generative model.
tokenizer: PreTrainedTokenizer
A tokenizer.
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]]
A function aimed at reducing the embedding dimensionality.
embedding_reduction_model: Optional
An embedding reduction model.
clustering_models: Optional[List]
A list of clustering models.
scoring_fn: Optional[Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]
Expand All @@ -61,8 +62,8 @@ def __init__(
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logging.info("`tokenizer.pad_token` is None. Set to `tokenizer.eos_token`.")
self.embedding_reduction_fn = (
embedding_reduction_fn or locally_linear_embedding_fn
self.embedding_reduction_model = embedding_reduction_model or umap.UMAP(
n_neighbors=20
)
self.scoring_fn = scoring_fn or inv_perplexity
self.clustering_models = clustering_models or [
Expand Down Expand Up @@ -124,7 +125,7 @@ def fit(
else:
targets = np.array(targets)

embeddings = self.embedding_reduction_fn(embeddings)
embeddings = self.embedding_reduction_model.fit_transform(embeddings)
embeddings = np.concatenate((embeddings, scores[:, None]), axis=1)

self.grouping_model = GroupingModel()
Expand All @@ -147,7 +148,7 @@ def fit(

def predict_proba(
self,
texts: Union[List[str], List[List[str]]],
texts: List[str],
contexts: List[str],
calibrate: bool = True,
) -> np.ndarray:
Expand All @@ -156,7 +157,7 @@ def predict_proba(
Parameters
----------
texts: Union[List[str], List[List[str]]]
texts: List[str]
The texts to fit.
This may either be a list of strings (e.g. a list of single answers),
or a list of lists of strings (e.g. a list of multi-choice answers).
Expand All @@ -176,14 +177,14 @@ def predict_proba(
(
scores,
embeddings,
which_choices,
_,
) = self._compute_scores_embeddings_which_choices(
texts=texts, contexts=contexts
)
if not calibrate:
return scores

embeddings = self.embedding_reduction_fn(embeddings)
embeddings = self.embedding_reduction_model.transform(embeddings)
embeddings = np.concatenate((embeddings, scores[:, None]), axis=1)

group_scores = self.grouping_model.predict_proba(
Expand All @@ -195,7 +196,7 @@ def predict_proba(

def predict(
self,
texts: Union[List[str], List[List[str]]],
texts: List[str],
contexts: List[str],
calibrate: bool = True,
probs: Optional[np.ndarray] = None,
Expand All @@ -206,7 +207,7 @@ def predict(
Parameters
----------
texts: Union[List[str], List[List[str]]]
texts: List[str],
The texts to fit.
This may either be a list of strings (e.g. a list of single answers),
or a list of lists of strings (e.g. a list of multi-choice answers).
Expand Down Expand Up @@ -253,7 +254,7 @@ def _compute_scores_embeddings_which_choices(
embeddings.append(_embeddings[which_choice, None])
elif isinstance(text, str):
embeddings.append(_embeddings)
scores.append(_scores)
scores.append(_scores[0])

return (
np.array(scores),
Expand All @@ -278,16 +279,26 @@ def _get_logits_scores(
with torch.no_grad():
_logits = self.generative_model(**inputs).logits

_scores = self.scoring_fn(
logits=_logits,
labels=inputs["input_ids"],
init_pos=len(context_inputs),
)
_scores = self.scoring_fn(
logits=_logits,
labels=inputs["input_ids"],
init_pos=len(context_inputs),
)

return _logits.cpu().numpy(), _scores.cpu().numpy()

def save(self, path):
state = dict(
embedding_reduction_model=self.embedding_reduction_model,
grouping_model=self.grouping_model,
multicalibrator=self.multicalibrator,
_quantiles=self._quantiles,
)

with open(path, "wb") as filehandler:
pickle.dump(state, filehandler, -1)

def locally_linear_embedding_fn(x: np.ndarray) -> np.ndarray:
return locally_linear_embedding(
x, n_neighbors=300, n_components=200, method="modified"
)[0]
def load(self, path):
state = pickle.load(open(path, "rb"))
for k, v in state.items():
setattr(self, k, v)
115 changes: 114 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ boto3 = {version = "^1.26.145", optional = true}
hydra-core = {version = "^1.3.2", optional = true}
torch = {version = "^2.1.0", optional = true}
scikit-learn = {version = "^1.3.2", optional = true}
umap-learn = {version = "^0.5.5", optional = true}

[tool.poetry.extras]
docs = ["Sphinx", "sphinx-autodoc-typehints", "pydata-sphinx-theme", "nbsphinx", "nbsphinx-link",
"sphinx-gallery", "ipython", "pandas", "tensorflow-datasets", "xlrd", "openpyxl", "yfinance", 'tabulate', 'pandoc']
notebooks = ["jupyter"]
transformers = ["transformers", "datasets"]
sagemaker = ["boto3", "hydra-core", "sagemaker", "sagemaker-utils"]
hallucination = ["torch", "transformers", "datasets", "scikit-learn"]
hallucination = ["torch", "transformers", "datasets", "scikit-learn", "umap-learn"]

[tool.poetry.group.dev.dependencies]
traitlets = "^5.5.0"
Expand Down

0 comments on commit 257370d

Please sign in to comment.