Skip to content

Latest commit

 

History

History
237 lines (201 loc) · 11 KB

README.md

File metadata and controls

237 lines (201 loc) · 11 KB

Self-supervised experiments of SimPool

Pre-trained models

You can download checkpoints, logs and configs for all self-supervised models, both official reproductions and SimPool.

Architecture Mode Gamma Epochs k-NN Linear Probing download
ViT-S/16 Official - 100 68.9 71.5 checkpoint logs configs
ViT-S/16 SimPool 1.25 100 69.7 72.8 checkpoint logs configs
ViT-S/16 SimPool 1.25 300 72.6 75.0 checkpoint logs configs
ViT-S/16 SimPool - 100 69.8 72.6 checkpoint logs configs
ResNet-50 Official - 100 61.8 63.8 checkpoint logs configs
ResNet-50 SimPool 2.0 100 63.8 64.4 checkpoint logs configs
ResNet-50 SimPool - 100 63.7 64.2 checkpoint logs configs
ConvNeXt-S Official - 100 59.3 63.9 checkpoint logs configs
ConvNeXt-S SimPool 2.0 100 68.7 72.2 checkpoint logs configs
ConvNeXt-S SimPool - 100 68.8 72.2 checkpoint logs configs

Training

Having created the self-supervised environment and downloaded the ImageNet dataset, you are now ready to train! We pre-train ViT-S, ResNet-50 and ConvNeXt-S with DINO.

ViT-S

Train ViT-S with SimPool on ImageNet-1k for 100 epochs:

python3 -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --mode simpool --gamma 1.25 \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --optimizer adamw --use_bn_in_head False --out_dim 65536 \
--subset -1 --batch_size_per_gpu 100 --local_crops_number 6 --epochs 100 --num_workers 10 --lr 0.0005 --min_lr 0.00001 \
--global_crops_scale 0.25 1.0 --local_crops_scale 0.05 0.25 --norm_last_layer False --warmup_teacher_temp_epochs 30 \
--weight_decay 0.04 --weight_decay_end 0.4

For ViT-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 8 GPUs x 100 batch size per GPU = 800 global batch size.

Extract features from ViT-S with SimPool on ImageNet-1k and evaluate with k-NN:

python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch vit_small --mode simpool --gamma 1.25 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/

For ViT-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗

Linear probing of ViT-S with SimPool on ImageNet-1k for 100 epochs:

 python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --n_last_blocks 1 \
--arch vit_small --mode simpool --pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ \
--output_dir /path/to/output/ --epochs 100

For ViT-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.

ResNet-50

Train ResNet-50 with SimPool on ImageNet-1k for 100 epochs:

python3 -m torch.distributed.launch --nproc_per_node=16 main_dino.py --arch resnet50 --mode simpool \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --subset -1 --num_workers 10 --batch_size_per_gpu 90 \
--out_dim 60000 --use_bn_in_head True --teacher_temp 0.07 --warmup_teacher_temp_epochs 50 --use_fp16 False \
--weight_decay 0.000001 --weight_decay_end 0.000001 --clip_grad 0.0 --epochs 100 --lr 0.3 --min_lr 0.0048 \
--optimizer lars --global_crops_scale 0.14 1.0 --local_crops_number 6 --local_crops_scale 0.05 0.14

For ResNet-50 official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 16 GPUs x 90 batch size per GPU = 1280 global batch size.

Extract features from ResNet-50 with SimPool on ImageNet-1k and evaluate with k-NN:

python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch resnet50 --mode simpool --gamma 2.0 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/

For ResNet-50 official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗

Linear probing of ResNet-50 with SimPool on ImageNet-1k for 100 epochs:

python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --arch resnet50 --mode simpool \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ --output_dir /path/to/output/ --epochs 100

For ResNet-50 official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.

ConvNeXt-S

Train ConvNeXt-S with SimPool on ImageNet-1k for 100 epochs:

python3 -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch convnext_small --mode simpool \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --subset -1 --num_workers 10 --batch_size_per_gpu 60 \
--out_dim 65536 --use_bn_in_head False --weight_decay 0.04 --weight_decay_end 0.4 --clip_grad 0.3 --epochs 100 \
--min_lr 2e-6 --optimizer adamw --lr 0.001 --freeze_last_layer 3

For ConvNeXt-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 8 GPUs x 60 batch size per GPU = 480 global batch size.

Extract features from ConvNeXt-S with SimPool on ImageNet-1k and evaluate with k-NN:

python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch convnext_small --mode simpool --gamma 2.0 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/

For ConvNeXt-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗

Linear probing of ConvNeXt-S with SimPool on ImageNet-1k for 100 epochs:

 python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --n_last_blocks 1 \
--arch convnext_small --mode simpool --pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ \
--output_dir /path/to/output/ --epochs 100

For ConvNeXt-S official adjust --mode official. For no $\gamma$ adjust --gamma None. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.

Extra notes

  • Use --subset 260 to train on ImageNet-20% dataset.
  • When loading our weights using --pretrained_weights, take care of any inconsistencies in model keys!
  • Default value of $\gamma$ is: 1.25 for transformers, 2.0 for convolutional networks.
  • In some cases, we observed that using no $\gamma$ facilitates the training, results in slightly better metrics, but also lowers the attention map quality.