-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathIGPair.py
127 lines (103 loc) · 4.06 KB
/
IGPair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from random import choice
class VDDataset(Dataset):
def __init__(
self,
json_file,
tokenizer,
size=512,
image_root_path="",
):
if isinstance(json_file, str):
with open(json_file, 'r') as file:
self.data = json.load(file)
elif isinstance(json_file, list):
for file_path in json_file:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
if not hasattr(self, 'data'):
self.data = data
else:
self.data.extend(data)
else:
raise ValueError("Input should be either a JSON file path (string) or a list")
print('=========', len(self.data))
self.tokenizer = tokenizer
self.size = size
self.image_root_path = image_root_path
self.transform = transforms.Compose([
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop([640, 512]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.clip_image_processor = CLIPImageProcessor()
def __getitem__(self, idx):
item = self.data[idx]
person_path = item["image_file"]
person_img = Image.open(person_path).convert("RGB")
cloth_path = item["cloth_file"]
clothes_img = Image.open(cloth_path).convert("RGB")
text = choice(item['text'])
drop_image_embed = 0
rand_num = random.random()
if rand_num < 0.05:
drop_image_embed = 1
elif rand_num < 0.1: # 0.55: #0.1:
text = ""
elif rand_num < 0.15: # 0.6: #0.15:
text = ""
drop_image_embed = 1
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
null_text_input_ids = self.tokenizer(
"",
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
vae_person = self.transform(person_img)
vae_clothes = self.transform(clothes_img)
clip_image = self.clip_image_processor(images=clothes_img, return_tensors="pt").pixel_values
return {
"vae_person": vae_person,
"vae_clothes": vae_clothes,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
"text": text,
"text_input_ids": text_input_ids,
"null_text_input_ids": null_text_input_ids,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
vae_person = torch.stack([example["vae_person"] for example in data]).to(
memory_format=torch.contiguous_format).float()
vae_clothes = torch.stack([example["vae_clothes"] for example in data]).to(
memory_format=torch.contiguous_format).float()
clip_image = torch.cat([example["clip_image"] for example in data], dim=0)
drop_image_embed = [example["drop_image_embed"] for example in data]
text = [example["text"] for example in data]
input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
null_input_ids = torch.cat([example["null_text_input_ids"] for example in data], dim=0)
return {
"vae_person": vae_person,
"vae_clothes": vae_clothes,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
"text": text,
"input_ids": input_ids,
"null_input_ids": null_input_ids,
}