Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add qwen2vl support #599

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM",
"qwen2_vl": "AutoModelForVision2Seq",
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
Expand Down
255 changes: 255 additions & 0 deletions awq/models/qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""hack to use qwen2vl model with awq"""
import torch
from torch import nn
from typing_extensions import TYPE_CHECKING

from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device
from .base import (
Annotated,
AwqConfig,
BaseAWQForCausalLM,
Dict,
Doc,
List,
PreTrainedTokenizer,
Union,
)


if TYPE_CHECKING:
from transformers import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer



# hack to
# 1. use `self.calib_data` as processed input data
# 2. set the `layer_kwargs` and `inps` correctly
class Qwen2VLAwqQuantizer(AwqQuantizer):
def init_quant(self, n_samples=None, max_seq_len=None):
modules = self.awq_model.get_model_layers(self.model)
samples = self.calib_data

inps = []
layer_kwargs = {}

best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)

# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)

inps.append(hidden_states)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference

def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device):
def get_device(obj: torch.Tensor | nn.Module):
if isinstance(obj, torch.Tensor):
return obj.device
return next(obj.parameters()).device

if get_device(obj) != device:
obj = obj.to(device)
return obj

# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
for k, v in samples.items():
if isinstance(v, (torch.Tensor, nn.Module)):
samples[k] = move_to_device(v, best_device)
try:
self.model(**samples)
except ValueError: # work with early exit
pass
finally:
for k, v in samples.items():
if isinstance(v, (torch.Tensor, nn.Module)):
samples[k] = move_to_device(v, "cpu")
modules[0] = modules[0].module # restore

del samples
inps = inps[0]

modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")

clear_memory()

return modules, layer_kwargs, inps


class Qwen2VLAWQForConditionalGeneration(BaseAWQForCausalLM):
layer_type = "Qwen2VLDecoderLayer"
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["visual"]

@staticmethod
def get_model_layers(model: "Qwen2VLForConditionalGeneration"):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: "Qwen2VLForConditionalGeneration", device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.visual = model.visual.to(device)

@staticmethod
def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers

# hack to use `Qwen2VLAwqQuantizer` as quantizer
@torch.no_grad()
def quantize(
self,
tokenizer: Annotated[
PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
] = None,
quant_config: Annotated[
Dict, Doc("The quantization config you want to use.")
] = {},
calib_data: Annotated[
Union[str, List[str]],
Doc(
"The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
),
] = "pileval",
split: Annotated[str, Doc("The split of calib_data.")] = "train",
text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
duo_scaling: Annotated[
bool, Doc("Whether to scale using both w/x or just x.")
] = True,
export_compatible: Annotated[
bool,
Doc(
"This argument avoids real quantization by only applying the scales without quantizing down to FP16."
),
] = False,
apply_clip: Annotated[
bool,
Doc(
"Whether to apply clipping to the model during quantization. Some models may perform better with this set to False."
),
] = True,
n_parallel_calib_samples: Annotated[
int,
Doc(
"The number of parallel samples to run through the model. "
"A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. "
"If None, runs through all samples at the same time. "
"You can set this to a low number for more memory efficient quantization."
),
] = None,
max_calib_samples: Annotated[
int, Doc("The maximum number of samples to run through the model.")
] = 128,
max_calib_seq_len: Annotated[
int,
Doc(
"The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len."
),
] = 512,
max_chunk_memory: Annotated[
int,
Doc(
"The loss computation and per-channel mean is optimized into chunked computations."
" Adjust this parameter to increase or decrease memory usage for these computations."
" Default is 1GB (1024 * 1024 * 1024)."
),
] = 1024
* 1024
* 1024,
):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

if hasattr(self, "modules_to_not_convert"):
self.quant_config.modules_to_not_convert = self.modules_to_not_convert

self.quantizer = Qwen2VLAwqQuantizer(
self,
self.model,
tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.zero_point,
self.quant_config.version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=self.quant_config.modules_to_not_convert,
export_compatible=export_compatible,
apply_clip=apply_clip,
n_parallel_calib_samples=n_parallel_calib_samples,
max_calib_samples=max_calib_samples,
max_calib_seq_len=max_calib_seq_len,
max_chunk_memory=max_chunk_memory,
)
self.quantizer.quantize()

self.is_quantized = True
87 changes: 87 additions & 0 deletions examples/quantize_qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import logging

from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLProcessor

from awq.models.qwen2vl import Qwen2VLAWQForConditionalGeneration


logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)

# Specify paths and hyperparameters for quantization
model_path = "your_model_path"
quant_path = "your_quantized_model_path"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

# Load your processor and model with AutoAWQ
processor = Qwen2VLProcessor.from_pretrained(model_path)
model = Qwen2VLAWQForConditionalGeneration.from_pretrained(
model_path, model_type="qwen2_vl", use_cache=False, attn_implementation="flash_attention_2"
)


# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list,
# each of which is a typical chat message as shown below. you can specify text and image in `content` field:
# dataset = [
# # message 0
# [
# {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": "Tell me who you are."},
# {"role": "assistant", "content": "I am a large language model named Qwen..."},
# ],
# # message 1
# [
# {
# "role": "user",
# "content": [
# {"type": "image", "image": "file:///path/to/your/image.jpg"},
# {"type": "text", "text": "Output all text in the image"},
# ],
# },
# {"role": "assistant", "content": "The text in the image is balabala..."},
# ],
# # other messages...
# ...,
# ]
# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset.
def prepare_dataset(n_sample: int = 8) -> list[list[dict]]:
from datasets import load_dataset

dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]")
return [
[
{
"role": "user",
"content": [
{"type": "image", "image": sample["url"]},
{"type": "text", "text": "generate a caption for this image"},
],
},
{"role": "assistant", "content": sample["caption"]},
]
for sample in dataset
]


dataset = prepare_dataset()

# process the dataset into tensors
text = processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(dataset)
inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")

# Then just run the calibration process by one line of code:
model.quantize(calib_data=inputs, quant_config=quant_config)

# Finally, save the quantized model:
model.model.config.use_cache = model.model.generation_config.use_cache = True
model.save_quantized(quant_path, safetensors=True, shard_size="4GB")
processor.save_pretrained(quant_path)

# Then you can obtain your own AWQ quantized model for deployment. Enjoy!