-
Notifications
You must be signed in to change notification settings - Fork 338
Description
Hi all,
As a part of my research on simulating and optimizing distributed inference, I've been running extensive tests on GLOO backend with all gather operations to predict transmission time in-between devices. However, I have found a weird behavior in time scaling for all gather when WORLD_SIZE=2 in multiple settings. It runs slower to do all gather with 2 devices than more than 2 devices (tested 3 to 8).
When the total transferred data size per device gets around 8MB and 16MB, the graph starts to divert significantly.
This looks very weird, as I imagine 2 device should be faster then 3 to 8 devices all gather for the same transfer size per device. However in contrast, it runs significantly slower, on average %58 more time for 200MB transfer size.
I have tested this in two different settings: A local jetson nano cluster with 4 devices and an HPC cluster with 8 nodes.
In order to find the total transfer size per device, I multiply the transferred matrix's size with (N-1), number of devices minus one, as each device will share their tensor with N-1 devices.
I have run Wireshark to understand what's going on, however I wasn't able to spot any anomalies with it. I have used the following script to test my network speed,
import json
import time
import torch
import torch.distributed as dist
import fire
import os
def run(num_latency_tests=100, num_bw_tests=5, mode="send_recv", output_file="network_results.json", size_multiplier=1):
backend = os.getenv("DIST_BACKEND", "gloo")
device = os.getenv("DEVICE", "cuda" if backend == "nccl" else "cpu")
dist.init_process_group(backend=backend)
rank = dist.get_rank()
world_size = dist.get_world_size()
# ---- Ping (latency) test ----
latencies: list[float] = []
tensor = torch.zeros(1, device=device)
for _ in range(num_latency_tests):
if rank == 0:
start = time.perf_counter()
dist.send(tensor, dst=1)
dist.recv(tensor, src=1)
if backend == "nccl":
torch.cuda.synchronize()
end = time.perf_counter()
latencies.append((end - start) * 1000)
elif rank == 1:
dist.recv(tensor, src=0)
dist.send(tensor, dst=0)
dist.barrier()
time.sleep(0.1)
if rank == 0:
mean = sum(latencies) / len(latencies)
std = (sum((x - mean) ** 2 for x in latencies) / len(latencies)) ** 0.5
print(f"[rank0] RTT (ms): min={min(latencies):.3f}, max={max(latencies):.3f}, mean={mean:.3f}, std={std:.3f}")
dist.barrier()
# ---- Bandwidth test with different tensor sizes ----
sizes_bytes = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
65536,
131072,
262144,
524288,
1048576,
2097152,
4194304,
8388608,
12589824,
16777216,
24576512,
33554432,
41943040,
50331648,
67108864,
83886080,
104857600,
125829120,
167772160,
201326592,
] # 1B, 2B, 4B, 8B, ..., 200MB
samples = []
bandwidth_means = []
sizes_means = []
for size_bytes in sizes_bytes:
size_bytes = size_bytes * size_multiplier
num_floats = max(1, size_bytes // 4) # float32 is 4 bytes
big = torch.ones(num_floats, dtype=torch.float32, device=device)
to_gather = [torch.zeros(num_floats, dtype=torch.float32, device=device) for _ in range(world_size)]
bw_results = []
time_results = []
for _ in range(num_bw_tests):
if rank == 0:
start = time.perf_counter()
if mode == "send_recv":
dist.send(big, dst=1)
dist.recv(big, src=1)
elif mode == "all_gather":
dist.all_gather(to_gather, big)
if backend == "nccl":
torch.cuda.synchronize()
end = time.perf_counter()
elapsed = end - start
mbps = size_bytes / elapsed # send + recv, MB/s
samples.append({"size_bytes": size_bytes, "time_s": elapsed})
bw_results.append(mbps)
time_results.append(elapsed)
else:
if mode == "send_recv":
dist.recv(big, src=0)
dist.send(big, dst=0)
elif mode == "all_gather":
dist.all_gather(to_gather, big)
if backend == "nccl":
torch.cuda.synchronize()
dist.barrier()
time.sleep(0.1)
if rank == 0:
mean_bw = sum(bw_results) / len(bw_results)
std_bw = (sum((x - mean_bw) ** 2 for x in bw_results) / len(bw_results)) ** 0.5
mean_time = sum(time_results) / len(time_results)
std_time = (sum((x - mean_time) ** 2 for x in time_results) / len(time_results)) ** 0.5
bandwidth_means.append(mean_bw)
sizes_means.append(mean_time)
if size_bytes < 1024:
size_str = f"{size_bytes}B"
elif size_bytes < 1024 * 1024:
size_str = f"{size_bytes // 1024}KB"
else:
size_str = f"{size_bytes // (1024 * 1024)}MB"
print(
f"[rank0] Size={size_str} | Bandwidth (MB/s): min={min(bw_results):.2f}, max={max(bw_results):.2f}, mean={mean_bw:.2f}, std={std_bw:.2f}"
)
print(
f"[rank0] Size={size_str} | Transfer time (s): min={min(time_results):.4f}, max={max(time_results):.4f}, mean={mean_time:.4f}, std={std_time:.4f}"
)
if rank == 0:
print(f"Saving raw results to {output_file}...")
with open(output_file, "w") as f:
json.dump(samples, f)
if __name__ == "__main__":
fire.Fire(run)
dist.destroy_process_group()