Skip to content

Commit f432e97

Browse files
author
sangchengmeng
committed
fix rmsnorm
1 parent dc4475a commit f432e97

File tree

8 files changed

+37
-62
lines changed

8 files changed

+37
-62
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,16 @@ def _get_qkv(
154154
q = layer_weight.q_weight_.mm(input)
155155
else:
156156
q = layer_weight.q_a_proj_.mm(input)
157-
q = rmsnorm_forward(
158-
q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, use_custom_tensor_mananger=True
159-
)
157+
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
160158
q = layer_weight.q_b_proj_.mm(q)
161159
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
162160
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
163161
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
164-
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
162+
rmsnorm_forward(
165163
cache_kv[:, :, : self.kv_lora_rank],
166164
weight=layer_weight.kv_a_layernorm_.weight,
167165
eps=self.eps_,
168-
use_custom_tensor_mananger=True,
166+
out=cache_kv[:, :, : self.kv_lora_rank],
169167
)
170168

171169
rotary_emb_fwd(
@@ -193,16 +191,16 @@ def _tpsp_get_qkv(
193191
q = layer_weight.q_weight_.mm(input)
194192
else:
195193
q = layer_weight.q_a_proj_.mm(input)
196-
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
194+
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
197195
q = layer_weight.q_b_proj_.mm(q)
198196
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
199197
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
200198
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
201-
cache_kv[:, :, : self.kv_lora_rank] = rmsnorm_forward(
199+
rmsnorm_forward(
202200
cache_kv[:, :, : self.kv_lora_rank],
203201
weight=layer_weight.kv_a_layernorm_.weight,
204202
eps=self.eps_,
205-
use_custom_tensor_mananger=True,
203+
out=cache_kv[:, :, : self.kv_lora_rank],
206204
)
207205
rotary_emb_fwd(
208206
q_rope,

lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ def _mtp_context_forward(
2020
):
2121
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
2222
assert input_embdings.shape[0] == tgt_embdings.shape[0]
23-
input_embdings = rmsnorm_forward(
24-
input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
25-
)
26-
tgt_embdings = rmsnorm_forward(
27-
tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
28-
)
23+
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
24+
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
2925

3026
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
3127

@@ -40,12 +36,8 @@ def _mtp_token_forward(
4036
):
4137
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
4238
assert input_embdings.shape[0] == tgt_embdings.shape[0]
43-
input_embdings = rmsnorm_forward(
44-
input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
45-
)
46-
tgt_embdings = rmsnorm_forward(
47-
tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, use_custom_tensor_mananger=True
48-
)
39+
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
40+
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
4941

5042
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
5143

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, network_config, mode):
2525
return
2626

2727
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
28-
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
28+
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
2929

3030
def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):
3131

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,16 @@ def _bind_attention(self):
135135
def _att_norm(
136136
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
137137
) -> torch.Tensor:
138-
return rmsnorm_forward(
139-
input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
140-
)
138+
out = self.alloc_tensor(input.shape, input.dtype)
139+
rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out)
140+
return out
141141

142142
def _ffn_norm(
143143
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
144144
) -> torch.Tensor:
145-
return rmsnorm_forward(
146-
input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
147-
)
145+
out = self.alloc_tensor(input.shape, input.dtype)
146+
rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out)
147+
return out
148148

149149
def _get_qkv(
150150
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import torch
33
import triton
44
import triton.language as tl
5-
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
65

76

87
@triton.jit
9-
def _rms_norm_low_accuracy_kernel(
8+
def _rmsnorm_kernel(
109
X, # pointer to the input
1110
Y, # pointer to the output
1211
W, # pointer to the weights
@@ -42,15 +41,9 @@ def _rms_norm_low_accuracy_kernel(
4241
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)
4342

4443

45-
def rmsnorm_forward_low_accuracy(x: torch.Tensor, weight, eps, use_custom_tensor_mananger: bool = False):
44+
def rmsnorm(x: torch.Tensor, weight, eps, out=None):
4645
# allocate output
47-
if use_custom_tensor_mananger:
48-
shape = x.shape
49-
dtype = x.dtype
50-
device = x.device
51-
y = g_cache_manager.alloc_tensor(shape, dtype, device=device)
52-
else:
53-
y = torch.empty_like(x)
46+
y = torch.empty_like(x) if out is None else out
5447
# reshape input data into 2D tensor
5548
x_arg = x.view(-1, x.shape[-1])
5649
y_arg = y.view(-1, x.shape[-1])
@@ -68,7 +61,7 @@ def rmsnorm_forward_low_accuracy(x: torch.Tensor, weight, eps, use_custom_tensor
6861
if BLOCK_SIZE > 16384:
6962
BLOCK_SIZE = 16384
7063
# enqueue kernel
71-
_rms_norm_low_accuracy_kernel[(M,)](
64+
_rmsnorm_kernel[(M,)](
7265
x_arg,
7366
y_arg,
7467
weight,
@@ -85,7 +78,7 @@ def rmsnorm_forward_low_accuracy(x: torch.Tensor, weight, eps, use_custom_tensor
8578

8679

8780
@triton.jit
88-
def _rms_norm_high_accuracy_kernel(
81+
def _rms_norm_kernel(
8982
input,
9083
weight,
9184
output,
@@ -115,9 +108,7 @@ def _rms_norm_high_accuracy_kernel(
115108
tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS)
116109

117110

118-
def rmsnorm_forward_high_accuracy(
119-
hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False
120-
):
111+
def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, out=None):
121112
"""Rms norm."""
122113

123114
assert hidden_states.is_contiguous(), "hidden_states must be contiguous"
@@ -133,17 +124,12 @@ def rmsnorm_forward_high_accuracy(
133124
in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1)
134125

135126
BLOCK_N = triton.next_power_of_2(hidden_dim)
136-
if use_custom_tensor_mananger:
137-
shape = hidden_states.shape
138-
dtype = hidden_states.dtype
139-
device = hidden_states.device
140-
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
141-
else:
142-
output = torch.empty_like(hidden_states)
127+
128+
output = torch.empty_like(hidden_states) if out is None else out
143129

144130
out_row_stride, out_col_stride = output.stride(0), output.stride(1)
145131
grid = (rows,)
146-
_rms_norm_high_accuracy_kernel[grid](
132+
_rms_norm_kernel[grid](
147133
hidden_states,
148134
weight,
149135
output,
@@ -171,21 +157,21 @@ def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"):
171157
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
172158
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
173159
# forward pass
174-
y_tri = rmsnorm_forward_low_accuracy(x, weight, eps)
175-
y_tri_high_acc = rmsnorm_forward_high_accuracy(x, weight, eps)
160+
y_tri = rmsnorm_forward(x, weight, eps)
161+
y_tri_1 = rms_norm(x, weight, eps)
176162
y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)
177163

178164
# compare
179-
print("type:", y_tri.dtype, y_ref.dtype, y_tri_high_acc.dtype)
165+
print("type:", y_tri.dtype, y_ref.dtype, y_tri_1.dtype)
180166
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
181-
print("max delta:", torch.max(torch.abs(y_tri_high_acc - y_ref)))
167+
print("max delta:", torch.max(torch.abs(y_tri_1 - y_ref)))
182168
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
183169
return
184170

185171

186172
use_high_acc = os.getenv("RMSNORM_HIGH_ACCURACY", "False").upper() in ["ON", "TRUE", "1"]
187173

188174
if use_high_acc:
189-
rmsnorm_forward = rmsnorm_forward_high_accuracy
175+
rmsnorm_forward = rms_norm
190176
else:
191-
rmsnorm_forward = rmsnorm_forward_low_accuracy
177+
rmsnorm_forward = rmsnorm

lightllm/models/qwen3/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ def _get_qkv(
3636
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
3737
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
3838

39-
q = rmsnorm_forward(
39+
rmsnorm_forward(
4040
q.view(-1, self.head_dim_),
4141
weight=layer_weight.q_norm_weight_.weight,
4242
eps=self.eps_,
4343
use_custom_tensor_mananger=True,
44+
out=q.view(-1, self.head_dim_),
4445
)
4546

4647
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
4748
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
4849
weight=layer_weight.k_norm_weight_.weight,
4950
eps=self.eps_,
50-
use_custom_tensor_mananger=True,
5151
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
5252

5353
rotary_emb_fwd(

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,17 @@ def _get_qkv(
6060
cache_kv = layer_weight.kv_proj.mm(
6161
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
6262
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
63-
q = rmsnorm_forward(
63+
rmsnorm_forward(
6464
q.view(-1, self.head_dim_),
6565
weight=layer_weight.q_norm_weight_.weight,
6666
eps=self.eps_,
67-
use_custom_tensor_mananger=True,
67+
out=q.view(-1, self.head_dim_),
6868
)
6969

7070
cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward(
7171
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
7272
weight=layer_weight.k_norm_weight_.weight,
7373
eps=self.eps_,
74-
use_custom_tensor_mananger=True,
7574
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])
7675

7776
rotary_emb_fwd(

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
1111
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
1212
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
13-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward_high_accuracy as rms_norm
13+
from lightllm.models.llama.triton_kernel.rmsnorm import rms_norm
1414
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1515

1616

0 commit comments

Comments
 (0)