-
Notifications
You must be signed in to change notification settings - Fork 333
Expand file tree
/
Copy pathcount_token.py
More file actions
61 lines (47 loc) · 1.72 KB
/
count_token.py
File metadata and controls
61 lines (47 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from multiprocessing import Pool
import fire
import jsonlines as jl
from loguru import logger
from tqdm import tqdm
from transformers import AutoTokenizer
TOKENIZER = None
def count_token_single(sample, text_keys):
global TOKENIZER
num = 0
for key in text_keys:
num += len(TOKENIZER.tokenize(sample[key]))
return num
def prepare_tokenizer(tokenizer_method):
global TOKENIZER
logger.info("Loading tokenizer from HuggingFace...")
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_method, trust_remote_code=True)
def main(data_path, text_keys="text", tokenizer_method="EleutherAI/pythia-6.9b-deduped", num_proc=1):
"""
Count the number of tokens for given dataset and tokenizer.
:param data_path: path to the input dataset. Only support 'jsonl' now.
:param text_keys: field keys that will be considered into token counts.
:param tokenizer_method: name of the Hugging Face tokenizer.
:param num_proc: number of processes to count tokens.
"""
prepare_tokenizer(tokenizer_method)
if isinstance(text_keys, str):
text_keys = [text_keys]
with jl.open(data_path) as reader:
token_count = 0
result_list = []
with Pool(num_proc) as p:
for sample in tqdm(reader):
result_list.append(
p.apply_async(
count_token_single,
args=(
sample,
text_keys,
),
)
)
for res in tqdm(result_list):
token_count += res.get()
logger.info(f"Total num of tokens: {token_count}")
if __name__ == "__main__":
fire.Fire(main)