Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RCCL deadlock discussion (was: increase NCCL_STEPS to match WARPSIZE/4) #1600

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

jglaser
Copy link

@jglaser jglaser commented Mar 13, 2025

Details

Work item: GH issue pending

What were the changes?
Increase the pipeline size (NCCL_STEPS) from 8 to 16 to maintain correct synchronization behavior

Why were the changes made?
On Frontier at OLCF with 8 x MI250X nodes, RCCL sporadically stalls, especially when executed on many (>~ 128) nodes, or when using the libfabrics plugin (aws-ofi-rccl). The proposed change adjusts the pipeline width in the 'simple' communication protocol to be compatible with the larger warp (wavefront) size on AMD, which is 64, compared to 32 on NVIDIA hardware.

How was the outcome achieved?
A debug trace with roc-gdb demonstrated that under stall conditions and during a NCCL broadcast, some threads on the first rank in the ring (GPU0) are hanging inside a busy wait loop (waitPeer)

while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
__builtin_amdgcn_s_sleep(1);
connStepCache = loadStepValue(connStepPtr);
if (checkAbort(spins)) break;
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
if (spins == 0 && repeat > 0) {
repeat --;
traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
}
}

A closer inspection revealed that GPU 0 was eight steps (NCCL_STEPS=8) ahead of GPU1 and waiting for it to be ready for itself to send, while GPU 1 was waiting on GPU 0 to receive, i.e. the two GPUs were found in a circular deadlock inside waitPeer, which should not be allowed by design. However, to ensure asynchronous progress (i.e., GPU 1 sending and receiving at the same time), the original NCCL design uses warp level synchronization and a FIFO pipeline width that matches the maximum number (8) of progress groups of four threads each (waitRecv, waitSend, postRecv, and postSend) that fit into a warp. If more than this number of groups fits into a warp, the warp-level synchronization will stall asynchronous progress made by another pipeline. By increasing NCCL_STEPS to 16, this guarantees that one wave front will exactly fit one pipeline.

Additional Documentation:
It may be necessary to assess the performance implications of this patch and re-tune collective parameters.

Regarding reproducer:
On Frontier, the stalling behavior is also modulated by the libfabrics version. I had the most "luck" with libfabric 1.20.1 w/ CXI provider, rocm 6.2.4, and aws-ofi-rccl commit 17d41cbf5618536c4b1076d29748416ab307040f
I have a local reproducer code I shared with AMD developers, and which uses jax, but it is not yet minimal.

Approval Checklist

Do not approve until these items are satisfied.

  • Verify the CHANGELOG has been updated, if
    • there are any NCCL API version changes,
    • any changes impact library users, and/or
    • any changes impact any other ROCm library.

@nicholasmalaya
Copy link

Thanks you @jglaser ! This is a valuable contribution.

@wenkaidu
Copy link
Collaborator

@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue

@nicholasmalaya
Copy link

@jglaser While I believe the deadlock you mentioned could be real, I am not convinced on the cause being WARP_SIZE. @nicholasmalaya can you help arrange a call to discuss? I think I still have Jens email address as we exchanged some emails 3 years ago on a different issue

Yes. I will reach out for us to discuss this in more detail.

@jglaser jglaser changed the title increase NCCL_STEPS to match 4*WARPSIZE increase NCCL_STEPS to match WARPSIZE/4 Mar 13, 2025
@wenkaidu
Copy link
Collaborator

@jglaser can you help dump proxy state with below patch to confirm reason of sender/receiver being stuck?

diff --git a/src/transport/net.cc b/src/transport/net.cc
index 5fd36f6f..16f07ba5 100644
--- a/src/transport/net.cc
+++ b/src/transport/net.cc
@@ -1379,6 +1379,7 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct
                 if (sub->done == sub->nsteps) *sendHead = sub->base + args->sliceSteps;
               } else {
                 *sendHead = sub->base + sub->done;
+                INFO(NCCL_COLL, "Send chan %d Posted %lx to %p", resources->channelId, sub->base + sub->done, sendHead);
               }
               if (resources->gdcSync) wc_store_fence(); // Flush out WC write
             }
