diff --git a/lmms_eval/tasks/fleurs/_default_template_yaml b/lmms_eval/tasks/fleurs/_default_template_yaml new file mode 100644 index 000000000..96ccfc5ba --- /dev/null +++ b/lmms_eval/tasks/fleurs/_default_template_yaml @@ -0,0 +1,25 @@ +dataset_path: google/fleurs +dataset_kwargs: + token: True +test_split: test +output_type: generate_until +doc_to_visual: !function utils.fleurs_doc_to_audio +doc_to_text: !function utils.fleurs_doc_to_text +doc_to_target: "transcription" +generation_kwargs: + max_new_tokens: 256 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false +process_results: !function utils.fleurs_process_result +metric_list: + - metric: wer + aggregation : !function utils.fleurs_wer + higher_is_better : false +metadata: + - version: 0.0 +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" \ No newline at end of file diff --git a/lmms_eval/tasks/fleurs/fleurs.yaml b/lmms_eval/tasks/fleurs/fleurs.yaml new file mode 100644 index 000000000..fbe4b9651 --- /dev/null +++ b/lmms_eval/tasks/fleurs/fleurs.yaml @@ -0,0 +1,3 @@ +group: fleurs +task: +- fleurs_en \ No newline at end of file diff --git a/lmms_eval/tasks/fleurs/fleurs_en.yaml b/lmms_eval/tasks/fleurs/fleurs_en.yaml new file mode 100644 index 000000000..23d2cde95 --- /dev/null +++ b/lmms_eval/tasks/fleurs/fleurs_en.yaml @@ -0,0 +1,3 @@ +dataset_name: en_us +include: _default_template_yaml +task: fleurs_en \ No newline at end of file diff --git a/lmms_eval/tasks/fleurs/utils.py b/lmms_eval/tasks/fleurs/utils.py new file mode 100644 index 000000000..4a047b248 --- /dev/null +++ b/lmms_eval/tasks/fleurs/utils.py @@ -0,0 +1,291 @@ +import os +import re +import unicodedata +from collections import OrderedDict + +import editdistance as ed +import zhconv + +from lmms_eval.tasks.librispeech.cn_tn import TextNorm +from lmms_eval.tasks.librispeech.whisper_normalizer.basic import BasicTextNormalizer +from lmms_eval.tasks.librispeech.whisper_normalizer.english import EnglishTextNormalizer + +_FLEURS_LANG_TO_ID = OrderedDict( + [ + ("Afrikaans", "af"), + ("Amharic", "am"), + ("Arabic", "ar"), + ("Armenian", "hy"), + ("Assamese", "as"), + ("Asturian", "ast"), + ("Azerbaijani", "az"), + ("Belarusian", "be"), + ("Bengali", "bn"), + ("Bosnian", "bs"), + ("Bulgarian", "bg"), + ("Burmese", "my"), + ("Catalan", "ca"), + ("Cebuano", "ceb"), + ("Mandarin Chinese", "cmn_hans"), + ("Cantonese Chinese", "yue_hant"), + ("Croatian", "hr"), + ("Czech", "cs"), + ("Danish", "da"), + ("Dutch", "nl"), + ("English", "en"), + ("Estonian", "et"), + ("Filipino", "fil"), + ("Finnish", "fi"), + ("French", "fr"), + ("Fula", "ff"), + ("Galician", "gl"), + ("Ganda", "lg"), + ("Georgian", "ka"), + ("German", "de"), + ("Greek", "el"), + ("Gujarati", "gu"), + ("Hausa", "ha"), + ("Hebrew", "he"), + ("Hindi", "hi"), + ("Hungarian", "hu"), + ("Icelandic", "is"), + ("Igbo", "ig"), + ("Indonesian", "id"), + ("Irish", "ga"), + ("Italian", "it"), + ("Japanese", "ja"), + ("Javanese", "jv"), + ("Kabuverdianu", "kea"), + ("Kamba", "kam"), + ("Kannada", "kn"), + ("Kazakh", "kk"), + ("Khmer", "km"), + ("Korean", "ko"), + ("Kyrgyz", "ky"), + ("Lao", "lo"), + ("Latvian", "lv"), + ("Lingala", "ln"), + ("Lithuanian", "lt"), + ("Luo", "luo"), + ("Luxembourgish", "lb"), + ("Macedonian", "mk"), + ("Malay", "ms"), + ("Malayalam", "ml"), + ("Maltese", "mt"), + ("Maori", "mi"), + ("Marathi", "mr"), + ("Mongolian", "mn"), + ("Nepali", "ne"), + ("Northern-Sotho", "nso"), + ("Norwegian", "nb"), + ("Nyanja", "ny"), + ("Occitan", "oc"), + ("Oriya", "or"), + ("Oromo", "om"), + ("Pashto", "ps"), + ("Persian", "fa"), + ("Polish", "pl"), + ("Portuguese", "pt"), + ("Punjabi", "pa"), + ("Romanian", "ro"), + ("Russian", "ru"), + ("Serbian", "sr"), + ("Shona", "sn"), + ("Sindhi", "sd"), + ("Slovak", "sk"), + ("Slovenian", "sl"), + ("Somali", "so"), + ("Sorani-Kurdish", "ckb"), + ("Spanish", "es"), + ("Swahili", "sw"), + ("Swedish", "sv"), + ("Tajik", "tg"), + ("Tamil", "ta"), + ("Telugu", "te"), + ("Thai", "th"), + ("Turkish", "tr"), + ("Ukrainian", "uk"), + ("Umbundu", "umb"), + ("Urdu", "ur"), + ("Uzbek", "uz"), + ("Vietnamese", "vi"), + ("Welsh", "cy"), + ("Wolof", "wo"), + ("Xhosa", "xh"), + ("Yoruba", "yo"), + ("Zulu", "zu"), + ] +) +_FLEURS_LANG_SHORT_TO_LONG = {v: k for k, v in _FLEURS_LANG_TO_ID.items()} + +# ImportError: To support decoding audio files, please install 'librosa' and 'soundfile'. +english_normalizer = EnglishTextNormalizer() +chinese_normalizer = TextNorm( + to_banjiao=False, + to_upper=False, + to_lower=False, + remove_fillers=False, + remove_erhua=False, + check_chars=False, + remove_space=False, + cc_mode="", +) +basic_normalizer = BasicTextNormalizer() + +dir_name = os.path.dirname(os.path.abspath(__file__)) + + +def fleurs_doc_to_audio(doc): + return [doc["audio"]] + + +def fleurs_doc_to_text(doc, lmms_eval_specific_kwargs): + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + return f"{pre_prompt}Please recognize the speech and only output the recognized content:{post_prompt}" + + +def fleurs_process_result(doc, result): + pred = result[0] if len(result) > 0 else "" + + gt = doc["transcription"] + source = doc["path"] + language = doc["language"] + + data_dict = {"gt": gt, "pred": pred, "source": source, "language": language} + + return {"wer": data_dict} + + +PUNCS = "!,.?;:" + + +def remove_sp(text, language): + gt = re.sub(r"<\|.*?\|>", " ", text) + gt = re.sub(rf"\s+", r" ", gt) # Replace consecutive spaces in the text with a single space. + gt = re.sub(f" ?([{PUNCS}])", r"\1", gt) + gt = gt.lstrip(" ") + if language == "cmn_hans": + gt = re.sub(rf"\s+", r"", gt) + return gt + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + # ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) + + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): + from sacrebleu.tokenizers.tokenizer_13a import Tokenizer13a + from sacrebleu.tokenizers.tokenizer_char import TokenizerChar + from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International + from sacrebleu.tokenizers.tokenizer_ja_mecab import TokenizerJaMecab + from sacrebleu.tokenizers.tokenizer_none import NoneTokenizer + from sacrebleu.tokenizers.tokenizer_zh import TokenizerZh + + TOKENIZERS = { + "none": NoneTokenizer, + "13a": Tokenizer13a, + "intl": TokenizerV14International, + "zh": TokenizerZh, + "ja-mecab": TokenizerJaMecab, + "char": TokenizerChar, + } + + assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + self.tokenizer = TOKENIZERS[tokenizer_type] + # self.tokenizer = tokenizer_none + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join(t for t in sent.split(cls.SPACE) if not all(unicodedata.category(c)[0] == "P" for c in t)) + + def tokenize(self, sent: str): + tokenized = self.tokenizer()(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join(list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized + + +def compute_wer(refs, hyps, language): + distance = 0 + ref_length = 0 + tokenizer = EvaluationTokenizer( + tokenizer_type="none", + lowercase=True, + punctuation_removal=True, + character_tokenization=False, + ) + for i in range(len(refs)): + ref = refs[i] + pred = hyps[i] + if language in ["yue_hant"]: + ref = zhconv.convert(ref, "zh-cn") + pred = zhconv.convert(pred, "zh-cn") + if language in ["en"]: + ref = english_normalizer(ref) + pred = english_normalizer(pred) + if language in ["cmn_hans"]: + ref = chinese_normalizer(ref) + pred = chinese_normalizer(pred) + else: + ref = basic_normalizer(ref) + pred = basic_normalizer(pred) + ref_items = tokenizer.tokenize(ref).split() + pred_items = tokenizer.tokenize(pred).split() + if language in ["zh", "yue"]: + ref_items = [x for x in "".join(ref_items)] + pred_items = [x for x in "".join(pred_items)] + if i == 0: + print(f"ref: {ref}") + print(f"pred: {pred}") + print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}") + print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}") + distance += ed.eval(ref_items, pred_items) + ref_length += len(ref_items) + return distance / ref_length + + +def fleurs_wer(results, args): + refs, hyps = [], [] + for result in results: + lan = _FLEURS_LANG_TO_ID[result["language"]] + gt = result["gt"] + response = result["pred"] + gt = remove_sp(gt, lan) + response = remove_sp(response, lan) + refs.append(gt) + hyps.append(response) + wer = compute_wer(refs, hyps, lan) + return wer * 100