Skip to content

Commit 75f4536

Browse files
pufanyikcz358
authored andcommitted
Fix InternVL2 model sharding (#481)
1 parent e43a580 commit 75f4536

File tree

1 file changed

+51
-3
lines changed

1 file changed

+51
-3
lines changed

lmms_eval/models/internvl2.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,55 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=3
119119
return pixel_values, num_patches_list
120120

121121

122+
import math
122123
from datetime import timedelta
123124

124125
from accelerate.state import AcceleratorState
125126
from accelerate.utils import InitProcessGroupKwargs
126127

127128

129+
# The reason for writing the code this way is to avoid errors that occur during multi-GPU inference due to tensors not being on the same device. By ensuring that the first and last layers of the large language model (LLM) are on the same device, we prevent such errors.
130+
def split_model(model_name, num_layers=None):
131+
device_map = {}
132+
world_size = torch.cuda.device_count()
133+
if num_layers is None:
134+
num_layers = {
135+
"InternVL2_5-1B": 24,
136+
"InternVL2_5-2B": 24,
137+
"InternVL2_5-4B": 36,
138+
"InternVL2_5-8B": 32,
139+
"InternVL2_5-26B": 48,
140+
"InternVL2_5-38B": 64,
141+
"InternVL2_5-78B": 80,
142+
"InternVL2-1B": 24,
143+
"InternVL2-2B": 24,
144+
"InternVL2-4B": 32,
145+
"InternVL2-8B": 32,
146+
"InternVL2-26B": 48,
147+
"InternVL2-40B": 60,
148+
"InternVL2-Llama3-76B": 80,
149+
}[model_name]
150+
# Since the first GPU will be used for ViT, treat it as half a GPU.
151+
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
152+
num_layers_per_gpu = [num_layers_per_gpu] * world_size
153+
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
154+
layer_cnt = 0
155+
for i, num_layer in enumerate(num_layers_per_gpu):
156+
for j in range(num_layer):
157+
device_map[f"language_model.model.layers.{layer_cnt}"] = i
158+
layer_cnt += 1
159+
device_map["vision_model"] = 0
160+
device_map["mlp1"] = 0
161+
device_map["language_model.model.tok_embeddings"] = 0
162+
device_map["language_model.model.embed_tokens"] = 0
163+
device_map["language_model.output"] = 0
164+
device_map["language_model.model.norm"] = 0
165+
device_map["language_model.lm_head"] = 0
166+
device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
167+
168+
return device_map
169+
170+
128171
@register_model("internvl2")
129172
class InternVL2(lmms):
130173
def __init__(
@@ -134,13 +177,14 @@ def __init__(
134177
device: str = "cuda:0",
135178
device_map: str = "cuda:0",
136179
batch_size: str = "1",
180+
num_frame: int = 32,
181+
num_layers=None,
137182
**kwargs,
138183
):
139184
super().__init__()
140185

141186
self.path = pretrained
142-
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
143-
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)
187+
self.num_frame = num_frame
144188

145189
batch_size = int(batch_size)
146190
assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}."
@@ -154,11 +198,15 @@ def __init__(
154198
self.device_map = f"cuda:{accelerator.local_process_index}"
155199
elif accelerator.num_processes == 1 and device_map == "auto":
156200
self._device = torch.device(device)
201+
device_map = split_model(pretrained.split("/")[-1], num_layers=num_layers)
157202
self.device_map = device_map
158203
else:
159204
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
160205
self.device_map = f"cuda:{accelerator.local_process_index}"
161206

207+
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
208+
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, device_map=device_map)
209+
162210
if accelerator.num_processes > 1:
163211
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
164212
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
@@ -269,7 +317,7 @@ def generate_until(self, requests) -> List[str]:
269317
elif self.modality == "video":
270318
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos."
271319
video_path = visuals[0]
272-
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
320+
pixel_values, num_patches_list = load_video(video_path, num_segments=self.num_frame)
273321
pixel_values = pixel_values.to(torch.bfloat16).cuda()
274322
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
275323
question = video_prefix + contexts

0 commit comments

Comments
 (0)