diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 14886a24..9b9dd73e 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -14,3 +14,4 @@ from .llava import LlavaAWQForCausalLM from .mixtral import MixtralAWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM +from .bert import BertAWQModel \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index c060b47f..71fcd409 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -21,7 +21,8 @@ "qwen": QwenAWQForCausalLM, "baichuan": BaichuanAWQForCausalLM, "llava": LlavaAWQForCausalLM, - "qwen2": Qwen2AWQForCausalLM + "qwen2": Qwen2AWQForCausalLM, + "bert": BertAWQModel, } diff --git a/awq/models/base.py b/awq/models/base.py index 36e86e0a..132b1168 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -41,6 +41,7 @@ # Since we support different `AutoModelForxxx` from transformers # we need to define a custom mapping dict as below: TRANSFORMERS_AUTO_MAPPING_DICT = { + "bert": "AutoModel", "mpt": "AutoModelForCausalLM", "llama": "AutoModelForCausalLM", "opt": "AutoModelForCausalLM", @@ -152,7 +153,8 @@ def forward(self, x): # Save model and config files with empty state dict self.model.config.quantization_config = self.quant_config.to_transformers_dict() - self.model.generation_config.do_sample = True + if self.model.generation_config is not None: + self.model.generation_config.do_sample = True self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.quant_config.save_pretrained(save_dir) @@ -441,3 +443,10 @@ def _scale_activations(self, layer): # scale activation scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like) set_op_by_name(layer, scale_dict["scale_name"], scaled_act) + +class BaseBetterTransformerAWQModel(BaseAWQForCausalLM): + def __init__( + self, model, model_type, is_quantized, config, quant_config, processor + ): + super().__init__(model, model_type, is_quantized, config, quant_config, processor) + self.model.to_bettertransformer() \ No newline at end of file diff --git a/awq/models/bert.py b/awq/models/bert.py new file mode 100644 index 00000000..4f3af085 --- /dev/null +++ b/awq/models/bert.py @@ -0,0 +1,67 @@ +from .base import BaseAWQForCausalLM +from transformers.models.bert.modeling_bert import BertModel, BertEncoder, BertLayer + +class BertAWQModel(BaseAWQForCausalLM): + layer_type = "BertEncoder" + max_new_tokens_key = "n_positions" + + @staticmethod + def get_model_layers(model: BertModel): + def prepare_inputs_for_generation(input_ids, **kwargs): + return {"input_ids": input_ids, **kwargs} + model.prepare_inputs_for_generation = prepare_inputs_for_generation + return model.encoder.layer + + @staticmethod + def get_act_for_scaling(module: BertEncoder): + return dict( + is_scalable=True, + scale_name="mlp.act", + scale_layer=module.mlp.act, + scale_shape=module.mlp.fc_in.out_features + ) + + + def move_embed(self, model: BertModel, device: str): + model.embeddings = model.embeddings.to(device) + + if model.pooler is not None: + model.pooler.dense = model.pooler.dense.to(device) + + def get_layers_for_scaling(self, module: BertLayer, input_feat, module_kwargs): + layers = [] + + # module.attention + # TODO: Handle NoOp. No previous LayerNorm/Linear in module.attention like in other models. + # layers.append(dict( + # prev_op=module.identity, + # layers=[module.attention.self.query, + # module.attention.self.key, module.attention.self.value], + # inp=input_feat['attention.self.query'], + # module2inspect=module.attention, kwargs=module_kwargs, + # )) + + # attention out + layers.append(dict( + prev_op=module.attention.self.value, + layers=[module.attention.output.dense], + inp=input_feat['attention.self.value'], + )) + + # # linear 2 + # layers.append(dict( + # prev_op=module.intermediate.intermediate_act_fn, + # layers=[module.output.dense], + # inp=input_feat['intermediate.dense'], + # )) + + # # linear 1 + # layers.append(dict( + # prev_op=module.attention.output.dropout, + # layers=[module.intermediate.dense], + # inp=input_feat['attention.output.dropout'], + # )) + + return layers + + \ No newline at end of file diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index 47c061f4..bb33bfba 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -73,6 +73,41 @@ def forward( return out, None, past_key_value +class BertBlock(nn.Module): + """ + """ + def __init__( + self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, + mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.hidden_size = hidden_size + self.norm_1 = norm_1.to(dev) + self.attn = QuantAttentionFused( + self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, + dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta + ).to(dev) + self.norm_2 = norm_2.to(dev) + self.mlp = mlp.to(dev) + self.device = dev + + def forward( + self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + ): + norm_out = self.norm_1(hidden_states) + attn_output, _, past_key_value = self.attn.forward( + hidden_states=norm_out, + past_key_value=past_key_value, + attention_mask=attention_mask + ) + + h = hidden_states.to(attn_output.device) + attn_output + out = h + self.mlp.forward(self.norm_2(h)) + + return out, None, past_key_value + class MPTBlock(nn.Module): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): super().__init__() diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 2e604547..a822b181 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -253,7 +253,7 @@ def _search_best_scale( # [STEP 3]: Compute output of module with torch.no_grad(): module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) - + module2inspect.to(inp.device) fp16_output = module2inspect(inp, **module_kwargs) if isinstance(fp16_output, tuple): fp16_output = fp16_output[0] diff --git a/awq/utils/calib_data.py b/awq/utils/calib_data.py index cc589a34..440f0845 100644 --- a/awq/utils/calib_data.py +++ b/awq/utils/calib_data.py @@ -38,7 +38,7 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval", else: line = data[text_column] line = line.strip() - line_encoded = tokenizer.encode(line) + line_encoded = tokenizer.encode(line, truncation=True) if len(line_encoded) > 512: continue sample = torch.tensor([line_encoded]) diff --git a/tests/compare_bert.py b/tests/compare_bert.py new file mode 100644 index 00000000..671f267e --- /dev/null +++ b/tests/compare_bert.py @@ -0,0 +1,34 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel +import torch, time + +model_id = "/home/michael/embeddings/bge-small-en-v1.5-quant" + +modelawq = AutoModel.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda:0") +model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +def get_tks(s="The quick brown fox jumps over the lazy dog"): + input_text = [s * 100] * 64 + inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) + for k, v in inputs.items(): + inputs[k] = v.to(modelawq.device) + return inputs + +def time_model(m, input_ids, printit=False): + with torch.inference_mode(): + start = time.time() + out = m(**input_ids) + end = time.time() + if printit: + print(f"Time: {end - start} {m.config._name_or_path}") + return out + + +warmup = get_tks("This is a long warmup string to warm up the model. Again.") +benchmark = get_tks("This is a long benchmark string to benchmark the model.") +time_model(model, warmup, printit=False) +out_m = time_model(model, benchmark, printit=True) +time_model(modelawq, warmup, printit=False) +out_awq = time_model(modelawq, benchmark, printit=True) + +print(out_m.last_hidden_state - out_awq.last_hidden_state) diff --git a/tests/test_bert_quant.py b/tests/test_bert_quant.py new file mode 100644 index 00000000..b0ce2a77 --- /dev/null +++ b/tests/test_bert_quant.py @@ -0,0 +1,17 @@ +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'BAAI/bge-small-en-v1.5' +quant_path = 'bge-small-en-v1.5-quant' +quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, truncation=True, padding=True, max_length=512, return_tensors="pt") + +# Quantize +model.quantize(tokenizer, quant_config=quant_config, ) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) \ No newline at end of file