-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
121 lines (103 loc) · 5.87 KB
/
dataset.py
File metadata and controls
121 lines (103 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from torch import nn as nn
from PIL import Image
from utils import load_json, metmeme_label_to_int, mood_label_to_int, memotion_label_to_int
class MEMEDataset(torch.utils.data.Dataset):
def __init__(self, path, vit_processor, tokenizer, dataset_name, usage="train", pre_fuse=True,
without_ID=False, without_TM=False, without_CIM=False, without_CA=False):
self.path = path
self.datas = load_json(path)
self.vit_processor = vit_processor
# self.swin_processor = swin_processor
self.tokenizer = tokenizer
self.dataset_name = dataset_name
self.usage = usage
self.pre_fuse = pre_fuse
self.without_ID = without_ID
self.without_TM = without_TM
self.without_CIM = without_CIM
self.without_CA = without_CA
def __len__(self):
return len(self.datas)
def convert_str_to_ids(self, discription, head_space=True, max_id_num=100):
'''
input: "Hello world!"
output: [101, 7592], [token_type_id, token_type_id]
'''
inputs = []
if discription is not None:
for i, token in enumerate(discription.split()):
token = token if i == 0 and not head_space else ' ' + token
tokenized_token = self.tokenizer(token, add_special_tokens=False)
inputs += tokenized_token['input_ids']
inputs = inputs[: max_id_num] if len(inputs) > max_id_num else inputs
return inputs
def __getitem__(self, idx):
line = self.datas[idx]
img, text = line['images_name'], line['text']
if self.dataset_name == 'metmeme':
label = metmeme_label_to_int(line["sentiment category"])
img0, img1 = img.split('/')
img_path = f'data/MET-MEME/image/{img0}/images/{img1}'
elif self.dataset_name == 'mood':
label = mood_label_to_int(line["emotion category"])
img_path = f'data/mood/mood_images/{img}'
elif self.dataset_name == 'memotion':
label = memotion_label_to_int(line["sentiment category"])
img_path = f'data/memotion/images/{img}'
img = Image.open(img_path)
if img.mode != 'RGB':
img = img.convert("RGB")
img_vit = self.vit_processor(img, padding="max_length", truncation=True, return_tensors='pt')
# img_swin = self.swin_processor(img, padding="max_length", truncation=True, return_tensors='pt')
img.close()
sep_id, cls_id = self.tokenizer.sep_token_id, self.tokenizer.cls_token_id
original_text = self.convert_str_to_ids(text, max_id_num=100)
image_description = self.convert_str_to_ids(line["img content"], max_id_num=100)
text_meaning = self.convert_str_to_ids(line["text content"], max_id_num=100)
combined_implicit_meaning = self.convert_str_to_ids(line["combined meaning"], max_id_num=100)
context_analysis = self.convert_str_to_ids(line["context analysis"], max_id_num=100)
if self.pre_fuse:
# pre fuse ID tokens and image tokens
if self.without_ID:
main_inputs = [cls_id] + text_meaning + combined_implicit_meaning + context_analysis + [sep_id]
elif self.without_TM:
main_inputs = [cls_id] + image_description + combined_implicit_meaning + context_analysis + [sep_id]
elif self.without_CIM:
main_inputs = [cls_id] + image_description + text_meaning + context_analysis + [sep_id]
elif self.without_CA:
main_inputs = [cls_id] + image_description + text_meaning + combined_implicit_meaning + [sep_id]
else:
main_inputs = [cls_id] + image_description + text_meaning + combined_implicit_meaning + context_analysis + [sep_id]
else:
main_inputs = [cls_id] + original_text + image_description + text_meaning + combined_implicit_meaning + context_analysis + [sep_id]
main_attention = [1] * len(main_inputs)
ID_inputs = [cls_id] + original_text + [sep_id]
ID_attention = [1] * len(ID_inputs)
return main_inputs, main_attention, ID_inputs, ID_attention, img_vit['pixel_values'], label
class Collator(object):
def __init__(self, tokenizer, vit_processor):
self.tokenizer = tokenizer
self.vit_processor = vit_processor
def __call__(self, batch):
max_main_length, max_ID_length = max([len(line[0]) for line in batch]), max([len(line[2]) for line in batch])
main_input_ids, main_attention_mask, ID_input_ids, ID_attention_mask = [], [], [], []
for line in batch:
main_inputs, main_attention = line[0], line[1]
main_input_ids.append(main_inputs + [self.tokenizer.pad_token_id] * (max_main_length - len(main_inputs)))
main_attention_mask.append(main_attention + [0] * (max_main_length - len(main_attention)))
ID_inputs, ID_attention = line[2], line[3]
ID_input_ids.append(ID_inputs + [self.tokenizer.pad_token_id] * (max_ID_length - len(ID_inputs)))
ID_attention_mask.append(ID_attention + [0] * (max_ID_length - len(ID_attention)))
vit_pixel_value = torch.stack([line[4] for line in batch]).squeeze(1)
# swin_pixel_value = torch.stack([line[5] for line in batch]).squeeze(1)
labels = torch.tensor([line[5] for line in batch])
outputs = {
'input_ids': torch.tensor(main_input_ids), # [batch_size, max_main_length]
'attention_mask': torch.tensor(main_attention_mask), # [batch_size, max_main_length]
'id_input_ids': torch.tensor(ID_input_ids), # [batch_size, max_id_length]
'id_attention_mask': torch.tensor(ID_attention_mask), # [batch_size, max_id_length]
'vit_pixel_value': vit_pixel_value, # [batch_size, 3, 224, 224]
'labels': labels, # [batch_size]
}
return outputs