Skip to content
/ AWT Public

[NeurIPS 2024] AWT: Transferring Vision-Language Models via Augmentation, Weighting, and Transportation

License

Notifications You must be signed in to change notification settings

MCG-NJU/AWT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AWT: Transferring Vision-Language Models via Augmentation, Weighting, and Transportation

✨ Welcome to the official repository for "AWT: Transferring Vision-Language Models via Augmentation, Weighting, and Transportation". This work is a collaborative effort by Yuhan Zhu, Yuyang Ji, Zhiyu Zhao, Gangshan Wu, and Limin Wang from Nanjing University and Shanghai AI Lab.

🔗 Read our paper: ArXiv | NeurIPS 2024

Overview

🚀 Our work presents AWT, an innovative framework for transferring pre-trained Vision-Language Models (VLMs) to downstream tasks. AWT supercharges VLMs' zero-shot capabilities without the need for additional training, and excels in few-shot learning by incorporating a multimodal adapter. AWT sets new benchmark records in both zero-shot and few-shot image and video tasks, achieving state-of-the-art performance.

AWT Overview

🗂 Contents

Installation

# Create a conda environment
conda create -y -n awt python=3.8
conda activate awt

# Require pytorch>=1.10
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia

# Install other dependencies 
pip install -r requirements.txt

# If running few-shot experiments
cd AWT_few_shot/Dassl.pytorch ; python setup.py develop

Data Preparation

Refer to the following guides for setting up datasets:

Using AWT

Zero-shot Image Classification

cd AWT_zero_shot

# Set data path in `./scripts/pre_extract.sh`
# `dataset_name` is chosen from:
# ['imagenet', 'oxford_flowers', 'dtd', 'oxford_pets', 'stanford_cars', 'ucf101', 'caltech101', 'food101', 'sun397', 'fgvc_aircraft', 'eurosat', 'caltech256', 'cub', 'birdsnap']

# Step 1: Extract visual features
bash ./scripts/pre_extract.sh [dataset_name]

# Step 2: Evaluate AWT performance
bash ./scripts/evaluate.sh [dataset_name]

Out-of-Distribution Generalization

Note that ImageNet-v2 and ImageNet-A share the same description file as ImageNet, as they have a similar visual style.

cd AWT_zero_shot

# `dataset_name` is chosen from:
# ['imagenet_a', 'imagenet_sketch', 'imagenet_r', 'imagenetv2']

# Step 1: Extract visual features
./scripts/pre_extract.sh [dataset_name]

# Step 2: Evaluate AWT performance
bash ./scripts/evaluate.sh [dataset_name]

Few-shot Image Classification

cd AWT_few_shot/MM_Adapter

We use one NVIDIA A100-SXM4-80GB GPU for training. If you encounter an OOM error on your hardware, please reduce the Desc_Per_Batch value in ./configs/trainers/AWT/[config_you_use].yaml (note that this might slightly decrease performance).

# `dataset_name` are defined in `./configs/datasets/`

# For 1, 2, or 4 shots training
bash scripts/awt/main.sh [dataset_name] vit_b16_1_2_4_shot [n_shot]

# For 8 or 16 shots training
bash scripts/awt/main.sh [dataset_name] vit_b16_8_16_shot [n_shot]

To evaluate the model:

bash scripts/awt/test.sh [dataset_name] [config_you_use] [n_shot]

Compute the mean and variance across different seeds:

python parse_test_res.py [experiment path] --test-log

Draw few-shot performance curves (refer to our Figure 3):

python draw_curves.py

Zero-shot Video Action Recognition

In this task, we treat the sampled video frames as augmented views in AWT. Specifically, we use 8x3x4 views (#frames x #crops x #clips).

To reproduce our results:

  • Use the pre-trained Open-VCLIP to extract features of each video. Save the features and ground-truth labels to disk, following the format in AWT_zero_shot/pre_extract.py.
  • Use AWT_zero_shot/evaluate.py to read the saved features and labels, then obtain the benchmark accuracy. Class descriptions for the three video datasets can be found in AWT_zero_shot/descriptions/video_datasets/.

Generate Descriptions

We take a two-step dataset-aware prompt strategy to generate descriptions for each class.

cd description_generation

# Step 1: Generate queries
# Add `--no_dataset_desc` to use AWT-base, i.e., w/o dataset descriptions
python gen_questions.py

# Step 2: Generate class descriptions
python gen_descriptions.py

Experiment with Other VLMs

cd AWT_zero_shot

# run AWT on ALIGN
bash scripts/run_align.sh [dataset_name]

# run AWT on EVA02-CLIP
bash scripts/run_eva02_clip.sh [dataset_name]

# run AWT on SigLIP
bash scripts/run_siglip.sh [dataset_name]

Citation

🤗 If you find our code useful or our work relevant, please consider citing:

@article{awt2024,
  title={AWT: Transferring Vision-Language Models via Augmentation, Weighting, and Transportation},
  author={Zhu, Yuhan and Ji, Yuyang and Zhao, Zhiyu and Wu, Gangshan and Wang, Limin},
  journal={arXiv preprint arXiv:2407.04603},
  year={2024}
}

Acknowledgements

Our work builds upon CoOp, PLOT, TPT, Dassl, and Open-VCLIP. Thanks for their excellent work and open-source contributions.