@@ -1662,8 +1663,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct
               if (sub->reg) {
                 // We may have added more net steps, but reg operations only have a single step w.r.t. the GPU.
                 if (sub->transmitted == sub->nsteps) *recvTail = sub->base + args->sliceSteps;
-              } else
+              } else {
                 *recvTail = sub->base + sub->transmitted;
+                INFO(NCCL_COLL, "Recv chan %d Posted %lx to %p", resources->channelId, sub->base + sub->transmitted, recvTail);
+              }
               if (resources->gdcSync) wc_store_fence(); // Flush out WC write
             }
           }

@jglaser
Copy link
Author

jglaser commented Mar 17, 2025

OK, I think the NCCL_STEPS change might actually be unnecessary, as the stalling behavior seems to be already improved (if not gone!) as of 36343be, which is a squashed merge commit from upstream NCCL.

Interestingly, RCCL performance also seems to be significantly improved (>2x) since that commit.

To find out, I used git bisect with a simple reproducer (a shell script, really), which interrupts one of the GPUs in the ring using rocgdb for a few seconds. This simulates a slow rank or a RDMA bottleneck. By simultaneously monitoring GPU usage (watch rocm-smi) one should see all GPUs performing work, then GPU1 throttling to 0W and the other MI250x GPUs to around 120-130W, and finally all of them returning to the original power. If a deadlock occurs, GPU1 will only return to the other GPU's busy wait usage after continuation in rocgdb. For commits prior to 36343be, the chance of deadlock is roughly 20% on 16 nodes (128 GPUs).

# save as wrapper.sh
if [ ${SLURM_PROCID} = 1 ]; then
    ${@} &
    pid=${!}
    sleep 20
    my_pid=`pstree -p "${pid}" |  awk -F'[()]' '{print $2; exit}'`
    pstree -p "${pid}"
    echo "Interrupting process ${my_pid}"
    {  echo "t a a bt"
       sleep 15
       echo "cont"
       echo "detach"
       echo "quit"
     } | \
    rocgdb -p ${my_pid}
    wait
else
    "${@}"
fi

Use the wrapper with rccl-tests like so

LD_LIBRARY_PATH=<path to rccl.so>:${LD_LIBRARY_PATH} \
NCCL_NET_GDR_LEVEL=3 \
LD_LIBRARY_PATH=<path to aws-ofi-rccl>/src/.libs:${LD_LIBRARY_PATH} \
srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 5000" -b 1G

As the commit is unfortunately quite large, it would be interesting to pinpoint the source code line that caused the change in behavior.

@wenkaidu
Copy link
Collaborator

@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking

@jglaser
Copy link
Author

jglaser commented Mar 17, 2025

@jglaser when there are large number of GPUs, it is hard to debug with rocgdb. For example, we may find GPU N is stuck in waiting for receive, but it is due to GPU N-1 didn't send. But GPU N-1 didn't send because it is wait for receive from GPU N-2... We need to have all the logs from entire ring to identify root of the problem. Latest develop branch has a lot of improvements in getting GPU kernel logs for this scenario, but CPU proxy side logging is still lacking

Here is a stalled trace
out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz
with the ROCM 6.2 commit

after the "Continuing" the GPUs are hanging

=========================================== ROCm System Management Interface ===========================================
===================================================== Concise Info =====================================================
Device  Node  IDs              Temp    Power   Partitions          SCLK     MCLK     Fan  Perf    PwrCap  VRAM%  GPU%
              (DID,     GUID)  (Edge)  (Avg)   (Mem, Compute, ID)
========================================================================================================================
0       4     0x7408,   63582  38.0°C  131.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
1       5     0x7408,   51740  42.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
2       6     0x7408,   15961  33.0°C  129.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
3       7     0x7408,   3099   36.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
4       8     0x7408,   13395  38.0°C  129.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
5       9     0x7408,   1553   41.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
6       10    0x7408,   62036  34.0°C  126.0W  N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  560.0W  2%     100%
7       11    0x7408,   49174  38.0°C  N/A     N/A, N/A, 0         1700Mhz  1600Mhz  0%   manual  0.0W    2%     100%
========================================================================================================================
================================================= End of ROCm SMI Log ==================================================

and a successful one with the latest HEAD
out_ccb082074351b560bbce3e1cb8d9ae2045b7beac.txt.gz

