Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Pipeline Parallel in Knowledge Distillation #11531

Open
zhaoyang-star opened this issue Dec 10, 2024 · 5 comments
Open

Support Pipeline Parallel in Knowledge Distillation #11531

zhaoyang-star opened this issue Dec 10, 2024 · 5 comments
Assignees

Comments

@zhaoyang-star
Copy link

As the nemo-docs says, there are limitations of KD impl in NeMo:

  • Only Megatron Core-based GPT models are supported
  • Only logit-pair distillation is supported for now
  • Pipeline parallelism not yet supported
  • FSDP strategy not yet supported

For the 3rd and 4th, both of them are very important features for KD as the teacher model is much larger. TP may cause less training throughput than PP. So will the 2 features be in Roadmap? Thanks. @AAnoosheh @ko3n1g

@AAnoosheh
Copy link
Collaborator

AAnoosheh commented Dec 12, 2024

Hi @zhaoyang-star,

PP will be available soon. First in Megatron-LM then followed by NeMo immediately afterwards.

Pytorch FSDP is supported for Model Optimizer based distillation in general, though the NeMo FSDP strategy has not been looked into yet and I cannot provide an ETA right now.

Thank you for the interest!

@zhaoyang-star
Copy link
Author

@AAnoosheh Thanks for your great work! Recently I used distillation on NeMo. It realy help me a lot.
Just one more question, how to record kd_loss in the tensorboard? reduced_train_loss is already recorded on tensorboard. I tried to print kd_loss in /usr/local/lib/python3.10/dist-packages/modelopt/torch/distill/loss_balancers.py but got no output info.

        aggregate_loss = sum(
            loss * weight for loss, weight in zip(kd_loss_dict.values(), self._kd_loss_weight)
        )
        kd_loss = aggregate_loss
        if output_loss is not None:
            aggregate_loss += (1.0 - sum(self._kd_loss_weight)) * output_loss
        # use 3 ways
        print(f"total_loss: {aggregate_loss}, ce_loss: {output_loss}, kd_loss: {kd_loss}", flush=True)
        print_rank_0(f"total_loss: {aggregate_loss}, ce_loss: {output_loss}, kd_loss: {kd_loss}")
        logging.info(f"total_loss: {aggregate_loss}, ce_loss: {output_loss}, kd_loss: {kd_loss}")

        return aggregate_loss

@AAnoosheh
Copy link
Collaborator

Hi, I need to look into that and see how NeMo's logging works, thanks for letting me know.

@zhaoyang-star
Copy link
Author

zhaoyang-star commented Dec 17, 2024

Hi, I need to look into that and see how NeMo's logging works, thanks for letting me know.

@AAnoosheh Realy thanks for your quick reply. There is few info when google it. Recently I tried to use NeMo Distillation API to train a 8B model and met these 2 problems. The main problem is the trainning time per step is very high, which is ~ 150TFLOPs (TP is enabled) per GPU. While the performance is 490TFLOPs per GPU (PP is enabled) when pretrain it from scratch. I guess it is because NeMo Distillation API only support TP. The machine I used is H800, which the intra bandwidth is half of the H100.

Looking forward to your feedback.

@SeanLi-OI
Copy link

SeanLi-OI commented Dec 23, 2024

Hi @zhaoyang-star,

PP will be available soon. First in Megatron-LM then followed by NeMo immediately afterwards.

Pytorch FSDP is supported for Model Optimizer based distillation in general, though the NeMo FSDP strategy has not been looked into yet and I cannot provide an ETA right now.

Thank you for the interest!

@AAnoosheh Hi, may I inquire if there is a roadmap or ETA for Knowledge Distillation in Megatron-LM? I've noticed that there is currently no support for KD in Megatron-LM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants