Skip to content

Commit 5da0471

Browse files
committed
support ep for fsdp
1 parent f48c4af commit 5da0471

File tree

4 files changed

+60
-17
lines changed

4 files changed

+60
-17
lines changed

internlm/checkpoint/checkpoint_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
582582
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
583583
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
584584
)
585-
elif is_using_fsdp() and is_using_hf() and not self.auto_resume:
585+
elif is_using_fsdp() and not self.auto_resume:
586586
pass
587587
else:
588588
load_path = self.load_ckpt_info["path"]

internlm/core/fsdp.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
FSDP2_SUPPORTED = False
3434

3535
try:
36+
import torch.distributed.checkpoint as dcp
3637
from torch.distributed.checkpoint.state_dict import (
3738
StateDictOptions,
39+
get_model_state_dict,
3840
set_model_state_dict,
3941
)
4042

@@ -163,8 +165,27 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
163165
)
164166
fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1")
165167
fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda")
168+
169+
if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
170+
assert gpc.get_world_size(ParallelMode.EXPERT_DATA) * gpc.get_world_size(ParallelMode.EXPERT) == gpc.get_world_size(ParallelMode.GLOBAL)
166171

167172
if fsdp_mode == "v1":
173+
ignored_mod = []
174+
if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
175+
for layer_id, layer in enumerate(model.model.layers):
176+
if layer_id >= gpc.config.model.first_k_dense_replace:
177+
layer.feed_forward.moe_layer.experts = FSDP(
178+
layer.feed_forward.moe_layer.experts,
179+
process_group=gpc.get_group(ParallelMode.EXPERT_DATA),
180+
sharding_strategy=ShardingStrategy.FULL_SHARD,
181+
sync_module_states=fsdp_init_method != "cuda", # sync model paramters
182+
forward_prefetch=True,
183+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
184+
limit_all_gathers=True,
185+
use_orig_params=True,
186+
device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
187+
)
188+
ignored_mod.append(layer.feed_forward.moe_layer.experts)
168189
model = FSDP(
169190
module=model,
170191
process_group=gpc.get_group(ParallelMode.GLOBAL),
@@ -176,6 +197,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
176197
limit_all_gathers=True,
177198
use_orig_params=True,
178199
device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
200+
ignored_modules=ignored_mod,
179201
)
180202
# For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
181203
# This hack is needed due to FSDP v1 lazy initialization in model construction.
@@ -196,7 +218,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
196218
else:
197219
raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}")
198220

199-
if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False):
221+
if not gpc.config.ckpt.get("auto_resume", False):
200222
load_ckpt_info = gpc.config.ckpt.load_ckpt_info
201223
load_ckpt_path = load_ckpt_info.get("path", None)
202224
load_ckpt_content = load_ckpt_info.get("content", [])
@@ -205,19 +227,25 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
205227
"model",
206228
), "If auto_resume=False and checkpoint path is given, only model can be loaded"
207229
if DCP_SUPPORTED:
208-
hf = gpc.config.hf
209-
mod = LazyObject(hf.mod, hf.mod_cls)
210-
mod = mod.build()
211-
state_dict = mod.from_pretrained(
212-
pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
213-
).state_dict()
214-
state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
215-
set_model_state_dict(
216-
model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
217-
)
230+
if is_using_hf():
231+
hf = gpc.config.hf
232+
mod = LazyObject(hf.mod, hf.mod_cls)
233+
mod = mod.build()
234+
state_dict = mod.from_pretrained(
235+
pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
236+
).state_dict()
237+
state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
238+
set_model_state_dict(
239+
model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
240+
)
241+
else:
242+
state_dict = get_model_state_dict(model=model)
243+
state_dict = {key: state_dict[key].clone().detach() for key in state_dict}
244+
dcp.load(state_dict=state_dict, checkpoint_id=load_ckpt_path)
245+
set_model_state_dict(model=model, model_state_dict=state_dict)
218246
del state_dict
219247
internlm_accelerator.empty_cache()
220248
else:
221249
raise RuntimeError("DCP is not supported in this version of PyTorch.")
222250

223-
return model
251+
return model

internlm/initialize/initialize_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def split_params_into_different_groups_for_optimizer(
5050

5151
if is_using_fsdp():
5252
optimizer_mode = ParallelMode.GLOBAL
53-
optimizer_mode_expert = ParallelMode.GLOBAL
53+
optimizer_mode_expert = ParallelMode.EXPERT_DATA
5454
else:
5555
optimizer_mode = ParallelMode.ZERO1
5656
optimizer_mode_expert = ParallelMode.EXPERT_DATA

internlm/solver/optimizer/fsdp_optimizer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_norm,
1717
release_param_grad,
1818
)
19-
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
19+
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
2020
from internlm.utils.config import Config
2121
from internlm.utils.logger import get_logger
2222

@@ -37,6 +37,7 @@
3737
def compute_norm(
3838
gradients: Iterable[torch.Tensor],
3939
parameters: Iterable[torch.Tensor],
40+
zero_mode,
4041
) -> float:
4142
"""Get L2 norm
4243
Arguments:
@@ -61,7 +62,17 @@ def compute_norm(
6162
if DTENSOR_SUPPORTED and isinstance(total_norm, DTensor):
6263
total_norm = total_norm.full_tensor()
6364

64-
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.GLOBAL))
65+
if gpc.is_using_parallel_mode(zero_mode):
66+
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
67+
68+
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
69+
# model and zero have been reduced!!!
70+
if zero_mode == ParallelMode.EXPERT_DATA:
71+
pg = gpc.get_group(ParallelMode.EXPERT)
72+
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
73+
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
74+
dist.all_reduce(scaled_norm_tensor, group=pg)
75+
total_norm = scaled_norm_tensor.item()
6576

6677
if torch.is_tensor(total_norm):
6778
total_norm = total_norm.item()
@@ -112,10 +123,14 @@ def __init__(
112123
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
113124
self._fp16_param_groups = dict()
114125
self._fp32_param_tensor_groups = dict()
126+
self._broadcast_parallel_mode = []
115127

116128
# init fp16 and fp32 params
117129
for group_idx, param_group in enumerate(self.optim.param_groups):
118130
group_params = param_group["params"]
131+
132+
zero_mode = param_group["optimizer_mode"]
133+
self._broadcast_parallel_mode.append(zero_mode)
119134

120135
# fp16 FlatParam storage
121136
self._fp16_param_groups[group_idx] = group_params
@@ -142,7 +157,7 @@ def _compute_norm_with_fsdp_flatten(self, group_id):
142157
norm_group = 0
143158
if len(params) <= 0 or len(gradients) <= 0:
144159
return norm_group
145-
norm_group = compute_norm(gradients=gradients, parameters=params)
160+
norm_group = compute_norm(gradients=gradients, parameters=params, zero_mode=self._broadcast_parallel_mode[group_id])
146161

147162
return norm_group
148163

0 commit comments

Comments
 (0)