Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepseek v2 support #508

Merged
merged 4 commits into from
Jun 24, 2024

Conversation

TechxGenus
Copy link
Contributor

Add support for Deepseek-V2.
I wrote it with reference to the papers of deepseek and awq. It works, but there are several potential issues:

  1. Due to the combination/compression strategy of mla and rope, q_group_size cannot take the commonly used 128, and the maximum is 64.
  2. The performance drop after quantization is slightly more than that of common dense models like llama. A possible reason is that mla and deepseekmoe decompose both the attn and mlp layers into small matrices, which will lead to a situation a bit like quantizing small models.

In addition, the implementation of the fusion layer is lacking (while considering the complexity of implementing new operators, is it necessary to add fusion layers for all new architectures?)

@fengyang95
Copy link

fengyang95 commented Jun 20, 2024

Nice job! Looking forward to the release of the deepseek-coder-v2 AWQ quantized model.

@PanameraXXX
Copy link

Excellent job! I'm looking forward to the release of the DeepSeek-Coder-v2 AWQ quantized model.

@casper-hansen
Copy link
Owner

Quantization works on deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct.

However, I get an error when trying to run generation. Do you have an example script that works?

Perplexity:

  • FP16: doesn't run on my 4090
  • INT4: 9.276

@TechxGenus
Copy link
Contributor Author

An example script is:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

tokenizer = AutoTokenizer.from_pretrained("TechxGenus/DeepSeek-Coder-V2-Lite-Instruct-AWQ")
model = AutoAWQForCausalLM.from_quantized(
    "TechxGenus/DeepSeek-Coder-V2-Lite-Instruct-AWQ",
    trust_remote_code=True,
).cuda()
model.eval()

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
tokens = tokenizer("def min(arr):\n", return_tensors='pt').input_ids.cuda()

# Generate output
generation_output = model.generate(
    tokens,
    streamer=streamer,
    max_new_tokens=512,
)

It works. May be some issues when working with integrations such as transformers and accelerate.

@fengyang95
Copy link

TechxGenus/DeepSeek-Coder-V2-Lite-Instruct-AWQ

Hi,Do you have plans to release the deepseek-coder-v2-instruct-awq model?

@TechxGenus
Copy link
Contributor Author

Hi,Do you have plans to release the deepseek-coder-v2-instruct-awq model?

It's too big and I lack the hardware to do it. Will update if I have enough resources.

@casper-hansen
Copy link
Owner

Hi,Do you have plans to release the deepseek-coder-v2-instruct-awq model?

I can release a quantized version of the big model once merged.

An example script is:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

tokenizer = AutoTokenizer.from_pretrained("TechxGenus/DeepSeek-Coder-V2-Lite-Instruct-AWQ")
model = AutoAWQForCausalLM.from_quantized(
    "TechxGenus/DeepSeek-Coder-V2-Lite-Instruct-AWQ",
    trust_remote_code=True,
).cuda()
model.eval()

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
tokens = tokenizer("def min(arr):\n", return_tensors='pt').input_ids.cuda()

# Generate output
generation_output = model.generate(
    tokens,
    streamer=streamer,
    max_new_tokens=512,
)

It works. May be some issues when working with integrations such as transformers and accelerate.

I had a similar script earlier, but ran into a few errors in transformers. Did you use the model.eval() to make it work? Which transformers version are you using?

@TechxGenus
Copy link
Contributor Author

TechxGenus commented Jun 21, 2024

I had a similar script earlier, but ran into a few errors in transformers. Did you use the model.eval() to make it work? Which transformers version are you using?

I use transformers==4.41.2. I just found that it didn't work after I deleted model.eval().
So I tested the following situations and recorded the results here:

  • use AutoAWQForCausalLM, model.eval(), single gpu: Success
  • use AutoAWQForCausalLM, no model.eval(), single gpu: Error Success
  • use AutoAWQForCausalLM, model.eval(), two gpus: Error, success when set use_cache=False Success
  • use AutoAWQForCausalLM (model.eval() and gpu num not affect): Success now

I'll try to fix some tomorrow.

@TechxGenus
Copy link
Contributor Author

