@@ -123,7 +123,7 @@ def pre_epoch_operations(loss, **kwargs):
123
123
124
124
125
125
def train_hashing (optimizer , model , train_loader , device , loss_name , loss_cfg , onehot ,
126
- gpu_train_transform = None , method = 'supervised' , criterion = None ):
126
+ gpu_train_transform = None , method = 'supervised' , criterion = None , logdir = None ):
127
127
model .train ()
128
128
129
129
batch_timer = Timer ()
@@ -133,7 +133,7 @@ def train_hashing(optimizer, model, train_loader, device, loss_name, loss_cfg, o
133
133
criterion = train_helper .get_loss (loss_name , ** loss_cfg )
134
134
meters = defaultdict (AverageMeter )
135
135
136
- train_helper .update_criterion (model = model , criterion = criterion , loss_name = loss_name )
136
+ train_helper .update_criterion (model = model , criterion = criterion , loss_name = loss_name , method = method , onehot = onehot )
137
137
criterion .train ()
138
138
139
139
pbar = tqdm (train_loader , desc = 'Train' , ascii = True , bar_format = '{l_bar}{bar:10}{r_bar}' ,
@@ -182,6 +182,13 @@ def train_hashing(optimizer, model, train_loader, device, loss_name, loss_cfg, o
182
182
running_times .append (batch_timer .total )
183
183
pbar .set_postfix ({key : val .avg for key , val in meters .items ()})
184
184
batch_timer .tick ()
185
+
186
+ # if i % 2 == 0:
187
+ # io.fast_save(output['code_logits'].detach().cpu(), f'{logdir}/outputs/train_iter_{i}.pth')
188
+ # if i > 200:
189
+ # import sys
190
+ # sys.exit(0)
191
+
185
192
total_timer .toc ()
186
193
meters ['total_time' ].update (total_timer .total )
187
194
std_time = f"time_std={ np .std (running_times [1 :]):.5f} "
@@ -206,7 +213,7 @@ def test_hashing(model, test_loader, device, loss_name, loss_cfg, onehot, return
206
213
if criterion is None :
207
214
criterion = train_helper .get_loss (loss_name , ** loss_cfg )
208
215
209
- train_helper .update_criterion (model = model , criterion = criterion , loss_name = loss_name )
216
+ train_helper .update_criterion (model = model , criterion = criterion , loss_name = loss_name , method = method , onehot = onehot )
210
217
criterion .eval ()
211
218
212
219
pbar = tqdm (test_loader , desc = 'Test' , ascii = True , bar_format = '{l_bar}{bar:10}{r_bar}' ,
@@ -330,7 +337,9 @@ def preprocess(model, config, device):
330
337
logging .info ('Preprocessing for CSQ' )
331
338
nclass = config ['arch_kwargs' ]['nclass' ]
332
339
nbit = config ['arch_kwargs' ]['nbit' ]
333
- centroids = get_hadamard (nclass , nbit , fast = True )
340
+ # centroids = get_hadamard(nclass, nbit, fast=True)
341
+ centroids = generate_centroids (nclass , nbit , 'B' )
342
+ logging .info ("using bernoulli" )
334
343
centroids = centroids .to (device )
335
344
336
345
# move to model
@@ -457,12 +466,6 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
457
466
ground_truth_path = os .path .join (test_loader .dataset .root , 'ground_truth.csv' )
458
467
ground_truth = pd .read_csv (ground_truth_path ) # id = index id, images = images id in database
459
468
460
- # update criterion as non-onehot mode, for pairwise methods
461
- if method in ['pairwise' ]:
462
- if not onehot :
463
- logging .info ("Not a onehot label dataset" )
464
- criterion .label_not_onehot = True
465
-
466
469
##### resume training #####
467
470
if config ['start_epoch_from' ] != 0 :
468
471
criterion , train_history , test_history = resume_training (config , logdir ,
@@ -502,7 +505,7 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
502
505
train_meters = train_hashing (optimizer , model , train_loader , device , loss_param ['loss' ],
503
506
loss_param ['loss_param' ], onehot = onehot ,
504
507
gpu_train_transform = gpu_train_transform ,
505
- method = method , criterion = criterion )
508
+ method = method , criterion = criterion , logdir = logdir )
506
509
507
510
##### scheduler #####
508
511
if isinstance (scheduler , list ):
0 commit comments