Skip to content

Commit 6ff37f7

Browse files
committed
update: allow csq train at specified loss scale
1 parent e6d2b56 commit 6ff37f7

File tree

4 files changed

+27
-15
lines changed

4 files changed

+27
-15
lines changed

configs/templates/csq.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ arch: dpn
22
loss: csq
33

44
loss_param:
5-
lambda_q: 0.001
5+
lambda_q: 0.001
6+
scale_c: 1.0

functions/loss/csq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ class CSQLoss(nn.Module):
66
"""https://github.com/swuxyj/DeepHash-pytorch/blob/master/CSQ.py
77
https://openaccess.thecvf.com/content_CVPR_2020/papers/Yuan_Central_Similarity_Quantization_for_Efficient_Image_and_Video_Retrieval_CVPR_2020_paper.pdf
88
"""
9-
def __init__(self, multiclass, nbit, device, lambda_q=0.001, **kwargs):
9+
def __init__(self, multiclass, nbit, device, lambda_q=0.001, scale_c=1., **kwargs):
1010
super(CSQLoss, self).__init__()
1111
device = torch.device(device)
1212
self.multiclass = multiclass
1313
self.lambda_q = lambda_q
14+
self.scale_c = scale_c
1415
self.criterion = nn.BCELoss()
1516
self.multi_label_random_center = torch.randint(2, (nbit,)).float().to(device)
1617
self.losses = {}
@@ -27,7 +28,7 @@ def forward(self, logits, code_logits, labels, onehot=True):
2728
self.losses['center'] = loss_c
2829
self.losses['quant'] = loss_q
2930

30-
loss = loss_c + self.lambda_q * loss_q
31+
loss = self.scale_c * loss_c + self.lambda_q * loss_q
3132
return loss
3233

3334
def label2center(self, y, onehot):

scripts/train_general.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def pre_epoch_operations(loss, **kwargs):
123123

124124

125125
def train_hashing(optimizer, model, train_loader, device, loss_name, loss_cfg, onehot,
126-
gpu_train_transform=None, method='supervised', criterion=None):
126+
gpu_train_transform=None, method='supervised', criterion=None, logdir=None):
127127
model.train()
128128

129129
batch_timer = Timer()
@@ -133,7 +133,7 @@ def train_hashing(optimizer, model, train_loader, device, loss_name, loss_cfg, o
133133
criterion = train_helper.get_loss(loss_name, **loss_cfg)
134134
meters = defaultdict(AverageMeter)
135135

136-
train_helper.update_criterion(model=model, criterion=criterion, loss_name=loss_name)
136+
train_helper.update_criterion(model=model, criterion=criterion, loss_name=loss_name, method=method, onehot=onehot)
137137
criterion.train()
138138

139139
pbar = tqdm(train_loader, desc='Train', ascii=True, bar_format='{l_bar}{bar:10}{r_bar}',
@@ -182,6 +182,13 @@ def train_hashing(optimizer, model, train_loader, device, loss_name, loss_cfg, o
182182
running_times.append(batch_timer.total)
183183
pbar.set_postfix({key: val.avg for key, val in meters.items()})
184184
batch_timer.tick()
185+
186+
# if i % 2 == 0:
187+
# io.fast_save(output['code_logits'].detach().cpu(), f'{logdir}/outputs/train_iter_{i}.pth')
188+
# if i > 200:
189+
# import sys
190+
# sys.exit(0)
191+
185192
total_timer.toc()
186193
meters['total_time'].update(total_timer.total)
187194
std_time = f"time_std={np.std(running_times[1:]):.5f}"
@@ -206,7 +213,7 @@ def test_hashing(model, test_loader, device, loss_name, loss_cfg, onehot, return
206213
if criterion is None:
207214
criterion = train_helper.get_loss(loss_name, **loss_cfg)
208215

209-
train_helper.update_criterion(model=model, criterion=criterion, loss_name=loss_name)
216+
train_helper.update_criterion(model=model, criterion=criterion, loss_name=loss_name, method=method, onehot=onehot)
210217
criterion.eval()
211218

212219
pbar = tqdm(test_loader, desc='Test', ascii=True, bar_format='{l_bar}{bar:10}{r_bar}',
@@ -330,7 +337,9 @@ def preprocess(model, config, device):
330337
logging.info('Preprocessing for CSQ')
331338
nclass = config['arch_kwargs']['nclass']
332339
nbit = config['arch_kwargs']['nbit']
333-
centroids = get_hadamard(nclass, nbit, fast=True)
340+
# centroids = get_hadamard(nclass, nbit, fast=True)
341+
centroids = generate_centroids(nclass, nbit, 'B')
342+
logging.info("using bernoulli")
334343
centroids = centroids.to(device)
335344

336345
# move to model
@@ -457,12 +466,6 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
457466
ground_truth_path = os.path.join(test_loader.dataset.root, 'ground_truth.csv')
458467
ground_truth = pd.read_csv(ground_truth_path) # id = index id, images = images id in database
459468

460-
# update criterion as non-onehot mode, for pairwise methods
461-
if method in ['pairwise']:
462-
if not onehot:
463-
logging.info("Not a onehot label dataset")
464-
criterion.label_not_onehot = True
465-
466469
##### resume training #####
467470
if config['start_epoch_from'] != 0:
468471
criterion, train_history, test_history = resume_training(config, logdir,
@@ -502,7 +505,7 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
502505
train_meters = train_hashing(optimizer, model, train_loader, device, loss_param['loss'],
503506
loss_param['loss_param'], onehot=onehot,
504507
gpu_train_transform=gpu_train_transform,
505-
method=method, criterion=criterion)
508+
method=method, criterion=criterion, logdir=logdir)
506509

507510
##### scheduler #####
508511
if isinstance(scheduler, list):

scripts/train_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_loss(loss_name, **cfg):
6868
return loss[loss_name](**cfg)
6969

7070

71-
def update_criterion(model, criterion, loss_name):
71+
def update_criterion(model, criterion, loss_name, method, onehot):
7272
if loss_name in ['dpn', 'csq']:
7373
criterion.centroids = model.centroids
7474
elif loss_name in ['sdhc', 'sdh']:
@@ -79,6 +79,13 @@ def update_criterion(model, criterion, loss_name):
7979
elif loss_name in ['adsh']:
8080
criterion.weight = model.ce_fc.centroids
8181

82+
# update criterion as non-onehot mode, for pairwise methods
83+
if method in ['pairwise']:
84+
85+
if not onehot and not criterion.label_not_onehot:
86+
logging.info("Not a onehot label dataset")
87+
criterion.label_not_onehot = True
88+
8289

8390
def generate_centroids(nclass, nbit, init_method):
8491
if init_method == 'N':

0 commit comments

Comments
 (0)