Now it should work well.

@casper-hansen
Copy link
Owner

@TechxGenus Am I reading it correctly that multi-GPU does not work with model.eval() but single-GPU does work?

@TechxGenus
Copy link
Contributor Author

Am I reading it correctly that multi-GPU does not work with model.eval() but single-GPU does work?

Both work well for all settings now.

@casper-hansen
Copy link
Owner

@TechxGenus I'm curious, why did you decided to quantize the gate? I see model.eval() was added in the base file. I will just have to do a few checks if we need to keep it:

  • does the model finetune fine?
  • do other models still work after adding model.eval()?

@TechxGenus
Copy link
Contributor Author

why did you decided to quantize the gate?

I did not quantize gate. I tried to filter it first, and found that when filtering mlp.gate, mlp.gate_proj would also be filtered, which caused some errors of transformers. I originally intended to modify exclude_layers_to_not_quantize to achieve this, but then I found that unlike Mixtral (https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L822), in the implementation of deepseek v2, gate is not a Linear (https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct/blob/main/modeling_deepseek.py#L410-L412). So it is not detected itself and does not need any processing.

does the model finetune fine? do other models still work after adding model.eval()?

Models like Llama can generate text normally. I haven't tested finetune but I think setting it back to model.train() should work fine. This modification is because deepseek v2 uses different computations based on whether it is in training mode (https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct/blob/main/modeling_deepseek.py#L572), and from_quantized loads in the training state by default. I recommend syncing with from_pretrained (https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L85-L86) and setting the model to load in model.eval() state by default.

@casper-hansen casper-hansen merged commit 6b45c95 into casper-hansen:main Jun 24, 2024
@Grey4sh
Copy link

Grey4sh commented Jun 27, 2024

Where can i get the quantize script?

here is my quantize script

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = '/var/mntpkg/deepseek-coder-v2-instruct'
quant_path = 'deepseek-coder-v2-ins-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
    model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')

I tried to quantize the DeepSeek-Coder-V2-Instruct-236B with the latest code version of casper-hansen:main, but it failed several times. Here are some error logs

AWQ:   0%|                                                                                                                                               | 0/60 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/home/chatgpt/young/awq_quant.py", line 15, in <module>
    model.quantize(tokenizer, quant_config=quant_config)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
    self.quantizer.quantize()
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 153, in quantize
    module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/deepseek_v2.py", line 51, in get_layers_for_scaling
    inp=input_feat["self_attn.q_proj"],
KeyError: 'self_attn.q_proj'

another one

File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
    assert torch.isnan(w).sum() == 0
AssertionError

TechxGenus added a commit to TechxGenus/AutoAWQ that referenced this pull request Jun 27, 2024
@TechxGenus
Copy link
Contributor Author

AWQ: 0%| | 0/60 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/chatgpt/young/awq_quant.py", line 15, in
model.quantize(tokenizer, quant_config=quant_config)
File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
self.quantizer.quantize()
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 153, in quantize
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/deepseek_v2.py", line 51, in get_layers_for_scaling
inp=input_feat["self_attn.q_proj"],
KeyError: 'self_attn.q_proj'

I fixed it using #524.

File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
assert torch.isnan(w).sum() == 0
AssertionError

This may require #498 and #516 to resolve.

@Grey4sh
Copy link

Grey4sh commented Jun 29, 2024

AWQ: 0%| | 0/60 [00:02<?, ?it/s]
Traceback (most recent call last):
File "/home/chatgpt/young/awq_quant.py", line 15, in
model.quantize(tokenizer, quant_config=quant_config)
File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
self.quantizer.quantize()
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 153, in quantize
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/deepseek_v2.py", line 51, in get_layers_for_scaling
inp=input_feat["self_attn.q_proj"],
KeyError: 'self_attn.q_proj'

I fixed it using #524.

File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
assert torch.isnan(w).sum() == 0
AssertionError

This may require #498 and #516 to resolve.

Thanks for your quick reply. After pull the latest code in #516 and #524, i encounter another problem.

File .../awq/quantize/quantize/quantizer.py , line 398, in _compute_best_clip
    input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]
