-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_splitter.py
66 lines (57 loc) · 1.95 KB
/
data_splitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import sys
from numpy.random import choice
from shutil import copyfile
if len(sys.argv) != 3:
errmsg = "This program takes two arguments:\n" + \
" > folder with emails; and\n" + \
" > number of emails per class (ham:spam); e.g.\n" + \
"python data_splitter.py emails 50:10"
sys.exit(errmsg)
def mkdir(d):
if not os.path.exists(d):
os.makedirs(d)
SPLIT = sys.argv[1]
ROOT = os.path.split(os.path.abspath(SPLIT))[0]
LOG_FILE = os.path.join(ROOT, "log.txt")
SELECTED = os.path.join(ROOT, "selected")
REMAINING = os.path.join(ROOT, "remaining")
mkdir(SELECTED)
mkdir(REMAINING)
HAM_N, SPAM_N = [int(i) for i in sys.argv[2].split(":")]
log = []
ham, spam = [], []
for d in os.listdir(SPLIT):
email = os.path.join(SPLIT,d)
if os.path.isfile(email) and "ham" in d:
ham.append(email)
elif os.path.isfile(email) and "spam" in d:
spam.append(email)
else:
print "ERROR! Unrecognised email: %s" % email
ham_train = choice(ham, size=HAM_N, replace=False)
spam_train = choice(spam, size=SPAM_N, replace=False)
for i, fname in enumerate(ham_train):
d = os.path.join(SELECTED,"ham%.3d.txt"%i)
copyfile(fname, d)
log.append("%s -> %s" % (fname, d))
for i, fname in enumerate(spam_train):
d = os.path.join(SELECTED,"spam%.3d.txt"%i)
copyfile(fname, d)
log.append("%s -> %s" % (fname, d))
for i in ham_train:
ham.remove(i)
for i in spam_train:
spam.remove(i)
for i, fname in enumerate(ham):
d = os.path.join(REMAINING,"ham%.3d.txt"%i)
copyfile(fname, d)
log.append("%s -> %s" % (fname, d))
for i, fname in enumerate(spam):
d = os.path.join(REMAINING,"spam%.3d.txt"%i)
copyfile(fname, d)
log.append("%s -> %s" % (fname, d))
with open(LOG_FILE, "w") as log_file:
log_file.write("\n".join(ham_train) + "\n\n" + 40*"~+" + "\n\n" + \
"\n".join(spam_train) + "\n\n" + 2*(40*"~+"+"\n") + "\n" + \
"\n".join(log) + "\n")