Skip to content

Commit 0541154

Browse files
committed
add food101 dataset
1 parent 96bcf24 commit 0541154

File tree

5 files changed

+33
-15
lines changed

5 files changed

+33
-15
lines changed

configs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def imagesize(config):
6464
'roxford5kdelgembed': 0,
6565
'mirflickr': 256,
6666
'sop': 256,
67-
'sop_instance': 256
67+
'sop_instance': 256,
68+
'food101': 256
6869
}[dsname]
6970

7071
return r
@@ -92,7 +93,8 @@ def cropsize(config):
9293
'rparis6kdelgembed': 0,
9394
'mirflickr': 224,
9495
'sop': 224,
95-
'sop_instance': 224
96+
'sop_instance': 224,
97+
'food101': 224
9698
}[dsname]
9799

98100
return r
@@ -118,7 +120,8 @@ def nclass(config):
118120
'rparis6kdelgembed': 0,
119121
'mirflickr': 24,
120122
'sop': 12,
121-
'sop_instance': 22634
123+
'sop_instance': 22634,
124+
'food101': 101
122125
}[dsname]
123126

124127
return r
@@ -141,7 +144,8 @@ def R(config):
141144
'rparis6kdelgembed': 0,
142145
'mirflickr': 1000,
143146
'sop': 1000,
144-
'sop_instance': 100
147+
'sop_instance': 100,
148+
'food101': 1000
145149
}[config['dataset'] + {2: '_2'}.get(config['dataset_kwargs']['evaluation_protocol'], '')]
146150

147151
return r
@@ -275,7 +279,7 @@ def dataset(config, filename, transform_mode,
275279
extra_dataset = config['dataset_kwargs'].get('extra_dataset', 0)
276280

277281
if dataset_name in ['imagenet100', 'nuswide', 'coco', 'cars', 'landmark',
278-
'roxford5k', 'rparis6k', 'mirflickr', 'sop', 'sop_instance']:
282+
'roxford5k', 'rparis6k', 'mirflickr', 'sop', 'sop_instance', 'food101']:
279283
norm = 2 if not gpu_mean_transform else 0 # 0 = turn off Normalize
280284

281285
if skip_preprocess: # will not resize and crop, and no augmentation
@@ -297,7 +301,8 @@ def dataset(config, filename, transform_mode,
297301
'rparis6k': datasets.rparis6k,
298302
'mirflickr': datasets.mirflickr,
299303
'sop': datasets.sop,
300-
'sop_instance': datasets.sop_instance
304+
'sop_instance': datasets.sop_instance,
305+
'food101': datasets.food101
301306
}[dataset_name]
302307
d = datafunc(transform=transform,
303308
filename=filename,

constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
datasets = {
4040
'class': ['imagenet100', 'nuswide', 'cifar10', 'imagenet50a', 'imagenet50b', 'cars', 'cifar10_II', 'landmark',
4141
'landmark200', 'landmark500', 'gldv2delgembed', 'roxford5kdelgembed', 'descriptor', 'sop',
42-
'sop_instance'],
42+
'sop_instance', 'food101'],
4343
'multiclass': ['nuswide', 'coco', 'mirflickr'],
4444
}
4545

scripts/train_general.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,14 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
579579

580580
io.fast_save(db_out, f'{logdir}/outputs/db_out.pth')
581581
io.fast_save(test_out, f'{logdir}/outputs/test_out.pth')
582+
if best < curr_metric:
583+
best = curr_metric
584+
if config['wandb_enable']:
585+
wandb.run.summary["best_map"] = best
586+
if config['save_model']:
587+
io.fast_save(modelsd, f'{logdir}/models/best.pth')
588+
io.fast_save(db_out, f'{logdir}/outputs/db_best.pth')
589+
io.fast_save(test_out, f'{logdir}/outputs/test_best.pth')
582590
del db_out, test_out
583591

584592
##### obtain training codes and statistics #####
@@ -612,13 +620,6 @@ def main(config, gpu_transform=False, gpu_mean_transform=False, method='supervis
612620
if save_now and config['save_model']:
613621
io.fast_save(modelsd, f'{logdir}/models/ep{ep + 1}.pth')
614622

615-
if best < curr_metric:
616-
best = curr_metric
617-
if config['wandb_enable']:
618-
wandb.run.summary["best_map"] = best
619-
if config['save_model']:
620-
io.fast_save(modelsd, f'{logdir}/models/best.pth')
621-
622623
##### training end #####
623624
modelsd = model.state_dict()
624625
if config['save_model']:

utils/augmentations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def get_train_transform(dataset_name, resize, crop):
5555
'sop_instance': [
5656
transforms.RandomResizedCrop(crop),
5757
transforms.RandomHorizontalFlip()
58-
]
58+
],
59+
'food101': [
60+
transforms.RandomResizedCrop(crop),
61+
transforms.RandomHorizontalFlip()
62+
],
5963
}[dataset_name]
6064
return t
6165

utils/datasets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,11 @@ def sop(**kwargs):
811811

812812
d = HashingDataset(f'data/sop{suffix}', transform=transform, filename=filename, ratio=kwargs.get('ratio', 1))
813813
return d
814+
815+
816+
def food101(**kwargs):
817+
transform = kwargs['transform']
818+
filename = kwargs['filename']
819+
820+
d = HashingDataset('data/food-101', transform=transform, filename=filename, ratio=kwargs.get('ratio', 1))
821+
return d

0 commit comments

Comments
 (0)