Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [Fix Suggestion] Uneven head sequence parallelism #6774

Open
Eugene29 opened this issue Nov 21, 2024 · 0 comments
Open

[BUG] [Fix Suggestion] Uneven head sequence parallelism #6774

Eugene29 opened this issue Nov 21, 2024 · 0 comments
Labels
bug Something isn't working training

Comments

@Eugene29
Copy link

Describe the bug

deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"

This happens because during the second all2all, the head count is already parallelized; hence, num_heads % seq_world_size != 0 returns true.
Second all2all input: [B, s, hc/sp, hs]. However, not always hc/sp % sp == 0.

def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
    seq_world_size = dist.get_world_size(group)
    # we only need num_heads once
    num_heads = input.shape[2]

    if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
        # Assuming here that the number of heads for q is consistent with kv
        # If not, additional logic is required for cases like GQA
        if get_num_kv_heads() is None:
            assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
            # set heads at first call by num_total_heads.
            # then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
            set_num_kv_heads(num_heads)
        assert async_op == False, "uneven head sp does not support async op"
        return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

To Reproduce

To reproduce the error, one can set the SP=head_count.

Fix Suggestion:
Adjust num_heads accordingly:

num_heads = input.shape[2]
if scatter_idx < 2:
    num_heads = seq_world_size * num_heads
@Eugene29 Eugene29 added bug Something isn't working training labels Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant