Skip to content

Commit 6b082d9

Browse files
committed
Initial commit
0 parents  commit 6b082d9

39 files changed

+3231
-0
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.sh text eol=lf

README.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
<div align="center">
2+
<h1>ResCLIP </h1>
3+
<h3>ResCLIP: Residual Attention for Training-free Dense Vision-language Inference</h3>
4+
<div>
5+
<h4 align="center">
6+
<a href='https://arxiv.org/abs/2411.15851'><img src='https://img.shields.io/badge/ArXiv-2411.15851-red'></a>
7+
</h4>
8+
</div>
9+
</div>
10+
11+
## News
12+
* **` Feb. 27th, 2025`**: This paper has been accepted by CVPR 2025.
13+
* **` Nov. 23rd, 2024`**: We release paper for ResCLIP.
14+
15+
## Abstract
16+
While vision-language models like CLIP have shown remarkable success in open-vocabulary tasks, their application is currently confined to image-level tasks, and they still struggle with dense predictions. Recent works often attribute such deficiency in dense predictions to the self-attention layers in the final block, and have achieved commendable results by modifying the original query-key attention to self-correlation attention, (e.g., query-query and key-key attention). However, these methods overlook the cross-correlation attention (query-key) properties, which capture the rich spatial correspondence. In this paper, we reveal that the cross-correlation of the self-attention in CLIP's non-final layers also exhibits localization properties. Therefore, we propose the Residual Cross-correlation Self-attention (RCS) module, which leverages the cross-correlation self-attention from intermediate layers to remold the attention in the final block. The RCS module effectively reorganizes spatial information, unleashing the localization potential within CLIP for dense vision-language inference. Furthermore, to enhance the focus on regions of the same categories and local consistency, we propose the Semantic Feedback Refinement (SFR) module, which utilizes semantic segmentation maps to further adjust the attention scores. By integrating these two strategies, our method, termed **ResCLIP**, can be easily incorporated into existing approaches as a plug-and-play module, significantly boosting their performance in dense vision-language inference. Extensive experiments across multiple standard benchmarks demonstrate that our method surpasses state-of-the-art training-free methods, validating the effectiveness of the proposed approach.
17+
For more information, please refer to our [paper](https://arxiv.org/abs/2411.15851).
18+
19+
<p align="center">
20+
<img src="./figs/method_simplify_all.png" width="800" />
21+
</p>
22+
23+
<p align="center">
24+
<img src="./figs/pipeline_all.png" width="800" />
25+
</p>
26+
27+
## Main Results
28+
29+
<p align="center">
30+
<img src="./figs/main_results_wo.png" width="800" />
31+
</p>
32+
33+
<p align="center">
34+
<img src="./figs/main_results_w.png" width="800" />
35+
</p>
36+
37+
<p align="center">
38+
<img src="./figs/main_visual.png" width="800" />
39+
</p>
40+
41+
42+
## Getting Started
43+
### Installation
44+
45+
**Step 1: Clone ResCLIP Repository:**
46+
47+
```bash
48+
git clone https://github.com/yvhangyang/ResCLIP.git
49+
cd ResCLIP
50+
```
51+
52+
**Step 2: Environment Setup:**
53+
54+
Create and activate a new conda environment.
55+
56+
```bash
57+
conda create -n ResCLIP
58+
conda activate ResCLIP
59+
```
60+
61+
**Step 3: Install Dependencies:**
62+
63+
To run ResCLIP, please install the following packages. We used `Python 3.9` in our experiments.
64+
65+
```
66+
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
67+
pip install openmim
68+
mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1
69+
pip install ftfy regex numpy==1.26 yapf==0.40.1
70+
```
71+
72+
### Quick Start
73+
74+
#### Datasets Preparation
75+
76+
Please follow the [MMSeg data preparation document](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download and pre-process the datasets including PASCAL VOC, PASCAL Context, Cityscapes, ADE20k, COCO Object and COCO-Stuff164k.
77+
The COCO-Object dataset can be converted from COCO-Stuff164k by executing the following command:
78+
79+
```
80+
python ./datasets/cvt_coco_object.py PATH_TO_COCO_STUFF164K -o PATH_TO_COCO_OBJECT
81+
```
82+
83+
**Remember to modify the dataset paths (`data_root`) in the config files in** `./configs/`.
84+
85+
#### Evaluation
86+
87+
To evaluate our approach on a single benchmark, run the following command:
88+
89+
```bash
90+
python eval.py --config ./config/cfg_DATASET.py --workdir YOUR_WORK_DIR
91+
```
92+
93+
or eval on all datasets:
94+
95+
```
96+
bash test_all.sh {arch} {attn} {gaussian_std} {gpu} {log_path}
97+
```
98+
99+
Values of `wo_resi` for `{arch}`, and `resclip` for `{attn}` represent our method.
100+
For example, to reproduce the main results, run:
101+
102+
```
103+
bash test_all.sh wo_resi resclip 5 {gpu} {log_path}
104+
```
105+
106+
#### Demo
107+
108+
109+
```bash
110+
python demo.py
111+
```
112+
113+
## Acknowledgment
114+
115+
This project is based on [NACLIP](https://github.com/sinahmr/NACLIP), [SCLIP](https://github.com/wangf3014/SCLIP), [ClearCLIP](https://github.com/mc-lan/ClearCLIP), [CLIP](https://github.com/openai/CLIP) and [OpenCLIP](https://github.com/mlfoundations/open_clip). Thanks for their excellent works.
116+
117+
118+
## Citation
119+
120+
If you find this project useful, please consider citing:
121+
122+
```bibtex
123+
@article{yang2024resclip,
124+
title={ResCLIP: Residual Attention for Training-free Dense Vision-language Inference},
125+
author={Yang, Yuhang and Deng, Jinhong and Li, Wen and Duan, Lixin},
126+
journal={arXiv preprint arXiv:2411.15851},
127+
year={2024}
128+
}
129+
```

clip/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .clip import *
2+
from .model import *

clip/bpe_simple_vocab_16e6.txt.gz

1.29 MB
Binary file not shown.

clip/clip.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
### CLIP source code from OpenAI:
2+
# https://github.com/openai/CLIP/blob/main/clip/clip.py
3+
4+
import hashlib
5+
import os
6+
import urllib
7+
import warnings
8+
from typing import Union, List
9+
10+
import torch
11+
from PIL import Image
12+
from pkg_resources import packaging
13+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14+
from tqdm import tqdm
15+
16+
from .model import build_model
17+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
18+
19+
try:
20+
from torchvision.transforms import InterpolationMode
21+
22+
BICUBIC = InterpolationMode.BICUBIC
23+
except ImportError:
24+
BICUBIC = Image.BICUBIC
25+
26+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
27+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
28+
29+
__all__ = ["available_models", "load", "tokenize"]
30+
_tokenizer = _Tokenizer()
31+
32+
_MODELS = {
33+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
34+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
35+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
36+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
37+
}
38+
39+
40+
def _download(url: str, root: str):
41+
os.makedirs(root, exist_ok=True)
42+
filename = os.path.basename(url)
43+
44+
expected_sha256 = url.split("/")[-2]
45+
download_target = os.path.join(root, filename)
46+
47+
if os.path.exists(download_target) and not os.path.isfile(download_target):
48+
raise RuntimeError(f"{download_target} exists and is not a regular file")
49+
50+
if os.path.isfile(download_target):
51+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
52+
return download_target
53+
else:
54+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
55+
56+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
57+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
58+
while True:
59+
buffer = source.read(8192)
60+
if not buffer:
61+
break
62+
63+
output.write(buffer)
64+
loop.update(len(buffer))
65+
66+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
67+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
68+
69+
return download_target
70+
71+
72+
def _convert_image_to_rgb(image):
73+
return image.convert("RGB")
74+
75+
76+
def _transform(n_px):
77+
return Compose([
78+
Resize(n_px, interpolation=BICUBIC),
79+
CenterCrop(n_px),
80+
_convert_image_to_rgb,
81+
ToTensor(),
82+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
83+
])
84+
85+
86+
def available_models() -> List[str]:
87+
"""Returns the names of available CLIP models"""
88+
return list(_MODELS.keys())
89+
90+
91+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
92+
"""Load a CLIP model
93+
94+
Parameters
95+
----------
96+
name : str
97+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
98+
99+
device : Union[str, torch.device]
100+
The device to put the loaded model
101+
102+
jit : bool
103+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
104+
105+
download_root: str
106+
path to download the model files; by default, it uses "~/.cache/clip"
107+
108+
Returns
109+
-------
110+
model : torch.nn.Module
111+
The CLIP model
112+
113+
preprocess : Callable[[PIL.Image], torch.Tensor]
114+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
115+
"""
116+
if name in _MODELS:
117+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
118+
elif os.path.isfile(name):
119+
model_path = name
120+
else:
121+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
122+
123+
try:
124+
# loading JIT archive
125+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
126+
state_dict = None
127+
except RuntimeError:
128+
# loading saved state dict
129+
if jit:
130+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
131+
jit = False
132+
state_dict = torch.load(model_path, map_location="cpu")
133+
134+
if not jit:
135+
model = build_model(state_dict or model.state_dict()).to(device)
136+
if str(device) == "cpu":
137+
model.float()
138+
return model, _transform(model.visual.input_resolution)
139+
140+
# patch the device names
141+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
142+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
143+
144+
def patch_device(module):
145+
try:
146+
graphs = [module.graph] if hasattr(module, "graph") else []
147+
except RuntimeError:
148+
graphs = []
149+
150+
if hasattr(module, "forward1"):
151+
graphs.append(module.forward1.graph)
152+
153+
for graph in graphs:
154+
for node in graph.findAllNodes("prim::Constant"):
155+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
156+
node.copyAttributes(device_node)
157+
158+
model.apply(patch_device)
159+
patch_device(model.encode_image)
160+
patch_device(model.encode_text)
161+
162+
# patch dtype to float32 on CPU
163+
if str(device) == "cpu":
164+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
165+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
166+
float_node = float_input.node()
167+
168+
def patch_float(module):
169+
try:
170+
graphs = [module.graph] if hasattr(module, "graph") else []
171+
except RuntimeError:
172+
graphs = []
173+
174+
if hasattr(module, "forward1"):
175+
graphs.append(module.forward1.graph)
176+
177+
for graph in graphs:
178+
for node in graph.findAllNodes("aten::to"):
179+
inputs = list(node.inputs())
180+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
181+
if inputs[i].node()["value"] == 5:
182+
inputs[i].node().copyAttributes(float_node)
183+
184+
model.apply(patch_float)
185+
patch_float(model.encode_image)
186+
patch_float(model.encode_text)
187+
188+
model.float()
189+
190+
return model, _transform(model.input_resolution.item())
191+
192+
193+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
194+
"""
195+
Returns the tokenized representation of given input string(s)
196+
197+
Parameters
198+
----------
199+
texts : Union[str, List[str]]
200+
An input string or a list of input strings to tokenize
201+
202+
context_length : int
203+
The context length to use; all CLIP models use 77 as the context length
204+
205+
truncate: bool
206+
Whether to truncate the text in case its encoding is longer than the context length
207+
208+
Returns
209+
-------
210+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
211+
"""
212+
if isinstance(texts, str):
213+
texts = [texts]
214+
215+
sot_token = _tokenizer.encoder["<|startoftext|>"]
216+
eot_token = _tokenizer.encoder["<|endoftext|>"]
217+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
218+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
219+
220+
for i, tokens in enumerate(all_tokens):
221+
if len(tokens) > context_length:
222+
if truncate:
223+
tokens = tokens[:context_length]
224+
tokens[-1] = eot_token
225+
else:
226+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
227+
result[i, :len(tokens)] = torch.tensor(tokens)
228+
229+
return result

0 commit comments

Comments
 (0)