From c081228e0cb9736e20548d3939bbf11ca1fb007c Mon Sep 17 00:00:00 2001 From: root <403644786@qq.com> Date: Mon, 2 Sep 2024 08:37:51 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=A4=84=E7=90=86=E6=96=B9=E6=B3=95=E4=BB=A5=E9=80=82?= =?UTF-8?q?=E9=85=8Dminicpmv2.6=E7=9A=84awq?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/minicpmv.py | 406 +++++++++++++++++ examples/minicpmv2.6_quantize.py | 718 +++++++++++++++++++++++++++++++ 5 files changed, 1127 insertions(+) create mode 100644 awq/models/minicpmv.py create mode 100644 examples/minicpmv2.6_quantize.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 2f1a88e2..cd85ae4d 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -24,3 +24,4 @@ from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM from .internlm2 import InternLM2AWQForCausalLM +from .minicpmv import MiniCPMVAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 3a6416f1..201fc3dd 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -34,6 +34,7 @@ "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, "internlm2": InternLM2AWQForCausalLM, + "minicpmv": MiniCPMVAWQForCausalLM } diff --git a/awq/models/base.py b/awq/models/base.py index 1d376fc0..86a98f01 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -84,6 +84,7 @@ "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", + "minicpmv":"AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", } diff --git a/awq/models/minicpmv.py b/awq/models/minicpmv.py new file mode 100644 index 00000000..bf12ebd7 --- /dev/null +++ b/awq/models/minicpmv.py @@ -0,0 +1,406 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer as OldLlamaDecoderLayer, +) +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer as OldQwen2DecoderLayer, + Qwen2ForCausalLM as OldQwen2ForCausalLM, +) +from .base import ( + Annotated, + AwqConfig, + BaseAWQForCausalLM, + Dict, + Doc, + List, + PreTrainedTokenizer, + Union, +) + +from transformers.models.llava.modeling_llava import ( + LlavaForConditionalGeneration as OldLlavaForConditionalGeneration, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm +import torch +from transformers import AutoProcessor +import json +from copy import deepcopy +from PIL import Image +from awq.modules.fused.attn import QuantAttentionFused +from torch import nn +from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +class CPMVAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): + def get_device(obj: torch.Tensor | nn.Module): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + if get_device(obj) != device: + obj = obj.to(device) + return obj + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, best_device) + try: + self.model(**samples) + except ValueError: # work with early exit + pass + finally: + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, "cpu") + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +class MiniCPMVAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Qwen2DecoderLayer" + max_seq_len_key = "max_position_embeddings" + + def chat( + self, + image, + msgs, + tokenizer, + processor=None, + vision_hidden_states=None, + max_new_tokens=2048, + min_new_tokens=0, + sampling=True, + max_inp_length=8192, + system_prompt='', + stream=False, + max_slice_nums=None, + use_image_id=None, + **kwargs + ): + if isinstance(msgs[0], list): + batched = True + else: + batched = False + msgs_list = msgs + images_list = image + + if batched is False: + images_list, msgs_list = [images_list], [msgs_list] + else: + assert images_list is None, "Please integrate image to msgs when using batch inference." + images_list = [None] * len(msgs_list) + assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." + + if processor is None: + if self.processor is None: + self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) + processor = self.processor + + assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + + prompts_lists = [] + input_images_lists = [] + for image, msgs in zip(images_list, msgs_list): + if isinstance(msgs, str): + msgs = json.loads(msgs) + copy_msgs = deepcopy(msgs) + + assert len(msgs) > 0, "msgs is empty" + assert sampling or not stream, "if use stream mode, make sure sampling=True" + + if image is not None and isinstance(copy_msgs[0]["content"], str): + copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] + + images = [] + for i, msg in enumerate(copy_msgs): + role = msg["role"] + content = msg["content"] + assert role in ["user", "assistant"] + if i == 0: + assert role == "user", "The role of first msg should be user" + if isinstance(content, str): + content = [content] + cur_msgs = [] + for c in content: + if isinstance(c, Image.Image): + images.append(c) + cur_msgs.append("(./)") + elif isinstance(c, str): + cur_msgs.append(c) + msg["content"] = "\n".join(cur_msgs) + + if system_prompt: + sys_msg = {'role': 'system', 'content': system_prompt} + copy_msgs = [sys_msg] + copy_msgs + + prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)) + input_images_lists.append(images) + + inputs = processor( + prompts_lists, + input_images_lists, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, + return_tensors="pt", + max_length=max_inp_length + ).to('cuda:0') + + if sampling: + generation_config = { + "top_p": 0.8, + "top_k": 100, + "temperature": 0.7, + "do_sample": True, + "repetition_penalty": 1.05 + } + else: + generation_config = { + "num_beams": 3, + "repetition_penalty": 1.2, + } + + if min_new_tokens > 0: + generation_config['min_new_tokens'] = min_new_tokens + + generation_config.update( + (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() + ) + + inputs.pop("image_sizes") + with torch.inference_mode(): + res = self.generate( + **inputs, + tokenizer=tokenizer, + max_new_tokens=max_new_tokens, + vision_hidden_states=vision_hidden_states, + stream=stream, + decode_text=True, + **generation_config + ) + + if stream: + def stream_gen(): + for text in res: + for term in self.terminators: + text = text.replace(term, '') + yield text + return stream_gen() + + else: + if batched: + answer = res + else: + answer = res[0] + return answer + # @staticmethod + # def fuse_layers(model: OldQwen2ForCausalLM): + # fuser = MiniCPMVFuser(model) # 这里是算子融合 + # fuser.fuse_transformer() + + # hack to use `Qwen2VLAwqQuantizer` as quantizer + @torch.no_grad() + def quantize( + self, + tokenizer: Annotated[ + PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") + ] = None, + quant_config: Annotated[ + Dict, Doc("The quantization config you want to use.") + ] = {}, + calib_data: Annotated[ + Union[str, List[str]], + Doc( + "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." + ), + ] = "pileval", + split: Annotated[str, Doc("The split of calib_data.")] = "train", + text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", + duo_scaling: Annotated[ + bool, Doc("Whether to scale using both w/x or just x.") + ] = True, + export_compatible: Annotated[ + bool, + Doc( + "This argument avoids real quantization by only applying the scales without quantizing down to FP16." + ), + ] = False, + apply_clip: Annotated[ + bool, + Doc( + "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." + ), + ] = True, + n_parallel_calib_samples: Annotated[ + int, + Doc( + "The number of parallel samples to run through the model. " + "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " + "If None, runs through all samples at the same time. " + "You can set this to a low number for more memory efficient quantization." + ), + ] = None, + max_calib_samples: Annotated[ + int, Doc("The maximum number of samples to run through the model.") + ] = 128, + max_calib_seq_len: Annotated[ + int, + Doc( + "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." + ), + ] = 512, + max_chunk_memory: Annotated[ + int, + Doc( + "The loss computation and per-channel mean is optimized into chunked computations." + " Adjust this parameter to increase or decrease memory usage for these computations." + " Default is 1GB (1024 * 1024 * 1024)." + ), + ] = 1024 + * 1024 + * 1024, + ): + self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + + self.quantizer = CPMVAwqQuantizer( + self, + self.model, + tokenizer, + self.quant_config.w_bit, + self.quant_config.q_group_size, + self.quant_config.zero_point, + self.quant_config.version, + calib_data, + split, + text_column, + duo_scaling, + modules_to_not_convert=self.quant_config.modules_to_not_convert, + export_compatible=export_compatible, + apply_clip=apply_clip, + n_parallel_calib_samples=n_parallel_calib_samples, + max_calib_samples=max_calib_samples, + max_calib_seq_len=max_calib_seq_len, + max_chunk_memory=max_chunk_memory, + ) + self.quantizer.quantize() + + self.is_quantized = True + + @staticmethod + def get_model_layers(model: OldQwen2ForCausalLM): + return model.llm.model.layers + + @staticmethod + def get_act_for_scaling(module: OldQwen2DecoderLayer): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldQwen2DecoderLayer, device: str): + model.llm.model.embed_tokens = model.get_input_embeddings().to( + device + ) + + @staticmethod + def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers + diff --git a/examples/minicpmv2.6_quantize.py b/examples/minicpmv2.6_quantize.py new file mode 100644 index 00000000..6677d9c0 --- /dev/null +++ b/examples/minicpmv2.6_quantize.py @@ -0,0 +1,718 @@ +from __future__ import annotations + +import logging +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer +from torchvision import transforms +import argparse +import json +import copy +import logging +import math +import os +import re +import random +from dataclasses import dataclass, field +from typing import Dict, List, Optional +import requests +from io import BytesIO +import numpy as np +import torch +from PIL import Image +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset +import logging +def main(): +# 打印读取的 JSON 列表 + + # 定义 argparse 解析器 + parser = argparse.ArgumentParser(description="Quantize and save a model.") + + # 添加参数 + parser.add_argument('--model-path', type=str, default="/root/ld/ld_model_pretrained/Minicpmv2_6", + help='Path to the model directory.') + parser.add_argument('--quant-path', type=str, default="/root/ld/ld_model_pretrained/Minicpmv2_6_awq_new", + help='Path to save the quantized model.') + parser.add_argument('--zero-point', action='store_true', + help='Enable zero point quantization.') + parser.add_argument('--q-group-size', type=int, default=128, + help='Quantization group size.') + parser.add_argument('--w-bit', type=int, default=4, + help='Weight bit size.') + parser.add_argument('--version', type=str, default="GEMM", + help='Quantization version.') + parser.add_argument('--batch-size', type=int, default=8, + help='you will forward batch_size, 4090 machine can run 8 batch, A100 machine can run 32 batch.') + args = parser.parse_args() + + quant_config = {"zero_point": args.zero_point, "q_group_size": args.q_group_size, "w_bit": args.w_bit, "version": args.version} + batch=args.batch_size # you will forward batch_size, 4090 machine can run 8 batch, A100 machine can run 32 batch + + # here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. + def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + {'id': str(index), + 'conversations': [{'content': '\n能否协助我辨认这张图片展示的内容?', 'role': 'user'}, + {'content': sample["caption"], 'role': 'assistant'}], + 'image': sample["url"]}, + ] + for index,sample in enumerate(dataset) + ] + logger = logging.getLogger(__name__) + #loading the data set + dataset=prepare_dataset(n_sample=32) + # Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, + # each of which is a typical chat message as shown below. you can specify text and image in `content` field: + # dataset = [ + # first_line + # {'id': '0', + # 'conversations': [{'content': '\n能否协助我辨认这张图片展示的内容?', 'role': 'user'}, + # {'content': '2017 中共文山市委执政纪要\n政策向符合条件的职业农民倾斜.....', 'role': 'assistant'}], + # 'image': '/root/ld/ld_dataset/30k_data/63677831/198.jpg'}, + # second_line + # {'id': '1', + # 'conversations': [{'content': '\n我需要帮忙确定这张图片里呈现的是什么。', 'role': 'user'}, + # {'content': '大学计算机是国家最重要的课程,我们需要从娃娃抓起....', 'role': 'assistant'}], + # 'image': '/root/ld/ld_dataset/30k_data/61520120/3.jpg'} + # ] + + ## The following is a script for data processing, no need to modify it + llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" + + class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + raw_data, + transform, + tokenizer, + slice_config, + llm_type="minicpm", + patch_size=14, + query_nums=64, + batch_vision=False, + max_length=2048, + ): + super(SupervisedDataset, self).__init__() + self.raw_data = raw_data + self.tokenizer = tokenizer + self.transform = transform + self.slice_config = slice_config + self.llm_type = llm_type + self.patch_size = patch_size + self.query_nums=query_nums + self.batch_vision = batch_vision + self.max_length = max_length + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + try: + if isinstance(self.raw_data[i]["image"], str): + if self.raw_data[i]["image"].startswith("http"): + yzmdata = requests.get(self.raw_data[i]["image"]) + tempIm = BytesIO(yzmdata.content) + image = Image.open(tempIm).convert('RGB') + images_dict = { "" : image } + else: + images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } + elif isinstance(self.raw_data[i]["image"], Dict): + ### for multi-images input, the template for every image is , such as , + images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} + + ret = preprocess( + images_dict, + self.raw_data[i]["conversations"], + self.tokenizer, + self.transform, + query_nums=self.query_nums, + slice_config=self.slice_config, + llm_type=self.llm_type, + patch_size=self.patch_size, + batch_vision=self.batch_vision, + max_length=self.max_length + ) + ret = dict( + input_ids=ret["input_ids"], + position_ids=ret["position_ids"], + labels=ret["target"], + attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), + pixel_values=ret["pixel_values"], + tgt_sizes=ret["tgt_sizes"], + image_bound=ret["image_bound"], + ) + except: + logger.error(f"data fetch error") + return self.__getitem__(random.randint(0, len(self))) + return ret + + + def data_collator(examples, padding_value=0, max_length=2048): + def trim_and_pad(seq, batch_first, padding_value): + return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) + + input_ids = trim_and_pad( + [example["input_ids"] for example in examples], + batch_first=True, + padding_value=padding_value, + ) + position_ids = trim_and_pad( + [example["position_ids"] for example in examples], + batch_first=True, + padding_value=padding_value, + ) + targets = trim_and_pad( + [example["labels"] for example in examples], + batch_first=True, + padding_value=-100, + ) + attention_mask = trim_and_pad( + [example["attention_mask"] for example in examples], + batch_first=True, + padding_value=padding_value, + ) + pixel_values = [example["pixel_values"] for example in examples] + image_bound = [example["image_bound"] for example in examples] + tgt_sizes = [example["tgt_sizes"] for example in examples] + return { + "input_ids": torch.tensor(input_ids), + "position_ids": torch.tensor(position_ids), + "labels": torch.tensor(targets), + "attention_mask": torch.tensor(attention_mask), + "image_bound": image_bound, + "tgt_sizes": tgt_sizes, + "pixel_values": pixel_values, + } + + + def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048): + """ + for single image multi-turn conversation + conversation: [{'role': 'user', 'content': 'Describe this image'}, + {'role': 'assistant', 'content': 'This is a cat.'}] + """ + if llm_type == "llama3": + input_ids, context, raw_msg = conversation_to_ids_llama3( + conversation, tokenizer + ) + elif llm_type == "qwen2": + input_ids, context, raw_msg = conversation_to_ids_qwen2( + conversation, tokenizer + ) + else: + input_ids, context, raw_msg = conversation_to_ids_minicpm( + conversation, tokenizer + ) + + ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) + context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + if input_ids.shape[-1] > max_length: + ids =ids[:max_length] + context = context[:max_length] + logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated") + + if torch.all(context): + logger.error("No tokens available to compute loss.") + raise Exception("No tokens available to compute loss.") + + # build target + target = torch.full_like(ids, -100, dtype=torch.int32) + + for i in range(1, len(ids)): + if context[i] == 0: + target[i - 1] = ids[i] + if context[i] == 1 and context[i - 1] == 0: + if hasattr(tokenizer, "eot_id"): + target[i - 1] = tokenizer.eot_id + else: + target[i - 1] = tokenizer.eos_id + + # build image bound + if new_schema: + start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id) + end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + else: + image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] + if len(image_start_tokens) != len(image_end_tokens): + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") + + if len(image_start_tokens) > 0: + image_bound = torch.hstack( + [image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)] + ) + else: + image_bound = [] + + position_ids = torch.arange(ids.size(0)).long() + return { + "input_ids": ids, + "target": target, + "image_bound": image_bound, + "raw_msg": raw_msg, + "position_ids": position_ids + } + + + def conversation_to_ids_minicpm(conversation, tokenizer): + raw_msg = "" + input_ids = [] + context = [] + for idx, msg in enumerate(conversation): + role = msg["role"] + message = msg["content"] + assert role in ["user", "assistant"] + if role == "user": + prefix = "<用户>" + else: + prefix = "" + # append eos + if idx == len(conversation) - 1: + message = message + tokenizer.eos_token + prefix_ids = tokenizer.encode(prefix)[1:] # remove bos + message_ids = tokenizer.encode(message)[1:] + + input_ids.append(prefix_ids) + input_ids.append(message_ids) + + context.append(np.ones((len(prefix_ids),), dtype=np.int8)) + if role == "assistant": + context.append(np.zeros((len(message_ids),), dtype=np.int8)) + else: + context.append(np.ones((len(message_ids),), dtype=np.int8)) + + raw_msg += prefix + message + + return input_ids, context, raw_msg + + + def conversation_to_ids_llama3(conversation, tokenizer): + raw_msg = "" + input_ids = [] + context = [] + raw_msg = tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template, + ) + input_ids = tokenizer.apply_chat_template( + conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template, + ) + input_ids = np.array(input_ids) + + start_header_idxs = np.where( + input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>") + )[0] + assistant_idxs = np.where( + input_ids == tokenizer.convert_tokens_to_ids("assistant") + )[0] + end_header_idxs = np.where( + input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>") + )[0] + eot_idxs = np.where( + input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0] + + context = np.ones_like(input_ids, dtype=np.int8) + + for assistant_idx in assistant_idxs: + if assistant_idx in set((start_header_idxs + end_header_idxs) / 2): + st = assistant_idx + 3 # assistant<|end_header_id|>\n\n + for eot_idx in eot_idxs: + if eot_idx > st: + context[st: eot_idx + 1] = 0 + break + + input_ids = np.hstack(input_ids) + context = np.hstack(context) + + return input_ids, context, raw_msg + + + def conversation_to_ids_qwen2(conversation, tokenizer): + raw_msg = "" + chat = [] + context = [] + for idx, msg in enumerate(conversation): + role = msg["role"] + message = msg["content"] + assert role in ["user", "assistant"] + if role == "user": + prefix = "user" + else: + prefix = "assistant" + chat.append({"role":prefix, "content":message}) + raw_msg += prefix + message + assert set([i['role'] for i in chat]) & set(['assistant']) + + ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) + input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False) + input_ids = np.array(input_ids) + + start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0] + assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0] + end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0] + + context = np.ones_like(input_ids, dtype=np.int8) + + for assistant_idx in assistant_idxs: + if assistant_idx-1 in set(start_idxs): + st = assistant_idx + 1 + for end_idx in end_idxs: + if end_idx > st: + context[st: end_idx + 1] = 0 + break + + input_ids = np.hstack(input_ids) + context = np.hstack(context) + return input_ids, context, raw_msg + + + def preprocess( + images_dict, + conversations, + tokenizer, + transform, + query_nums=64, + slice_config=None, + llm_type=None, + patch_size=14, + batch_vision=False, + max_length=2048, + ): + """ + single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation + """ + conversations = copy.deepcopy(conversations) + assert len(conversations) > 1, "conversations length must large than 2" + assert conversations[0]["role"] == "user", "the first role must be user" + + if slice_config is not None: + assert isinstance(slice_config, Dict) + assert "patch_size" in slice_config + assert "max_slice_nums" in slice_config + assert "scale_resolution" in slice_config + default_image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + ) + new_schema = False + use_image_id = False + if llm_type=='qwen2': + new_schema = True + use_image_id = True + image_placeholder_dict = {} + images = [] + image_id_cnt = 0 + for img_name, image in images_dict.items(): + if slice_config: + source_image, patches, best_grid = slice_image( + image, + slice_config["max_slice_nums"], + slice_config["scale_resolution"], + slice_config["patch_size"], + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + if use_image_id: + image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder + image_id_cnt += 1 + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums, new_schema = new_schema) + image_placeholder_dict[img_name] = image_placeholder + else: + images.append(image) + if use_image_id: + image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder + image_id_cnt += 1 + else: + image_placeholder = default_image_placeholder + image_placeholder_dict[img_name] = image_placeholder + + images = [transform(i) for i in images] + + if len(images_dict) == 1 and "" in images_dict: + if "" in conversations[0]["content"]: + conversations[0]["content"] = conversations[0]["content"].replace( + "", image_placeholder + ) + else: + conversations[0]["content"] = ( + image_placeholder + "\n" + conversation[0]["content"] + ) + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) + else: + pattern = r'' + new_conversations = [] + for conversation in conversations: + content = conversation['content'] + parts = re.split(f'({pattern})', content) + for i, part in enumerate(parts): + if not part.strip(): + continue + if re.match(pattern, part): + if part in image_placeholder_dict: + parts[i] = image_placeholder_dict[part] + else: + raise Exception(f"not found {part} in image dict") + conversation['content'] = '\n'.join(parts) + new_conversations.append(conversation) + conversations = new_conversations + + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) + + if batch_vision: + tgt_sizes = [] + reshape_images = [] + for image in images: + H, W = image.shape[1:] + reshape_image = reshape_by_patch(image, patch_size) + reshape_images.append(reshape_image) + tgt_sizes.append([H // patch_size, W // patch_size]) + if tgt_sizes: + tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32) + + input_dict["pixel_values"] = reshape_images + input_dict["tgt_sizes"] = tgt_sizes + + else: + input_dict["pixel_values"] = images + input_dict["tgt_sizes"] = [] + + return input_dict + + + def slice_image( + image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False + ): + original_size = image.size + original_width, original_height = original_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / \ + (scale_resolution * scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + + source_image = None + best_grid = None + patches = [] + + if multiple <= 1 or never_split: + # dont need to slice, upsample + best_size = find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) + source_image = image.resize(best_size, Image.Resampling.BICUBIC) + else: + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + # source image, down-sampling and ensure divided by patch_size + best_resize = find_best_resize( + original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + candidate_grids = [] + + # find best grid + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + refine_size = get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) + + refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) + patches = split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + + def ensure_divide(length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + + def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + if (width * height > scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = ensure_divide(width, patch_size) + best_height = ensure_divide(height, patch_size) + return (best_width, best_height) + + + def get_refine_size( + original_size, grid, scale_resolution, patch_size, allow_upscale=False + ): + width, height = original_size + grid_x, grid_y = grid + + refine_width = ensure_divide(width, grid_x) + refine_height = ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = find_best_resize( + (grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale, + ) + + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + + return refine_size + + + def split_to_patches(image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + + return patches + + + def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): + if new_schema: + image_placeholder = ( + tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end + ) + else: + image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + ) + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(image_placeholder) + slices.append("".join(lines)) + if new_schema: + slice_placeholder = '\n'.join(slices) + else: + slice_placeholder = tokenizer.slice_start + \ + "\n".join(slices) + tokenizer.slice_end + return slice_placeholder + + + def reshape_by_patch(image_tensor, patch_size): + """ + :param image_tensor: shape [3, H, W] + :param patch_size: + :return: [3, patch_size, HW/patch_size] + """ + patches = torch.nn.functional.unfold( + image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size) + ) + + patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) + patches = patches.permute(0, 1, 3, 2).reshape( + image_tensor.size(0), patch_size, -1) + patches=patches.cuda() + return patches + # Load your tokenizer and model with AutoAWQ + model = AutoAWQForCausalLM.from_pretrained(args.model_path).cuda() + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True,device_map={"": "cuda:0"}) + + # set some parameters + if hasattr(model.config, "slice_config"): + model.config.slice_config.max_slice_nums = 1 + slice_config = model.config.slice_config.to_dict() + else: + model.config.max_slice_nums = 1 + slice_config = model.config.to_dict() + + if hasattr(model.config, "batch_vision_input"): + batch_vision = model.config.batch_vision_input + else: + batch_vision = False + def build_transform(): + IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN + IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD + ), + ] + ) + transform_func = build_transform() + + + + + calib_data = SupervisedDataset( + dataset, + transform_func, + tokenizer, + slice_config=slice_config, + llm_type="qwen2", + patch_size=model.config.patch_size, + query_nums=model.config.query_num, + batch_vision=batch_vision, + max_length=2048, + ) + + + out_data=[] + batch_data=[] + for index in range(len(calib_data)//batch): + batch_data=[] + for j in range(batch): + batch_data.append(calib_data[j+index*batch]) + out_data.append(data_collator(batch_data)) + + + # Then just run the calibration process by one line of code: + model.quantize(calib_data=out_data[0], quant_config=quant_config) + + # remove pos_embed + if hasattr(model.model, 'resampler') and hasattr(model.model.resampler, 'pos_embed'): + del model.model.resampler.pos_embed + # Finally, save the quantized model: + model.model.config.use_cache = model.model.generation_config.use_cache = True + model.save_quantized(args.quant_path, safetensors=True, shard_size="4GB") + +if __name__ == "__main__": + main() + + From c5b28d6bde630b00dc5a3eb013d159d0fc9b0219 Mon Sep 17 00:00:00 2001 From: root <403644786@qq.com> Date: Fri, 6 Sep 2024 19:05:51 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E9=80=82=E9=85=8D=E4=BA=86minicpm3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- awq/models/__init__.py | 2 +- awq/models/auto.py | 2 +- awq/models/base.py | 2 +- awq/models/minicpm3.py | 265 ++++++++++++ awq/models/minicpmv.py | 406 ----------------- examples/minicpmv2.6_quantize.py | 718 ------------------------------- 6 files changed, 268 insertions(+), 1127 deletions(-) create mode 100644 awq/models/minicpm3.py delete mode 100644 awq/models/minicpmv.py delete mode 100644 examples/minicpmv2.6_quantize.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index cd85ae4d..615ed25c 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -24,4 +24,4 @@ from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM from .internlm2 import InternLM2AWQForCausalLM -from .minicpmv import MiniCPMVAWQForCausalLM +from .minicpm3 import MiniCPM3AWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 201fc3dd..df4d580f 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -34,7 +34,7 @@ "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, "internlm2": InternLM2AWQForCausalLM, - "minicpmv": MiniCPMVAWQForCausalLM + "minicpm3": MiniCPM3AWQForCausalLM } diff --git a/awq/models/base.py b/awq/models/base.py index 86a98f01..e8512ca7 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -84,7 +84,7 @@ "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", - "minicpmv":"AutoModelForCausalLM", + "minicpm3":"AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", } diff --git a/awq/models/minicpm3.py b/awq/models/minicpm3.py new file mode 100644 index 00000000..0ea87b38 --- /dev/null +++ b/awq/models/minicpm3.py @@ -0,0 +1,265 @@ + +from .base import BaseAWQForCausalLM +from ..quantize.quantizer import AwqQuantizer +import torch +from .base import ( + Annotated, + AwqConfig, + BaseAWQForCausalLM, + Dict, + Doc, + List, + PreTrainedTokenizer, + Union, + ) +from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +class CPM3AwqQuantizer(AwqQuantizer): + @torch.no_grad() + def _compute_best_clip( + self, + w: torch.Tensor, + input_feat: torch.Tensor, + n_grid=20, + max_shrink=0.5, + n_sample_token=512, + ): + assert w.dim() == 2 + org_w_shape = w.shape + # w [co, ci] -> [co, 1, n_group, group size] + # input_feat [n_token, ci] -> [1, n_token, n_group, group size] + group_size = self.group_size if self.group_size > 0 else org_w_shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + + # Compute input feature step size (minimum 1) + step_size = max(1, input_feat.shape[1] // n_sample_token) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(org_w_shape[0], 1, -1, group_size) + + oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM + if org_w_shape[0] % oc_batch_size != 0: + oc_batch_size = org_w_shape[0] + assert org_w_shape[0] % oc_batch_size == 0 + w_all = w + best_max_val_all = [] + + for i_b in range(org_w_shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w.device) + org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + + for i_s in range(int(max_shrink * n_grid)): + max_val = org_max_val * (1 - i_s / n_grid) + min_val = -max_val + cur_w = torch.clamp(w, min_val, max_val) + q_w = self.pseudo_quantize_tensor(cur_w)[0] + cur_out = (input_feat * q_w).sum(dim=-1) + + # co, 1, n_group, 1 + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + del cur_w + del cur_out + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + best_max_val = torch.cat(best_max_val_all, dim=0) + + clear_memory(input_feat) + clear_memory(org_out) + + return best_max_val.squeeze(1) + +class MiniCPM3AWQForCausalLM(BaseAWQForCausalLM): + layer_type = "MiniCPMDecoderLayer" + max_seq_len_key = "max_position_embeddings" + @torch.no_grad() + def quantize( + self, + tokenizer: Annotated[ + PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") + ] = None, + quant_config: Annotated[ + Dict, Doc("The quantization config you want to use.") + ] = {}, + calib_data: Annotated[ + Union[str, List[str]], + Doc( + "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." + ), + ] = "pileval", + split: Annotated[str, Doc("The split of calib_data.")] = "train", + text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", + duo_scaling: Annotated[ + bool, Doc("Whether to scale using both w/x or just x.") + ] = True, + export_compatible: Annotated[ + bool, + Doc( + "This argument avoids real quantization by only applying the scales without quantizing down to FP16." + ), + ] = False, + apply_clip: Annotated[ + bool, + Doc( + "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." + ), + ] = True, + n_parallel_calib_samples: Annotated[ + int, + Doc( + "The number of parallel samples to run through the model. " + "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " + "If None, runs through all samples at the same time. " + "You can set this to a low number for more memory efficient quantization." + ), + ] = None, + max_calib_samples: Annotated[ + int, Doc("The maximum number of samples to run through the model.") + ] = 128, + max_calib_seq_len: Annotated[ + int, + Doc( + "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." + ), + ] = 512, + max_chunk_memory: Annotated[ + int, + Doc( + "The loss computation and per-channel mean is optimized into chunked computations." + " Adjust this parameter to increase or decrease memory usage for these computations." + " Default is 1GB (1024 * 1024 * 1024)." + ), + ] = 1024 + * 1024 + * 1024, + ): + """ + The main quantization function that you can use to quantize your model. + + Example: + + ```python + from awq import AutoAWQForCausalLM + from transformers import AutoTokenizer + + model_path = "..." + model = AutoAWQForCausalLM.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + model.quantize(tokenizer, quant_config) + ``` + """ + self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + + self.quantizer = CPM3AwqQuantizer( + self, + self.model, + tokenizer, + self.quant_config.w_bit, + self.quant_config.q_group_size, + self.quant_config.zero_point, + self.quant_config.version, + calib_data, + split, + text_column, + duo_scaling, + modules_to_not_convert=self.quant_config.modules_to_not_convert, + export_compatible=export_compatible, + apply_clip=apply_clip, + n_parallel_calib_samples=n_parallel_calib_samples, + max_calib_samples=max_calib_samples, + max_calib_seq_len=max_calib_seq_len, + max_chunk_memory=max_chunk_memory, + ) + self.quantizer.quantize() + + self.is_quantized = True + @staticmethod + def get_model_layers(model): + print(model.model.layers) + return model.model.layers + + @staticmethod + def get_act_for_scaling(module): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module, input_feat, module_kwargs): + layers = [] + + # layers.append( + # dict( + # prev_op=module.input_layernorm, + # layers=[ + # module.self_attn.q_a_proj, + # ], + # inp=input_feat["self_attn.q_a_proj"], + # module2inspect=module.self_attn.q_a_proj, + # kwargs=module_kwargs, + # ) + # ) + # mlp + layers.append( + dict( + prev_op=module.self_attn.q_a_layernorm, + layers=[ + module.self_attn.q_b_proj, + + ], + inp=input_feat["self_attn.q_b_proj"], + module2inspect=module.self_attn.q_b_proj, + kwargs=module_kwargs, + ) + ) + + layers.append( + dict( + prev_op=module.self_attn.kv_a_layernorm, + layers=[ + module.self_attn.kv_b_proj, + ], + inp=input_feat["self_attn.kv_b_proj"], + module2inspect=module.self_attn.kv_b_proj, + kwargs=module_kwargs, + ) + ) + + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj,module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp + ) + ) + + return layers + + diff --git a/awq/models/minicpmv.py b/awq/models/minicpmv.py deleted file mode 100644 index bf12ebd7..00000000 --- a/awq/models/minicpmv.py +++ /dev/null @@ -1,406 +0,0 @@ -import tqdm -from typing import List, Tuple -from .base import BaseAWQForCausalLM -from awq.utils.fused_utils import fuse_qkv -from awq.modules.fused.block import LlamaLikeBlock -from awq.modules.fused.model import LlamaLikeModel -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer as OldLlamaDecoderLayer, -) -from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2DecoderLayer as OldQwen2DecoderLayer, - Qwen2ForCausalLM as OldQwen2ForCausalLM, -) -from .base import ( - Annotated, - AwqConfig, - BaseAWQForCausalLM, - Dict, - Doc, - List, - PreTrainedTokenizer, - Union, -) - -from transformers.models.llava.modeling_llava import ( - LlavaForConditionalGeneration as OldLlavaForConditionalGeneration, -) -from awq.modules.fused.norm import FasterTransformerRMSNorm -import torch -from transformers import AutoProcessor -import json -from copy import deepcopy -from PIL import Image -from awq.modules.fused.attn import QuantAttentionFused -from torch import nn -from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device - -class CPMVAwqQuantizer(AwqQuantizer): - def init_quant(self, n_samples=None, max_seq_len=None): - modules = self.awq_model.get_model_layers(self.model) - samples = self.calib_data - - inps = [] - layer_kwargs = {} - - best_device = get_best_device() - modules[0] = modules[0].to(best_device) - self.awq_model.move_embed(self.model, best_device) - - # get input and kwargs to layer 0 - # with_kwargs is only supported in PyTorch 2.0 - # use this Catcher hack for now - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, *args, **kwargs): - # assume first input to forward is hidden states - if len(args) > 0: - hidden_states = args[0] - del args - else: - first_key = list(kwargs.keys())[0] - hidden_states = kwargs.pop(first_key) - - inps.append(hidden_states) - layer_kwargs.update(kwargs) - raise ValueError # early exit to break later inference - - def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): - def get_device(obj: torch.Tensor | nn.Module): - if isinstance(obj, torch.Tensor): - return obj.device - return next(obj.parameters()).device - - if get_device(obj) != device: - obj = obj.to(device) - return obj - - # patch layer 0 to catch input and kwargs - modules[0] = Catcher(modules[0]) - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, best_device) - try: - self.model(**samples) - except ValueError: # work with early exit - pass - finally: - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, "cpu") - modules[0] = modules[0].module # restore - - del samples - inps = inps[0] - - modules[0] = modules[0].cpu() - self.awq_model.move_embed(self.model, "cpu") - - clear_memory() - - return modules, layer_kwargs, inps - -class MiniCPMVAWQForCausalLM(BaseAWQForCausalLM): - layer_type = "Qwen2DecoderLayer" - max_seq_len_key = "max_position_embeddings" - - def chat( - self, - image, - msgs, - tokenizer, - processor=None, - vision_hidden_states=None, - max_new_tokens=2048, - min_new_tokens=0, - sampling=True, - max_inp_length=8192, - system_prompt='', - stream=False, - max_slice_nums=None, - use_image_id=None, - **kwargs - ): - if isinstance(msgs[0], list): - batched = True - else: - batched = False - msgs_list = msgs - images_list = image - - if batched is False: - images_list, msgs_list = [images_list], [msgs_list] - else: - assert images_list is None, "Please integrate image to msgs when using batch inference." - images_list = [None] * len(msgs_list) - assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." - - if processor is None: - if self.processor is None: - self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) - processor = self.processor - - assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - - prompts_lists = [] - input_images_lists = [] - for image, msgs in zip(images_list, msgs_list): - if isinstance(msgs, str): - msgs = json.loads(msgs) - copy_msgs = deepcopy(msgs) - - assert len(msgs) > 0, "msgs is empty" - assert sampling or not stream, "if use stream mode, make sure sampling=True" - - if image is not None and isinstance(copy_msgs[0]["content"], str): - copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] - - images = [] - for i, msg in enumerate(copy_msgs): - role = msg["role"] - content = msg["content"] - assert role in ["user", "assistant"] - if i == 0: - assert role == "user", "The role of first msg should be user" - if isinstance(content, str): - content = [content] - cur_msgs = [] - for c in content: - if isinstance(c, Image.Image): - images.append(c) - cur_msgs.append("(./)") - elif isinstance(c, str): - cur_msgs.append(c) - msg["content"] = "\n".join(cur_msgs) - - if system_prompt: - sys_msg = {'role': 'system', 'content': system_prompt} - copy_msgs = [sys_msg] + copy_msgs - - prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)) - input_images_lists.append(images) - - inputs = processor( - prompts_lists, - input_images_lists, - max_slice_nums=max_slice_nums, - use_image_id=use_image_id, - return_tensors="pt", - max_length=max_inp_length - ).to('cuda:0') - - if sampling: - generation_config = { - "top_p": 0.8, - "top_k": 100, - "temperature": 0.7, - "do_sample": True, - "repetition_penalty": 1.05 - } - else: - generation_config = { - "num_beams": 3, - "repetition_penalty": 1.2, - } - - if min_new_tokens > 0: - generation_config['min_new_tokens'] = min_new_tokens - - generation_config.update( - (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() - ) - - inputs.pop("image_sizes") - with torch.inference_mode(): - res = self.generate( - **inputs, - tokenizer=tokenizer, - max_new_tokens=max_new_tokens, - vision_hidden_states=vision_hidden_states, - stream=stream, - decode_text=True, - **generation_config - ) - - if stream: - def stream_gen(): - for text in res: - for term in self.terminators: - text = text.replace(term, '') - yield text - return stream_gen() - - else: - if batched: - answer = res - else: - answer = res[0] - return answer - # @staticmethod - # def fuse_layers(model: OldQwen2ForCausalLM): - # fuser = MiniCPMVFuser(model) # 这里是算子融合 - # fuser.fuse_transformer() - - # hack to use `Qwen2VLAwqQuantizer` as quantizer - @torch.no_grad() - def quantize( - self, - tokenizer: Annotated[ - PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") - ] = None, - quant_config: Annotated[ - Dict, Doc("The quantization config you want to use.") - ] = {}, - calib_data: Annotated[ - Union[str, List[str]], - Doc( - "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." - ), - ] = "pileval", - split: Annotated[str, Doc("The split of calib_data.")] = "train", - text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", - duo_scaling: Annotated[ - bool, Doc("Whether to scale using both w/x or just x.") - ] = True, - export_compatible: Annotated[ - bool, - Doc( - "This argument avoids real quantization by only applying the scales without quantizing down to FP16." - ), - ] = False, - apply_clip: Annotated[ - bool, - Doc( - "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." - ), - ] = True, - n_parallel_calib_samples: Annotated[ - int, - Doc( - "The number of parallel samples to run through the model. " - "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " - "If None, runs through all samples at the same time. " - "You can set this to a low number for more memory efficient quantization." - ), - ] = None, - max_calib_samples: Annotated[ - int, Doc("The maximum number of samples to run through the model.") - ] = 128, - max_calib_seq_len: Annotated[ - int, - Doc( - "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." - ), - ] = 512, - max_chunk_memory: Annotated[ - int, - Doc( - "The loss computation and per-channel mean is optimized into chunked computations." - " Adjust this parameter to increase or decrease memory usage for these computations." - " Default is 1GB (1024 * 1024 * 1024)." - ), - ] = 1024 - * 1024 - * 1024, - ): - self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) - - if hasattr(self, "modules_to_not_convert"): - self.quant_config.modules_to_not_convert = self.modules_to_not_convert - - self.quantizer = CPMVAwqQuantizer( - self, - self.model, - tokenizer, - self.quant_config.w_bit, - self.quant_config.q_group_size, - self.quant_config.zero_point, - self.quant_config.version, - calib_data, - split, - text_column, - duo_scaling, - modules_to_not_convert=self.quant_config.modules_to_not_convert, - export_compatible=export_compatible, - apply_clip=apply_clip, - n_parallel_calib_samples=n_parallel_calib_samples, - max_calib_samples=max_calib_samples, - max_calib_seq_len=max_calib_seq_len, - max_chunk_memory=max_chunk_memory, - ) - self.quantizer.quantize() - - self.is_quantized = True - - @staticmethod - def get_model_layers(model: OldQwen2ForCausalLM): - return model.llm.model.layers - - @staticmethod - def get_act_for_scaling(module: OldQwen2DecoderLayer): - return dict(is_scalable=False) - - @staticmethod - def move_embed(model: OldQwen2DecoderLayer, device: str): - model.llm.model.embed_tokens = model.get_input_embeddings().to( - device - ) - - @staticmethod - def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs): - layers = [] - - # attention input - layers.append( - dict( - prev_op=module.input_layernorm, - layers=[ - module.self_attn.q_proj, - module.self_attn.k_proj, - module.self_attn.v_proj, - ], - inp=input_feat["self_attn.q_proj"], - module2inspect=module.self_attn, - kwargs=module_kwargs, - ) - ) - - # attention out - # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 - if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: - layers.append( - dict( - prev_op=module.self_attn.v_proj, - layers=[module.self_attn.o_proj], - inp=input_feat["self_attn.o_proj"], - ) - ) - - # linear 1 - layers.append( - dict( - prev_op=module.post_attention_layernorm, - layers=[module.mlp.gate_proj, module.mlp.up_proj], - inp=input_feat["mlp.gate_proj"], - module2inspect=module.mlp, - ) - ) - - # linear 2 - layers.append( - dict( - prev_op=module.mlp.up_proj, - layers=[module.mlp.down_proj], - inp=input_feat["mlp.down_proj"], - ) - ) - - return layers - diff --git a/examples/minicpmv2.6_quantize.py b/examples/minicpmv2.6_quantize.py deleted file mode 100644 index 6677d9c0..00000000 --- a/examples/minicpmv2.6_quantize.py +++ /dev/null @@ -1,718 +0,0 @@ -from __future__ import annotations - -import logging -from awq import AutoAWQForCausalLM -from transformers import AutoTokenizer -from torchvision import transforms -import argparse -import json -import copy -import logging -import math -import os -import re -import random -from dataclasses import dataclass, field -from typing import Dict, List, Optional -import requests -from io import BytesIO -import numpy as np -import torch -from PIL import Image -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import Dataset -import logging -def main(): -# 打印读取的 JSON 列表 - - # 定义 argparse 解析器 - parser = argparse.ArgumentParser(description="Quantize and save a model.") - - # 添加参数 - parser.add_argument('--model-path', type=str, default="/root/ld/ld_model_pretrained/Minicpmv2_6", - help='Path to the model directory.') - parser.add_argument('--quant-path', type=str, default="/root/ld/ld_model_pretrained/Minicpmv2_6_awq_new", - help='Path to save the quantized model.') - parser.add_argument('--zero-point', action='store_true', - help='Enable zero point quantization.') - parser.add_argument('--q-group-size', type=int, default=128, - help='Quantization group size.') - parser.add_argument('--w-bit', type=int, default=4, - help='Weight bit size.') - parser.add_argument('--version', type=str, default="GEMM", - help='Quantization version.') - parser.add_argument('--batch-size', type=int, default=8, - help='you will forward batch_size, 4090 machine can run 8 batch, A100 machine can run 32 batch.') - args = parser.parse_args() - - quant_config = {"zero_point": args.zero_point, "q_group_size": args.q_group_size, "w_bit": args.w_bit, "version": args.version} - batch=args.batch_size # you will forward batch_size, 4090 machine can run 8 batch, A100 machine can run 32 batch - - # here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. - def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: - from datasets import load_dataset - dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") - return [ - [ - {'id': str(index), - 'conversations': [{'content': '\n能否协助我辨认这张图片展示的内容?', 'role': 'user'}, - {'content': sample["caption"], 'role': 'assistant'}], - 'image': sample["url"]}, - ] - for index,sample in enumerate(dataset) - ] - logger = logging.getLogger(__name__) - #loading the data set - dataset=prepare_dataset(n_sample=32) - # Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, - # each of which is a typical chat message as shown below. you can specify text and image in `content` field: - # dataset = [ - # first_line - # {'id': '0', - # 'conversations': [{'content': '\n能否协助我辨认这张图片展示的内容?', 'role': 'user'}, - # {'content': '2017 中共文山市委执政纪要\n政策向符合条件的职业农民倾斜.....', 'role': 'assistant'}], - # 'image': '/root/ld/ld_dataset/30k_data/63677831/198.jpg'}, - # second_line - # {'id': '1', - # 'conversations': [{'content': '\n我需要帮忙确定这张图片里呈现的是什么。', 'role': 'user'}, - # {'content': '大学计算机是国家最重要的课程,我们需要从娃娃抓起....', 'role': 'assistant'}], - # 'image': '/root/ld/ld_dataset/30k_data/61520120/3.jpg'} - # ] - - ## The following is a script for data processing, no need to modify it - llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" - - class SupervisedDataset(Dataset): - """Dataset for supervised fine-tuning.""" - - def __init__( - self, - raw_data, - transform, - tokenizer, - slice_config, - llm_type="minicpm", - patch_size=14, - query_nums=64, - batch_vision=False, - max_length=2048, - ): - super(SupervisedDataset, self).__init__() - self.raw_data = raw_data - self.tokenizer = tokenizer - self.transform = transform - self.slice_config = slice_config - self.llm_type = llm_type - self.patch_size = patch_size - self.query_nums=query_nums - self.batch_vision = batch_vision - self.max_length = max_length - - def __len__(self): - return len(self.raw_data) - - def __getitem__(self, i) -> Dict[str, torch.Tensor]: - try: - if isinstance(self.raw_data[i]["image"], str): - if self.raw_data[i]["image"].startswith("http"): - yzmdata = requests.get(self.raw_data[i]["image"]) - tempIm = BytesIO(yzmdata.content) - image = Image.open(tempIm).convert('RGB') - images_dict = { "" : image } - else: - images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } - elif isinstance(self.raw_data[i]["image"], Dict): - ### for multi-images input, the template for every image is , such as , - images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} - - ret = preprocess( - images_dict, - self.raw_data[i]["conversations"], - self.tokenizer, - self.transform, - query_nums=self.query_nums, - slice_config=self.slice_config, - llm_type=self.llm_type, - patch_size=self.patch_size, - batch_vision=self.batch_vision, - max_length=self.max_length - ) - ret = dict( - input_ids=ret["input_ids"], - position_ids=ret["position_ids"], - labels=ret["target"], - attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), - pixel_values=ret["pixel_values"], - tgt_sizes=ret["tgt_sizes"], - image_bound=ret["image_bound"], - ) - except: - logger.error(f"data fetch error") - return self.__getitem__(random.randint(0, len(self))) - return ret - - - def data_collator(examples, padding_value=0, max_length=2048): - def trim_and_pad(seq, batch_first, padding_value): - return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) - - input_ids = trim_and_pad( - [example["input_ids"] for example in examples], - batch_first=True, - padding_value=padding_value, - ) - position_ids = trim_and_pad( - [example["position_ids"] for example in examples], - batch_first=True, - padding_value=padding_value, - ) - targets = trim_and_pad( - [example["labels"] for example in examples], - batch_first=True, - padding_value=-100, - ) - attention_mask = trim_and_pad( - [example["attention_mask"] for example in examples], - batch_first=True, - padding_value=padding_value, - ) - pixel_values = [example["pixel_values"] for example in examples] - image_bound = [example["image_bound"] for example in examples] - tgt_sizes = [example["tgt_sizes"] for example in examples] - return { - "input_ids": torch.tensor(input_ids), - "position_ids": torch.tensor(position_ids), - "labels": torch.tensor(targets), - "attention_mask": torch.tensor(attention_mask), - "image_bound": image_bound, - "tgt_sizes": tgt_sizes, - "pixel_values": pixel_values, - } - - - def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048): - """ - for single image multi-turn conversation - conversation: [{'role': 'user', 'content': 'Describe this image'}, - {'role': 'assistant', 'content': 'This is a cat.'}] - """ - if llm_type == "llama3": - input_ids, context, raw_msg = conversation_to_ids_llama3( - conversation, tokenizer - ) - elif llm_type == "qwen2": - input_ids, context, raw_msg = conversation_to_ids_qwen2( - conversation, tokenizer - ) - else: - input_ids, context, raw_msg = conversation_to_ids_minicpm( - conversation, tokenizer - ) - - ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) - context = torch.from_numpy(np.hstack(context, dtype=np.int8)) - if input_ids.shape[-1] > max_length: - ids =ids[:max_length] - context = context[:max_length] - logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated") - - if torch.all(context): - logger.error("No tokens available to compute loss.") - raise Exception("No tokens available to compute loss.") - - # build target - target = torch.full_like(ids, -100, dtype=torch.int32) - - for i in range(1, len(ids)): - if context[i] == 0: - target[i - 1] = ids[i] - if context[i] == 1 and context[i - 1] == 0: - if hasattr(tokenizer, "eot_id"): - target[i - 1] = tokenizer.eot_id - else: - target[i - 1] = tokenizer.eos_id - - # build image bound - if new_schema: - start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id) - end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] - else: - image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] - if len(image_start_tokens) != len(image_end_tokens): - logger.error("image start token != image end tokens") - raise Exception("image start token != image end tokens") - - if len(image_start_tokens) > 0: - image_bound = torch.hstack( - [image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)] - ) - else: - image_bound = [] - - position_ids = torch.arange(ids.size(0)).long() - return { - "input_ids": ids, - "target": target, - "image_bound": image_bound, - "raw_msg": raw_msg, - "position_ids": position_ids - } - - - def conversation_to_ids_minicpm(conversation, tokenizer): - raw_msg = "" - input_ids = [] - context = [] - for idx, msg in enumerate(conversation): - role = msg["role"] - message = msg["content"] - assert role in ["user", "assistant"] - if role == "user": - prefix = "<用户>" - else: - prefix = "" - # append eos - if idx == len(conversation) - 1: - message = message + tokenizer.eos_token - prefix_ids = tokenizer.encode(prefix)[1:] # remove bos - message_ids = tokenizer.encode(message)[1:] - - input_ids.append(prefix_ids) - input_ids.append(message_ids) - - context.append(np.ones((len(prefix_ids),), dtype=np.int8)) - if role == "assistant": - context.append(np.zeros((len(message_ids),), dtype=np.int8)) - else: - context.append(np.ones((len(message_ids),), dtype=np.int8)) - - raw_msg += prefix + message - - return input_ids, context, raw_msg - - - def conversation_to_ids_llama3(conversation, tokenizer): - raw_msg = "" - input_ids = [] - context = [] - raw_msg = tokenizer.apply_chat_template( - conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template, - ) - input_ids = tokenizer.apply_chat_template( - conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template, - ) - input_ids = np.array(input_ids) - - start_header_idxs = np.where( - input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>") - )[0] - assistant_idxs = np.where( - input_ids == tokenizer.convert_tokens_to_ids("assistant") - )[0] - end_header_idxs = np.where( - input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>") - )[0] - eot_idxs = np.where( - input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0] - - context = np.ones_like(input_ids, dtype=np.int8) - - for assistant_idx in assistant_idxs: - if assistant_idx in set((start_header_idxs + end_header_idxs) / 2): - st = assistant_idx + 3 # assistant<|end_header_id|>\n\n - for eot_idx in eot_idxs: - if eot_idx > st: - context[st: eot_idx + 1] = 0 - break - - input_ids = np.hstack(input_ids) - context = np.hstack(context) - - return input_ids, context, raw_msg - - - def conversation_to_ids_qwen2(conversation, tokenizer): - raw_msg = "" - chat = [] - context = [] - for idx, msg in enumerate(conversation): - role = msg["role"] - message = msg["content"] - assert role in ["user", "assistant"] - if role == "user": - prefix = "user" - else: - prefix = "assistant" - chat.append({"role":prefix, "content":message}) - raw_msg += prefix + message - assert set([i['role'] for i in chat]) & set(['assistant']) - - ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) - input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False) - input_ids = np.array(input_ids) - - start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0] - assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0] - end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0] - - context = np.ones_like(input_ids, dtype=np.int8) - - for assistant_idx in assistant_idxs: - if assistant_idx-1 in set(start_idxs): - st = assistant_idx + 1 - for end_idx in end_idxs: - if end_idx > st: - context[st: end_idx + 1] = 0 - break - - input_ids = np.hstack(input_ids) - context = np.hstack(context) - return input_ids, context, raw_msg - - - def preprocess( - images_dict, - conversations, - tokenizer, - transform, - query_nums=64, - slice_config=None, - llm_type=None, - patch_size=14, - batch_vision=False, - max_length=2048, - ): - """ - single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation - """ - conversations = copy.deepcopy(conversations) - assert len(conversations) > 1, "conversations length must large than 2" - assert conversations[0]["role"] == "user", "the first role must be user" - - if slice_config is not None: - assert isinstance(slice_config, Dict) - assert "patch_size" in slice_config - assert "max_slice_nums" in slice_config - assert "scale_resolution" in slice_config - default_image_placeholder = ( - tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end - ) - new_schema = False - use_image_id = False - if llm_type=='qwen2': - new_schema = True - use_image_id = True - image_placeholder_dict = {} - images = [] - image_id_cnt = 0 - for img_name, image in images_dict.items(): - if slice_config: - source_image, patches, best_grid = slice_image( - image, - slice_config["max_slice_nums"], - slice_config["scale_resolution"], - slice_config["patch_size"], - ) - images.append(source_image) - image_placeholder = default_image_placeholder - if len(patches) > 0: - for i in range(len(patches)): - for j in range(len(patches[0])): - images.append(patches[i][j]) - if use_image_id: - image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder - image_id_cnt += 1 - image_placeholder += get_grid_placeholder( - tokenizer, best_grid, query_nums, new_schema = new_schema) - image_placeholder_dict[img_name] = image_placeholder - else: - images.append(image) - if use_image_id: - image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder - image_id_cnt += 1 - else: - image_placeholder = default_image_placeholder - image_placeholder_dict[img_name] = image_placeholder - - images = [transform(i) for i in images] - - if len(images_dict) == 1 and "" in images_dict: - if "" in conversations[0]["content"]: - conversations[0]["content"] = conversations[0]["content"].replace( - "", image_placeholder - ) - else: - conversations[0]["content"] = ( - image_placeholder + "\n" + conversation[0]["content"] - ) - input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) - else: - pattern = r'' - new_conversations = [] - for conversation in conversations: - content = conversation['content'] - parts = re.split(f'({pattern})', content) - for i, part in enumerate(parts): - if not part.strip(): - continue - if re.match(pattern, part): - if part in image_placeholder_dict: - parts[i] = image_placeholder_dict[part] - else: - raise Exception(f"not found {part} in image dict") - conversation['content'] = '\n'.join(parts) - new_conversations.append(conversation) - conversations = new_conversations - - input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) - - if batch_vision: - tgt_sizes = [] - reshape_images = [] - for image in images: - H, W = image.shape[1:] - reshape_image = reshape_by_patch(image, patch_size) - reshape_images.append(reshape_image) - tgt_sizes.append([H // patch_size, W // patch_size]) - if tgt_sizes: - tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32) - - input_dict["pixel_values"] = reshape_images - input_dict["tgt_sizes"] = tgt_sizes - - else: - input_dict["pixel_values"] = images - input_dict["tgt_sizes"] = [] - - return input_dict - - - def slice_image( - image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False - ): - original_size = image.size - original_width, original_height = original_size - log_ratio = math.log(original_width / original_height) - ratio = original_width * original_height / \ - (scale_resolution * scale_resolution) - multiple = min(math.ceil(ratio), max_slice_nums) - - source_image = None - best_grid = None - patches = [] - - if multiple <= 1 or never_split: - # dont need to slice, upsample - best_size = find_best_resize( - original_size, scale_resolution, patch_size, allow_upscale=True - ) - source_image = image.resize(best_size, Image.Resampling.BICUBIC) - else: - candidate_split_grids_nums = [] - for i in [multiple - 1, multiple, multiple + 1]: - if i == 1 or i > max_slice_nums: - continue - candidate_split_grids_nums.append(i) - - # source image, down-sampling and ensure divided by patch_size - best_resize = find_best_resize( - original_size, scale_resolution, patch_size) - source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - candidate_grids = [] - - # find best grid - for split_grids_nums in candidate_split_grids_nums: - m = 1 - while m <= split_grids_nums: - if split_grids_nums % m == 0: - candidate_grids.append([m, split_grids_nums // m]) - m += 1 - - best_grid = [1, 1] - min_error = float("inf") - for grid in candidate_grids: - error = abs(log_ratio - math.log(grid[0] / grid[1])) - if error < min_error: - best_grid = grid - min_error = error - - refine_size = get_refine_size( - original_size, best_grid, scale_resolution, patch_size, allow_upscale=True - ) - - refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) - patches = split_to_patches(refine_image, best_grid) - - return source_image, patches, best_grid - - - def ensure_divide(length, patch_size): - return max(round(length / patch_size) * patch_size, patch_size) - - - def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): - width, height = original_size - if (width * height > scale_resolution * scale_resolution) or allow_upscale: - r = width / height - height = int(scale_resolution / math.sqrt(r)) - width = int(height * r) - best_width = ensure_divide(width, patch_size) - best_height = ensure_divide(height, patch_size) - return (best_width, best_height) - - - def get_refine_size( - original_size, grid, scale_resolution, patch_size, allow_upscale=False - ): - width, height = original_size - grid_x, grid_y = grid - - refine_width = ensure_divide(width, grid_x) - refine_height = ensure_divide(height, grid_y) - - grid_width = refine_width / grid_x - grid_height = refine_height / grid_y - - best_grid_size = find_best_resize( - (grid_width, grid_height), - scale_resolution, - patch_size, - allow_upscale=allow_upscale, - ) - - refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) - - return refine_size - - - def split_to_patches(image, grid): - patches = [] - width, height = image.size - grid_x = int(width / grid[0]) - grid_y = int(height / grid[1]) - - for i in range(0, height, grid_y): - images = [] - for j in range(0, width, grid_x): - box = (j, i, j + grid_x, i + grid_y) - patch = image.crop(box) - images.append(patch) - patches.append(images) - - return patches - - - def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): - if new_schema: - image_placeholder = ( - tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end - ) - else: - image_placeholder = ( - tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end - ) - - cols = grid[0] - rows = grid[1] - slices = [] - for i in range(rows): - lines = [] - for j in range(cols): - lines.append(image_placeholder) - slices.append("".join(lines)) - if new_schema: - slice_placeholder = '\n'.join(slices) - else: - slice_placeholder = tokenizer.slice_start + \ - "\n".join(slices) + tokenizer.slice_end - return slice_placeholder - - - def reshape_by_patch(image_tensor, patch_size): - """ - :param image_tensor: shape [3, H, W] - :param patch_size: - :return: [3, patch_size, HW/patch_size] - """ - patches = torch.nn.functional.unfold( - image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size) - ) - - patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) - patches = patches.permute(0, 1, 3, 2).reshape( - image_tensor.size(0), patch_size, -1) - patches=patches.cuda() - return patches - # Load your tokenizer and model with AutoAWQ - model = AutoAWQForCausalLM.from_pretrained(args.model_path).cuda() - tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True,device_map={"": "cuda:0"}) - - # set some parameters - if hasattr(model.config, "slice_config"): - model.config.slice_config.max_slice_nums = 1 - slice_config = model.config.slice_config.to_dict() - else: - model.config.max_slice_nums = 1 - slice_config = model.config.to_dict() - - if hasattr(model.config, "batch_vision_input"): - batch_vision = model.config.batch_vision_input - else: - batch_vision = False - def build_transform(): - IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN - IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize( - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD - ), - ] - ) - transform_func = build_transform() - - - - - calib_data = SupervisedDataset( - dataset, - transform_func, - tokenizer, - slice_config=slice_config, - llm_type="qwen2", - patch_size=model.config.patch_size, - query_nums=model.config.query_num, - batch_vision=batch_vision, - max_length=2048, - ) - - - out_data=[] - batch_data=[] - for index in range(len(calib_data)//batch): - batch_data=[] - for j in range(batch): - batch_data.append(calib_data[j+index*batch]) - out_data.append(data_collator(batch_data)) - - - # Then just run the calibration process by one line of code: - model.quantize(calib_data=out_data[0], quant_config=quant_config) - - # remove pos_embed - if hasattr(model.model, 'resampler') and hasattr(model.model.resampler, 'pos_embed'): - del model.model.resampler.pos_embed - # Finally, save the quantized model: - model.model.config.use_cache = model.model.generation_config.use_cache = True - model.save_quantized(args.quant_path, safetensors=True, shard_size="4GB") - -if __name__ == "__main__": - main() - -