Skip to content

Commit

Permalink
small refactoring of hallucination multicalibrator
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Nov 16, 2023
1 parent 86d6ec5 commit 75d4f7c
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 76 deletions.
118 changes: 45 additions & 73 deletions fortuna/hallucination/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(
self,
generative_model: nn.Module,
tokenizer: PreTrainedTokenizer,
embedding_reduction_fn: Callable[[np.ndarray], np.ndarray] = None,
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
clustering_models: Optional[List] = None,
scoring_fn: Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor
] = inv_perplexity,
scoring_fn: Optional[
Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
] = None,
):
"""
A hallucination multicalibrator class.
Expand All @@ -48,29 +48,25 @@ def __init__(
A generative model.
tokenizer: PreTrainedTokenizer
A tokenizer.
embedding_reduction_fn: Callable[[np.ndarray], np.ndarray]
embedding_reduction_fn: Optional[Callable[[np.ndarray], np.ndarray]]
A function aimed at reducing the embedding dimensionality.
clustering_models: Optional[List]
A list of clustering models.
scoring_fn: Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
scoring_fn: Optional[Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]
A scoring function.
"""
self.generative_model = generative_model
self.tokenizer = tokenizer
if embedding_reduction_fn is not None:
self.embedding_reduction_fn = embedding_reduction_fn
else:
self.embedding_reduction_fn = locally_linear_embedding_fn
self.scoring_fn = scoring_fn
if clustering_models is not None:
self.clustering_models = clustering_models
else:
self.clustering_models = [
GaussianMixture(n_components=i) for i in range(2, 11)
]
self.embedding_reduction_fn = (
embedding_reduction_fn or locally_linear_embedding_fn
)
self.scoring_fn = scoring_fn or inv_perplexity
self.clustering_models = clustering_models or [
GaussianMixture(n_components=i) for i in range(2, 11)
]
self.grouping_model = None
self._quantiles = None
self.multicalibrator = None
self._quantiles = None

def fit(
self,
Expand Down Expand Up @@ -255,81 +251,57 @@ def _compute_scores_embeddings_which_choices(
context_inputs = self.tokenizer(context, return_tensors="pt").to(
self.generative_model.device
)
len_context_inputs = len(context_inputs)
if isinstance(text, list):
_scores = []
_embeddings = []

for _text in text:
_text_inputs = self.tokenizer(_text, return_tensors="pt").to(
self.generative_model.device
)
_inputs = {
k: torch.cat((context_inputs[k], v), dim=1)
for k, v in _text_inputs.items()
}

with torch.no_grad():
__logits = self.generative_model(
input_ids=_inputs["input_ids"],
attention_mask=_inputs["attention_mask"],
).logits

_scores.append(
self.scoring_fn(
logits=__logits,
labels=_inputs["input_ids"],
init_pos=len_context_inputs,
)
.cpu()
.numpy()
)
_embeddings.append(__logits.mean(1).cpu().numpy())
__logits, __scores = self._get_logits_scores(_text, context_inputs)
_embeddings.append(__logits.mean(1))
_scores.append(__scores)

which_choice = np.argmax(_scores)
which_choices.append(which_choice)
scores.append(_scores[which_choice])
embeddings.append(_embeddings[which_choice])

elif isinstance(text, str):
text_inputs = self.tokenizer(text, return_tensors="pt").to(
self.generative_model.device
)
inputs = {
k: torch.cat((context_inputs[k], v), dim=1)
for k, v in text_inputs.items()
}

with torch.no_grad():
_logits = self.generative_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
).logits
embeddings.append(_logits.mean(1).cpu().numpy())

scores.append(
self.scoring_fn(
logits=_logits,
labels=inputs["input_ids"],
init_pos=len_context_inputs,
)
.cpu()
.numpy()
)

else:
raise ValueError(
"`texts` format must be a list of strings, or a list of lists of strings."
)
_logits, _scores = self._get_logits_scores(text, context_inputs)
embeddings.append(_logits.mean(1))
scores.append(_scores)

return (
np.array(scores),
np.concatenate(embeddings, axis=0),
np.array(which_choices),
)

def _get_logits_scores(
self, _text: str, context_inputs
) -> Tuple[np.ndarray, np.ndarray]:
_text_inputs = self.tokenizer(_text, return_tensors="pt").to(
self.generative_model.device
)
_inputs = {
k: torch.cat((context_inputs[k], v), dim=1) for k, v in _text_inputs.items()
}

with torch.no_grad():
__logits = self.generative_model(
input_ids=_inputs["input_ids"],
attention_mask=_inputs["attention_mask"],
).logits

__scores = self.scoring_fn(
logits=__logits,
labels=_inputs["input_ids"],
init_pos=len(context_inputs),
)

return __logits.cpu().numpy(), __scores.cpu().numpy()


def locally_linear_embedding_fn(x: np.ndarray) -> np.ndarray:
return locally_linear_embedding(
x, n_neighbors=20, n_components=10, method="modified"
x, n_neighbors=300, n_components=100, method="modified"
)[0]
89 changes: 87 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ boto3 = {version = "^1.26.145", optional = true}
hydra-core = {version = "^1.3.2", optional = true}
torch = {version = "^2.1.0", optional = true}
scikit-learn = {version = "^1.3.2", optional = true}
accelerate = {version = "^0.24.1", optional = true}
sentencepiece = {version = "^0.1.99", optional = true}

[tool.poetry.extras]
docs = ["Sphinx", "sphinx-autodoc-typehints", "pydata-sphinx-theme", "nbsphinx", "nbsphinx-link",
"sphinx-gallery", "ipython", "pandas", "tensorflow-datasets", "xlrd", "openpyxl", "yfinance", 'tabulate', 'pandoc']
notebooks = ["jupyter"]
transformers = ["transformers", "datasets"]
sagemaker = ["boto3", "hydra-core", "sagemaker", "sagemaker-utils"]
hallucination = ["torch", "transformers", "datasets", "scikit-learn"]
hallucination = ["torch", "transformers", "datasets", "scikit-learn", "accelerate", "sentencepiece"]

[tool.poetry.group.dev.dependencies]
traitlets = "^5.5.0"
Expand Down

0 comments on commit 75d4f7c

Please sign in to comment.