33
33
FSDP2_SUPPORTED = False
34
34
35
35
try :
36
+ import torch .distributed .checkpoint as dcp
36
37
from torch .distributed .checkpoint .state_dict import (
37
38
StateDictOptions ,
39
+ get_model_state_dict ,
38
40
set_model_state_dict ,
39
41
)
40
42
@@ -163,8 +165,27 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
163
165
)
164
166
fsdp_mode = gpc .config .parallel .fsdp .get ("mode" , "v1" )
165
167
fsdp_init_method = gpc .config .parallel .fsdp .get ("init_method" , "cuda" )
168
+
169
+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
170
+ assert gpc .get_world_size (ParallelMode .EXPERT_DATA ) * gpc .get_world_size (ParallelMode .EXPERT ) == gpc .get_world_size (ParallelMode .GLOBAL )
166
171
167
172
if fsdp_mode == "v1" :
173
+ ignored_mod = []
174
+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
175
+ for layer_id , layer in enumerate (model .model .layers ):
176
+ if layer_id >= gpc .config .model .first_k_dense_replace :
177
+ layer .feed_forward .moe_layer .experts = FSDP (
178
+ layer .feed_forward .moe_layer .experts ,
179
+ process_group = gpc .get_group (ParallelMode .EXPERT_DATA ),
180
+ sharding_strategy = ShardingStrategy .FULL_SHARD ,
181
+ sync_module_states = fsdp_init_method != "cuda" , # sync model paramters
182
+ forward_prefetch = True ,
183
+ backward_prefetch = BackwardPrefetch .BACKWARD_PRE ,
184
+ limit_all_gathers = True ,
185
+ use_orig_params = True ,
186
+ device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
187
+ )
188
+ ignored_mod .append (layer .feed_forward .moe_layer .experts )
168
189
model = FSDP (
169
190
module = model ,
170
191
process_group = gpc .get_group (ParallelMode .GLOBAL ),
@@ -176,6 +197,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
176
197
limit_all_gathers = True ,
177
198
use_orig_params = True ,
178
199
device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
200
+ ignored_modules = ignored_mod ,
179
201
)
180
202
# For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
181
203
# This hack is needed due to FSDP v1 lazy initialization in model construction.
@@ -196,7 +218,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
196
218
else :
197
219
raise ValueError (f"Unsupported FSDP mode: { fsdp_mode } " )
198
220
199
- if is_using_hf () and not gpc .config .ckpt .get ("auto_resume" , False ):
221
+ if not gpc .config .ckpt .get ("auto_resume" , False ):
200
222
load_ckpt_info = gpc .config .ckpt .load_ckpt_info
201
223
load_ckpt_path = load_ckpt_info .get ("path" , None )
202
224
load_ckpt_content = load_ckpt_info .get ("content" , [])
@@ -205,19 +227,25 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
205
227
"model" ,
206
228
), "If auto_resume=False and checkpoint path is given, only model can be loaded"
207
229
if DCP_SUPPORTED :
208
- hf = gpc .config .hf
209
- mod = LazyObject (hf .mod , hf .mod_cls )
210
- mod = mod .build ()
211
- state_dict = mod .from_pretrained (
212
- pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
213
- ).state_dict ()
214
- state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
215
- set_model_state_dict (
216
- model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
217
- )
230
+ if is_using_hf ():
231
+ hf = gpc .config .hf
232
+ mod = LazyObject (hf .mod , hf .mod_cls )
233
+ mod = mod .build ()
234
+ state_dict = mod .from_pretrained (
235
+ pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
236
+ ).state_dict ()
237
+ state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
238
+ set_model_state_dict (
239
+ model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
240
+ )
241
+ else :
242
+ state_dict = get_model_state_dict (model = model )
243
+ state_dict = {key : state_dict [key ].clone ().detach () for key in state_dict }
244
+ dcp .load (state_dict = state_dict , checkpoint_id = load_ckpt_path )
245
+ set_model_state_dict (model = model , model_state_dict = state_dict )
218
246
del state_dict
219
247
internlm_accelerator .empty_cache ()
220
248
else :
221
249
raise RuntimeError ("DCP is not supported in this version of PyTorch." )
222
250
223
- return model
251
+ return model
0 commit comments