Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 27, 2025
1 parent cb075a5 commit 0a49ac0
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Faithfulness,
LoglikelihoodAcc,
MajAtK,
PassAtK,
Recall,
StringDistance,
acc_golds_likelihood,
Expand Down Expand Up @@ -364,6 +365,14 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
pass_at_k_32 = SampleLevelMetric(
metric_name="pass@k:32",
sample_level_fn=PassAtK(k=32, strip_strings=True).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
perfect_exact_match = SampleLevelMetric(
metric_name="perfect_em",
sample_level_fn=ExactMatches().compute,
Expand Down
115 changes: 115 additions & 0 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,118 @@ def compute_score(self, pred: str, gold: str) -> int:
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0


class PassAtK:
def __init__(
self,
k: int,
n: int = None,
normalize_gold: callable = None,
normalize_pred: callable = None,
strip_strings: bool = False,
sample_scoring_function: callable | str = None,
):
"""Computing pass at k
Args:
k (int): Threshold for the number of successful attempts.
n (int): Number of samples to generate
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
sample_scoring_function (callable or str, optional): Function to use to score each sample.
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
or a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want. Defaults to "full".
`prefix` checks if the prediction starts with the gold,
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
self.k = k
self.n = n
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings

# Managed the logic of the per prediction of sample scoring
if isinstance(sample_scoring_function, callable):
self.score_sample = sample_scoring_function
self.type_exact_match = None
else:
if isinstance(sample_scoring_function, str):
if sample_scoring_function not in ["prefix", "suffix", "full"]:
raise ValueError(
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
)
self.type_exact_match = sample_scoring_function
else:
self.type_exact_match = "full"
self.score_sample = self.default_sample_scoring

def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
then aggregates the scores over the samples using a pass@k.
Args:
golds (list[str]): Reference targets
predictions (list[str]): k predicted strings
Returns:
float: Aggregated score over the current sample's items.
"""
if len(golds) > 1:
raise Exception("Cannot compute pass@k with several golds")

if self.n is None:
self.n = len(predictions)
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
elif len(predictions) < self.n:
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")

gold = self.get_processed_gold(golds[0])

all_scores = []
for pred in predictions[: self.n]:
cur_pred = self.get_processed_pred(pred=pred)
all_scores.append(self.score_sample(cur_pred, gold))

return self.pass_at_k(all_scores)

def get_processed_gold(self, gold: str) -> float:
if self.strip_strings:
gold = gold.strip()

if self.normalize_gold:
gold = self.normalize_gold(gold)

return gold

def get_processed_pred(self, pred: str) -> float:
if not pred:
return ""

if self.strip_strings:
pred = pred.strip()

if self.normalize_pred:
pred = self.normalize_pred(pred)

return pred

def default_sample_scoring(self, pred: str, gold: str) -> int:
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
return 1 if pred.endswith(gold) else 0
return 1 if gold == pred else 0

def pass_at_k(self, all_scores: list[int]) -> float:
"""Algo from https://arxiv.org/pdf/2107.03374"""
c: int = all_scores.count(1)
if self.n - c < self.k:
return 1.0

return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))

0 comments on commit 0a49ac0

Please sign in to comment.