-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
65 lines (52 loc) · 1.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from eval import load_test_data
from config import get_arguments
from load_data import LoadData
from model import FastText, get_classifier
from utils import get_one_hot_labels
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import os
import pandas as pd
def split_data(df_data, config, test_frac=0.2):
"""
split df_data to train and test.
"""
df_train, df_test = train_test_split(df_data, test_size=test_frac)
df_train.reset_index(inplace=True, drop=True)
df_test.reset_index(inplace=True, drop=True)
df_train.to_csv(config.path_train_data, index=False)
df_test.to_csv(config.path_test_data, index=False)
return df_train
def get_training_data(config):
if os.path.isfile(config.path_train_data):
load_instance = LoadData.load(config.preprocessing_class_path)
config.n_classes = load_instance.config.n_classes
df = pd.read_csv(config.path_train_data)
df["labels"] = df["labels"].apply(eval)
return df
else:
preprocessing_instance = LoadData(config)
## load data
preprocessing_instance.preprocess()
## save preprocessing instance
preprocessing_instance.save()
## split data
df_train = split_data(preprocessing_instance.df_data, config)
return df_train
if __name__ == "__main__":
parser = get_arguments()
config = parser.parse_args()
config.eval = False
print("processing data ...")
df_train = get_training_data(config)
# load fasttext and train it
fast_text = FastText(config, df_train)
fast_text.train()
X_train = fast_text.get_embeddings()
y_train = get_one_hot_labels(df_train, config)
classifier = get_classifier(config)
print("fitting classifier ...")
classifier.fit(X_train, y_train)
print("saving classifier ....")
pickle.dump(classifier, open(config.model_path, "wb"))