Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed training is stuck #21217

Open
syyxsxx opened this issue May 14, 2024 · 6 comments
Open

Distributed training is stuck #21217

syyxsxx opened this issue May 14, 2024 · 6 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@syyxsxx
Copy link

syyxsxx commented May 14, 2024

Description

I use two 4090 host for data parallel distributed training by jax.distributed, like this:
jax.distributed.initialize(coordinator_address="[ip]:[port]",
num_processes=2,
process_id=[index])
the train is stuck when doing all_reduce ops
2611715271422_ pic
How can I debug this problem?
Are there any examples for parallel distributed training

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.26.3
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1

$ nvidia-smi
Mon May 13 19:53:47 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:45:00.0 Off | Off |
| 30% 28C P2 39W / 450W | 406MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 Off | 00000000:46:00.0 Off | Off |
| 72% 61C P2 411W / 450W | 19697MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 Off | 00000000:49:00.0 Off | Off |
| 78% 62C P2 418W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 Off | 00000000:4E:00.0 Off | Off |
| 73% 61C P2 396W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 Off | 00000000:4F:00.0 Off | Off |
| 71% 61C P2 407W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 Off | 00000000:C5:00.0 Off | Off |
| 73% 61C P2 411W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA GeForce RTX 4090 Off | 00000000:C6:00.0 Off | Off |
| 80% 63C P2 416W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA GeForce RTX 4090 Off | 00000000:C9:00.0 Off | Off |
| 78% 62C P2 402W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 8 NVIDIA GeForce RTX 4090 Off | 00000000:CE:00.0 Off | Off |
| 73% 61C P2 382W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 9 NVIDIA GeForce RTX 4090 Off | 00000000:CF:00.0 Off | Off |
| 78% 62C P2 404W / 450W | 19689MiB / 24564MiB | 100% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+

@syyxsxx syyxsxx added the bug Something isn't working label May 14, 2024
@yashk2810 yashk2810 added the NVIDIA GPU Issues specific to NVIDIA GPUs label May 14, 2024
@yueshengys
Copy link
Member

The latest jax and jaxlib versions are 0.4.28, can you try them first? 0.4.23 is pretty old.

@hawkinsp
Copy link
Member

On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.

@syyxsxx
Copy link
Author

syyxsxx commented May 15, 2024

On GPU, it might also be worth trying a configuration that has one GPU per process. That may avoid a class of deadlocks in NVIDIA's NCCL library.

@hawkinsp hi,How to configure configuration in Jax can make one GPU per process. ps When one host, I can train multiple GPUs by pmap, but when multi host it is stuck

@hawkinsp
Copy link
Member

How did you launch the job? Are you using a cluster scheduler of some kind? If you're using one that JAX already integrates with (e.g., SLURM) we have code handle this already, but perhaps you're not using one.

Basically you need to do two things:
a) run one process per GPU, and arrange that it has visibility to only the GPU it is supposed to have. If you are running multiple processes on a single machine with multiple GPUs, you can limit which GPUs any given JAX process sees by setting JAX_CUDA_VISIBLE_DEVICES to 0, 1, ... for each process within a machine.

b) When you call jax.distributed.initialize to set up a distributed training job, set process_id and num_processes to reflect the fact you have one process per GPU
(https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html).

Does that answer the question?

@syyxsxx
Copy link
Author

syyxsxx commented May 28, 2024

@hawkinsp
hi,
I launch the job by starting process manually on each machine. i set the gpu by CUDA_VISIBLE_DEVICES, also, i have set the process_id and num_processes ,the code like this,
host0:

jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
                           num_processes=2,
                           process_id=0)

host1:

jax.distributed.initialize(coordinator_address="66.181.42.141:8889",
                           num_processes=2,
                           process_id=1)

also, i have tried launch the distributed training on two a100 machine, single-machine training is ok, but distributed training is stuck
2661716815053_ pic

i am trying SLURM, but not succeed yet,Is SLURM necessary for jax distribution?

@hawkinsp
Copy link
Member

Please confirm you're using jax 0.4.28.

@syyxsxx Those warnings might mean compilation is slow, but shouldn't cause a deadlock.

If you're using pmap yourself explicitly, then one thing to make sure is that both processes are performing the same pmaps in the same order.

I think I'll need a reproduction of the problem to help further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

4 participants