Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding bert - WIP #328

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM
from .bert import BertAWQModel
3 changes: 2 additions & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM
"qwen2": Qwen2AWQForCausalLM,
"bert": BertAWQModel,
}


Expand Down
11 changes: 10 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
# Since we support different `AutoModelForxxx` from transformers
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
"bert": "AutoModel",
"mpt": "AutoModelForCausalLM",
"llama": "AutoModelForCausalLM",
"opt": "AutoModelForCausalLM",
Expand Down Expand Up @@ -152,7 +153,8 @@ def forward(self, x):

# Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.generation_config.do_sample = True
if self.model.generation_config is not None:
self.model.generation_config.do_sample = True
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)

Expand Down Expand Up @@ -441,3 +443,10 @@ def _scale_activations(self, layer):
# scale activation
scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
set_op_by_name(layer, scale_dict["scale_name"], scaled_act)

class BaseBetterTransformerAWQModel(BaseAWQForCausalLM):
def __init__(
self, model, model_type, is_quantized, config, quant_config, processor
):
super().__init__(model, model_type, is_quantized, config, quant_config, processor)
self.model.to_bettertransformer()
67 changes: 67 additions & 0 deletions awq/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from .base import BaseAWQForCausalLM
from transformers.models.bert.modeling_bert import BertModel, BertEncoder, BertLayer

class BertAWQModel(BaseAWQForCausalLM):
layer_type = "BertEncoder"
max_new_tokens_key = "n_positions"

@staticmethod
def get_model_layers(model: BertModel):
def prepare_inputs_for_generation(input_ids, **kwargs):
return {"input_ids": input_ids, **kwargs}
model.prepare_inputs_for_generation = prepare_inputs_for_generation
return model.encoder.layer

@staticmethod
def get_act_for_scaling(module: BertEncoder):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features
)


def move_embed(self, model: BertModel, device: str):
model.embeddings = model.embeddings.to(device)

if model.pooler is not None:
model.pooler.dense = model.pooler.dense.to(device)

def get_layers_for_scaling(self, module: BertLayer, input_feat, module_kwargs):
layers = []

# module.attention
# TODO: Handle NoOp. No previous LayerNorm/Linear in module.attention like in other models.
# layers.append(dict(
# prev_op=module.identity,
# layers=[module.attention.self.query,
# module.attention.self.key, module.attention.self.value],
# inp=input_feat['attention.self.query'],
# module2inspect=module.attention, kwargs=module_kwargs,
# ))

# attention out
layers.append(dict(
prev_op=module.attention.self.value,
layers=[module.attention.output.dense],
inp=input_feat['attention.self.value'],
))

# # linear 2
# layers.append(dict(
# prev_op=module.intermediate.intermediate_act_fn,
# layers=[module.output.dense],
# inp=input_feat['intermediate.dense'],
# ))

# # linear 1
# layers.append(dict(
# prev_op=module.attention.output.dropout,
# layers=[module.intermediate.dense],
# inp=input_feat['attention.output.dropout'],
# ))

return layers


35 changes: 35 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,41 @@ def forward(

return out, None, past_key_value

class BertBlock(nn.Module):
"""
"""
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev

def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)

h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(self.norm_2(h))

return out, None, past_key_value

class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _search_best_scale(
# [STEP 3]: Compute output of module
with torch.no_grad():
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

module2inspect.to(inp.device)
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
Expand Down
2 changes: 1 addition & 1 deletion awq/utils/calib_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
line_encoded = tokenizer.encode(line, truncation=True)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
Expand Down
34 changes: 34 additions & 0 deletions tests/compare_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
import torch, time

model_id = "/home/michael/embeddings/bge-small-en-v1.5-quant"

modelawq = AutoModel.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda:0")
model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)

def get_tks(s="The quick brown fox jumps over the lazy dog"):
input_text = [s * 100] * 64
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
for k, v in inputs.items():
inputs[k] = v.to(modelawq.device)
return inputs

def time_model(m, input_ids, printit=False):
with torch.inference_mode():
start = time.time()
out = m(**input_ids)
end = time.time()
if printit:
print(f"Time: {end - start} {m.config._name_or_path}")
return out


warmup = get_tks("This is a long warmup string to warm up the model. Again.")
benchmark = get_tks("This is a long benchmark string to benchmark the model.")
time_model(model, warmup, printit=False)
out_m = time_model(model, benchmark, printit=True)
time_model(modelawq, warmup, printit=False)
out_awq = time_model(modelawq, benchmark, printit=True)

print(out_m.last_hidden_state - out_awq.last_hidden_state)
17 changes: 17 additions & 0 deletions tests/test_bert_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'BAAI/bge-small-en-v1.5'
quant_path = 'bge-small-en-v1.5-quant'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, truncation=True, padding=True, max_length=512, return_tensors="pt")

# Quantize
model.quantize(tokenizer, quant_config=quant_config, )

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)