-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_viterbi.py
82 lines (69 loc) · 3.29 KB
/
train_viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# from utils import get_data, bopomofo2engTyping, zh2bopomofo, del_symbols
import json
import time
# print('Loading data...')
# lines = get_data('PTT', 100)
# print('Preprocessing data...')
# engt = [[bopomofo2engTyping(char) for char in zh2bopomofo(line)] for line in lines]
# zh = [[char for char in del_symbols(line)] for line in lines]
datasets_name = 'CWIKI_2023_09_27'
model_name = f'{datasets_name}_engTyping2Zh_HMM70_{time.ctime().replace(" ", "_").replace(":", "")}.json'
split_char = '⫯'
with open(f'datasets/{datasets_name}_engTyping_inserted_lines.txt', 'r', encoding='utf-8') as file:
engTyping_inserted_lines = file.read().split('\n')
with open(f'datasets/{datasets_name}_zh_lines.txt', 'r', encoding='utf-8') as file:
zh_lines = file.read().split('\n')
lines_len = len(engTyping_inserted_lines)
# engt = [line.replace('\n', '').split(split_char) for line in engTyping_inserted_lines]
# zh = [line.replace('\n', '').split(split_char)[1:-1] for line in zh_lines]
engTyping_inserted_lines = engTyping_inserted_lines[:int(lines_len * 0.7)]
zh_lines = zh_lines[:int(lines_len * 0.7)]
assert len(engTyping_inserted_lines) == len(zh_lines)
start_probability = {}
transition_probability = {}
emission_probability = {}
engTyping2zh = {}
print('Calculating probability part 1...')
t = time.time()
for engt_line, zh_line in zip(engTyping_inserted_lines, zh_lines):
engt_line = engt_line.split(split_char)
zh_line = zh_line.split(split_char)[1:-1]
# print(engt_line, zh_line)
if zh_line == []: continue
start_probability[zh_line[0]] = start_probability.get(zh_line[0], 0) + 1
tmp = ''
for engt_char, zh_char in zip(engt_line, zh_line):
if engt_char not in engTyping2zh.keys():
engTyping2zh[engt_char] = [zh_char]
elif zh_char not in engTyping2zh[engt_char]:
engTyping2zh[engt_char].append(zh_char)
# print(tmp)
if tmp != '':
if tmp not in transition_probability.keys():
transition_probability[tmp] = {}
transition_probability[tmp][zh_char] = transition_probability[tmp].get(zh_char, 0) + 1
tmp = zh_char
if zh_char not in emission_probability.keys():
emission_probability[zh_char] = {}
emission_probability[zh_char][engt_char] = emission_probability[zh_char].get(engt_char, 0) + 1
for key in emission_probability.keys():
if len(emission_probability[key]) > 1:
print(key, emission_probability[key])
def sp_calculation(sp: dict):
tatal = sum(sp.values())
for key in sp.keys():
sp[key] /= tatal
return sp
def tpep_calculation(tpep: dict):
for key in tpep.keys():
total = sum(tpep[key].values())
for key2 in tpep[key].keys():
tpep[key][key2] /= total
return tpep
print('Calculating probability part 2...')
start_probability = sp_calculation(start_probability)
transition_probability = tpep_calculation(transition_probability)
emission_probability = tpep_calculation(emission_probability)
print(f'Time used: {time.time() - t} seconds')
print('Saving model...')
json.dump({'start_probability': start_probability, 'transition_probability': transition_probability, 'emission_probability': emission_probability, 'engTyping2zh': engTyping2zh}, open(f'models/{model_name}', 'w', encoding='utf-8'), ensure_ascii=False, indent=4)