-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running DeepSeek-VL2 with multiple cards #8
Comments
Well, the code works if I have the following change: outputs = vl_gpt.language.generate(
input_ids = prepare_inputs["input_ids"].to(vl_gpt.device),
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
) But it just prints some meaning less words, such as:
|
Hi @robinren03, please consider this minimal code for model sharding. Based on my preliminary tests, it has successfully run our 16B and 27B models. However, I have not extensively tested which sharding strategy performs the best. You may want to experiment with tuning these parameters later based on your specific requirements. import torch
from transformers import AutoModelForCausalLM
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from deepseek_vl2.utils.io import load_pil_images
def split_model(model_name):
device_map = {}
model_splits = {
'deepseek-ai/deepseek-vl2-small': [13, 14], # 2 GPU for 16b
'deepseek-ai/deepseek-vl2': [10, 10, 10], # 3 GPU for 27b
}
num_layers_per_gpu = model_splits[model_name]
num_layers = sum(num_layers_per_gpu)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision'] = 0
device_map['projector'] = 0
device_map['image_newline'] = 0
device_map['view_seperator'] = 0
device_map['language.model.embed_tokens'] = 0
device_map['language.model.norm'] = 0
device_map['language.lm_head'] = 0
device_map[f'language.model.layers.{num_layers - 1}'] = 0
return device_map
# specify the path to the model
model_path = 'deepseek-ai/deepseek-vl2'
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
device_map = split_model(model_path)
vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,torch_dtype=torch.bfloat16, device_map=device_map).eval()
## single image conversation example
conversation = [
{
"role": "<|User|>",
"content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
"images": ["./images/visual_grounding.jpeg"],
},
{"role": "<|Assistant|>", "content": ""},
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True,
system_prompt=""
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = vl_gpt.language.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer) |
After the update, deepseek-vl2 appears to have an issue with KV Cache |
Doesn't work for me too. I am on two H100 cards. I followed #8 (comment) but I do get garbage output. Here is my code: (I downloaded the model beforehand to local-dir) import torch
from transformers import AutoModelForCausalLM
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from deepseek_vl2.utils.io import load_pil_images
def split_model():
device_map = {}
num_layers_per_gpu = [15, 15]
num_layers = sum(num_layers_per_gpu)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision'] = 0
device_map['projector'] = 0
device_map['image_newline'] = 0
device_map['view_seperator'] = 0
device_map['language.model.embed_tokens'] = 0
device_map['language.model.norm'] = 0
device_map['language.lm_head'] = 0
device_map[f'language.model.layers.{num_layers - 1}'] = 0
return device_map
# specify the path to the model
model_path = 'deepseek'
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained("deepseek-ai/deepseek-vl2")
tokenizer = vl_chat_processor.tokenizer
device_map = split_model()
vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True,torch_dtype=torch.bfloat16, device_map=device_map).eval()
## single image conversation example
conversation = [
{
"role": "<|User|>",
"content": "<image>\n<|ref|>All Apps Button<|/ref|>.",
"images": ["android_home_screen.png"],
},
{"role": "<|Assistant|>", "content": ""},
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True,
system_prompt=""
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = vl_gpt.language.generate(
inputs_embeds=inputs_embeds,
input_ids = prepare_inputs["input_ids"],
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer) |
I used three 16G NVIDIA V100 to load the DeepSeek-VL2-small model, and failed.
|
I have 3 A6000 GPUs with 48GB memory each, and I need to use TP to load DeepSeek-VL2 model into the GPUs (not the tiny / small ones).
Here is my code.
And it comes into the following problem:
Is there any official guidelines about using TP (i.e. MP) or PP to run the larger models? Or could you please point out my mistake. Thanks a lot!
The text was updated successfully, but these errors were encountered: