Skip to content

Commit caf4cd7

Browse files
committed
Update: saving label dic and tokenizer
1 parent 5afecbb commit caf4cd7

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

data.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Libraries imported.
22
import re
33
import os
4+
import io
45
import tensorflow as tf
56
import pandas as pd
67
import nltk
@@ -15,20 +16,32 @@
1516
nltk.download('wordnet')
1617

1718
class Dataset:
18-
def __init__(self, data_path, vocab_size, data_classes):
19+
def __init__(self, data_path, vocab_size, data_classes, vocab_folder):
1920
self.data_path = data_path
2021
self.vocab_size = vocab_size
2122
self.data_classes = data_classes
2223
self.sentences_tokenizer = None
2324
self.label_dict = None
24-
25+
self.vocab_folder = vocab_folder
26+
self.save_tokenizer_path = '{}tokenizer.json'.format(self.vocab_folder)
27+
self.save_label_path = 'label.json'
28+
29+
if os.path.isfile(self.save_tokenizer_path):
30+
with open(self.save_tokenizer_path) as file:
31+
data = json.load(file)
32+
self.sentences_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)
33+
34+
if os.path.isfile(self.save_label_path):
35+
with open(self.save_label_path) as file:
36+
self.label_dict = json.load(file)
37+
2538
def labels_encode(self, labels, data_classes):
2639
'''Encode labels to categorical'''
2740
labels.replace(data_classes, inplace=True)
2841

2942
labels_target = labels.values
3043
labels_target = tf.keras.utils.to_categorical(labels_target)
31-
44+
3245
return labels_target
3346

3447
def removeHTML(self, text):
@@ -105,22 +118,38 @@ def load_dataset(self, max_length, data_name, label_name):
105118
datastore = pd.read_csv(self.data_path)
106119
sentences = datastore[data_name]
107120
labels = datastore[label_name]
121+
self.label_dict = dict((item, idx)
122+
for idx, item in enumerate(set(labels)))
108123

109124
# Cleaning
110125
sentences, labels = self.data_processing(sentences, labels)
111-
126+
112127
# Tokenizing
113128
self.sentences_tokenizer = self.build_tokenizer(sentences, self.vocab_size)
114129
tensor = self.tokenize(
115130
self.sentences_tokenizer, sentences, max_length)
116131

117-
print("Done! Next to ... ")
118132
print(" ")
133+
print("Save tokenizer ... ")
134+
135+
# Saving tokenizer
136+
if not os.path.exists(self.vocab_folder):
137+
try:
138+
os.makedirs(self.vocab_folder)
139+
except OSError as e:
140+
raise IOError("Failed to create folders")
141+
142+
tokenizer_json = self.sentences_tokenizer.to_json()
143+
with io.open(self.save_tokenizer_path, 'w', encoding='utf-8') as file:
144+
file.write(json.dumps(tokenizer_json, ensure_ascii=False))
119145

120146
# Saving label dict
121147
with open('label.json', 'w') as f:
122-
json.dump(self.label_dict, f)
123-
148+
json.dump(self.label_dict, f)
149+
150+
print("Done! Next to ... ")
151+
print(" ")
152+
124153
return tensor, labels
125154

126155
def build_dataset(self, max_length=128, test_size=0.2, buffer_size=128, batch_size=128, data_name='review', label_name='sentiment'):

train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"--model-folder", default='{}/tmp/model/'.format(home_dir), type=str)
2121
parser.add_argument(
2222
"--checkpoint-folder", default='{}/tmp/checkpoints/'.format(home_dir), type=str)
23-
23+
parser.add_argument(
24+
"--vocab-folder", default='{}/tmp/saved_vocab/'.format(home_dir), type=str)
2425
parser.add_argument("--data-path", default='data/IMDB_Dataset.csv', type=str)
2526
parser.add_argument("--data-name", default='review', type=str)
2627
parser.add_argument("--label-name", default='sentiment', type=str)
@@ -58,7 +59,8 @@
5859
print('===========================')
5960

6061
# Prepair dataset
61-
dataset = Dataset(args.data_path, args.vocab_size, data_classes=args.data_classes)
62+
dataset = Dataset(args.data_path, args.vocab_size,
63+
args.data_classes, args.vocab_folder)
6264

6365
train_ds, val_ds = dataset.build_dataset(
6466
args.max_length, args.test_size, args.buffer_size, args.batch_size, args.data_name, args.label_name)

0 commit comments

Comments
 (0)