diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index ff4b6b05..306ae603 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -44,6 +44,7 @@ Faithfulness, LoglikelihoodAcc, MajAtK, + PassAtK, Recall, StringDistance, acc_golds_likelihood, @@ -364,6 +365,14 @@ class Metrics(Enum): corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute, higher_is_better=True, ) + pass_at_k_32 = SampleLevelMetric( + metric_name="pass@k:32", + sample_level_fn=PassAtK(k=32, strip_strings=True).compute, + category=MetricCategory.GENERATIVE_SAMPLING, + use_case=MetricUseCase.REASONING, + corpus_level_fn=np.mean, + higher_is_better=True, + ) perfect_exact_match = SampleLevelMetric( metric_name="perfect_em", sample_level_fn=ExactMatches().compute, diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 352c2b98..2b8023f9 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1043,3 +1043,118 @@ def compute_score(self, pred: str, gold: str) -> int: if self.type_exact_match == "suffix": return 1 if pred.endswith(gold) else 0 return 1 if gold == pred else 0 + + +class PassAtK: + def __init__( + self, + k: int, + n: int = None, + normalize_gold: callable = None, + normalize_pred: callable = None, + strip_strings: bool = False, + sample_scoring_function: callable | str = None, + ): + """Computing pass at k + + Args: + k (int): Threshold for the number of successful attempts. + n (int): Number of samples to generate + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. + sample_scoring_function (callable or str, optional): Function to use to score each sample. + Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1) + or a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want. Defaults to "full". + `prefix` checks if the prediction starts with the gold, + `suffix` if the prediction ends with the gold, + `full` if the prediction and gold are equal + """ + self.k = k + self.n = n + self.normalize_gold = normalize_gold + self.normalize_pred = normalize_pred + self.strip_strings = strip_strings + + # Managed the logic of the per prediction of sample scoring + if isinstance(sample_scoring_function, callable): + self.score_sample = sample_scoring_function + self.type_exact_match = None + else: + if isinstance(sample_scoring_function, str): + if sample_scoring_function not in ["prefix", "suffix", "full"]: + raise ValueError( + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." + ) + self.type_exact_match = sample_scoring_function + else: + self.type_exact_match = "full" + self.score_sample = self.default_sample_scoring + + def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]: + """Computes the metric over a list of golds and predictions for one single item with possibly many samples. + It applies normalisation (if needed) to model prediction and gold, computes their per prediction score, + then aggregates the scores over the samples using a pass@k. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): k predicted strings + + Returns: + float: Aggregated score over the current sample's items. + """ + if len(golds) > 1: + raise Exception("Cannot compute pass@k with several golds") + + if self.n is None: + self.n = len(predictions) + logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.") + elif len(predictions) < self.n: + logger.warning(f"Number of predictions is less than {self.n} for pass@k.") + + gold = self.get_processed_gold(golds[0]) + + all_scores = [] + for pred in predictions[: self.n]: + cur_pred = self.get_processed_pred(pred=pred) + all_scores.append(self.score_sample(cur_pred, gold)) + + return self.pass_at_k(all_scores) + + def get_processed_gold(self, gold: str) -> float: + if self.strip_strings: + gold = gold.strip() + + if self.normalize_gold: + gold = self.normalize_gold(gold) + + return gold + + def get_processed_pred(self, pred: str) -> float: + if not pred: + return "" + + if self.strip_strings: + pred = pred.strip() + + if self.normalize_pred: + pred = self.normalize_pred(pred) + + return pred + + def default_sample_scoring(self, pred: str, gold: str) -> int: + if self.type_exact_match == "prefix": + return 1 if pred.startswith(gold) else 0 + if self.type_exact_match == "suffix": + return 1 if pred.endswith(gold) else 0 + return 1 if gold == pred else 0 + + def pass_at_k(self, all_scores: list[int]) -> float: + """Algo from https://arxiv.org/pdf/2107.03374""" + c: int = all_scores.count(1) + if self.n - c < self.k: + return 1.0 + + return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))