Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepseek v2 support #508

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from .starcoder2 import Starcoder2AWQForCausalLM
from .phi3 import Phi3AWQForCausalLM
from .cohere import CohereAWQForCausalLM
from .deepseek_v2 import DeepseekV2AWQForCausalLM
from .minicpm import MiniCPMAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"starcoder2": Starcoder2AWQForCausalLM,
"phi3": Phi3AWQForCausalLM,
"cohere": CohereAWQForCausalLM,
"deepseek_v2": DeepseekV2AWQForCausalLM,
"minicpm": MiniCPMAWQForCausalLM,
}

Expand Down
5 changes: 4 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
"starcoder2": "AutoModelForCausalLM",
"phi3": "AutoModelForCausalLM",
"cohere": "AutoModelForCausalLM",
"minicpm":"AutoModelForCausalLM"
"deepseek_v2": "AutoModelForCausalLM",
"minicpm":"AutoModelForCausalLM",
}


Expand Down Expand Up @@ -506,6 +507,8 @@ def from_quantized(
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
)

model.eval()

return self(
model,
model_type,
Expand Down
128 changes: 128 additions & 0 deletions awq/models/deepseek_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM


class DeepseekV2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "DeepseekV2DecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def get_model_layers(model):
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)

@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(
module, input_feat, module_kwargs
):
layers = []

if hasattr(module.self_attn, "q_proj"):
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.kv_a_proj_with_mqa,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
else:
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_a_proj,
module.self_attn.kv_a_proj_with_mqa,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
layers.append(
dict(
prev_op=module.self_attn.q_a_layernorm,
layers=[
module.self_attn.q_b_proj,
],
inp=input_feat["self_attn.q_b_proj"],
)
)

# kv layernorm
layers.append(
dict(
prev_op=module.self_attn.kv_a_layernorm,
layers=[
module.self_attn.kv_b_proj,
],
inp=input_feat["self_attn.kv_b_proj"],
)
)

if hasattr(module.mlp, "gate"):
# linear in
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[
w
for expert in module.mlp.experts
for w in [expert.gate_proj, expert.up_proj]
] + [module.mlp.shared_experts.gate_proj, module.mlp.shared_experts.up_proj],
inp=input_feat["mlp"],
module2inspect=module.mlp,
)
)

# linear out
for i, expert in enumerate(module.mlp.experts):
layers.append(
dict(
prev_op=expert.up_proj,
layers=[expert.down_proj],
inp=input_feat[f"mlp.experts.{i}.down_proj"],
)
)
layers.append(
dict(
prev_op=module.mlp.shared_experts.up_proj,
layers=[module.mlp.shared_experts.down_proj],
inp=input_feat[f"mlp.shared_experts.down_proj"],
)
)
else:
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers
6 changes: 6 additions & 0 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ def cache_input_hook(m, x, y, name, feat_dict):
"block_sparse_moe": layer.block_sparse_moe,
}

if self.awq_model.model_type == "deepseek_v2":
named_linears = {
**named_linears,
"mlp": layer.mlp,
}

for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
Expand Down