|
| 1 | +import os |
| 2 | +import warnings |
| 3 | +from typing import List, Optional, Tuple, Union |
| 4 | + |
| 5 | +import torch |
| 6 | +from accelerate import Accelerator, DistributedType |
| 7 | +from accelerate.state import AcceleratorState |
| 8 | +from sae import Sae |
| 9 | +from tqdm import tqdm |
| 10 | +from transformers import ( |
| 11 | + AutoConfig, |
| 12 | + AutoProcessor, |
| 13 | + LlavaForConditionalGeneration, |
| 14 | + LlavaNextForConditionalGeneration, |
| 15 | +) |
| 16 | + |
| 17 | +from lmms_eval import utils |
| 18 | +from lmms_eval.api.instance import Instance |
| 19 | +from lmms_eval.api.model import lmms |
| 20 | +from lmms_eval.api.registry import register_model |
| 21 | + |
| 22 | +warnings.filterwarnings("ignore") |
| 23 | + |
| 24 | +from loguru import logger as eval_logger |
| 25 | + |
| 26 | +DEFAULT_IMAGE_TOKEN = "<image>" |
| 27 | + |
| 28 | +# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0 |
| 29 | +VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" |
| 30 | + |
| 31 | +model_map = { |
| 32 | + "llava": LlavaForConditionalGeneration, |
| 33 | + "llava_next": LlavaNextForConditionalGeneration, |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +@register_model("llava_sae_hooked") |
| 38 | +class LlavaSaeHooked(lmms): |
| 39 | + """ |
| 40 | + Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava |
| 41 | +
|
| 42 | + Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py |
| 43 | +
|
| 44 | + Example usage: |
| 45 | +
|
| 46 | + accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \ |
| 47 | + --model llava_hf \ |
| 48 | + --model_args pretrained=llava-hf/llava-1.5-7b-hf \ |
| 49 | + --tasks seedbench \ |
| 50 | + --batch_size 1 \ |
| 51 | + --output_path ./logs/ \ |
| 52 | + --log_samples |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + pretrained: str = "llava-hf/llava-1.5-7b-hf", |
| 58 | + revision: str = "main", |
| 59 | + device: str = "cuda", |
| 60 | + dtype: Optional[Union[str, torch.dtype]] = "auto", |
| 61 | + batch_size: int = 1, |
| 62 | + trust_remote_code: Optional[bool] = False, |
| 63 | + attn_implementation: Optional[str] = None, |
| 64 | + device_map: str = "", |
| 65 | + chat_template: Optional[str] = None, |
| 66 | + use_cache: bool = True, |
| 67 | + specified_eot_token_id: Optional[int] = None, |
| 68 | + sae_path: Optional[str] = None, |
| 69 | + **kwargs, |
| 70 | + ) -> None: |
| 71 | + super().__init__() |
| 72 | + # Do not use kwargs for now |
| 73 | + assert kwargs == {}, f"Unexpected kwargs: {kwargs}" |
| 74 | + |
| 75 | + accelerator = Accelerator() |
| 76 | + if accelerator.num_processes > 1 and device_map == "": |
| 77 | + self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
| 78 | + self.device_map = f"cuda:{accelerator.local_process_index}" |
| 79 | + else: |
| 80 | + self._device = torch.device(device) |
| 81 | + self.device_map = device_map |
| 82 | + if isinstance(dtype, str) and dtype != "auto": |
| 83 | + dtype = getattr(torch, dtype) |
| 84 | + |
| 85 | + config = AutoConfig.from_pretrained(pretrained) |
| 86 | + model_type = getattr(config, "model_type", "llava") |
| 87 | + model_type = model_map[model_type] |
| 88 | + self._model = model_type.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation) |
| 89 | + if sae_path is not None: |
| 90 | + self.module_dict = Sae.load_many(sae_path, local=True if os.path.exists(sae_path) else False, device=self._device) |
| 91 | + else: |
| 92 | + self.module_dict = None |
| 93 | + self.name_to_module = {name: self.model.language_model.get_submodule(name) for name in self.module_dict.keys()} |
| 94 | + self.module_to_name = {v: k for k, v in self.name_to_module.items()} |
| 95 | + |
| 96 | + self.pretrained = pretrained |
| 97 | + self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code) |
| 98 | + # Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips |
| 99 | + self._image_processor.tokenizer.padding_side = "left" |
| 100 | + self._tokenizer = self._image_processor.tokenizer |
| 101 | + self._config = self._model.config |
| 102 | + self.batch_size_per_gpu = int(batch_size) |
| 103 | + self.chat_template = chat_template |
| 104 | + self.use_cache = use_cache |
| 105 | + self.specified_eot_token_id = specified_eot_token_id |
| 106 | + if accelerator.num_processes > 1 and device_map == "": |
| 107 | + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
| 108 | + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model |
| 109 | + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works |
| 110 | + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. |
| 111 | + if accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 112 | + kwargs = { |
| 113 | + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
| 114 | + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
| 115 | + } |
| 116 | + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
| 117 | + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
| 118 | + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 119 | + self._model = accelerator.prepare(self.model) |
| 120 | + else: |
| 121 | + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
| 122 | + self.module_dict = {k: accelerator.prepare_model(v) for k, v in self.module_dict.items()} |
| 123 | + self.accelerator = accelerator |
| 124 | + if self.accelerator.is_local_main_process: |
| 125 | + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
| 126 | + self._rank = self.accelerator.local_process_index |
| 127 | + self._world_size = self.accelerator.num_processes |
| 128 | + elif accelerator.num_processes == 1 and device_map == "auto": |
| 129 | + eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism") |
| 130 | + self._rank = 0 |
| 131 | + self._word_size = 1 |
| 132 | + else: |
| 133 | + eval_logger.info(f"Using single device: {self._device}") |
| 134 | + self.model.to(self._device) |
| 135 | + self._rank = 0 |
| 136 | + self._word_size = 1 |
| 137 | + self.accelerator = accelerator |
| 138 | + |
| 139 | + @property |
| 140 | + def config(self): |
| 141 | + # return the associated transformers.AutoConfig for the given pretrained model. |
| 142 | + return self._config |
| 143 | + |
| 144 | + @property |
| 145 | + def tokenizer(self): |
| 146 | + return self._tokenizer |
| 147 | + |
| 148 | + @property |
| 149 | + def model(self): |
| 150 | + # returns the model, unwrapping it if using Accelerate |
| 151 | + if hasattr(self, "accelerator"): |
| 152 | + return self.accelerator.unwrap_model(self._model) |
| 153 | + else: |
| 154 | + return self._model |
| 155 | + |
| 156 | + @property |
| 157 | + def eot_token_id(self): |
| 158 | + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* |
| 159 | + return self.tokenizer.eos_token_id |
| 160 | + |
| 161 | + @property |
| 162 | + def max_length(self): |
| 163 | + return self._max_length |
| 164 | + |
| 165 | + @property |
| 166 | + def batch_size(self): |
| 167 | + return self.batch_size_per_gpu |
| 168 | + |
| 169 | + @property |
| 170 | + def device(self): |
| 171 | + return self._device |
| 172 | + |
| 173 | + @property |
| 174 | + def rank(self): |
| 175 | + return self._rank |
| 176 | + |
| 177 | + @property |
| 178 | + def world_size(self): |
| 179 | + return self._world_size |
| 180 | + |
| 181 | + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: |
| 182 | + """ """ |
| 183 | + add_special_tokens = False if add_special_tokens is None else add_special_tokens |
| 184 | + encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
| 185 | + # left-truncate the encoded context to be at most `left_truncate_len` tokens long |
| 186 | + if left_truncate_len: |
| 187 | + encoding = encoding[-left_truncate_len:] |
| 188 | + return encoding |
| 189 | + |
| 190 | + def tok_decode(self, tokens): |
| 191 | + return self.tokenizer.decode(tokens) |
| 192 | + |
| 193 | + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
| 194 | + res = [] |
| 195 | + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") |
| 196 | + |
| 197 | + for context, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: |
| 198 | + # encode, pad, and truncate contexts for this batch |
| 199 | + if type(doc_to_target) == str: |
| 200 | + continuation = doc_to_target |
| 201 | + else: |
| 202 | + continuation = doc_to_target(self.task_dict[task][split][doc_id]) |
| 203 | + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] |
| 204 | + visuals = self.flatten(visuals) |
| 205 | + |
| 206 | + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) |
| 207 | + image_tokens = " ".join(image_tokens) |
| 208 | + context = f"{image_tokens}\n{context}" |
| 209 | + # Apply chat template |
| 210 | + messages = [{"role": "user", "content": context}, {"role": "assistant", "content": continuation}] |
| 211 | + if self.chat_template is not None: |
| 212 | + self.tokenizer.chat_template = self.chat_template |
| 213 | + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) |
| 214 | + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| 215 | + elif self.tokenizer.chat_template is not None: |
| 216 | + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) |
| 217 | + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| 218 | + else: |
| 219 | + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE |
| 220 | + prompt = self.tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) |
| 221 | + prompt_and_continuation = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| 222 | + |
| 223 | + formatted_contexts = [prompt] |
| 224 | + formatted_continuation = [prompt_and_continuation] |
| 225 | + model_inputs = self._image_processor(text=formatted_continuation, images=visuals).to(self._device, self.model.dtype) |
| 226 | + labels = model_inputs["input_ids"].clone() |
| 227 | + contxt_id = self._image_processor(text=formatted_contexts, return_tensors="pt")["input_ids"] |
| 228 | + labels[: len(contxt_id)] = -100 |
| 229 | + |
| 230 | + if self.accelerator.is_main_process and doc_id % 100 == 0: |
| 231 | + eval_logger.debug(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n") |
| 232 | + eval_logger.debug(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n") |
| 233 | + |
| 234 | + with torch.inference_mode(): |
| 235 | + outputs = self.model(**model_inputs, labels=labels) |
| 236 | + loss = outputs["loss"] |
| 237 | + logits = outputs["logits"] |
| 238 | + greedy_tokens = logits.argmax(dim=-1) |
| 239 | + cont_toks = model_inputs["input_ids"][:, contxt_id.shape[1] :] # [1, seq] |
| 240 | + greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : model_inputs["input_ids"].shape[1]] # [1, seq] |
| 241 | + max_equal = (greedy_tokens == cont_toks).all() |
| 242 | + res.append((float(loss.item()), bool(max_equal))) |
| 243 | + pbar.update(1) |
| 244 | + |
| 245 | + pbar.close() |
| 246 | + return res |
| 247 | + |
| 248 | + def flatten(self, input): |
| 249 | + new_list = [] |
| 250 | + for i in input: |
| 251 | + for j in i: |
| 252 | + new_list.append(j) |
| 253 | + return new_list |
| 254 | + |
| 255 | + def generate_until(self, requests: List[Instance]) -> List[str]: |
| 256 | + res = [] |
| 257 | + |
| 258 | + def _collate(x): |
| 259 | + # the negative sign on len(toks) sorts descending - this has a few advantages: |
| 260 | + # - time estimates will always be over not underestimates, which is more useful for planning |
| 261 | + # - to know the size of a batch when going through the list, you know the first one is always the batch |
| 262 | + # padded context length. this is useful to simplify the batching logic and more importantly to make |
| 263 | + # automatic adaptive batches much much easier to implement |
| 264 | + # - any OOMs will happen right away rather than near the end |
| 265 | + toks = self.tok_encode(x[0]) |
| 266 | + return -len(toks), x[0] |
| 267 | + |
| 268 | + # we group requests by their generation_kwargs, |
| 269 | + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling |
| 270 | + # in the same batch. |
| 271 | + re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) |
| 272 | + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
| 273 | + num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 |
| 274 | + pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") |
| 275 | + for chunk in chunks: |
| 276 | + contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) |
| 277 | + task = task[0] |
| 278 | + split = split[0] |
| 279 | + visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] |
| 280 | + visuals = self.flatten(visuals) |
| 281 | + # we assume all gen kwargs in the batch are the same |
| 282 | + # this is safe to assume because the `grouper` object ensures it. |
| 283 | + gen_kwargs = all_gen_kwargs[0] |
| 284 | + |
| 285 | + # Set default values for until and max_new_tokens |
| 286 | + until = [self.tok_decode(self.eot_token_id)] |
| 287 | + |
| 288 | + # Update values from gen_kwargs if present |
| 289 | + if "until" in gen_kwargs: |
| 290 | + until = gen_kwargs.pop("until") |
| 291 | + if isinstance(until, str): |
| 292 | + until = [until] |
| 293 | + elif not isinstance(until, list): |
| 294 | + raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") |
| 295 | + assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" |
| 296 | + context = contexts[0] |
| 297 | + |
| 298 | + # Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt. |
| 299 | + if DEFAULT_IMAGE_TOKEN not in context: |
| 300 | + image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals) |
| 301 | + image_tokens = " ".join(image_tokens) |
| 302 | + context = f"{image_tokens}\n{context}" |
| 303 | + # Apply chat template |
| 304 | + messages = [{"role": "user", "content": context}] |
| 305 | + if self.chat_template is not None: |
| 306 | + self.tokenizer.chat_template = self.chat_template |
| 307 | + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 308 | + elif self.tokenizer.chat_template is not None: |
| 309 | + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 310 | + else: |
| 311 | + self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE |
| 312 | + text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 313 | + |
| 314 | + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: |
| 315 | + eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") |
| 316 | + |
| 317 | + inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype) |
| 318 | + |
| 319 | + gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] |
| 320 | + if "max_new_tokens" not in gen_kwargs: |
| 321 | + gen_kwargs["max_new_tokens"] = 1024 |
| 322 | + if "temperature" not in gen_kwargs: |
| 323 | + gen_kwargs["temperature"] = 0 |
| 324 | + if "top_p" not in gen_kwargs: |
| 325 | + gen_kwargs["top_p"] = None |
| 326 | + if "num_beams" not in gen_kwargs: |
| 327 | + gen_kwargs["num_beams"] = 1 |
| 328 | + |
| 329 | + def hook(module: torch.nn.Module, _, outputs): |
| 330 | + # Maybe unpack tuple outputs |
| 331 | + if isinstance(outputs, tuple): |
| 332 | + unpack_outputs = list(outputs) |
| 333 | + else: |
| 334 | + unpack_outputs = list(outputs) |
| 335 | + name = self.module_to_name[module] |
| 336 | + sae = self.module_dict[name] |
| 337 | + sae_out = sae(unpack_outputs[0][0]).sae_out.unsqueeze(0).to(torch.float16) |
| 338 | + unpack_outputs[0] = sae_out |
| 339 | + if isinstance(outputs, tuple): |
| 340 | + outputs = tuple(unpack_outputs) |
| 341 | + else: |
| 342 | + outputs = unpack_outputs[0] |
| 343 | + return outputs |
| 344 | + |
| 345 | + handles = [mod.register_forward_hook(hook) for mod in self.name_to_module.values()] |
| 346 | + try: |
| 347 | + cont = self.model.generate( |
| 348 | + **inputs, |
| 349 | + do_sample=True if gen_kwargs["temperature"] > 0 else False, |
| 350 | + temperature=gen_kwargs["temperature"], |
| 351 | + top_p=gen_kwargs["top_p"], |
| 352 | + num_beams=gen_kwargs["num_beams"], |
| 353 | + max_new_tokens=gen_kwargs["max_new_tokens"], |
| 354 | + use_cache=self.use_cache, |
| 355 | + pad_token_id=self.tokenizer.eos_token_id, |
| 356 | + eos_token_id=self.specified_eot_token_id, |
| 357 | + ) |
| 358 | + cont = cont[:, inputs["input_ids"].shape[-1] :] |
| 359 | + except Exception as e: |
| 360 | + eval_logger.error(f"Error {e} in generating") |
| 361 | + cont = "" |
| 362 | + finally: |
| 363 | + for handle in handles: |
| 364 | + handle.remove() |
| 365 | + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] |
| 366 | + if self.accelerator.is_main_process and doc_id[0] % 100 == 0: |
| 367 | + eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") |
| 368 | + |
| 369 | + res.append(text_outputs) |
| 370 | + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) |
| 371 | + pbar.update(1) |
| 372 | + # reorder this group of results back to original unsorted form |
| 373 | + res = re_ords.get_original(res) |
| 374 | + |
| 375 | + pbar.close() |
| 376 | + return res |
0 commit comments