ValueError: slice step cannot be zero

@TechxGenus
Copy link
Contributor Author

File .../awq/quantize/quantize/quantizer.py , line 398, in _compute_best_clip
input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]
ValueError: slice step cannot be zero

This is usually caused by too few tokens in the calibration data calculated by an expert of MoE model.
Which dataset are you using? I think increasing max_calib_samples or decreasing n_sample_token might be effective.

@Grey4sh
Copy link

Grey4sh commented Jul 1, 2024

File .../awq/quantize/quantize/quantizer.py , line 398, in _compute_best_clip
input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]
ValueError: slice step cannot be zero

This is usually caused by too few tokens in the calibration data calculated by an expert of MoE model. Which dataset are you using? I think increasing max_calib_samples or decreasing n_sample_token might be effective.

Since my development environment is offline, I manually downloaded pile-val-backup-val.jsonl and modified the dataset load code.

chatgpt@chatgpt:~/young$ wc -l pile-val-backup-val.jsonl
214670 pile-val-backup-val.jsonl

awq/utils/calib_data.py

def load_local_dataset(jsonl_file_path):
    return load_dataset('json', data_files={'validation': jsonl_file_path}, split='validation')


def get_calib_dataset(
    data: Union[str, List[str], List[List[int]]] = "pileval",
    tokenizer=None,
    n_samples=256,
    block_size=512,
    split="train",
    text_column="text",
    local_data_path: str = None
):
    if isinstance(data, str):
        if local_data_path:
            dataset = load_local_dataset(local_data_path)
        elif data == "pileval":
            dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
        else:
            dataset = load_dataset(data, split=split)

        dataset = dataset.shuffle(seed=42)

awq/quantize/quantizer.py

def init_quant(self, n_samples=32, seqlen=512):
        modules = self.awq_model.get_model_layers(self.model)
        samples = get_calib_dataset(
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
            local_data_path="/home/chatgpt/young/pile-val-backup-val.jsonl"
        )
        samples = torch.cat(samples, dim=0)

I followed your advice decreasing n_sample_token still got the same error input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]

@casper-hansen
Copy link
Owner

Due to the combination/compression strategy of mla and rope, q_group_size cannot take the commonly used 128, and the maximum is 64.

@Cucunnber you need to specify group size of 64, as explained by @TechxGenus

I am beginning testing of the big model currently on a lot of GPUs, but I cannot promise that my attempts will be successful without more credits for compute.

At the moment, this is my choice of dataset (using batched_quantization branch):

def load_openhermes_coding():
    data = load_dataset("alvarobartt/openhermes-preferences-coding", split="train")

    samples = []
    for sample in data:
        responses = [f'{response["role"]}: {response["content"]}' for response in sample["chosen"]]
        samples.append("\n".join(responses))

    return samples

# Quantize
model.quantize(
    tokenizer,
    quant_config=quant_config,
    calib_data=load_openhermes_coding(),
    # n_parallel_calib_samples=32,
    # max_calib_samples=128,
    # max_calib_seq_len=4096
)

@casper-hansen
Copy link
Owner

casper-hansen commented Jul 1, 2024

Nevermind, I also got the same error. I have not seen this error before on any other model, so it will be a little hard to debug what's happening here. If you just want a quantized model, you can turn off the clipping or optimize it for deepseekv2 specifically. Clipping may need to skip a specific layer.

