Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add asian language support to CorpusLevelTranslationMetric #479

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
import logging
import math
from typing import Literal

import numpy as np
import sacrebleu
Expand Down Expand Up @@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):


class CorpusLevelTranslationMetric:
def __init__(self, metric_type: str):
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""Stores the relevant parameters for a corpus level translation metric.

Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
if metric_type == "bleu":
self.metric = sacrebleu.corpus_bleu
elif metric_type == "chrf":
self.metric = sacrebleu.corpus_chrf
elif metric_type == "ter":
self.metric = sacrebleu.corpus_ter
self.metric_type = metric_type
self.lang = lang

def get_metric(self):
if self.metric_type == "bleu":
return sacrebleu.BLEU(trg_lang=self.lang)
elif self.metric_type == "chrf":
return sacrebleu.CHRF()
elif self.metric_type == "ter":
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
else:
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")

def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
golds = [i.golds for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
return float(self.metric(hypotheses=preds, references=golds).score)
return float(metric.corpus_score(hypotheses=preds, references=golds).score)


class CorpusLevelPerplexityMetric:
Expand Down
Loading