Skip to content

Commit da0d554

Browse files
committed
Add sae hooked llava
1 parent 0a6889e commit da0d554

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

lmms_eval/models/llava_sae_hooked.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
import os
2+
import warnings
3+
from typing import List, Optional, Tuple, Union
4+
5+
import torch
6+
from accelerate import Accelerator, DistributedType
7+
from accelerate.state import AcceleratorState
8+
from sae import Sae
9+
from tqdm import tqdm
10+
from transformers import (
11+
AutoConfig,
12+
AutoProcessor,
13+
LlavaForConditionalGeneration,
14+
LlavaNextForConditionalGeneration,
15+
)
16+
17+
from lmms_eval import utils
18+
from lmms_eval.api.instance import Instance
19+
from lmms_eval.api.model import lmms
20+
from lmms_eval.api.registry import register_model
21+
22+
warnings.filterwarnings("ignore")
23+
24+
from loguru import logger as eval_logger
25+
26+
DEFAULT_IMAGE_TOKEN = "<image>"
27+
28+
# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0
29+
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
30+
31+
model_map = {
32+
"llava": LlavaForConditionalGeneration,
33+
"llava_next": LlavaNextForConditionalGeneration,
34+
}
35+
36+
37+
@register_model("llava_sae_hooked")
38+
class LlavaSaeHooked(lmms):
39+
"""
40+
Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava
41+
42+
Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py
43+
44+
Example usage:
45+
46+
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
47+
--model llava_hf \
48+
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
49+
--tasks seedbench \
50+
--batch_size 1 \
51+
--output_path ./logs/ \
52+
--log_samples
53+
"""
54+
55+
def __init__(
56+
self,
57+
pretrained: str = "llava-hf/llava-1.5-7b-hf",
58+
revision: str = "main",
59+
device: str = "cuda",
60+
dtype: Optional[Union[str, torch.dtype]] = "auto",
61+
batch_size: int = 1,
62+
trust_remote_code: Optional[bool] = False,
63+
attn_implementation: Optional[str] = None,
64+
device_map: str = "",
65+
chat_template: Optional[str] = None,
66+
use_cache: bool = True,
67+
specified_eot_token_id: Optional[int] = None,
68+
sae_path: Optional[str] = None,
69+
**kwargs,
70+
) -> None:
71+
super().__init__()
72+
# Do not use kwargs for now
73+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
74+
75+
accelerator = Accelerator()
76+
if accelerator.num_processes > 1 and device_map == "":
77+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
78+
self.device_map = f"cuda:{accelerator.local_process_index}"
79+
else:
80+
self._device = torch.device(device)
81+
self.device_map = device_map
82+
if isinstance(dtype, str) and dtype != "auto":
83+
dtype = getattr(torch, dtype)
84+
85+
config = AutoConfig.from_pretrained(pretrained)
86+
model_type = getattr(config, "model_type", "llava")
87+
model_type = model_map[model_type]
88+
self._model = model_type.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
89+
if sae_path is not None:
90+
self.module_dict = Sae.load_many(sae_path, local=True if os.path.exists(sae_path) else False, device=self._device)
91+
else:
92+
self.module_dict = None
93+
self.name_to_module = {name: self.model.language_model.get_submodule(name) for name in self.module_dict.keys()}
94+
self.module_to_name = {v: k for k, v in self.name_to_module.items()}
95+
96+
self.pretrained = pretrained
97+
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
98+
# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
99+
self._image_processor.tokenizer.padding_side = "left"
100+
self._tokenizer = self._image_processor.tokenizer
101+
self._config = self._model.config
102+
self.batch_size_per_gpu = int(batch_size)
103+
self.chat_template = chat_template
104+
self.use_cache = use_cache
105+
self.specified_eot_token_id = specified_eot_token_id
106+
if accelerator.num_processes > 1 and device_map == "":
107+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
108+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
109+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
110+
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
111+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
112+
kwargs = {
113+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
114+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
115+
}
116+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
117+
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
118+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
119+
self._model = accelerator.prepare(self.model)
120+
else:
121+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
122+
self.module_dict = {k: accelerator.prepare_model(v) for k, v in self.module_dict.items()}
123+
self.accelerator = accelerator
124+
if self.accelerator.is_local_main_process:
125+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
126+
self._rank = self.accelerator.local_process_index
127+
self._world_size = self.accelerator.num_processes
128+
elif accelerator.num_processes == 1 and device_map == "auto":
129+
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
130+
self._rank = 0
131+
self._word_size = 1
132+
else:
133+
eval_logger.info(f"Using single device: {self._device}")
134+
self.model.to(self._device)
135+
self._rank = 0
136+
self._word_size = 1
137+
self.accelerator = accelerator
138+
139+
@property
140+
def config(self):
141+
# return the associated transformers.AutoConfig for the given pretrained model.
142+
return self._config
143+
144+
@property
145+
def tokenizer(self):
146+
return self._tokenizer
147+
148+
@property
149+
def model(self):
150+
# returns the model, unwrapping it if using Accelerate
151+
if hasattr(self, "accelerator"):
152+
return self.accelerator.unwrap_model(self._model)
153+
else:
154+
return self._model
155+
156+
@property
157+
def eot_token_id(self):
158+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
159+
return self.tokenizer.eos_token_id
160+
161+
@property
162+
def max_length(self):
163+
return self._max_length
164+
165+
@property
166+
def batch_size(self):
167+
return self.batch_size_per_gpu
168+
169+
@property
170+
def device(self):
171+
return self._device
172+
173+
@property
174+
def rank(self):
175+
return self._rank
176+
177+
@property
178+
def world_size(self):
179+
return self._world_size
180+
181+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
182+
""" """
183+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
184+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
185+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
186+
if left_truncate_len:
187+
encoding = encoding[-left_truncate_len:]
188+
return encoding
189+
190+
def tok_decode(self, tokens):
191+
return self.tokenizer.decode(tokens)
192+
193+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
194+
res = []
195+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
196+
197+
for context, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
198+
# encode, pad, and truncate contexts for this batch
199+
if type(doc_to_target) == str:
200+
continuation = doc_to_target
201+
else:
202+
continuation = doc_to_target(self.task_dict[task][split][doc_id])
203+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
204+
visuals = self.flatten(visuals)
205+
206+
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
207+
image_tokens = " ".join(image_tokens)
208+
context = f"{image_tokens}\n{context}"
209+
# Apply chat template
210+
messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}]
211+
if self.chat_template is not None:
212+
self.tokenizer.chat_template = self.chat_template
213+
prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
214+
prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
215+
elif self.tokenizer.chat_template is not None:
216+
prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
217+
prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
218+
else:
219+
self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE
220+
prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
221+
prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
222+
223+
formatted_contexts = [prompt]
224+
formatted_continuation = [prompt_and_continuation]
225+
model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self.model.dtype)
226+
labels = model_inputs["input_ids"].clone()
227+
contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"]
228+
labels[: len(contxt_id)] = -100
229+
230+
if self.accelerator.is_main_process and doc_id % 100 == 0:
231+
eval_logger.debug(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")
232+
eval_logger.debug(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n")
233+
234+
with torch.inference_mode():
235+
outputs = self.model(**model_inputs, labels=labels)
236+
loss = outputs["loss"]
237+
logits = outputs["logits"]
238+
greedy_tokens = logits.argmax(dim=-1)
239+
cont_toks = model_inputs["input_ids"][:, contxt_id.shape[1] :] # [1, seq]
240+
greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : model_inputs["input_ids"].shape[1]] # [1, seq]
241+
max_equal = (greedy_tokens == cont_toks).all()
242+
res.append((float(loss.item()), bool(max_equal)))
243+
pbar.update(1)
244+
245+
pbar.close()
246+
return res
247+
248+
def flatten(self, input):
249+
new_list = []
250+
for i in input:
251+
for j in i:
252+
new_list.append(j)
253+
return new_list
254+
255+
def generate_until(self, requests: List[Instance]) -> List[str]:
256+
res = []
257+
258+
def _collate(x):
259+
# the negative sign on len(toks) sorts descending - this has a few advantages:
260+
# - time estimates will always be over not underestimates, which is more useful for planning
261+
# - to know the size of a batch when going through the list, you know the first one is always the batch
262+
# padded context length. this is useful to simplify the batching logic and more importantly to make
263+
# automatic adaptive batches much much easier to implement
264+
# - any OOMs will happen right away rather than near the end
265+
toks = self.tok_encode(x[0])
266+
return -len(toks), x[0]
267+
268+
# we group requests by their generation_kwargs,
269+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
270+
# in the same batch.
271+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
272+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
273+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
274+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
275+
for chunk in chunks:
276+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
277+
task = task[0]
278+
split = split[0]
279+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
280+
visuals = self.flatten(visuals)
281+
# we assume all gen kwargs in the batch are the same
282+
# this is safe to assume because the `grouper` object ensures it.
283+
gen_kwargs = all_gen_kwargs[0]
284+
285+
# Set default values for until and max_new_tokens
286+
until = [self.tok_decode(self.eot_token_id)]
287+
288+
# Update values from gen_kwargs if present
289+
if "until" in gen_kwargs:
290+
until = gen_kwargs.pop("until")
291+
if isinstance(until, str):
292+
until = [until]
293+
elif not isinstance(until, list):
294+
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
295+
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
296+
context = contexts[0]
297+
298+
# Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt.
299+
if DEFAULT_IMAGE_TOKEN not in context:
300+
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
301+
image_tokens = " ".join(image_tokens)
302+
context = f"{image_tokens}\n{context}"
303+
# Apply chat template
304+
messages = [{"role": "user", "content": context}]
305+
if self.chat_template is not None:
306+
self.tokenizer.chat_template = self.chat_template
307+
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
308+
elif self.tokenizer.chat_template is not None:
309+
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
310+
else:
311+
self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE
312+
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
313+
314+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
315+
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
316+
317+
inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)
318+
319+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
320+
if "max_new_tokens" not in gen_kwargs:
321+
gen_kwargs["max_new_tokens"] = 1024
322+
if "temperature" not in gen_kwargs:
323+
gen_kwargs["temperature"] = 0
324+
if "top_p" not in gen_kwargs:
325+
gen_kwargs["top_p"] = None
326+
if "num_beams" not in gen_kwargs:
327+
gen_kwargs["num_beams"] = 1
328+
329+
def hook(module: torch.nn.Module, _, outputs):
330+
# Maybe unpack tuple outputs
331+
if isinstance(outputs, tuple):
332+
unpack_outputs = list(outputs)
333+
else:
334+
unpack_outputs = list(outputs)
335+
name = self.module_to_name[module]
336+
sae = self.module_dict[name]
337+
sae_out = sae(unpack_outputs[0][0]).sae_out.unsqueeze(0).to(torch.float16)
338+
unpack_outputs[0] = sae_out
339+
if isinstance(outputs, tuple):
340+
outputs = tuple(unpack_outputs)
341+
else:
342+
outputs = unpack_outputs[0]
343+
return outputs
344+
345+
handles = [mod.register_forward_hook(hook) for mod in self.name_to_module.values()]
346+
try:
347+
cont = self.model.generate(
348+
**inputs,
349+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
350+
temperature=gen_kwargs["temperature"],
351+
top_p=gen_kwargs["top_p"],
352+
num_beams=gen_kwargs["num_beams"],
353+
max_new_tokens=gen_kwargs["max_new_tokens"],
354+
use_cache=self.use_cache,
355+
pad_token_id=self.tokenizer.eos_token_id,
356+
eos_token_id=self.specified_eot_token_id,
357+
)
358+
cont = cont[:, inputs["input_ids"].shape[-1] :]
359+
except Exception as e:
360+
eval_logger.error(f"Error {e} in generating")
361+
cont = ""
362+
finally:
363+
for handle in handles:
364+
handle.remove()
365+
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
366+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
367+
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
368+
369+
res.append(text_outputs)
370+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
371+
pbar.update(1)
372+
# reorder this group of results back to original unsorted form
373+
res = re_ords.get_original(res)
374+
375+
pbar.close()
376+
return res

0 commit comments

Comments
 (0)