Skip to content

Commit 496f867

Browse files
committed
Fix noncustom model issue
1 parent ec16c7a commit 496f867

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

api/DataProcesser.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,18 @@ def handle_classify(self, df, classifier):
5252
#return classifier_switcher.get(classifier, lambda: "Invalid Classifier")(df)
5353

5454
def get_pipeline(self, model_name):
55+
if os.path.exists('assets/tweet_emotions.csv'):
56+
prefix = ''
57+
else:
58+
prefix = 'public/'
5559
if model_name=="emotion_pipeline.pkl":
56-
df = pd.read_csv('assets/tweet_emotions.csv')
60+
df = pd.read_csv(prefix + 'assets/tweet_emotions.csv')
5761
train_data, test_data, train_target, test_target = train_test_split(df['content'], df['sentiment'], test_size=0.2, shuffle=True)
5862
elif model_name=="hate_speech.pkl":
59-
df = pd.read_csv('assets/nb_hatespeech.csv', sep=';')
63+
df = pd.read_csv(prefix + 'assets/nb_hatespeech.csv', sep=';')
6064
train_data, test_data, train_target, test_target = train_test_split(df['comment'], df['isHate'], test_size=0.2, shuffle=True)
6165
elif model_name=="text_classification_pipeline.pkl":
62-
df = pd.read_csv('assets/nb_news.csv')
66+
df = pd.read_csv(prefix + 'assets/nb_news.csv')
6367
train_data, test_data, train_target, test_target = train_test_split(df['short_description'], df['category'], test_size=0.2, shuffle=True)
6468
else:
6569
with open(f'api/models/{model_name}', 'rb') as file:
@@ -109,9 +113,12 @@ def pretrained_predict(self, df, pipeline, model_name = None):
109113
texts_to_predict = [str(text) for text in texts_to_predict]
110114

111115
predictions = pipeline.predict(texts_to_predict)
112-
label_predictions = label_encoder.inverse_transform(predictions)
113116

114-
df['output_column'] = label_predictions
117+
if model_name:
118+
label_predictions = label_encoder.inverse_transform(predictions)
119+
df['output_column'] = label_predictions
120+
else:
121+
df['output_column'] = predictions
115122

116123
return df
117124

0 commit comments

Comments
 (0)