-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathxlm_v_tokenizer_comparison.py
70 lines (61 loc) · 2.73 KB
/
xlm_v_tokenizer_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from datasets import load_dataset
from datasets.utils import disable_progress_bar
from datasets.utils.logging import set_verbosity_error
from fairseq.models.roberta import XLMRModel as FairseqXLMRModel
from transformers import AutoTokenizer
# Points to previously converted model dir that must include sentencepiece.bpe.model
hf_tokenizer = AutoTokenizer.from_pretrained("../exported-working")
xlm_v = FairseqXLMRModel.from_pretrained("../xlmv.base")
languages = ['ace', 'af', 'als', 'am', 'an',
'ang', 'ar', 'arc', 'arz', 'as',
'ast', 'ay', 'az', 'ba', 'bar',
'bat-smg', 'be', 'be-x-old', 'bg', 'bh',
'bn', 'bo', 'br', 'bs', 'ca',
'cbk-zam', 'cdo', 'ce', 'ceb', 'ckb',
'co', 'crh', 'cs', 'csb', 'cv',
'cy', 'da', 'de', 'diq', 'dv',
'el', 'eml', 'en', 'eo', 'es',
'et', 'eu', 'ext', 'fa', 'fi',
'fiu-vro', 'fo', 'fr', 'frr',
'fur', 'fy', 'ga', 'gan', 'gd',
'gl', 'gn', 'gu', 'hak', 'he',
'hi', 'hr', 'hsb', 'hu', 'hy',
'ia', 'id', 'ig', 'ilo', 'io',
'is', 'it', 'ja', 'jbo', 'jv',
'ka', 'kk', 'km', 'kn', 'ko',
'ksh', 'ku', 'ky', 'la', 'lb',
'li', 'lij', 'lmo', 'ln', 'lt',
'lv', 'map-bms', 'mg', 'mhr', 'mi',
'min', 'mk', 'ml', 'mn', 'mr',
'ms', 'mt', 'mwl', 'my', 'mzn',
'nap', 'nds', 'ne', 'nl', 'nn',
'no', 'nov', 'oc', 'or', 'os',
'pa', 'pdc', 'pl', 'pms', 'pnb',
'ps', 'pt', 'qu', 'rm', 'ro',
'ru', 'rw', 'sa', 'sah', 'scn',
'sco', 'sd', 'sh', 'si', 'simple',
'sk', 'sl', 'so', 'sq', 'sr',
'su', 'sv', 'sw', 'szl', 'ta',
'te', 'tg', 'th', 'tk', 'tl',
'tr', 'tt', 'ug', 'uk', 'ur',
'uz', 'vec', 'vep', 'vi', 'vls',
'vo', 'wa', 'war', 'wuu', 'xmf',
'yi', 'yo', 'zea', 'zh', 'zh-classical',
'zh-min-nan', 'zh-yue']
set_verbosity_error()
disable_progress_bar()
for language in languages:
print(f"Tokenizing language {language}...")
dataset = load_dataset("wikiann", language)
train_sentences = dataset["train"]
for train_sentence in train_sentences:
plain_sentence = " ".join(train_sentence["tokens"])
xlm_v_ids = xlm_v.encode(plain_sentence).tolist()
hf_ids = hf_tokenizer.encode(plain_sentence)
if xlm_v_ids != hf_ids:
print("-" * 90)
print(f"Mismatch for {language} sentence:")
print(plain_sentence)
print(f"XLM-V ids: {xlm_v_ids}")
print(f"HF ids: {hf_ids}")
print("-" * 90)