Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [Fix-Suggested] ZeRO Stage 3 Overwrites Module ID Attribute Causing Incorrect Expert Placement on GPUs #6772

Open
traincheck-team opened this issue Nov 20, 2024 · 0 comments
Labels
bug Something isn't working training

Comments

@traincheck-team
Copy link

Description

We experienced wrong GPU placement when doing MoE with ZeRO Stage 3. We use module.id to control which expert to be loaded onto which GPU for finegrained controlm and we find out that module.id got corrupted after deepspeed.initialize.

Suspected Root Cause

DeepSpeed uses .id in ZeRO Stage 3 optimization to manage states, as seen in runtime/zero/parameter_offload.py:L271.

This practice is very brittle in that:

  1. id is an overly generic attribute name, might get easilly collided with some user-defined attributes.
  2. There's no special check on .id attribute before setting it, this allows for accidental overwrites of the attribute, causing hard-to-diagnose problems.

In the specific bug we've encountered (bug.py provided below), each expert module is identified by the .id attribute, but during initialization, the .id is overwritten by the _register_hooks_recursively function in deepspeed/runtime/zero/stage3.py, leading to a mess on expert-GPU placement.

To reproduce

The following code in ZeRO Stage 3 is responsible for overwriting the .id attribute:

  1. Install deepspeed 0.15.4

  2. run bug.py using deepspeed --num_gpus=2 bug.py (num_gpus argument here doesn't matter, use 1 if you don't have multigpu nodes.)

import torch
import deepspeed
from torch.nn import Module, Linear

# Define a simple expert module
class Expert(Module):
    def __init__(self, id):
        super().__init__()
        self.id = id  # ID for custom GPU placement
        self.fc = Linear(128, 128)

    def forward(self, x):
        return self.fc(x)

# Create a model with 60 experts
class MoEModel(Module):
    def __init__(self):
        super().__init__()
        self.experts = torch.nn.ModuleList([Expert(i) for i in range(60)])
    def forward(self, x, expert_id):
        return self.experts[expert_id](x)

# Helper function to log expert ids
def log_expert_ids(model, rank):
    loaded_experts = [e.id for e in model.experts]

def main():
    deepspeed.init_distributed()
    rank = torch.distributed.get_rank()

    # Create model
    model = MoEModel()
    log_expert_ids(model, rank)  # prints 0, 1, 2, .., 59

    # Configure DeepSpeed
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        optimizer=torch.optim.Adam(model.parameters(), lr=3e-5),
        config={
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "steps_per_print": 1,
            "zero_optimization": {"stage": 3,}
        }
    )

    # print model ids again after deepspeed.initialize
    log_expert_ids(model, rank)  # prints 0, 2, 4, 6, ...

    # if you do a deepspeed.intialize here again, you will see the id itself completely messed up.

    dummy_input = torch.randn(1, 128).cuda(rank)
    for expert_id in range(60):
        model_engine(dummy_input, expert_id=expert_id)

if __name__ == "__main__":
    main()
  1. We print ids of all experts twice, one before deepspeed.initialize and one after that. Observe that the first print gives 0, 1, 2, ..., 59 while the second one gives 2, 4, 6, 8, .., 120

In this code, module.id is set to a value based on a counter (my_count), which conflicts with user-defined .id attributes used for expert placement.

Bug Significance

This bug can significantly affect model behavior when expert modules are incorrectly placed across GPUs, leading to incorrect training outcomes or potential crashes. Ensuring that internal DeepSpeed modifications do not overwrite user-defined attributes is crucial for stability and expected functionality.

Even if user-side conflicts are not in your scope, deepspeed itself can accidently modify these attributes as well. For example, you can reproduce the same problem by calling deepspeed.initialize multiple times.

Thus, we argue for two fixes / engineering practices for this issue.

Expected Behavior / Suggested Fix

  1. Use a Specific Attribute for Internal IDs: Instead of overwriting .id, use a more specific attribute name such as _deepspeed_id to avoid conflicts with user-defined attributes.
  2. Restrict Attribute Modification: Modify the __setattr__ method to only allow setting fields that have not been previously set, preventing unintentional overwrites of user-defined attributes.
  3. **Forbid Duplicated deepspeed.initialize: We observe a lot of issue with accidental duplicate calls to deepspeed.initialize. Thus we suggest to forbid duplicate calls by recording the models / optimizers that have already been inited, as mentioned in [BUG] [Fix-Suggested] KeyError in stage_1_and_2.py Due to Optimizer-Model Parameter Mismatch #6770 .

ds_report output

Click to Show
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/xxx/python3.10/site-packages/torch']
torch version .................... 2.2.2+cu121
deepspeed install path ........... ['/home/xxx/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.15.4, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 2.2, cuda 12.1
shared memory (/dev/shm) size .... 31.24 GB

I will be more than happy to contribute to the two suggested fixes, let me know what you think!

@traincheck-team traincheck-team added bug Something isn't working training labels Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant