forked from xyz189411yt/C4MMD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_divide.py
98 lines (75 loc) · 3.25 KB
/
data_divide.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
IMPORTANT!!
Data is from Met-Meme dataset: https://github.com/liaolianfoka/MET-Meme-A-Multi-modal-Meme-Dataset-Rich-in-Metaphors
Please refer to this link for specific data formats.
Since the original data did not clearly specify the method of dividing the dataset, the purpose of this file is to tell you how we did a simple data segmentation.
The splitting ratio is 6:2:2 for the training set, validation set, and testing set.
You can try your own way to divide the dataset. It is sufficient as long as the final data format matches the data format in the data/.
"""
import csv
import json
import random
random.seed(42)
load_path = 'data'
'''
data path you donwload from met-meme dataset, for example:
data/Chinese -> This contains tow file.
|_C_text.csv -> This file indicates the OCR text corresponding to each image
|_label_C.csv -> This file indicates the label to each image
'''
def read_file(file_name, image_file, chinese=False):
with open(file_name, 'r', encoding='gbk') as f:
reader = csv.reader(f)
output_list = []
for i, line in enumerate(reader):
if i == 0:
head = line
else:
if chinese:
line[0] = image_file + line[0].replace('_', '- ')
else:
line[0] = image_file + line[0]
output_list.append(line)
random.shuffle(output_list)
return output_list, head
def divide_dataset(data):
train_num = int(len(data) * 0.6)
train_data = data[: train_num]
val_num = int(len(data) * 0.2)
val_data = data[train_num: train_num + val_num]
test_data = data[val_num + train_num:]
return train_data, val_data, test_data
def add_text(data, text_data, head):
for line in data:
img_name = line[0]
sentence = text_data[img_name]
line.insert(1, sentence)
head.insert(1, 'text')
return data, head
def convert_list2dict(data, head):
new_data = []
for line in data:
new_data.append({head[i]: line[i] for i in range(len(line))})
return new_data
def save_json(file_name, data):
with open(file_name, 'w', encoding='utf-8') as f:
f.write(json.dumps(data))
chinese_text_data, _ = read_file(f'{load_path}/C_text.csv', 'Chinese/', chinese=True)
chinese_text_data = {line[0]: line[1] for line in chinese_text_data}
english_text_data, _ = read_file(f'{load_path}/E_text.csv', 'English/')
english_text_data = {line[0]: line[1] for line in english_text_data}
chinese_data, c_head = read_file(f'{load_path}/label_C.csv', 'Chinese/', chinese=True)
chinese_data, c_head = add_text(chinese_data, chinese_text_data, c_head)
english_data, e_head = read_file(f'{load_path}/label_E.csv', 'English/')
english_data, e_head = add_text(english_data, english_text_data, e_head)
c_train, c_val, c_test = divide_dataset(chinese_data)
e_train, e_val, e_test = divide_dataset(english_data)
train_data = c_train + e_train
random.shuffle(train_data)
val_data = c_val + e_val
random.shuffle(val_data)
test_data = c_test + e_test
random.shuffle(test_data)
save_json('data/train_data.json', convert_list2dict(train_data, c_head))
save_json('data/val_data.json', convert_list2dict(val_data, c_head))
save_json('data/test_data.json', convert_list2dict(test_data, c_head))