-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_train_and_validation.py
88 lines (72 loc) · 4.13 KB
/
create_train_and_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import random
import os
def get_folder_names_dict(path, difficulty):
"""
Returns:
dict: dictionary, where keys are NON test folder names and
values represent the number of image sequences given folder contains
dict: dictionary, where keys are TEST folder names and
values represent the number of image sequences given folder contains
int: total number of images sequences in all non test folders
"""
all_not_test_folder_names = {}
all_test_folder_names = {}
f = open("{0}/{1}.txt".format(path, difficulty))
content = f.readlines()
# save all the train folder names and how much sequences the folder holds into all_not_test_folder_names
for folder in content:
# Load the data
folder_components = folder.split("_")
folder_components[-1] = folder_components[-1][:-1]
folder_name = str("_".join(folder_components))
# count the number of images in given folder
base = str("_".join(folder_components[:-2]))
folder = "{0}/{1}/{2}".format(path, base, (base + "_" + str(folder_components[-2])))
folder += "/" + folder_name
folder_mask = folder + "/light_mask"
images = [folder_mask + "/" + f for f in os.listdir(folder_mask) if os.path.isfile(os.path.join(folder_mask, f))]
# number of sequeces we get from a given folder (e.g. 18 pictures is 3 sequences, 1-16, 2-17, 3-18)
if("test-" in folder_name):
all_test_folder_names[folder] = len(images) - 15
else:
all_not_test_folder_names[folder] = len(images) - 15
count_not_test_img_seq = sum(all_not_test_folder_names.values())
count_test_img_seq = sum(all_test_folder_names.values())
print('total number of non-test folders is ', len(all_not_test_folder_names) , ' for difficulty ', difficulty)
print('total number of non-test image sequences is ', count_not_test_img_seq)
print('total number of test image sequences is ', count_test_img_seq)
return all_not_test_folder_names, all_test_folder_names, count_not_test_img_seq, count_test_img_seq
def create_train_and_validation_set(path, difficulty):
"""
Returns:
list: contains folder names belonging to train set
list: contains folder names belonging to validation set
int: number of images sequences in train set
int: number of images sequences in validation set
"""
# get non test folder names
all_not_test_folder_names, all_test_folder_names, count_not_test_img_seq, count_test_img_seq = get_folder_names_dict(path, difficulty)
# each folder contains different amount of sequences
# randomly start adding FOLDERS to validation set, until we have reach 30%
folder_names_list = list(all_not_test_folder_names.keys())
random.shuffle(folder_names_list)
valid_folder_list = [] # validation folder names
count_valid_seq = 0 # how many images sequences our validation list currently holds
# creates validation set
for folder_name in folder_names_list:
valid_folder_list.append(folder_name)
folder_names_list.remove(folder_name)
count_valid_seq += all_not_test_folder_names[folder_name] # find how many sequences folder contains
# we have gotten enough folder for validation set
if count_valid_seq > int(count_not_test_img_seq*0.3):
break
# creates train set
train_folder_list = folder_names_list # folder names remaining in folder_names_list are train folders
count_train_seq = 0 # number of sequences in training data
for train_folder in train_folder_list:
count_train_seq += all_not_test_folder_names[train_folder]
test_folder_list = list(all_test_folder_names.keys())
print('\nfinal train set contains ', count_train_seq, ' image sequences (', round(count_train_seq/count_not_test_img_seq*100), '% )')
print('final validation set contains ', count_valid_seq, ' image sequences (', round(count_valid_seq/count_not_test_img_seq*100),'% )')
print('final test set contains ', count_test_img_seq, ' image sequences')
return train_folder_list, valid_folder_list, test_folder_list, count_train_seq, count_valid_seq, count_test_img_seq