Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge LoCo with Zero++ #6730

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

Conversation

XingyuXie
Copy link

Integration of LoCo Method into ZeRO++

Overview

This PR introduces the integration of the LoCo method, as outlined in this paper, into the ZeRO++ framework of DeepSpeed. The key enhancement involves applying error feedback compensation to 4-bit gradients before communication. This approach improves pre-training loss outcomes without additional time overhead, though it requires extra GPU memory. The extent of this memory increase depends on model size and training configuration.

Experimental Results

We conducted pre-training experiments using the Llama2 architecture, adjusting the number of layers and hidden size. The experiments included:

  • A smaller-scale model with 0.8B parameters trained on 30B tokens.
  • A larger-scale model with 8B parameters trained on 5B tokens.

The training data was sampled from Redpajama-V2.

Findings:

  • Smaller Models (0.8B parameters): Significant gains were observed when applying the LoCo method.
  • Larger Models (8B parameters): The gains were present but less pronounced. This could be due to:
    1. Relatively smaller data volume.
    2. Lower pre-training loss for larger models, making significant improvements harder to achieve.

However, even a smaller pre-training loss gap in larger models can translate to meaningful gains in downstream tasks.

Example Script

For reference, the run.sh script used for the 8B parameter, 5B tokens experiment is attached. The experiment was conducted using the DeepSpeed-Megatron platform.

Acknowledgments

Special thanks to cc @GuanhuaWang for ongoing communication and guidance throughout this work.


We appreciate your consideration of this PR and welcome any feedback or questions!

@XingyuXie
Copy link
Author

@XingyuXie please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@loadams loadams requested review from GuanhuaWang and removed request for awan-10 November 12, 2024 14:48
Copy link
Member

@GuanhuaWang GuanhuaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XingyuXie thx for this effort.

Overall looks good to me. Just left a few comments

@@ -301,6 +302,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Boolean indicating whether to use quantized zero gradients
for efficient all_2_all_reduce comm
"""
zeropp_loco_param: Optional[Dict[str, Any]] = None
"""
loco-zero++ params
Copy link
Member

@GuanhuaWang GuanhuaWang Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe having a brief introduction of two params loco needed and what are prefered default values

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice, we already add the basic description here in the update.

@@ -1379,7 +1382,11 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]
global_world_size = dist.get_world_size()
num_nodes = global_world_size // local_world_size
if self.all2all_process_group is not None and num_nodes > 1:
grad_partitions_for_rank = all_to_all_quant_reduce(full_grads_for_rank, self.all2all_process_group)
grad_partitions_for_rank = (
all_to_all_loco_quant_reduce(params_to_reduce, self.all2all_process_group, self.zeropp_loco_param)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why loco_quant has params_to_reduce as first argument, not full_grads_for_rank

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This is because we want to use param.xxx within the all_to_all_loco_quant_reduce function to store, access, and modify the error feedback buffer corresponding to each parameter. For example, in the function, we use p.inter_ef_buf[0]. Since full_grads_for_rank is a list of tensors, assigning or maintaining attributes directly to it is not convenient.

  2. Another reason is to minimize modifications to stage3.py. By restricting changes to our own functions, we reduce the risk of introducing unintended bugs.

csrc/includes/quantization_utils.h Outdated Show resolved Hide resolved
Comment on lines 238 to 239
__half2 local_buffer[totalChunks * quantize::h_per_load];
__half2 err_buffer[totalChunks * quantize::h_per_load];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You use both quantize::h_per_load and quantize::h2_per_load in this kernel code. Can you double check to make sure that both of them were used correctly? This one doesn't match the code line 245 and 246.

Copy link
Author

@XingyuXie XingyuXie Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. And test on Llama2-8B. The loss curve seems okay.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants