Skip to content

Commit 4319237

Browse files
authored
Use backend to replace macro to control enablement of MNNVL all reduce (#4635)
Signed-off-by: Hui Gao <[email protected]>
1 parent c592798 commit 4319237

File tree

23 files changed

+2651
-101
lines changed

23 files changed

+2651
-101
lines changed

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -621,14 +621,12 @@ class AllreduceOp
621621

622622
AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size)
623623
{
624-
static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY");
625-
bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr);
626624
AllReduceStrategyType runtime_strategy;
627625
if (mStrategy == AllReduceStrategyType::UB)
628626
{
629627
runtime_strategy = AllReduceStrategyType::UB;
630628
}
631-
else if (force_nccl_all_reduce_strategy || mStrategy == AllReduceStrategyType::NCCL)
629+
else if (mStrategy == AllReduceStrategyType::NCCL)
632630
{
633631
runtime_strategy = AllReduceStrategyType::NCCL;
634632
}
@@ -936,10 +934,7 @@ class AllreduceOp
936934

937935
bool isUsingLowPrecision(size_t message_size) const noexcept
938936
{
939-
static char* force_low_precision_allreduce_strategy_char
940-
= std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY");
941-
bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr)
942-
|| (mStrategy == AllReduceStrategyType::LOWPRECISION);
937+
bool force_low_precision = mStrategy == AllReduceStrategyType::LOWPRECISION;
943938

944939
#ifdef ENABLE_FP8
945940
// Use LowPrecision if PCIe and p2p support and message size is larger than 2MB

docs/source/advanced/lowprecision-pcie-allreduce.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ The Low-Precision-AllReduce algorithm can be enabled in two ways:
4141
```
4242
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.LOWPRECISION);
4343
```
44-
2. **Environment variable control** with AUTO strategy:
44+
45+
2. Enable by LlmArgs
4546
```
46-
// In your code
47-
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.AUTO);
48-
// Set environment variable before running
49-
export FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY=1
47+
Set allreduce_strategy field in LlmArgs.
48+
Candidates of strategies are "AUTO", "NCCL", "UB", "MINLATENCY", "ONESHOT", "TWOSHOT", "LOWPRECISION" and "MNNVL".
49+
If no strategy is set, AUTO will be set.
5050
```
5151

5252
## Performance and Accuracy Considerations
@@ -58,8 +58,4 @@ Low-Precision-AllReduce reduces communication volume by using FP8 data format fo
5858

5959
Users should evaluate the precision impact on their specific models and workloads.
6060

61-
## Environment Variables
62-
63-
- `FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY`: When set to `1`, forces the use of low-precision algorithm with AUTO strategy. If the algorithm determines it cannot provide performance benefits, it will automatically fall back to other strategies.
64-
6561
**Note**: When compiling TensorRT-LLM without enabling the `ENABLE_FP8` option, setting Low Precision allreduce will not take effect.

examples/pytorch/out_of_tree_example/modeling_opt.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,22 @@ def __init__(
6464
config.hidden_size,
6565
elementwise_affine=config.layer_norm_elementwise_affine,
6666
dtype=config.torch_dtype)
67-
self.fc1 = Linear(
68-
config.hidden_size,
69-
config.ffn_dim,
70-
bias=config.enable_bias,
71-
dtype=config.torch_dtype,
72-
mapping=model_config.mapping,
73-
tensor_parallel_mode=TensorParallelMode.COLUMN,
74-
quant_config=model_config.get_quant_config(),
75-
)
76-
self.fc2 = Linear(
77-
config.ffn_dim,
78-
config.hidden_size,
79-
bias=config.enable_bias,
80-
dtype=config.torch_dtype,
81-
mapping=model_config.mapping,
82-
tensor_parallel_mode=TensorParallelMode.ROW,
83-
quant_config=model_config.get_quant_config(),
84-
)
67+
self.fc1 = Linear(config.hidden_size,
68+
config.ffn_dim,
69+
bias=config.enable_bias,
70+
dtype=config.torch_dtype,
71+
mapping=model_config.mapping,
72+
tensor_parallel_mode=TensorParallelMode.COLUMN,
73+
quant_config=model_config.get_quant_config(),
74+
allreduce_strategy=model_config.allreduce_strategy)
75+
self.fc2 = Linear(config.ffn_dim,
76+
config.hidden_size,
77+
bias=config.enable_bias,
78+
dtype=config.torch_dtype,
79+
mapping=model_config.mapping,
80+
tensor_parallel_mode=TensorParallelMode.ROW,
81+
quant_config=model_config.get_quant_config(),
82+
allreduce_strategy=model_config.allreduce_strategy)
8583
self.final_layer_norm = LayerNorm(
8684
config.hidden_size,
8785
elementwise_affine=config.layer_norm_elementwise_affine,

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
try:
77
from ....mapping import Mapping
88
from ...distributed import AllReduce, allgather
9-
from ...modules.linear import AllReduceFusionOp, AllReduceParams
9+
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
1010

1111
def trtllm_allgather(tensor, dim, sizes=None):
1212
rank, world_size = get_rank_world_size()
@@ -17,7 +17,7 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
1717
rank, world_size = get_rank_world_size()
1818
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
1919
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
20-
torch_op = AllReduce(p_config)
20+
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
2121
return torch_op(tensor, all_reduce_params=all_reduce_params)
2222

2323
@torch.library.custom_op(

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,17 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype):
307307
super().__init__()
308308
self.mapping = mapping
309309
self.dtype = dtype
310-
self.enable_mnnvl = (os.environ.get("TRTLLM_MNNVL_AR_ENABLED",
311-
"0") == "1"
312-
and dtype in [torch.bfloat16, torch.float32]
313-
and (not mapping.has_cp()))
310+
assert (
311+
dtype in MNNVLAllReduce.get_supported_dtypes()
312+
and (not mapping.has_cp())
313+
), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp."
314314

315-
if self.enable_mnnvl:
316-
self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace(
317-
self.mapping, dtype)
315+
self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace(
316+
self.mapping, dtype)
317+
318+
@staticmethod
319+
def get_supported_dtypes():
320+
return (torch.bfloat16, torch.float32)
318321

319322
def forward(
320323
self,
@@ -330,7 +333,7 @@ def forward(
330333
Returns:
331334
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Reduced tensor(s)
332335
"""
333-
if not self.enable_mnnvl or input.numel() > self.max_num_elements_mnnvl:
336+
if input.numel() > self.max_num_elements_mnnvl:
334337
return None
335338

336339
fusion_op = all_reduce_params.fusion_op
@@ -411,27 +414,27 @@ def __init__(self,
411414
For the reference implementation for each pattern, please refer to the following unit test:
412415
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py
413416
414-
The LOWPRECISION strategy can be selected either by directly specifying it in the constructor
415-
or by setting the environment variable FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY when using
416-
the AUTO strategy.
417+
The LOWPRECISION strategy can be selected either by directly specifying it in the constructor.
417418
"""
418419

419420
self.mapping = mapping
420421
self.workspace = None
421422
self.strategy = strategy
423+
self.mnnvl_allreduce = None
422424

423-
self.force_low_precision_env = os.environ.get(
424-
"FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY")
425425
if self.mapping.tp_size > 1:
426426
# When Strategy is UB, it is guaranteed that the workspace is not used.
427427
if self.strategy != AllReduceStrategy.UB:
428-
if self.strategy == AllReduceStrategy.LOWPRECISION or self.force_low_precision_env is not None:
428+
if self.strategy == AllReduceStrategy.LOWPRECISION:
429429
allocate_low_presicion_allreduce_workspace(self.mapping)
430430
self.workspace = get_allreduce_workspace(self.mapping)
431431

432432
# Initialize MNNVL AllReduce if needed
433-
self.mnnvl_allreduce = MNNVLAllReduce(mapping,
434-
dtype) if dtype else None
433+
if self.strategy == AllReduceStrategy.MNNVL and (
434+
dtype and dtype in MNNVLAllReduce.get_supported_dtypes()
435+
) and (not self.mapping.has_cp()):
436+
self.mnnvl_allreduce = MNNVLAllReduce(self.mapping,
437+
dtype) if dtype else None
435438

436439
def forward(
437440
self,

tensorrt_llm/_torch/model_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from tensorrt_llm import logger
1010
from tensorrt_llm._utils import torch_dtype_to_binding
11+
from tensorrt_llm.functional import AllReduceStrategy
12+
from tensorrt_llm.logger import logger
1113
from tensorrt_llm.mapping import Mapping
1214
from tensorrt_llm.models.modeling_utils import QuantConfig
1315
from tensorrt_llm.quantization.mode import QuantAlgo
@@ -77,6 +79,7 @@ class ModelConfig(Generic[TConfig]):
7779

7880
attn_backend: str = 'TRTLLM'
7981
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
82+
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
8083

8184
# If true, enable min-latency mode. Currently only used for Llama4.
8285
enable_min_latency: bool = False
@@ -106,6 +109,24 @@ def __post_init__(self):
106109
self.is_generation = self.is_generation_model(
107110
self.pretrained_config.architectures)
108111

112+
def get_all_reduce_strategy(strategy: str = "AUTO"):
113+
maps = {
114+
"AUTO": AllReduceStrategy.AUTO,
115+
"NCCL": AllReduceStrategy.NCCL,
116+
"UB": AllReduceStrategy.UB,
117+
"MINLATENCY": AllReduceStrategy.MIN_LATENCY,
118+
"ONESHOT": AllReduceStrategy.ONESHOT,
119+
"TWOSHOT": AllReduceStrategy.TWOSHOT,
120+
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
121+
"MNNVL": AllReduceStrategy.MNNVL
122+
}
123+
key = strategy.upper()
124+
return maps[key] if key in maps else AllReduceStrategy.AUTO
125+
126+
if isinstance(self.allreduce_strategy, str):
127+
self.allreduce_strategy = get_all_reduce_strategy(
128+
self.allreduce_strategy)
129+
109130
@property
110131
def fuse_pos_embd(self):
111132
if self.attn_backend == 'TRTLLM':

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ def __init__(self,
399399
overridden_tp_size=shared_tp_size,
400400
reduce_output=False)
401401

402-
self.allreduce = AllReduce(self.mapping)
402+
self.allreduce = AllReduce(mapping=model_config.mapping,
403+
strategy=model_config.allreduce_strategy)
403404
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
404405
self.event_dict = {
405406
key: torch.cuda.Event()
@@ -628,7 +629,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
628629
eps=config.rms_norm_eps,
629630
dtype=config.torch_dtype)
630631
self.layer_idx = layer_idx
631-
self.allreduce = AllReduce(self.mapping, dtype=config.torch_dtype)
632+
self.allreduce = AllReduce(mapping=model_config.mapping,
633+
strategy=model_config.allreduce_strategy,
634+
dtype=config.torch_dtype)
632635
self.moe_allreduce = MoEAllReduce(self.mapping)
633636
self.next_layer_layernorm: RMSNorm = None
634637

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def __init__(
282282
quant_config=None)
283283

284284
self.mapping = model_config.mapping
285-
self.all_reduce = AllReduce(self.mapping)
285+
self.all_reduce = AllReduce(
286+
mapping=model_config.mapping,
287+
strategy=model_config.allreduce_strategy,
288+
)
286289
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
287290
self.aux_stream = aux_stream
288291

@@ -414,7 +417,8 @@ def __init__(
414417
dtype=config.torch_dtype)
415418

416419
self.mapping = model_config.mapping
417-
self.all_reduce = AllReduce(self.mapping)
420+
self.all_reduce = AllReduce(mapping=model_config.mapping,
421+
strategy=model_config.allreduce_strategy)
418422
self.next_layer_layernorm: RMSNorm = None
419423
self.next_attn: LlamaAttention = None
420424

@@ -625,7 +629,7 @@ def __init__(
625629
quant_config=model_config.get_quant_config(),
626630
skip_create_weights_in_init=model_config.
627631
skip_create_weights_in_init,
628-
)
632+
allreduce_strategy=model_config.allreduce_strategy)
629633

630634

631635
class Eagle3LlamaDecoderLayer(DecoderLayer):

tensorrt_llm/_torch/models/modeling_nemotron_nas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
4444
gather_output=True,
4545
quant_config=model_config.get_quant_config(),
4646
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
47-
)
47+
allreduce_strategy=model_config.allreduce_strategy)
4848

4949

5050
class NemotronNASAttention(Attention):

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def __init__(
8989
self.top_k = config.num_experts_per_tok
9090
self.enable_attention_dp = model_config.mapping.enable_attention_dp
9191
self.mapping = model_config.mapping
92-
self.allreduce = AllReduce(self.mapping)
92+
self.allreduce = AllReduce(mapping=model_config.mapping,
93+
strategy=model_config.allreduce_strategy)
9394
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
9495
model_config, self.top_k)
9596
if self.enable_alltoall:
@@ -202,7 +203,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
202203
dtype=config.torch_dtype)
203204
self.layer_idx = layer_idx
204205

205-
self.allreduce = AllReduce(self.mapping)
206+
self.allreduce = AllReduce(mapping=model_config.mapping,
207+
strategy=model_config.allreduce_strategy)
206208
self.next_layer_layernorm: RMSNorm = None
207209

208210
self.fusion_config = EagerFusionConfig()

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
weight_mode=WeightMode.FUSED_QKV_LINEAR),
127127
quant_config=config.get_quant_config(),
128128
skip_create_weights_in_init=config.skip_create_weights_in_init,
129-
)
129+
allreduce_strategy=config.allreduce_strategy)
130130
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
131131
[self.hidden_size])
132132

@@ -140,7 +140,7 @@ def __init__(
140140
quant_config=config.get_quant_config(),
141141
skip_create_weights_in_init=config.skip_create_weights_in_init,
142142
lora=self.o_lora,
143-
)
143+
allreduce_strategy=config.allreduce_strategy)
144144

145145
self.quant_config = config.get_quant_config()
146146
self.attn_backend = config.attn_backend
@@ -481,7 +481,8 @@ def __init__(
481481
mapping=mapping,
482482
tensor_parallel_mode=TensorParallelMode.COLUMN,
483483
quant_config=quant_config,
484-
skip_create_weights_in_init=config.skip_create_weights_in_init)
484+
skip_create_weights_in_init=config.skip_create_weights_in_init,
485+
allreduce_strategy=config.allreduce_strategy)
485486
else:
486487
self.fused_a = Linear(
487488
hidden_size,
@@ -501,7 +502,7 @@ def __init__(
501502
tensor_parallel_mode=TensorParallelMode.COLUMN,
502503
quant_config=quant_config,
503504
skip_create_weights_in_init=config.skip_create_weights_in_init,
504-
)
505+
allreduce_strategy=config.allreduce_strategy)
505506
self.q_b_proj = self.q_proj
506507

507508
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
@@ -517,7 +518,8 @@ def __init__(
517518
mapping=mapping,
518519
tensor_parallel_mode=TensorParallelMode.COLUMN,
519520
quant_config=quant_config,
520-
skip_create_weights_in_init=config.skip_create_weights_in_init)
521+
skip_create_weights_in_init=config.skip_create_weights_in_init,
522+
allreduce_strategy=config.allreduce_strategy)
521523
# This parameter will view into self.kv_b_proj.weight after loading weights.
522524
# For dummy weight initialization, this parameter is initialized with empty tensor.
523525
# Used in forward_generation only
@@ -538,7 +540,7 @@ def __init__(
538540
tensor_parallel_mode=TensorParallelMode.ROW,
539541
quant_config=quant_config,
540542
skip_create_weights_in_init=config.skip_create_weights_in_init,
541-
)
543+
allreduce_strategy=config.allreduce_strategy)
542544

543545
def yarn_get_mscale(scale=1, mscale=1):
544546
if scale <= 1:

0 commit comments

Comments
 (0)