You can download checkpoints, logs and configs for all self-supervised models, both official reproductions and SimPool.
Architecture | Mode | Gamma | Epochs | k-NN | Linear Probing | download | ||
---|---|---|---|---|---|---|---|---|
ViT-S/16 | Official | - | 100 | 68.9 | 71.5 | checkpoint | logs | configs |
ViT-S/16 | SimPool | 1.25 | 100 | 69.7 | 72.8 | checkpoint | logs | configs |
ViT-S/16 | SimPool | 1.25 | 300 | 72.6 | 75.0 | checkpoint | logs | configs |
ViT-S/16 | SimPool | - | 100 | 69.8 | 72.6 | checkpoint | logs | configs |
ResNet-50 | Official | - | 100 | 61.8 | 63.8 | checkpoint | logs | configs |
ResNet-50 | SimPool | 2.0 | 100 | 63.8 | 64.4 | checkpoint | logs | configs |
ResNet-50 | SimPool | - | 100 | 63.7 | 64.2 | checkpoint | logs | configs |
ConvNeXt-S | Official | - | 100 | 59.3 | 63.9 | checkpoint | logs | configs |
ConvNeXt-S | SimPool | 2.0 | 100 | 68.7 | 72.2 | checkpoint | logs | configs |
ConvNeXt-S | SimPool | - | 100 | 68.8 | 72.2 | checkpoint | logs | configs |
Having created the self-supervised environment and downloaded the ImageNet dataset, you are now ready to train! We pre-train ViT-S, ResNet-50 and ConvNeXt-S with DINO.
Train ViT-S with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --mode simpool --gamma 1.25 \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --optimizer adamw --use_bn_in_head False --out_dim 65536 \
--subset -1 --batch_size_per_gpu 100 --local_crops_number 6 --epochs 100 --num_workers 10 --lr 0.0005 --min_lr 0.00001 \
--global_crops_scale 0.25 1.0 --local_crops_scale 0.05 0.25 --norm_last_layer False --warmup_teacher_temp_epochs 30 \
--weight_decay 0.04 --weight_decay_end 0.4
For ViT-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 8 GPUs x 100 batch size per GPU = 800 global batch size.
Extract features from ViT-S with SimPool on ImageNet-1k and evaluate with k-NN:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch vit_small --mode simpool --gamma 1.25 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/
For ViT-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗
Linear probing of ViT-S with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --n_last_blocks 1 \
--arch vit_small --mode simpool --pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ \
--output_dir /path/to/output/ --epochs 100
For ViT-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.
Train ResNet-50 with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=16 main_dino.py --arch resnet50 --mode simpool \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --subset -1 --num_workers 10 --batch_size_per_gpu 90 \
--out_dim 60000 --use_bn_in_head True --teacher_temp 0.07 --warmup_teacher_temp_epochs 50 --use_fp16 False \
--weight_decay 0.000001 --weight_decay_end 0.000001 --clip_grad 0.0 --epochs 100 --lr 0.3 --min_lr 0.0048 \
--optimizer lars --global_crops_scale 0.14 1.0 --local_crops_number 6 --local_crops_scale 0.05 0.14
For ResNet-50 official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 16 GPUs x 90 batch size per GPU = 1280 global batch size.
Extract features from ResNet-50 with SimPool on ImageNet-1k and evaluate with k-NN:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch resnet50 --mode simpool --gamma 2.0 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/
For ResNet-50 official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗
Linear probing of ResNet-50 with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --arch resnet50 --mode simpool \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ --output_dir /path/to/output/ --epochs 100
For ResNet-50 official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.
Train ConvNeXt-S with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch convnext_small --mode simpool \
--data_path /path/to/imagenet/ --output_dir /path/to/output/ --subset -1 --num_workers 10 --batch_size_per_gpu 60 \
--out_dim 65536 --use_bn_in_head False --weight_decay 0.04 --weight_decay_end 0.4 --clip_grad 0.3 --epochs 100 \
--min_lr 2e-6 --optimizer adamw --lr 0.001 --freeze_last_layer 3
For ConvNeXt-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 8 GPUs x 60 batch size per GPU = 480 global batch size.
Extract features from ConvNeXt-S with SimPool on ImageNet-1k and evaluate with k-NN:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --arch convnext_small --mode simpool --gamma 2.0 \
--pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/
For ConvNeXt-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗
Linear probing of ConvNeXt-S with SimPool on ImageNet-1k for 100 epochs:
python3 -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --batch_size_per_gpu 256 --n_last_blocks 1 \
--arch convnext_small --mode simpool --pretrained_weights /path/to/checkpoint/ --data_path /path/to/imagenet/ \
--output_dir /path/to/output/ --epochs 100
For ConvNeXt-S official adjust
--mode official
. For no$\gamma$ adjust--gamma None
. ❗ NOTE: Here we use 4 GPUs x 256 batch size per GPU = 1028 global batch size.
- Use
--subset 260
to train on ImageNet-20% dataset. - When loading our weights using
--pretrained_weights
, take care of any inconsistencies in model keys! - Default value of
$\gamma$ is: 1.25 for transformers, 2.0 for convolutional networks. - In some cases, we observed that using no
$\gamma$ facilitates the training, results in slightly better metrics, but also lowers the attention map quality.