|
20 | 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 | 21 | # SOFTWARE.
|
22 | 22 |
|
23 |
| -from typing import Callable, Literal |
| 23 | +import logging |
| 24 | +from typing import Callable, Literal, Sequence |
24 | 25 |
|
25 | 26 | import numpy as np
|
26 | 27 |
|
|
37 | 38 | LogProbTokenNorm,
|
38 | 39 | get_multilingual_normalizer,
|
39 | 40 | )
|
| 41 | +from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401 |
| 42 | + ExprExtractionConfig, |
| 43 | + ExtractionTarget, |
| 44 | + IndicesExtractionConfig, |
| 45 | + LatexExtractionConfig, |
| 46 | + extract_target_from_pred, |
| 47 | + get_extraction_regexes, |
| 48 | +) |
| 49 | +from lighteval.metrics.utils.math_comparison import compare_gold_target |
40 | 50 | from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
|
| 51 | +from lighteval.tasks.requests import Doc |
41 | 52 | from lighteval.utils.language import Language
|
| 53 | +from lighteval.utils.timeout import timeout |
| 54 | + |
| 55 | + |
| 56 | +logger = logging.getLogger(__name__) |
42 | 57 |
|
43 | 58 |
|
44 | 59 | def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
|
@@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
|
168 | 183 | corpus_level_fn=np.mean,
|
169 | 184 | higher_is_better=True,
|
170 | 185 | )
|
| 186 | + |
| 187 | + |
| 188 | +def multilingual_extractive_match_metric( |
| 189 | + language: Language = Language.ENGLISH, |
| 190 | + gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), |
| 191 | + pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), |
| 192 | + aggregation_function: Callable[[list[float]], float] = max, |
| 193 | + fallback_mode: Literal["no_fallback", "first_match"] = "first_match", |
| 194 | + precision: int = 6, |
| 195 | +) -> SampleLevelMetric: |
| 196 | + """Creates a language-aware extractive match metric that extracts answers from the model's output. |
| 197 | +
|
| 198 | + Known issues: |
| 199 | + - If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..), |
| 200 | + it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression, |
| 201 | + thus it will match gold, despite model not doing anything. PRs to fix this are welcome. |
| 202 | +
|
| 203 | + - There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted. |
| 204 | +
|
| 205 | + Args: |
| 206 | + language: Language |
| 207 | + The language of the samples. |
| 208 | + gold_extraction_target: Sequence[ExtractionTarget] |
| 209 | + Extraction targets to use for gold answers. Defaults to extracting simple math expressions. |
| 210 | + pred_extraction_target: Sequence[ExtractionTarget] |
| 211 | + Extraction targets to use for predictions. Defaults to extracting simple math expressions. |
| 212 | + aggregation_function: Callable[[list[float]], float] |
| 213 | + Function to aggregate scores when multiple golds/predictions are present. Defaults to max. |
| 214 | + fallback_mode: Literal["no_fallback", "first_match"] |
| 215 | + How to perform extraction. Defaults to "first_match". |
| 216 | + - "no_fallback": Only use first successfully parsed matches |
| 217 | + - "first_match": Use the first successfully parsed match + first match irregardless the parsing success |
| 218 | + precision: int |
| 219 | + Number of decimal places to use when comparing numerical values. Defaults to 6. |
| 220 | +
|
| 221 | + Returns: |
| 222 | + A sample level metric that extracts and compares mathematical expressions. |
| 223 | +
|
| 224 | + """ |
| 225 | + |
| 226 | + @timeout(2) |
| 227 | + def add_to_specifics_with_timeout( |
| 228 | + formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]] |
| 229 | + ) -> None: |
| 230 | + if formatted_doc.specific is None: |
| 231 | + formatted_doc.specific = {} |
| 232 | + |
| 233 | + formatted_doc.specific["extracted_predictions"] = [ |
| 234 | + str(pred) for preds in extracted_predictions for pred in preds |
| 235 | + ] |
| 236 | + formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds] |
| 237 | + |
| 238 | + def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float: |
| 239 | + gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language) |
| 240 | + pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language) |
| 241 | + |
| 242 | + extracted_predictions = [ |
| 243 | + extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions |
| 244 | + ] |
| 245 | + extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds] |
| 246 | + |
| 247 | + # Assert on empty gold and warn on empty pred |
| 248 | + if any(len(g) == 0 for g in extracted_golds): |
| 249 | + raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}") |
| 250 | + |
| 251 | + if all(len(p) == 0 for p in extracted_predictions): |
| 252 | + logger.warning( |
| 253 | + f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}" |
| 254 | + ) |
| 255 | + |
| 256 | + # We have to use timeout because the sypmy to str conversion can be very slow |
| 257 | + try: |
| 258 | + add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds) |
| 259 | + except: # noqa: E722 |
| 260 | + logger.warning("Timeout when adding extracted predictions and golds to specific") |
| 261 | + |
| 262 | + return aggregation_function( |
| 263 | + [ |
| 264 | + (1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0) |
| 265 | + for pred in extracted_predictions |
| 266 | + ] |
| 267 | + ) |
| 268 | + |
| 269 | + return SampleLevelMetric( |
| 270 | + metric_name="extractive_match", |
| 271 | + sample_level_fn=sample_level_fn, |
| 272 | + category=MetricCategory.GENERATIVE, |
| 273 | + use_case=MetricUseCase.ACCURACY, |
| 274 | + corpus_level_fn=np.mean, |
| 275 | + higher_is_better=True, |
| 276 | + ) |
0 commit comments