-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed training is stuck #21217
Comments
The latest jax and jaxlib versions are 0.4.28, can you try them first? 0.4.23 is pretty old. |
On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library. |
@hawkinsp hi,How to configure configuration in Jax can make one GPU per process. ps When one host, I can train multiple GPUs by pmap, but when multi host it is stuck |
How did you launch the job? Are you using a cluster scheduler of some kind? If you're using one that JAX already integrates with (e.g., SLURM) we have code handle this already, but perhaps you're not using one. Basically you need to do two things: b) When you call Does that answer the question? |
@hawkinsp
host1:
also, i have tried launch the distributed training on two a100 machine, single-machine training is ok, but distributed training is stuck i am trying SLURM, but not succeed yet,Is SLURM necessary for jax distribution? |
Please confirm you're using jax 0.4.28. @syyxsxx Those warnings might mean compilation is slow, but shouldn't cause a deadlock. If you're using I think I'll need a reproduction of the problem to help further. |
Description
I use two 4090 host for data parallel distributed training by jax.distributed, like this:
jax.distributed.initialize(coordinator_address="[ip]:[port]",
num_processes=2,
process_id=[index])
the train is stuck when doing all_reduce ops
How can I debug this problem?
Are there any examples for parallel distributed training
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.26.3
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1
$ nvidia-smi
Mon May 13 19:53:47 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:45:00.0 Off | Off |
| 30% 28C P2 39W / 450W | 406MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 Off | 00000000:46:00.0 Off | Off |
| 72% 61C P2 411W / 450W | 19697MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 Off | 00000000:49:00.0 Off | Off |
| 78% 62C P2 418W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 Off | 00000000:4E:00.0 Off | Off |
| 73% 61C P2 396W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 Off | 00000000:4F:00.0 Off | Off |
| 71% 61C P2 407W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 Off | 00000000:C5:00.0 Off | Off |
| 73% 61C P2 411W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA GeForce RTX 4090 Off | 00000000:C6:00.0 Off | Off |
| 80% 63C P2 416W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA GeForce RTX 4090 Off | 00000000:C9:00.0 Off | Off |
| 78% 62C P2 402W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 8 NVIDIA GeForce RTX 4090 Off | 00000000:CE:00.0 Off | Off |
| 73% 61C P2 382W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 9 NVIDIA GeForce RTX 4090 Off | 00000000:CF:00.0 Off | Off |
| 78% 62C P2 404W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
The text was updated successfully, but these errors were encountered: