Skip to content

Commit

Permalink
[platform] add base class for communicators (#13208)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Feb 16, 2025
1 parent 124776e commit a0231b7
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 282 deletions.
117 changes: 117 additions & 0 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup


class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""

def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor

def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)

def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size

tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def destroy(self):
pass
33 changes: 33 additions & 0 deletions vllm/distributed/device_communicators/cpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
from torch.distributed import ProcessGroup

from .base_device_communicator import DeviceCommunicatorBase


class CpuCommunicator(DeviceCommunicatorBase):

def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex
self.ipex_available = True
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass

def all_reduce(self, input_):
return self.dist_module.all_reduce(input_, group=self.device_group)
106 changes: 106 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
from torch.distributed import ProcessGroup

from .base_device_communicator import DeviceCommunicatorBase


class CudaCommunicator(DeviceCommunicatorBase):

def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_pynccl = True

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)

self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)

def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size

pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)

def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size

tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
33 changes: 14 additions & 19 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,40 @@

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

from .base_device_communicator import DeviceCommunicatorBase

if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401


class HpuCommunicator:

def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
class HpuCommunicator(DeviceCommunicatorBase):

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
dist.all_reduce(input_, group=self.device_group)
return input_

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype,
device=x.device)
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
Expand Down
29 changes: 16 additions & 13 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

from .base_device_communicator import DeviceCommunicatorBase

if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
Expand All @@ -16,19 +18,20 @@
from vllm.executor import ray_utils


class TpuCommunicator:
class TpuCommunicator(DeviceCommunicatorBase):

def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)

# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
global_rank = self.global_rank
global_world_size = self.global_world_size

# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
Expand All @@ -55,9 +58,9 @@ def __init__(self, group: ProcessGroup):
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
return xm.all_gather(input_, dim=dim)
Loading

0 comments on commit a0231b7

Please sign in to comment.