Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for Jamba #454

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .jamba import JambaAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
"jamba": JambaAWQForCausalLM,
}


Expand Down
3 changes: 2 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
"jamba": "AutoModelForCausalLM",
}


Expand Down Expand Up @@ -449,7 +450,7 @@ def from_quantized(
model,
checkpoint=model_weights_path,
device_map=device_map,
no_split_module_classes=[self.layer_type],
no_split_module_classes=[self.layer_type] if isinstance(self.layer_type, str) else self.layer_type,
offload_folder=offload_folder,
dtype=torch_dtype,
)
Expand Down
107 changes: 107 additions & 0 deletions awq/models/jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import tqdm
import torch
from typing import List, Tuple, Union
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from transformers.models.jamba.modeling_jamba import (
JambaAttentionDecoderLayer as OldJambaAttentionDecoderLayer,
JambaMambaDecoderLayer as OldJambaMambaDecoderLayer,
JambaForCausalLM as OldJambaForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class JambaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["mamba", "router"]

@staticmethod
def get_model_layers(model: OldJambaForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer]):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldJambaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer], input_feat, module_kwargs):
layers = []

# attention input
if isinstance(module, OldJambaAttentionDecoderLayer):
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

if hasattr(module.feed_forward, "router"):
# linear in
layers.append(
dict(
prev_op=module.pre_ff_layernorm,
layers=[
w
for expert in module.feed_forward.experts
for w in [expert.gate_proj, expert.up_proj]
],
inp=input_feat["feed_forward"],
module2inspect=module.feed_forward,
)
)

# linear out
for i, expert in enumerate(module.feed_forward.experts):
layers.append(
dict(
prev_op=expert.up_proj,
layers=[expert.down_proj],
inp=input_feat[f"feed_forward.experts.{i}.down_proj"],
)
)

else:
# linear 1
layers.append(
dict(
prev_op=module.pre_ff_layernorm,
layers=[module.feed_forward.gate_proj, module.feed_forward.up_proj],
inp=input_feat["feed_forward.gate_proj"],
module2inspect=module.feed_forward,
)
)

# linear 2
layers.append(
dict(
prev_op=module.feed_forward.up_proj,
layers=[module.feed_forward.down_proj],
inp=input_feat["feed_forward.down_proj"],
)
)

return layers
6 changes: 6 additions & 0 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,12 @@ def cache_input_hook(m, x, y, name, feat_dict):
"block_sparse_moe": layer.block_sparse_moe,
}

if self.awq_model.model_type == "jamba":
named_linears = {
**named_linears,
"feed_forward": layer.feed_forward,
}

for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
Expand Down