You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm experiencing an issue when using DeepSpeed's universal checkpointing. After converting my DeepSpeed checkpoint to a universal checkpoint using ds_to_universal.py, resuming training from the converted checkpoint results in the model behaving as if it was initialized from scratch—the training loss is significantly higher, similar to starting training from the beginning.
Environment:
DeepSpeed version: 0.15.0
PyTorch version: 2.4.1
Python version: 3.9
Model: Llama
ZeRO Optimization Stage: 3
Number of GPUs: 16
Distributed Backend: NCCL
Steps to Reproduce:
Train a model using DeepSpeed with ZeRO Stage 3 optimization and save checkpoints.
Use ds_to_universal.py to convert the DeepSpeed checkpoint to a universal checkpoint:
[2024-10-30 11:42:34,345] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
args = Namespace(input_folder='.global/', output_folder='global_universal', num_extract_workers=1, num_merge_workers=1, keep_temp_folder=False, strict=True, inject_missing_state=False)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in ./global_universal
ds_to_universal.py:449: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
ds_to_universal.py:407: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
*** 1. Extracting ZeRO fragments
0%| | 0/16 [00:00<?, ?it/s] ds_to_universal.py:153: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:46<00:00, 6.68s/it]
*** 2. Merging slices .....
0%| | 0/201 [00:00<?, ?it/s] ds_to_universal.py:217: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
shards = [torch.load(p) for p in paths]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [01:34<00:00, 2.13it/s]
*** 3. Saving common optimizer states
ds_to_universal.py:423: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
*** Done!
Modify the training script to resume training from the universal checkpoint:
Set load_universal to True in the DeepSpeed config.
Load the checkpoint using model.load_checkpoint().
Resume training using the converted universal checkpoint.
Expected Behavior:
The model should resume training from the checkpointed state, with training loss consistent with the point at which training was paused.
Actual Behavior:
Upon resuming training, the training loss is significantly higher, as if the model weights and optimizer states were not correctly restored.
It appears that the model is starting from random initialization rather than the checkpointed state.
Additional Details:
Optimizer State Loading:
After loading the checkpoint the optimizer states are updated
Checkpoint Conversion:
The conversion script ds_to_universal.py runs without errors.
The generated universal checkpoint seems to have the correct structure.
Resuming from Original Checkpoint:
When resuming training from the original DeepSpeed checkpoint (before conversion), the training loss is consistent, indicating the issue arises after conversion to the universal checkpoint.
Ensured that load_universal is set to True in the DeepSpeed config.
Captured optimizer and lr_scheduler returned by deepspeed.initialize().
Monitoring:
Checked learning rate and global steps before and after loading the checkpoint.
Observed that the learning rate remains the same, and global steps do not reflect the checkpointed state.
Testing:
Reduced the number of extract and merge workers to 1 when running ds_to_universal.py to rule out parallelism issues.
Ensured version consistency of DeepSpeed and PyTorch across training, conversion, and resumption.
Questions:
All the examples provided already have a high loss like the ones in the repo continued training from loss of 7 or so, is there an example continuing training from a chekcpoint that is trained well for a long time/high number of tokens?
Is there a known issue with ds_to_universal.py not correctly converting optimizer states for ZeRO Stage 3 and continuing with higher loss?
Any help or guidance would be greatly appreciated.
The text was updated successfully, but these errors were encountered:
I'm experiencing an issue when using DeepSpeed's universal checkpointing. After converting my DeepSpeed checkpoint to a universal checkpoint using
ds_to_universal.py
, resuming training from the converted checkpoint results in the model behaving as if it was initialized from scratch—the training loss is significantly higher, similar to starting training from the beginning.Environment:
Steps to Reproduce:
Train a model using DeepSpeed with ZeRO Stage 3 optimization and save checkpoints.
Use
ds_to_universal.py
to convert the DeepSpeed checkpoint to a universal checkpoint:Click me
[2024-10-30 11:42:34,345] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
args = Namespace(input_folder='.global/', output_folder='global_universal', num_extract_workers=1, num_merge_workers=1, keep_temp_folder=False, strict=True, inject_missing_state=False)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in ./global_universal
ds_to_universal.py:449: FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
ds_to_universal.py:407: FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
*** 1. Extracting ZeRO fragments
0%| | 0/16 [00:00<?, ?it/s] ds_to_universal.py:153: FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.state_dict = torch.load(optim_files[dp_index], map_location='cpu')
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:46<00:00, 6.68s/it]
*** 2. Merging slices .....
0%| | 0/201 [00:00<?, ?it/s] ds_to_universal.py:217: FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.shards = [torch.load(p) for p in paths]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [01:34<00:00, 2.13it/s]
*** 3. Saving common optimizer states
ds_to_universal.py:423: FutureWarning: You are using
torch.load
withweights_only=False
(the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value forweights_only
will be flipped toTrue
. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user viatorch.serialization.add_safe_globals
. We recommend you start settingweights_only=True
for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
*** Done!
Modify the training script to resume training from the universal checkpoint:
load_universal
toTrue
in the DeepSpeed config.model.load_checkpoint()
.Resume training using the converted universal checkpoint.
Expected Behavior:
Actual Behavior:
Additional Details:
Optimizer State Loading:
Checkpoint Conversion:
ds_to_universal.py
runs without errors.Resuming from Original Checkpoint:
Code Snippets:
Relevant sections of my training script:
DeepSpeed configuration:
Attempts to Resolve:
load_universal
is set toTrue
in the DeepSpeed config.optimizer
andlr_scheduler
returned bydeepspeed.initialize()
.ds_to_universal.py
to rule out parallelism issues.Questions:
ds_to_universal.py
not correctly converting optimizer states for ZeRO Stage 3 and continuing with higher loss?Any help or guidance would be greatly appreciated.
The text was updated successfully, but these errors were encountered: