forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[platform] add base class for communicators (vllm-project#13208)
Signed-off-by: youkaichao <[email protected]>
- Loading branch information
1 parent
b14b17a
commit 56401ce
Showing
13 changed files
with
364 additions
and
282 deletions.
There are no files selected for viewing
117 changes: 117 additions & 0 deletions
117
vllm/distributed/device_communicators/base_device_communicator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
|
||
class DeviceCommunicatorBase: | ||
""" | ||
Base class for device-specific communicator. | ||
It can use the `cpu_group` to initialize the communicator. | ||
If the device has PyTorch integration (PyTorch can recognize its | ||
communication backend), the `device_group` will also be given. | ||
""" | ||
|
||
def __init__(self, | ||
cpu_group: ProcessGroup, | ||
device: Optional[torch.device] = None, | ||
device_group: Optional[ProcessGroup] = None, | ||
unique_name: str = ""): | ||
self.device = device or torch.device("cpu") | ||
self.cpu_group = cpu_group | ||
self.device_group = device_group | ||
self.unique_name = unique_name | ||
self.rank = dist.get_rank(cpu_group) | ||
self.world_size = dist.get_world_size(cpu_group) | ||
self.ranks = dist.get_process_group_ranks(cpu_group) | ||
self.global_rank = dist.get_rank() | ||
self.global_world_size = dist.get_world_size() | ||
self.rank_in_group = dist.get_group_rank(self.cpu_group, | ||
self.global_rank) | ||
|
||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: | ||
dist.all_reduce(input_, group=self.device_group) | ||
return input_ | ||
|
||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: | ||
if dim < 0: | ||
# Convert negative dim to positive. | ||
dim += input_.dim() | ||
input_size = input_.size() | ||
# NOTE: we have to use concat-style all-gather here, | ||
# stack-style all-gather has compatibility issues with | ||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795 | ||
output_size = (input_size[0] * self.world_size, ) + input_size[1:] | ||
# Allocate output tensor. | ||
output_tensor = torch.empty(output_size, | ||
dtype=input_.dtype, | ||
device=input_.device) | ||
# All-gather. | ||
dist.all_gather_into_tensor(output_tensor, | ||
input_, | ||
group=self.device_group) | ||
# Reshape | ||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size) | ||
output_tensor = output_tensor.movedim(0, dim) | ||
output_tensor = output_tensor.reshape(input_size[:dim] + | ||
(self.world_size * | ||
input_size[dim], ) + | ||
input_size[dim + 1:]) | ||
return output_tensor | ||
|
||
def gather(self, | ||
input_: torch.Tensor, | ||
dst: int = 0, | ||
dim: int = -1) -> Optional[torch.Tensor]: | ||
""" | ||
NOTE: We assume that the input tensor is on the same device across | ||
all the ranks. | ||
NOTE: `dst` is the local rank of the destination rank. | ||
""" | ||
world_size = self.world_size | ||
assert -input_.dim() <= dim < input_.dim(), ( | ||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") | ||
if dim < 0: | ||
# Convert negative dim to positive. | ||
dim += input_.dim() | ||
|
||
# Allocate output tensor. | ||
if self.rank_in_group == dst: | ||
gather_list = [torch.empty_like(input_) for _ in range(world_size)] | ||
else: | ||
gather_list = None | ||
# Gather. | ||
torch.distributed.gather(input_, | ||
gather_list, | ||
dst=self.ranks[dst], | ||
group=self.device_group) | ||
if self.rank_in_group == dst: | ||
output_tensor = torch.cat(gather_list, dim=dim) | ||
else: | ||
output_tensor = None | ||
return output_tensor | ||
|
||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: | ||
"""Sends a tensor to the destination rank in a non-blocking way""" | ||
"""NOTE: `dst` is the local rank of the destination rank.""" | ||
if dst is None: | ||
dst = (self.rank_in_group + 1) % self.world_size | ||
torch.distributed.send(tensor, self.ranks[dst], self.device_group) | ||
|
||
def recv(self, | ||
size: torch.Size, | ||
dtype: torch.dtype, | ||
src: Optional[int] = None) -> torch.Tensor: | ||
"""Receives a tensor from the source rank.""" | ||
"""NOTE: `src` is the local rank of the source rank.""" | ||
if src is None: | ||
src = (self.rank_in_group - 1) % self.world_size | ||
|
||
tensor = torch.empty(size, dtype=dtype, device=self.device) | ||
torch.distributed.recv(tensor, self.ranks[src], self.device_group) | ||
return tensor | ||
|
||
def destroy(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from torch.distributed import ProcessGroup | ||
|
||
from .base_device_communicator import DeviceCommunicatorBase | ||
|
||
|
||
class CpuCommunicator(DeviceCommunicatorBase): | ||
|
||
def __init__(self, | ||
cpu_group: ProcessGroup, | ||
device: Optional[torch.device] = None, | ||
device_group: Optional[ProcessGroup] = None, | ||
unique_name: str = ""): | ||
super().__init__(cpu_group, device, device_group, unique_name) | ||
self.ipex_available = False | ||
self.dist_module = torch.distributed | ||
try: | ||
import intel_extension_for_pytorch as ipex | ||
self.ipex_available = True | ||
self.dist_module = ipex.distributed | ||
except ImportError: | ||
""" | ||
Intel IPEX not found. Falling back to PyTorch native | ||
all_reduce for CPU (e.g. MacOS) | ||
""" | ||
pass | ||
|
||
def all_reduce(self, input_): | ||
return self.dist_module.all_reduce(input_, group=self.device_group) |
106 changes: 106 additions & 0 deletions
106
vllm/distributed/device_communicators/cuda_communicator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from torch.distributed import ProcessGroup | ||
|
||
from .base_device_communicator import DeviceCommunicatorBase | ||
|
||
|
||
class CudaCommunicator(DeviceCommunicatorBase): | ||
|
||
def __init__(self, | ||
cpu_group: ProcessGroup, | ||
device: Optional[torch.device] = None, | ||
device_group: Optional[ProcessGroup] = None, | ||
unique_name: str = ""): | ||
super().__init__(cpu_group, device, device_group, unique_name) | ||
if "pp" in unique_name: | ||
# pipeline parallel does not need custom allreduce | ||
use_custom_allreduce = False | ||
else: | ||
from vllm.distributed.parallel_state import ( | ||
_ENABLE_CUSTOM_ALL_REDUCE) | ||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE | ||
use_pynccl = True | ||
|
||
self.use_pynccl = use_pynccl | ||
self.use_custom_allreduce = use_custom_allreduce | ||
|
||
# lazy import to avoid documentation build error | ||
from vllm.distributed.device_communicators.custom_all_reduce import ( | ||
CustomAllreduce) | ||
from vllm.distributed.device_communicators.pynccl import ( | ||
PyNcclCommunicator) | ||
|
||
self.pynccl_comm: Optional[PyNcclCommunicator] = None | ||
if use_pynccl and self.world_size > 1: | ||
self.pynccl_comm = PyNcclCommunicator( | ||
group=self.cpu_group, | ||
device=self.device, | ||
) | ||
|
||
self.ca_comm: Optional[CustomAllreduce] = None | ||
if use_custom_allreduce and self.world_size > 1: | ||
# Initialize a custom fast all-reduce implementation. | ||
self.ca_comm = CustomAllreduce( | ||
group=self.cpu_group, | ||
device=self.device, | ||
) | ||
|
||
def all_reduce(self, input_): | ||
# always try custom allreduce first, | ||
# and then pynccl. | ||
ca_comm = self.ca_comm | ||
if ca_comm is not None and not ca_comm.disabled and \ | ||
ca_comm.should_custom_ar(input_): | ||
out = ca_comm.custom_all_reduce(input_) | ||
assert out is not None | ||
return out | ||
pynccl_comm = self.pynccl_comm | ||
assert pynccl_comm is not None | ||
out = pynccl_comm.all_reduce(input_) | ||
if out is None: | ||
# fall back to the default all-reduce using PyTorch. | ||
# this usually happens during testing. | ||
# when we run the model, allreduce only happens for the TP | ||
# group, where we always have either custom allreduce or pynccl. | ||
out = input_.clone() | ||
torch.distributed.all_reduce(out, group=self.device_group) | ||
return out | ||
|
||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: | ||
"""Sends a tensor to the destination rank in a non-blocking way""" | ||
"""NOTE: `dst` is the local rank of the destination rank.""" | ||
if dst is None: | ||
dst = (self.rank_in_group + 1) % self.world_size | ||
|
||
pynccl_comm = self.pynccl_comm | ||
if pynccl_comm is not None and not pynccl_comm.disabled: | ||
pynccl_comm.send(tensor, dst) | ||
else: | ||
torch.distributed.send(tensor, self.ranks[dst], self.device_group) | ||
|
||
def recv(self, | ||
size: torch.Size, | ||
dtype: torch.dtype, | ||
src: Optional[int] = None) -> torch.Tensor: | ||
"""Receives a tensor from the source rank.""" | ||
"""NOTE: `src` is the local rank of the source rank.""" | ||
if src is None: | ||
src = (self.rank_in_group - 1) % self.world_size | ||
|
||
tensor = torch.empty(size, dtype=dtype, device=self.device) | ||
pynccl_comm = self.pynccl_comm | ||
if pynccl_comm is not None and not pynccl_comm.disabled: | ||
pynccl_comm.recv(tensor, src) | ||
else: | ||
torch.distributed.recv(tensor, self.ranks[src], self.device_group) | ||
return tensor | ||
|
||
def destroy(self): | ||
if self.pynccl_comm is not None: | ||
self.pynccl_comm = None | ||
if self.ca_comm is not None: | ||
self.ca_comm = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.