forked from fawazsammani/chatbot-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
88 lines (59 loc) · 2.32 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from collections import Counter
import json
corpus_movie_conv = 'cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25
with open(corpus_movie_conv, 'r') as c:
conv = c.readlines()
with open(corpus_movie_lines, 'r') as l:
lines = l.readlines()
lines_dic = {}
for line in lines:
objects = line.split(" +++$+++ ")
lines_dic[objects[0]] = objects[-1]
def remove_punc(string):
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
no_punct = ""
for char in string:
if char not in punctuations:
no_punct = no_punct + char # space is also a character
return no_punct.lower()
pairs = []
for con in conv:
ids = eval(con.split(" +++$+++ ")[-1])
for i in range(len(ids)):
qa_pairs = []
if i==len(ids)-1:
break
first = remove_punc(lines_dic[ids[i]].strip())
second = remove_punc(lines_dic[ids[i+1]].strip())
qa_pairs.append(first.split()[:max_len])
qa_pairs.append(second.split()[:max_len])
pairs.append(qa_pairs)
word_freq = Counter()
for pair in pairs:
word_freq.update(pair[0])
word_freq.update(pair[1])
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0
print("Total words are: {}".format(len(word_map)))
with open('WORDMAP_corpus.json', 'w') as j:
json.dump(word_map, j)
def encode_question(words, word_map):
enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
return enc_c
def encode_reply(words, word_map):
enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
return enc_c
pairs_encoded = []
for pair in pairs:
qus = encode_question(pair[0], word_map)
ans = encode_reply(pair[1], word_map)
pairs_encoded.append([qus, ans])
with open('pairs_encoded.json', 'w') as p:
json.dump(pairs_encoded, p)