Skip to content

Commit d34cd7b

Browse files
authored
Add files via upload
1 parent dcb7bf7 commit d34cd7b

File tree

7 files changed

+2373
-0
lines changed

7 files changed

+2373
-0
lines changed

core/AWA2DataLoader.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Sat Jul 20 21:23:18 2019
4+
5+
@author: badat
6+
"""
7+
8+
import os,sys
9+
#import scipy.io as sio
10+
import torch
11+
import numpy as np
12+
import h5py
13+
import time
14+
import pickle
15+
from sklearn import preprocessing
16+
from global_setting import NFS_path
17+
#%%
18+
import scipy.io as sio
19+
import pandas as pd
20+
#%%
21+
import pdb
22+
#%%
23+
dataset = 'AWA2'
24+
img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset))
25+
mat_path = os.path.join(NFS_path,'data/xlsa17/data/{}/res101.mat'.format(dataset))
26+
attr_path = '/data2/shimingchen/BCA/attribute/{}/new_des.csv'.format(dataset)
27+
28+
29+
class AWA2DataLoader():
30+
def __init__(self, data_path, device, is_scale = False, is_unsupervised_attr = False,is_balance =True):
31+
32+
print(data_path)
33+
sys.path.append(data_path)
34+
35+
self.data_path = data_path
36+
self.device = device
37+
self.dataset = 'AWA2'
38+
print('$'*30)
39+
print(self.dataset)
40+
print('$'*30)
41+
self.datadir = self.data_path + 'data/{}/'.format(self.dataset)
42+
self.index_in_epoch = 0
43+
self.epochs_completed = 0
44+
self.is_scale = is_scale
45+
self.is_balance = is_balance
46+
if self.is_balance:
47+
print('Balance dataloader')
48+
self.is_unsupervised_attr = is_unsupervised_attr
49+
self.read_matdataset()
50+
self.get_idx_classes()
51+
52+
53+
def augment_img_path(self,mat_path=mat_path,img_dir=img_dir):
54+
self.matcontent = sio.loadmat(mat_path)
55+
self.image_files = np.squeeze(self.matcontent['image_files'])
56+
57+
def convert_path(image_files,img_dir):
58+
new_image_files = []
59+
for idx in range(len(image_files)):
60+
image_file = image_files[idx][0]
61+
image_file = os.path.join(img_dir,'/'.join(image_file.split('/')[5:]))
62+
new_image_files.append(image_file)
63+
return np.array(new_image_files)
64+
65+
self.image_files = convert_path(self.image_files,img_dir)
66+
67+
path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset)
68+
hf = h5py.File(path, 'r')
69+
70+
trainval_loc = np.array(hf.get('trainval_loc'))
71+
test_seen_loc = np.array(hf.get('test_seen_loc'))
72+
test_unseen_loc = np.array(hf.get('test_unseen_loc'))
73+
74+
self.data['train_seen']['img_path'] = self.image_files[trainval_loc]
75+
self.data['test_seen']['img_path'] = self.image_files[test_seen_loc]
76+
self.data['test_unseen']['img_path'] = self.image_files[test_unseen_loc]
77+
78+
self.attr_name = pd.read_csv(attr_path)['new_des']
79+
80+
81+
def next_batch_img(self, batch_size,class_id,is_trainset = False):
82+
features = None
83+
labels = None
84+
img_files = None
85+
if class_id in self.seenclasses:
86+
if is_trainset:
87+
features = self.data['train_seen']['resnet_features']
88+
labels = self.data['train_seen']['labels']
89+
img_files = self.data['train_seen']['img_path']
90+
else:
91+
features = self.data['test_seen']['resnet_features']
92+
labels = self.data['test_seen']['labels']
93+
img_files = self.data['test_seen']['img_path']
94+
elif class_id in self.unseenclasses:
95+
features = self.data['test_unseen']['resnet_features']
96+
labels = self.data['test_unseen']['labels']
97+
img_files = self.data['test_unseen']['img_path']
98+
else:
99+
raise Exception("Cannot find this class {}".format(class_id))
100+
101+
#note that img_files is numpy type !!!!!
102+
103+
idx_c = torch.squeeze(torch.nonzero(labels == class_id))
104+
105+
features = features[idx_c]
106+
labels = labels[idx_c]
107+
img_files = img_files[idx_c.cpu().numpy()]
108+
109+
batch_label = labels[:batch_size].to(self.device)
110+
batch_feature = features[:batch_size].to(self.device)
111+
batch_files = img_files[:batch_size]
112+
batch_att = self.att[batch_label].to(self.device)
113+
114+
return batch_label, batch_feature,batch_files, batch_att
115+
116+
def next_batch(self, batch_size):
117+
if self.is_balance:
118+
idx = []
119+
n_samples_class = max(batch_size //self.ntrain_class,1)
120+
sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist()
121+
for i_c in sampled_idx_c:
122+
idxs = self.idxs_list[i_c]
123+
idx.append(np.random.choice(idxs,n_samples_class))
124+
idx = np.concatenate(idx)
125+
idx = torch.from_numpy(idx)
126+
else:
127+
idx = torch.randperm(self.ntrain)[0:batch_size]
128+
129+
batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device)
130+
batch_label = self.data['train_seen']['labels'][idx].to(self.device)
131+
batch_att = self.att[batch_label].to(self.device)
132+
return batch_label, batch_feature, batch_att
133+
134+
def get_idx_classes(self):
135+
n_classes = self.seenclasses.size(0)
136+
self.idxs_list = []
137+
train_label = self.data['train_seen']['labels']
138+
for i in range(n_classes):
139+
idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy()
140+
idx_c = np.squeeze(idx_c)
141+
self.idxs_list.append(idx_c)
142+
return self.idxs_list
143+
144+
def read_matdataset(self):
145+
146+
path= self.datadir + 'feature_map_ResNet_101_448_{}.hdf5'.format(self.dataset)
147+
print('_____')
148+
print(path)
149+
# tic = time.clock()
150+
hf = h5py.File(path, 'r')
151+
features = np.array(hf.get('feature_map'))
152+
# shape = features.shape
153+
# features = features.reshape(shape[0],shape[1],shape[2]*shape[3])
154+
labels = np.array(hf.get('labels'))
155+
trainval_loc = np.array(hf.get('trainval_loc'))
156+
# train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN
157+
# val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN
158+
test_seen_loc = np.array(hf.get('test_seen_loc'))
159+
test_unseen_loc = np.array(hf.get('test_unseen_loc'))
160+
161+
if self.is_unsupervised_attr:
162+
print('Unsupervised Attr')
163+
class_path = './w2v/{}_class.pkl'.format(self.dataset)
164+
with open(class_path,'rb') as f:
165+
w2v_class = pickle.load(f)
166+
assert w2v_class.shape == (50,300)
167+
w2v_class = torch.tensor(w2v_class).float()
168+
169+
U, s, V = torch.svd(w2v_class)
170+
reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0))
171+
print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item()))
172+
173+
print('shape U:{} V:{}'.format(U.size(),V.size()))
174+
print('s: {}'.format(s))
175+
176+
self.w2v_att = torch.transpose(V,1,0).to(self.device)
177+
self.att = torch.mm(U,torch.diag(s)).to(self.device)
178+
self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device)
179+
180+
else:
181+
print('Expert Attr')
182+
att = np.array(hf.get('att'))
183+
184+
print("threshold at zero attribute with negative value")
185+
att[att<0]=0
186+
187+
self.att = torch.from_numpy(att).float().to(self.device)
188+
189+
original_att = np.array(hf.get('original_att'))
190+
self.original_att = torch.from_numpy(original_att).float().to(self.device)
191+
192+
w2v_att = np.array(hf.get('w2v_att'))
193+
self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device)
194+
195+
self.normalize_att = self.original_att/100
196+
197+
# print('Finish loading data in ',time.clock()-tic)
198+
199+
train_feature = features[trainval_loc]
200+
test_seen_feature = features[test_seen_loc]
201+
test_unseen_feature = features[test_unseen_loc]
202+
if self.is_scale:
203+
scaler = preprocessing.MinMaxScaler()
204+
205+
train_feature = scaler.fit_transform(train_feature)
206+
test_seen_feature = scaler.fit_transform(test_seen_feature)
207+
test_unseen_feature = scaler.fit_transform(test_unseen_feature)
208+
209+
train_feature = torch.from_numpy(train_feature).float() #.to(self.device)
210+
test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device)
211+
test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device)
212+
213+
train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device)
214+
test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device)
215+
test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device)
216+
217+
self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device)
218+
219+
220+
221+
self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device)
222+
self.ntrain = train_feature.size()[0]
223+
self.ntrain_class = self.seenclasses.size(0)
224+
self.ntest_class = self.unseenclasses.size(0)
225+
self.train_class = self.seenclasses.clone()
226+
self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long()
227+
228+
# self.train_mapped_label = map_label(train_label, self.seenclasses)
229+
230+
self.data = {}
231+
self.data['train_seen'] = {}
232+
self.data['train_seen']['resnet_features'] = train_feature
233+
self.data['train_seen']['labels']= train_label
234+
235+
236+
self.data['train_unseen'] = {}
237+
self.data['train_unseen']['resnet_features'] = None
238+
self.data['train_unseen']['labels'] = None
239+
240+
self.data['test_seen'] = {}
241+
self.data['test_seen']['resnet_features'] = test_seen_feature
242+
self.data['test_seen']['labels'] = test_seen_label
243+
244+
self.data['test_unseen'] = {}
245+
self.data['test_unseen']['resnet_features'] = test_unseen_feature
246+
self.data['test_unseen']['labels'] = test_unseen_label

0 commit comments

Comments
 (0)