AWQ:   3%|▎         | 2/60 [25:52<12:30:21, 776.22s/it]
Computing Best Clip:  59%|█████▉    | 290/488 [05:28<03:43,  1.13s/it]
Traceback (most recent call last):
  File "/workspace/AutoAWQ/examples/quantize.py", line 27, in <module>
    model.quantize(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/AutoAWQ/awq/models/base.py", line 230, in quantize
    self.quantizer.quantize()
  File "/workspace/AutoAWQ/awq/quantize/quantizer.py", line 177, in quantize
    clip_list = self._search_best_clip(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/AutoAWQ/awq/quantize/quantizer.py", line 469, in _search_best_clip
    max_val = self._compute_best_clip(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/AutoAWQ/awq/quantize/quantizer.py", line 493, in _compute_best_clip
    input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
ValueError: slice step cannot be zero

Potential fix (Claude 3.5 Sonnet's suggestion):

image

@TechxGenus
Copy link
Contributor Author

I followed your advice decreasing n_sample_token still got the same error input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]

Ah, I should have made it clearer. What needs to be reduced is n_sample_token in _compute_best_clip, not n_samples in init_quant.

Nevermind, I also got the same error. I have not seen this error before on any other model, so it will be a little hard to debug what's happening here.

I encountered this problem before when I was quantizing a fine-tuned MoE model. It was because some experts calculated too few tokens, which caused input_feat.shape[1] to be too small (less than n_sample_tokens=512). DS-Coder-V2 has 162 experts, so it is very likely that one of the experts calculated too small.
I think Claude's suggested modification is feasible. At this time, error will only be reported in the extreme case that some experts did not calculate any tokens.

@Grey4sh
Copy link

Grey4sh commented Jul 2, 2024

I followed your advice decreasing n_sample_token still got the same error input_feat = input_feat[:, 0 : : input_feat.shape[1] // n_sample_token]

Ah, I should have made it clearer. What needs to be reduced is n_sample_token in _compute_best_clip, not n_samples in init_quant.

Nevermind, I also got the same error. I have not seen this error before on any other model, so it will be a little hard to debug what's happening here.

I encountered this problem before when I was quantizing a fine-tuned MoE model. It was because some experts calculated too few tokens, which caused input_feat.shape[1] to be too small (less than n_sample_tokens=512). DS-Coder-V2 has 162 experts, so it is very likely that one of the experts calculated too small. I think Claude's suggested modification is feasible. At this time, error will only be reported in the extreme case that some experts did not calculate any tokens.

@TechxGenus @casper-hansen Another error ocurs.

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [01:07<00:00,  1.23s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Token indices sequence length is longer than the specified maximum sequence length for this model (80864 > 16384). Running this sequence through the model will result in indexing errors
AWQ:   5%|███████▌                                                                                                                                              | 3/60 [22:57<7:16:12, 459.16s/it]
Traceback (most recent call last):
  File "/home/chatgpt/young/awq_quant.py", line 15, in <module>
    model.quantize(tokenizer, quant_config=quant_config)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/models/base.py", line 198, in quantize
    self.quantizer.quantize()
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 156, in quantize
    scales_list = [
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 157, in <listcomp>
    self._search_best_scale(self.modules[i], **layer)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 278, in _search_best_scale
    best_scales = self._compute_best_scale(
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 335, in _compute_best_scale
    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
  File "/home/chatgpt/.local/lib/python3.10/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
    assert torch.isnan(w).sum() == 0
AssertionError

@casper-hansen
Copy link
Owner

assert torch.isnan(w).sum() == 0

This is a long-standing issue that you cannot quantize models when some of the weights end up as NaN values. I do not have a fix for this at the moment. @TechxGenus what do you think?

@TechxGenus
Copy link
Contributor Author

I originally thought that this should be fixed in the latest #516. Does this error still occur after using the latest code? All possible overflow parts seem to be handled.

@casper-hansen
Copy link
Owner

I have now merged it to the main branch. I have spun up a machine to test quantization of this model after the latest improvements.

@casper-hansen
Copy link
Owner

Works on main branch so far

AWQ:   7%|▋         | 4/60 [57:16<15:30:43, 997.21s/it]

@TechxGenus
Copy link
Contributor Author

assert torch.isnan(w).sum() == 0

If still get this error may need #532. I haven't tested this patch thoroughly.

@casper-hansen
Copy link
Owner

assert torch.isnan(w).sum() == 0

If still get this error may need #532. I haven't tested this patch thoroughly.

I don't think we need it so far. I have been able to progress quite a bit. However, as you can see from the estimate, this will take quite a long time to finish.

AWQ:  22%|██▏       | 13/60 [3:52:59<16:18:56, 1249.71s/it]

@Grey4sh
Copy link

Grey4sh commented Jul 3, 2024

I'll give it a try too. Thanks for your support.

@Grey4sh
Copy link

Grey4sh commented Jul 4, 2024

@TechxGenus @casper-hansen
Quantization finished successfully, though it did take a quite long time.

AWQ: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [15:32:32<00:00, 932.54s/it]

I ran the awq model with vLLM.
When I set tensor_parallel_size=8, raise error ValueError: The input size is not aligned with the quantized weight shape.
When i set tensor_parallel_size=4, raise error AttributeError: 'MergedColumnParallelLinear' object has no attribute 'weight'. Did you mean: 'qweight'?

Here is my vLLM docker lancher script.

model_path=deepseek-coder-v2-ins-awq
tensor_parallel_size=4
port=9190

sudo docker run --gpus='"device=4,5,6,7"' \
    -v ~/.cache/huggingface:/root/.cache/huggingface \
    -v $model_path:/model/ \
    --env "HUGGING_FACE_HUB_TOKEN=<secret>" \
    -p $port:8000 \
    --ipc=host \
    vllm042_dsv2-v2:latest \
    --model /model/ \
    --served-model-name cmwCoder \
    --kv-cache-dtype fp8 \
    --quantization awq \
    --trust-remote-code \
    --seed 42 \
    --max-model-len 8192 \
    --tensor-parallel-size $tensor_parallel_size

Here is the error log when tensor_parallel_size=8

[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 169, in <module>
[rank0]:     engine = AsyncLLMEngine.from_engine_args(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 366, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 324, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 442, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 300, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 43, in _init_executor
[rank0]:     self._init_workers_ray(placement_group)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 165, in _init_workers_ray
[rank0]:     self._run_workers("load_model",
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 234, in _run_workers
[rank0]:     driver_worker_output = self.driver_worker.execute_method(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 146, in execute_method
[rank0]:     raise e
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
[rank0]:     return executor(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
[rank0]:     self.model = get_model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
[rank0]:     model = _initialize_model(model_config, self.load_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
[rank0]:     return model_class(config=model_config.hf_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
[rank0]:     self.model = DeepseekV2Model(config, quant_config)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
[rank0]:     self.layers = nn.ModuleList([
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
[rank0]:     DeepseekV2DecoderLayer(config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
[rank0]:     self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 104, in __init__
[rank0]:     self.experts = nn.ModuleList([
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp>
[rank0]:     DeepseekV2MLP(hidden_size=config.hidden_size,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 68, in __init__
[rank0]:     self.down_proj = RowParallelLinear(intermediate_size,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__
[rank0]:     self.quant_method.create_weights(self,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights
[rank0]:     raise ValueError(
[rank0]: ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution.
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] Traceback (most recent call last):
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     return executor(*args, **kwargs)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.model_runner.load_model()
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.model = get_model(
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     return loader.load_model(model_config=model_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     model = _initialize_model(model_config, self.load_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     return model_class(config=model_config.hf_config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.model = DeepseekV2Model(config, quant_config)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.layers = nn.ModuleList([
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     DeepseekV2DecoderLayer(config,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 104, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.experts = nn.ModuleList([
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp>
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     DeepseekV2MLP(hidden_size=config.hidden_size,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 68, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.down_proj = RowParallelLinear(intermediate_size,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     self.quant_method.create_weights(self,
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145]     raise ValueError(
(RayWorkerWrapper pid=7240) ERROR 07-04 01:55:38 worker_base.py:145] ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size.
(RayWorkerWrapper pid=7806) INFO 07-04 01:55:38 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json [repeated 6x across cluster]
(RayWorkerWrapper pid=7806) Cache shape torch.Size([163840, 64]) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution. [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] Traceback (most recent call last): [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     return executor(*args, **kwargs) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model [repeated 18x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.model_runner.load_model() [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.model = get_model( [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     return loader.load_model(model_config=model_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     model = _initialize_model(model_config, self.load_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     return model_class(config=model_config.hf_config, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 633, in __init__ [repeated 36x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.model = DeepseekV2Model(config, quant_config) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.layers = nn.ModuleList([ [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 105, in <listcomp> [repeated 12x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     DeepseekV2MLP(hidden_size=config.hidden_size, [repeated 12x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.experts = nn.ModuleList([ [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.down_proj = RowParallelLinear(intermediate_size, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     self.quant_method.create_weights(self, [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/awq.py", line 91, in create_weights [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145]     raise ValueError( [repeated 6x across cluster]
(RayWorkerWrapper pid=7150) ERROR 07-04 01:55:39 worker_base.py:145] ValueError: The input size is not aligned with the quantized weight shape. This can be caused by too large tensor parallel size. [repeated 6x across cluster]

Here is the error log when tensor_parallel_size=4

[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 169, in <module>
[rank0]:     engine = AsyncLLMEngine.from_engine_args(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 366, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 324, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 442, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 300, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 43, in _init_executor
[rank0]:     self._init_workers_ray(placement_group)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 165, in _init_workers_ray
[rank0]:     self._run_workers("load_model",
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/ray_gpu_executor.py", line 234, in _run_workers
[rank0]:     driver_worker_output = self.driver_worker.execute_method(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 146, in execute_method
[rank0]:     raise e
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method
[rank0]:     return executor(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 117, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 156, in load_model
[rank0]:     self.model = get_model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model
[rank0]:     model = _initialize_model(model_config, self.load_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model
[rank0]:     return model_class(config=model_config.hf_config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 458, in __init__
[rank0]:     self.model = DeepseekV2Model(config, quant_config)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 422, in __init__
[rank0]:     self.layers = nn.ModuleList([
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp>
[rank0]:     DeepseekV2DecoderLayer(config,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 363, in __init__
[rank0]:     self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 112, in __init__
[rank0]:     self.pack_params()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 134, in pack_params
[rank0]:     w1.append(expert.gate_up_proj.weight)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1709, in __getattr__
[rank0]:     raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
[rank0]: AttributeError: 'MergedColumnParallelLinear' object has no attribute 'weight'. Did you mean: 'qweight'?
(RayWorkerWrapper pid=7463) INFO 07-04 01:58:41 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1,2,3.json [repeated 2x across cluster]
(RayWorkerWrapper pid=7463) Cache shape torch.Size([163840, 64]) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] Error executing method load_model. This might cause deadlock in distributed execution. [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] Traceback (most recent call last): [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 137, in execute_method [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     return executor(*args, **kwargs) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model [repeated 6x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.model_runner.load_model() [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.model = get_model( [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     return loader.load_model(model_config=model_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     model = _initialize_model(model_config, self.load_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 88, in _initialize_model [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     return model_class(config=model_config.hf_config, [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 112, in __init__ [repeated 8x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.model = DeepseekV2Model(config, quant_config) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.layers = nn.ModuleList([ [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 423, in <listcomp> [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     w1.append(expert.gate_up_proj.weight) [repeated 4x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     self.pack_params() [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 134, in pack_params [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1709, in __getattr__ [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145]     raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") [repeated 2x across cluster]
(RayWorkerWrapper pid=7153) ERROR 07-04 01:58:42 worker_base.py:145] AttributeError: 'MergedColumnParallelLinear' object has no attribute 'weight' [repeated 2x across cluster]
[rank0]:[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

IDK what leads to the problem, awq? or vLLM? If u guys infer the deepseekV2_awq model successfully, could you tell me how to do that. THX <3

@casper-hansen
Copy link
Owner

vLLM has had problems with tensor parallelism on quantized models for a long time now. Please try the script I have put in the documentation and modify it to use tensor_parallel_size. BUT if this does not work, please raise an issue in vLLM

https://casper-hansen.github.io/AutoAWQ/examples/#vllm

@fengyang95
Copy link

fengyang95 commented Jul 9, 2024

vLLM has had problems with tensor parallelism on quantized models for a long time now. Please try the script I have put in the documentation and modify it to use tensor_parallel_size. BUT if this does not work, please raise an issue in vLLM

https://casper-hansen.github.io/AutoAWQ/examples/#vllm

I tried to load the deepseek-coder-v2-instruct-awq model using the following code with 8 L40 GPUs:

from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True)

However, it keeps failing to load successfully. I observed that the GPU memory on gpu0 is slowly increasing. Could you provide some suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants