diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 461fb6d30c4..cee44c7d26c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -312,6 +312,7 @@ steps: - pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py + - pytest -v -s compile/test_async_tp.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental, amdproduction] diff --git a/tests/compile/backend.py b/tests/compile/backend.py index a21e8eca3a6..5a02c4e2b37 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -5,6 +5,8 @@ from torch import fx +from vllm.compilation.fx_utils import (find_specified_fn, + find_specified_fn_maybe) from vllm.compilation.inductor_pass import InductorPass from vllm.config import get_current_vllm_config @@ -44,3 +46,19 @@ def post_pass(self, graph: fx.Graph): self.graph_post_pass = deepcopy(graph) # assign by reference, will reflect the final state of the graph self.final_graph = graph + + def check_before_ops(self, ops, + find_fn=find_specified_fn, \ + find_fn_maybe=find_specified_fn_maybe, \ + ops_fully_replaced=True): + for op in ops: + find_fn(self.graph_pre_pass.nodes, op) + if ops_fully_replaced: + assert find_fn_maybe(self.graph_post_pass.nodes, op) is None + + def check_after_ops(self, ops, + find_fn=find_specified_fn, \ + find_fn_maybe=find_specified_fn_maybe): + for op in ops: + find_fn(self.graph_post_pass.nodes, op) + assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py new file mode 100644 index 00000000000..8e4e0ba8357 --- /dev/null +++ b/tests/compile/test_async_tp.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.collective_fusion import AsyncTPPass +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, + PassConfig, VllmConfig) +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import (compare_two_settings, create_new_process_for_each_test, + multi_gpu_test) +from .backend import TestBackend + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestMMRSModel(torch.nn.Module): + + def __init__(self, hidden_size=16): + super().__init__() + self.hidden_size = hidden_size + self.gate_proj = torch.nn.Parameter(torch.empty( + (self.hidden_size * 2, hidden_size)), + requires_grad=False) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states): + """ + Forward pass implementing the mm + reduce scatter in the FX graph + + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + # matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0) + return reduce_scatter + + def ops_in_model_before(self): + return [torch.ops.vllm.reduce_scatter.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default] + + +class TestAGMMModel(torch.nn.Module): + + def __init__(self, hidden_size=16): + super().__init__() + self.hidden_size = hidden_size + self.weight = torch.nn.Parameter(torch.empty( + (hidden_size, hidden_size)), + requires_grad=False) + # Initialize weights + torch.nn.init.normal_(self.weight, std=0.02) + + def forward(self, hidden_states): + """ + Forward pass implementing the mm + all gather in the FX graph + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + all_gather = tensor_model_parallel_all_gather(view, dim=0) + permute = self.weight.permute(1, 0) + mm = torch.mm(all_gather, permute) + return mm + + def ops_in_model_before(self): + return [torch.ops.vllm.all_gather.default] + + def ops_in_model_after(self): + return [torch.ops.symm_mem.fused_all_gather_matmul.default] + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=(num_processes, test_model, + batch_size, seq_len, hidden_size, + dtype), + nprocs=nprocs) + + run_torch_spawn(async_tp_pass_on_test_model, num_processes) + + +def async_tp_pass_on_test_model(local_rank: int, world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # configure vllm config for SequenceParallelismPass + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_async_tp=True, ), ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + async_tp_pass = AsyncTPPass(vllm_config) + backend = TestBackend(async_tp_pass) + + model = test_model_cls(hidden_size) + + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype, + requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + # In pre-nodes, all gather or reduce scatter should exist, + # fused_matmul_reduce_scatter or fused_all_gather_matmul should not + backend.check_before_ops(model.ops_in_model_before(), + ops_fully_replaced=False) + + # In post-nodes, fused_matmul_reduce_scatter or \ + # fused_all_gather_matmul should exist + backend.check_after_ops(model.ops_in_model_after()) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("async_tp_enabled", [True]) +@pytest.mark.parametrize("distributed_backend", ["mp"]) +@pytest.mark.parametrize("eager_mode", [False, True]) +def test_async_tp_pass_correctness( + model_id: str, + tp_size: int, + async_tp_enabled: bool, + distributed_backend: str, + eager_mode: bool, + num_gpus_available: int, +): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if eager_mode: + common_args.append("--enforce-eager") + + compilation_config = { + 'level': 3, + 'compile_sizes': [2, 4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_async_tp': async_tp_enabled + }, + } + + async_tp_env = tp_env = { + "VLLM_USE_V1": "1", + } + + aysnc_tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(compilation_config), + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + compare_two_settings(model_id, + aysnc_tp_args, + tp_args, + async_tp_env, + tp_env, + method="generate") diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4d56b34bdec..509593e7328 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -29,6 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool, self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: @@ -59,6 +63,15 @@ def forward(self, x): y3, resid = self.norm[2](x3, resid) # use resid here return y3 + def ops_in_model_before(self): + return [QUANT_OPS[self.key]] + + def ops_in_model_after(self): + return [ + FUSED_OPS[FusedRMSQuantKey(self.key, False)], + FUSED_OPS[FusedRMSQuantKey(self.key, True)] + ] + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes - - # static is per-tensor, dynamic is per-token - key = QuantKey(dtype=FP8_DTYPE, - static=static, - per_tensor=static, - symmetric=True) - rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] - add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] - fp8_quant = QUANT_OPS[key] - # In pre-nodes, fp8 quant should be there and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, rms_quant) is None - assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + backend.check_before_ops(model.ops_in_model_before(), find_auto_fn, + find_auto_fn_maybe) # In post-nodes, fused kernels should be there and fp8 quant should not - find_auto_fn(post_nodes, rms_quant) - find_auto_fn(post_nodes, add_rms_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + backend.check_after_ops(model.ops_in_model_after(), find_auto_fn, + find_auto_fn_maybe) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 6152f171705..2cd7ebaacec 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -5,9 +5,7 @@ import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, - find_specified_fn, - find_specified_fn_maybe, is_func) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, PassConfig, VllmConfig) @@ -21,17 +19,6 @@ from ..utils import multi_gpu_test from .backend import TestBackend -OPS_IN_MODEL_BEFORE = [ - torch.ops.vllm.all_reduce.default, -] - -OPS_IN_MODEL_AFTER = [ - torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default, -] - -OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] - prompts = [ "Hello, my name is", "The president of the United States is", @@ -78,6 +65,18 @@ def forward(self, hidden_states, residual): return norm_output, residual_output + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + + def ops_in_model(self): + return [torch.ops._C.fused_add_rms_norm.default] + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("batch_size", [8]) @@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func(hidden_states, residual) - # Check substitution worked - pre_nodes = backend_no_func.graph_pre_pass.nodes - post_nodes = backend_no_func.graph_post_pass.nodes - # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not - for op in OPS_IN_MODEL_BEFORE: - find_specified_fn(pre_nodes, op) - for op in OPS_IN_MODEL_AFTER: - assert find_specified_fn_maybe(pre_nodes, op) is None + backend_no_func.check_before_ops(model.ops_in_model_before()) # In post-nodes, reduce scatter and all gather should be there, # all reduce should not - for op in OPS_IN_MODEL_AFTER: - find_specified_fn(post_nodes, op) - for op in OPS_IN_MODEL_BEFORE: - assert find_specified_fn_maybe(post_nodes, op) is None + backend_no_func.check_after_ops(model.ops_in_model_after()) # check if the functionalization pass is applied - for op in OPS_IN_MODEL: + for op in model.ops_in_model(): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501 @@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # make sure the ops were all de-functionalized found = dict() for node in backend_func.graph_post_pass.nodes: - for op in OPS_IN_MODEL: + for op in model.ops_in_model(): if is_func(node, op): found[op] = True - assert all(found[op] for op in OPS_IN_MODEL) + assert all(found[op] for op in model.ops_in_model()) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py new file mode 100644 index 00000000000..f651ee6912a --- /dev/null +++ b/vllm/compilation/collective_fusion.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class BasePattern: + + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +class GEMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [mul, mm_weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): + mm = torch.ops.aten.mm.default(mul, mm_weight) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): + gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( + mul, + mm_weight, + "avg", + scatter_dim=0, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherGEMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + return [x, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten.mm.default(all_gather, weight) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [weight], + gather_dim=0, + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AsyncTPPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + # Enable symmetric memory for the TP process group + enable_symm_mem_for_group(get_tp_group().device_group.group_name) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="async_tp_pass") + GEMMReduceScatterPattern(self.model_dtype, + self.device).register(self.patterns) + + AllGatherGEMMPattern(self.model_dtype, + self.device).register(self.patterns) + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_async_tp_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_async_tp_pass") + self.end_and_log() diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index f4d3fd9b457..07ebd3e1b7d 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -6,6 +6,7 @@ from vllm.logger import init_logger from .activation_quant_fusion import ActivationQuantFusionPass +from .collective_fusion import AsyncTPPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context @@ -54,6 +55,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] self.fix_functionalization = FixFunctionalizationPass(config) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index f0476bfcb65..17dded87fe8 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -243,24 +243,25 @@ def __init__(self, config: VllmConfig): pass_name="sequence_parallelism_pass") for epsilon in [1e-5, 1e-6]: EmbeddingAllReduceRMSNormPattern( - epsilon, self.dtype, self.device).register(self.patterns) + epsilon, self.model_dtype, self.device).register(self.patterns) - MiddleAllReduceRMSNormPattern(epsilon, self.dtype, + MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) - LastAllReduceRMSNormPattern(epsilon, self.dtype, + LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() def is_applicable_for_shape(self, shape: Optional[int]) -> bool: - # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 def __call__(self, graph: fx.Graph): + self.begin() self.dump_graph(graph, "before_sequence_parallelism_pass") count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_sequence_parallelism_pass") + self.end_and_log() diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index c95e0bce5f2..0fe73b72b1d 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass): def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - self.dtype = config.model_config.dtype if config.model_config else None + self.model_dtype = config.model_config.dtype if config.model_config \ + else None self.device = config.device_config.device if config.device_config \ else None self.pass_name = self.__class__.__name__ diff --git a/vllm/config.py b/vllm/config.py index a185a75c6bf..925be0c4ae3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3646,6 +3646,8 @@ class PassConfig: """Whether to enable the custom no-op elimination pass.""" enable_sequence_parallelism: bool = False """Whether to enable sequence parallelism.""" + enable_async_tp: bool = False + """Whether to enable async TP.""" def uuid(self): """ @@ -3655,7 +3657,8 @@ def uuid(self): compilation. """ include = { - "enable_fusion", "enable_noop", "enable_sequence_parallelism" + "enable_fusion", "enable_noop", "enable_sequence_parallelism", + "enable_async_tp" } dict_ = {k: v for k, v in asdict(self).items() if k in include} return InductorPass.hash_dict(dict_) @@ -4268,6 +4271,12 @@ def __post_init__(self): if self.compilation_config is None: self.compilation_config = CompilationConfig() + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = \ + True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 51c519d8f86..112861990c3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -120,7 +120,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group.reduce_scatter(tensor, dim) + return group._reduce_scatter_out_place(tensor, dim) def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, @@ -136,7 +136,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group.all_gather(tensor, dim) + return group._all_gather_out_place(tensor, dim) def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, @@ -161,6 +161,7 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, op_func=reduce_scatter, mutates_args=[], fake_impl=reduce_scatter_fake, + dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -168,6 +169,7 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, op_func=all_gather, mutates_args=[], fake_impl=all_gather_fake, + dispatch_key=current_platform.dispatch_key, ) @@ -367,6 +369,16 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if self.use_custom_op_call: + return torch.ops.vllm.all_gather(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._all_gather_out_place(input_, dim) + + def _all_gather_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) def reduce_scatter(self, @@ -379,6 +391,16 @@ def reduce_scatter(self, assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if self.use_custom_op_call: + return torch.ops.vllm.reduce_scatter(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._reduce_scatter_out_place(input_, dim) + + def _reduce_scatter_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: return self.device_communicator.reduce_scatter(input_, dim) def gather(self,