Command:

NCCL_DEBUG_SUBSYS=COLL NCCL_DEBUG=info LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/rccl/build/debug/:${LD_LIBRARY_PATH} NCCL_NET_GDR_LEVEL=3 LD_LIBRARY_PATH=/lustre/orion/world-shared/stf006/glaser/aws-ofi-rccl/src/.libs:${LD_LIBRARY_PATH} srun -N 16 --ntasks-per-node=8 -c 8 bash -c "source wrapper.sh build/broadcast_perf -n 100"```

@jglaser
Copy link
Author

jglaser commented Mar 17, 2025

Here is a stalled trace out_435756af02a976d9686e5f6bc8954727dc49a29b_lasthalf.txt.gz with the ROCM 6.2 commit

and the first half, too, which include RCCL setup
out_435756af02a976d9686e5f6bc8954727dc49a29b_firsthalf.txt.gz

@jglaser
Copy link
Author

jglaser commented Mar 18, 2025

Caveat: I have another example here (RL training), which freezes even with the ccb08 commit.. (in allgather)

@wenkaidu
Copy link
Collaborator

@jglaser Increasing NCCL_STEPS makes GPU kernel side FIFO deeper. I am wondering if this somehow helps with network stack which has another FIFO. There is NCCL_OFI_MAX_REQUESTS in aws ofi plugin https://github.com/ROCm/aws-ofi-rccl/blob/17d41cbf5618536c4b1076d29748416ab307040f/include/nccl_ofi.h#L68
Can you try increase or decrease this number to see if it has any effect?

@jglaser
Copy link
Author

jglaser commented Mar 20, 2025

Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with -O0 (no optimizations), the pipeline recovers from the interruption, which it does not with -O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...

@jglaser
Copy link
Author

jglaser commented Mar 20, 2025

Another observation: using the reproducer as before with the failing commit 43575..., but when RCCL has been compiled with -O0 (no optimizations), the pipeline recovers from the interruption, which it does not with -O3. This hints at a subtle correctness issue like memory ordering, or starvation/occupancy effects (e.g. due to the spin loops)... investigating further...

I should perhaps mention that I replaced the thread barrier in that commit with my own, but am not sure if it that led to to the optimization level dependent behavior (this one avoids potential memory ordering issues by entirely relying on atomicCAS)

template<typename T>
__device__ T atomicBarrierCAS(T *address, int numThreads, int expectedPhase) {
    T old = *address;  // Read the initial state once before entering the loop
    T assumed;

    do {
        assumed = (expectedPhase << 31) | (old & 0x7FFFFFFF); // Ensure phase is expected

        int counter = assumed & 0x7FFFFFFF; // Extract counter (lower 31 bits)
        int phase = (assumed >> 31) & 1;    // Extract phase bit (highest bit)

        // Determine the next phase, but don't use it for the increment/decrement
        int nextPhase = phase;
        if ((phase == 0 && counter == numThreads - 1) || (phase == 1 && counter == 1)) {
            nextPhase = 1 - phase; // Flip phase for the next barrier
        }

        // Increment or decrement based on the original phase, not nextPhase
        int newCounter = (phase == 0) ? (counter + 1) : (counter - 1);

        // Construct new packed value
        T newValue = (nextPhase << 31) | newCounter;

        old = atomicCAS(address, assumed, newValue); // Attempt to update
    } while (old != assumed); // Retry if atomicCAS failed

    return old; // Return the final observed value
}

#define barrier_by_group() { \
    const int wid = threadIdx.x%WARP_SIZE; \
    if (wid == 0) { \
      uint64_t num_leaders = nthreads/WARP_SIZE; \
      atomicBarrierCAS(barriers, num_leaders, 0); \
      atomicBarrierCAS(barriers, num_leaders, 1); \
    } \
}

@wenkaidu
Copy link
Collaborator

@jglaser if you are testing collectives like all reduce and broadcast etc, barrier_by_group() will call __builtin_amdgcn_s_barrier(), because RCCL will not reduce nthreads from NCCL_MAX_NTHREADS which is 256.

@jglaser
Copy link
Author

jglaser commented Mar 24, 2025

I think I found the issue of what was causing the hang in my application: multiple (N/R)CCL communicators used by multiple CPU threads. It is well known that NCCL is not thread safe, see e.g.
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently

When I analyzed the call pattern from my app, I noticed that there was more than one comunnicator (intra+internode from FSPD HYBRID_SHARD). Each communicator had its own HIP stream associated with it. However, these communicators were shared between different host threads. That creates a danger! Because the launch order on the host thread is random, whereas the GPU synchronizes (or, serializes) kernels launched to the same stream. Therefore, deadlocks are expected.

I confirmed the thread unsafe behavior of RCCL with this code

#include <iostream>
#include <thread>
#include <chrono>
#include <vector>
#include <cstdlib>
#include <sstream>
#include <rccl/rccl.h>
#include <hip/hip_runtime.h>
#include <mpi.h>

#define CHECK_HIP(call) \
    do { \
        hipError_t err = call; \
        if (err != hipSuccess) { \
            std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

#define CHECK_RCCL(call) \
    do { \
        ncclResult_t res = call; \
        if (res != ncclSuccess) { \
            std::cerr << "RCCL Error: " << ncclGetErrorString(res) << "\n"; \
            MPI_Abort(MPI_COMM_WORLD, 1); \
        } \
    } while (0)

void run_nccl_op(ncclComm_t comm, hipStream_t stream, int device, int count, int delay_us, int thread_id, int world_rank) {
    CHECK_HIP(hipSetDevice(device));

    float *sendbuf, *recvbuf;
    CHECK_HIP(hipMalloc(&sendbuf, count * sizeof(float)));
    CHECK_HIP(hipMalloc(&recvbuf, count * sizeof(float)));

    std::this_thread::sleep_for(std::chrono::microseconds(delay_us));

    CHECK_RCCL(ncclAllGather(sendbuf, recvbuf, count, ncclFloat, comm, stream));

    CHECK_HIP(hipFree(sendbuf));
    CHECK_HIP(hipFree(recvbuf));

    if (world_rank == 0) {
        std::ostringstream msg;
        msg << "Thread " << thread_id << " completed on rank 0\n";
        std::cerr << msg.str();
    }
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);

    int world_size, world_rank;
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);

    int device = 0; // use fixed device, rely on ROCR_VISIBLE_DEVICES
    int count = 1024;
    int num_threads = 2;

    if (argc > 1) {
        num_threads = std::atoi(argv[1]);
    }

    ncclComm_t comm;
    hipStream_t stream;

    CHECK_HIP(hipSetDevice(device));
    CHECK_HIP(hipStreamCreate(&stream));

    ncclUniqueId id;
    if (world_rank == 0) {
        ncclGetUniqueId(&id);
    }
    MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
    CHECK_RCCL(ncclCommInitRank(&comm, world_size, id, world_rank));

    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back(run_nccl_op, comm, stream, device, count, i * 10, i, world_rank);
    }
    for (auto& t : threads) t.join();

    CHECK_HIP(hipStreamSynchronize(stream));

    ncclCommDestroy(comm);
    CHECK_HIP(hipStreamDestroy(stream));

    std::cout << "Rank " << world_rank << " completed" << std::endl;

    MPI_Finalize();
    return 0;
}

RCCL thread-unsafe reproducer

This demonstration confirms that launching RCCL collectives from the same GPU
stream, but from different host threads creates non-deterministic results.

module load rocm/6.2.4
hipcc -o rccl_race_test rccl_race_test.cu -lrccl -L${MPICH_DIR}/lib -lmpi ${CRAY_XPMEM_POST_LINK_OPTS} -I${MPICH_DIR}/include
# interactive job with 1 gpu/process (ROCR_VISIBLE_DEVICES=0) on 1 node
salloc -N 1 -t 02:00:00 -A <proid> -p batch -q debug --cpus-per-task=8 -S 0 --tasks-per-node=8 --gpus-per-task=1
# run 1 w/4 threads
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Thread 3 completed on rank 0
Thread 1 completed on rank 0
Thread 0 completed on rank 0
Thread 2 completed on rank 0
Rank 4 completed
Rank 6 completed
Rank 2 completed
Rank 1 completed
Rank 3 completed
Rank 5 completed
Rank 7 completed
Rank 0 completed

# run 2
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> srun -n 8 rccl_race_test 4
Rank 1 completed
^Csrun: interrupt (one more within 1 sec to abort)
srun: StepId=3231174.44 tasks 0-7: running
^Csrun: sending Ctrl-C to StepId=3231174.44
srun: forcing job termination
^C^Csrun: sending Ctrl-C to StepId=3231174.44
srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
srun: Terminating StepId=3231174.44
srun: job abort in progress
# baseline: 1 CPU thread, 10 runs
glaser@frontier10425:/lustre/orion/world-shared/stf006/glaser/twenty_questions> for Z in `seq 0 10`; do srun -n 8 rccl_race_test 1; done
Thread 0 completed on rank 0
Rank 6 completed
Rank 0 completed
Rank 4 completed
Rank 2 completed
Rank 1 completed
Rank 5 completed
Rank 7 completed
Rank 3 completed
srun: Step created for StepId=3231264.2
Thread 0 completed on rank 0
Rank 0 completed
Rank 6 completed
Rank 4 completed

@jglaser
Copy link
Author

jglaser commented Mar 24, 2025

On the python side, I could confirm that PyTorch uses multiple threads to parallelize the backward pass

# torch_reproducer.py
import torch
from torch.nn import Linear

model = Linear(4, 4).cuda()
input = torch.randn(1, 4, device='cuda', requires_grad=True)

def custom_hook(grad):
    import os, threading
    print(f"[Grad Hook] PID {os.getpid()} Thread {threading.get_native_id()}")
    return grad

output = model(input)
output.register_hook(custom_hook)
loss = output.sum()
loss.backward()

output

[Grad Hook] PID 3371415 Thread 3373242

Notice PID != Thread ID

@jglaser
Copy link
Author

jglaser commented Mar 24, 2025

Now the mitigation strategy became obvious:

  1. Put all collectives into the main thread
  2. Make sure streams synchronize between calls to different communicators

Torch has an option to turn off threading during autograd:

 with set_multithreading_enabled(False):
          # trainer.train() or model.backward() ....

@jglaser
Copy link
Author

jglaser commented Mar 24, 2025

To ensure that streams sync before and after collectives, a little more effort is needed: another context manager

import contextlib
import torch
import functools

import os
import torch.distributed as dist

# Get native thread ID
import ctypes
import threading

libc = ctypes.CDLL("libc.so.6")
gettid = libc.syscall
SYS_gettid = 186  # x86_64 Linux syscall number for gettid

@contextlib.contextmanager
def patch_distributed_collectives(label="DISTRIBUTED"):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ]

    def make_wrapper(opname, fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            thread_id = threading.get_ident()
            native_tid = gettid(SYS_gettid)
            pid = os.getpid()
            try:
                rank = dist.get_rank()
            except RuntimeError:
                rank = -1
            #print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}")
            stream = torch.cuda.current_stream()
            stream.synchronize()
            try:
                return fn(*args, **kwargs)
            finally:
                stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper(opname, original))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)


# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.

@jglaser
Copy link
Author

jglaser commented Mar 25, 2025

To ensure that streams sync before and after collectives, a little more effort is needed: another context manager

import contextlib
import torch
import functools

import os
import torch.distributed as dist

# Get native thread ID
import ctypes
import threading

libc = ctypes.CDLL("libc.so.6")
gettid = libc.syscall
SYS_gettid = 186  # x86_64 Linux syscall number for gettid

@contextlib.contextmanager
def patch_distributed_collectives(label="DISTRIBUTED"):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'all_reduce', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier' ]

    def make_wrapper(opname, fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            thread_id = threading.get_ident()
            native_tid = gettid(SYS_gettid)
            pid = os.getpid()
            try:
                rank = dist.get_rank()
            except RuntimeError:
                rank = -1
            #print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname}")
            stream = torch.cuda.current_stream()
            stream.synchronize()
            try:
                return fn(*args, **kwargs)
            finally:
                stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper(opname, original))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)


# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using these two together fixes the hang for me. Performance may not be optimal and NCCL usage in torch autograd may have to be re-examined.

@thananon (as per offline discussion) try this patch ... if it still hangs, the next escalation strategy would be to replace the above two calls to stream.synchronize() with torch.cuda.synchronize(), to synchronize the entire device and not just the current stream

@thananon
Copy link
Contributor

@jglaser You found me here. Yes, I incorporated this patch in the latest job as you suggested. I hope we get good result back.

@jglaser
Copy link
Author

jglaser commented Mar 29, 2025

After fixing the first hang, my app was able to make it to an error message, which I was able to fix. Subsequently, however, it still sometimes hung.

This is a minimal version of the patch that expands the coverage of distributed collectives, and which also only has a single (necessary) synchronization point.

import contextlib
import torch
import functools
import threading

import os
import torch.distributed as dist

@contextlib.contextmanager
def patch_distributed_collectives(logging=False):
    patched = {}
    dist_ops = ['all_reduce', 'reduce_scatter', 'reduce_scatter_tensor',
                'all_gather', 'broadcast', 'all_gather_into_tensor', 'barrier',
                'gather', 'scatter', 'all_to_all', 'send', 'isend', 'irecv', 'send_object_list',
                'recv_object_list', 'batch_isend_irecv', 'broadcast_object_list', 'reduce',
                'all_gather_object', 'gather_object', 'all_to_all_single', 'monitored_barrier',
                '_broadcast_coalesced']

    def make_wrapper(label, opname, fn, logging=False):
        stream = torch.cuda.Stream()

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            if logging:
                thread_id = threading.get_ident()
                native_tid = threading.current_thread().native_id
                pid = os.getpid()

                try:
                    rank = dist.get_rank()
                except RuntimeError:
                    rank = -1
                device = torch.cuda.current_device()
                current_stream = torch.cuda.current_stream()
                async_op = kwargs.get('async_op', False)
                print(f"[{label}] Rank {rank} PID {pid} TID {native_tid} (Python {thread_id}) called {opname} on device {device} stream {current_stream} async_op {async_op}")

            current_stream = torch.cuda.current_stream()
            try:
                return fn(*args, **kwargs)
            finally:
                current_stream.synchronize()

        return wrapper

    try:
        for opname in dist_ops:
            if hasattr(dist, opname):
                original = getattr(dist, opname)
                patched[opname] = original
                setattr(dist, opname, make_wrapper('DISTRIBUTED', opname, original, logging))
        yield
    finally:
        for opname, original in patched.items():
            setattr(dist, opname, original)

# Example usage:
# with patch_distributed_collectives():
#     output = model(input)

Using this version, and with autograd threads disabled as before, my app runs to completion on 64 nodes using rccl commit 532f54c (or actually, to the next OOM :)

The likely cause is that running concurrent collectives on different streams/communicators (e.g. intra/internode) still requires stream dependencies to be set, which may (or may not) be implemented in FSDP. The above sync should be unnecessary once NCCL upstream 2.26 is merged into RCCL, so that users can set NCCL_LAUNCH_ORDER_IMPLICIT. Then, they only need to ensure that the host side order of calls is consistent, e.g., by disabling threads.

@jglaser jglaser changed the title increase NCCL_STEPS to match WARPSIZE/4 RCCL deadlock discussion (was: increase NCCL_STEPS to match WARPSIZE/4) Mar 29, 2025
@jglaser jglaser marked this pull request as draft March 29, 2025 02:10
@jglaser
Copy link
Author

jglaser commented Apr 3, 2025

pytorch/pytorch#147729 seems related

@jeffdaily
Copy link
Contributor

@jglaser can you try building pytorch/pytorch#148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?

@jglaser
Copy link
Author

jglaser commented Apr 8, 2025

@jglaser can you try building pytorch/pytorch#148590 to see if it resolves the deadlock? Or do you not expect it to resolve the issue?

Will do -- and just FYI, I still did encounter a hang with my RL app during collectives, after switching to full BF16 training (instead of mixed).... will investigate @thananon

@jeffdaily
Copy link
Contributor

@jglaser Just realizing #148590 was relanded as pytorch/pytorch@acf5139 due to merge conflict. So it's been in main since Mon Mar 31. No need to build that PR yourself, it should be in nightly wheels. Have you given a nightly wheel a try yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants