Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 75d4f7c

Browse files
small refactoring of hallucination multicalibrator
1 parent 86d6ec5 commit 75d4f7c

File tree

3 files changed

+135
-76
lines changed

3 files changed

+135
-76
lines changed

fortuna/hallucination/base.py

Lines changed: 45 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ def __init__(
2525
self,
2626
generative_model: nn.Module,
2727
tokenizer: PreTrainedTokenizer,
28-
embedding_reduction_fn: Callable[[np.ndarray], np.ndarray] = None,
28+
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
2929
clustering_models: Optional[List] = None,
30-
scoring_fn: Callable[
31-
[torch.Tensor, torch.Tensor, int], torch.Tensor
32-
] = inv_perplexity,
30+
scoring_fn: Optional[
31+
Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
32+
] = None,
3333
):
3434
"""
3535
A hallucination multicalibrator class.
@@ -48,29 +48,25 @@ def __init__(
4848
A generative model.
4949
tokenizer: PreTrainedTokenizer
5050
A tokenizer.
51-
embedding_reduction_fn: Callable[[np.ndarray], np.ndarray]
51+
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]]
5252
A function aimed at reducing the embedding dimensionality.
5353
clustering_models: Optional[List]
5454
A list of clustering models.
55-
scoring_fn: Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
55+
scoring_fn: Optional[Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]
5656
A scoring function.
5757
"""
5858
self.generative_model = generative_model
5959
self.tokenizer = tokenizer
60-
if embedding_reduction_fn is not None:
61-
self.embedding_reduction_fn = embedding_reduction_fn
62-
else:
63-
self.embedding_reduction_fn = locally_linear_embedding_fn
64-
self.scoring_fn = scoring_fn
65-
if clustering_models is not None:
66-
self.clustering_models = clustering_models
67-
else:
68-
self.clustering_models = [
69-
GaussianMixture(n_components=i) for i in range(2, 11)
70-
]
60+
self.embedding_reduction_fn = (
61+
embedding_reduction_fn or locally_linear_embedding_fn
62+
)
63+
self.scoring_fn = scoring_fn or inv_perplexity
64+
self.clustering_models = clustering_models or [
65+
GaussianMixture(n_components=i) for i in range(2, 11)
66+
]
7167
self.grouping_model = None
72-
self._quantiles = None
7368
self.multicalibrator = None
69+
self._quantiles = None
7470

7571
def fit(
7672
self,
@@ -255,81 +251,57 @@ def _compute_scores_embeddings_which_choices(
255251
context_inputs = self.tokenizer(context, return_tensors="pt").to(
256252
self.generative_model.device
257253
)
258-
len_context_inputs = len(context_inputs)
259254
if isinstance(text, list):
260255
_scores = []
261256
_embeddings = []
262257

263258
for _text in text:
264-
_text_inputs = self.tokenizer(_text, return_tensors="pt").to(
265-
self.generative_model.device
266-
)
267-
_inputs = {
268-
k: torch.cat((context_inputs[k], v), dim=1)
269-
for k, v in _text_inputs.items()
270-
}
271-
272-
with torch.no_grad():
273-
__logits = self.generative_model(
274-
input_ids=_inputs["input_ids"],
275-
attention_mask=_inputs["attention_mask"],
276-
).logits
277-
278-
_scores.append(
279-
self.scoring_fn(
280-
logits=__logits,
281-
labels=_inputs["input_ids"],
282-
init_pos=len_context_inputs,
283-
)
284-
.cpu()
285-
.numpy()
286-
)
287-
_embeddings.append(__logits.mean(1).cpu().numpy())
259+
__logits, __scores = self._get_logits_scores(_text, context_inputs)
260+
_embeddings.append(__logits.mean(1))
261+
_scores.append(__scores)
288262

289263
which_choice = np.argmax(_scores)
290264
which_choices.append(which_choice)
291265
scores.append(_scores[which_choice])
292266
embeddings.append(_embeddings[which_choice])
293267

294268
elif isinstance(text, str):
295-
text_inputs = self.tokenizer(text, return_tensors="pt").to(
296-
self.generative_model.device
297-
)
298-
inputs = {
299-
k: torch.cat((context_inputs[k], v), dim=1)
300-
for k, v in text_inputs.items()
301-
}
302-
303-
with torch.no_grad():
304-
_logits = self.generative_model(
305-
input_ids=inputs["input_ids"],
306-
attention_mask=inputs["attention_mask"],
307-
).logits
308-
embeddings.append(_logits.mean(1).cpu().numpy())
309-
310-
scores.append(
311-
self.scoring_fn(
312-
logits=_logits,
313-
labels=inputs["input_ids"],
314-
init_pos=len_context_inputs,
315-
)
316-
.cpu()
317-
.numpy()
318-
)
319-
320-
else:
321-
raise ValueError(
322-
"`texts` format must be a list of strings, or a list of lists of strings."
323-
)
269+
_logits, _scores = self._get_logits_scores(text, context_inputs)
270+
embeddings.append(_logits.mean(1))
271+
scores.append(_scores)
324272

325273
return (
326274
np.array(scores),
327275
np.concatenate(embeddings, axis=0),
328276
np.array(which_choices),
329277
)
330278

279+
def _get_logits_scores(
280+
self, _text: str, context_inputs
281+
) -> Tuple[np.ndarray, np.ndarray]:
282+
_text_inputs = self.tokenizer(_text, return_tensors="pt").to(
283+
self.generative_model.device
284+
)
285+
_inputs = {
286+
k: torch.cat((context_inputs[k], v), dim=1) for k, v in _text_inputs.items()
287+
}
288+
289+
with torch.no_grad():
290+
__logits = self.generative_model(
291+
input_ids=_inputs["input_ids"],
292+
attention_mask=_inputs["attention_mask"],
293+
).logits
294+
295+
__scores = self.scoring_fn(
296+
logits=__logits,
297+
labels=_inputs["input_ids"],
298+
init_pos=len(context_inputs),
299+
)
300+
301+
return __logits.cpu().numpy(), __scores.cpu().numpy()
302+
331303

332304
def locally_linear_embedding_fn(x: np.ndarray) -> np.ndarray:
333305
return locally_linear_embedding(
334-
x, n_neighbors=20, n_components=10, method="modified"
306+
x, n_neighbors=300, n_components=100, method="modified"
335307
)[0]

poetry.lock

Lines changed: 87 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ boto3 = {version = "^1.26.145", optional = true}
3333
hydra-core = {version = "^1.3.2", optional = true}
3434
torch = {version = "^2.1.0", optional = true}
3535
scikit-learn = {version = "^1.3.2", optional = true}
36+
accelerate = {version = "^0.24.1", optional = true}
37+
sentencepiece = {version = "^0.1.99", optional = true}
3638

3739
[tool.poetry.extras]
3840
docs = ["Sphinx", "sphinx-autodoc-typehints", "pydata-sphinx-theme", "nbsphinx", "nbsphinx-link",
3941
"sphinx-gallery", "ipython", "pandas", "tensorflow-datasets", "xlrd", "openpyxl", "yfinance", 'tabulate', 'pandoc']
4042
notebooks = ["jupyter"]
4143
transformers = ["transformers", "datasets"]
4244
sagemaker = ["boto3", "hydra-core", "sagemaker", "sagemaker-utils"]
43-
hallucination = ["torch", "transformers", "datasets", "scikit-learn"]
45+
hallucination = ["torch", "transformers", "datasets", "scikit-learn", "accelerate", "sentencepiece"]
4446

4547
[tool.poetry.group.dev.dependencies]
4648
traitlets = "^5.5.0"

0 commit comments

Comments
 (0)