-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbambambam.py
77 lines (62 loc) · 1.88 KB
/
bambambam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
'''
bambambam
by github.com/simonlindgren
'''
import glob
import spacy
import classy_classification
import pandas as pd
from tqdm import tqdm
print('\nbambambam')
print('---------')
# Read label data
label_files = glob.glob('labels/*.txt')
labels = {}
for lf in label_files:
label = lf.split('/')[1].split('.')[0]
examples = [e.strip() for e in open(lf).readlines()]
labels[label] = examples
labels = dict(sorted(labels.items())) # <-- important sorting to keep labels in order
labelstring = ', '.join([str(i) for i in labels.keys()])
print("loaded " + str(len(labels)) + " labels (" + labelstring + ')')
# Prepare classifier
## start with a blank spacy model
nlp = spacy.blank("en")
## add in the text_categorizer from classy_classification to the processing pipeline
## load a huggingface pretrained bert model
nlp.add_pipe("text_categorizer",
config ={
"data": labels,
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"device": "cpu"
}
)
# Read unseen data
sentences = [i.strip() for i in open('data/unseen.txt').readlines()]
# Run classification
sents = []
scores = []
print('classifying ...')
for s in tqdm(sentences):
sents.append(s)
scores.append(nlp(s)._.cats)
df = pd.DataFrame(zip(sents,scores), columns = ['text', 'scores'])
df = pd.concat([df.drop(['scores'], axis=1), df['scores'].apply(pd.Series)], axis=1)
df.to_csv('bambambam.csv', index = False)
print('done.\n')
# Bonus printout
response = input("\nsee examples? (y/n): ")
# check the response
if response.lower() == "y":
for k in labels.keys():
df0 = df.sort_values(by=k, ascending=False)
df0 = df0[['text',k]][:5]
print("Examples of '" + k + "'" + "\n" + "-"*30)
for txt,score in zip(df0.text,df0[k]):
print(str(score)[0:4],'--',txt)
print()
elif response.lower() == "n":
pass
else:
pass