Skip to content

fix rms_norm #945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward


class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer):
Expand Down
25 changes: 25 additions & 0 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.qwen2.model import Qwen2TpPartModel
from lightllm.models.qwen3.model import Qwen3TpPartModel
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import (
Expand Down Expand Up @@ -297,3 +298,27 @@ def _init_config(self):
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size
return


@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe"))
class InternVLQwen3MOETpPartModel(Qwen3MOEModel):
# weight class
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight

# infer class
pre_layer_infer_class = LlamaMultimodalPreLayerInfer

def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
self.config = json.load(json_file)["llm_config"]
# rename keys
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
if self.finetune_config:
self.config["vocab_size"] = self.finetune_config.vocab_size
return
89 changes: 84 additions & 5 deletions lightllm/models/llama/triton_kernel/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import torch

import triton
import triton.language as tl


@triton.jit
def _rms_norm_fwd_fused(
def _rmsnorm_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
Expand Down Expand Up @@ -41,7 +41,7 @@ def _rms_norm_fwd_fused(
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)


def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
def rmsnorm(x: torch.Tensor, weight, eps, out=None):
# allocate output
y = torch.empty_like(x) if out is None else out
# reshape input data into 2D tensor
Expand All @@ -61,7 +61,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
# enqueue kernel
_rms_norm_fwd_fused[(M,)](
_rmsnorm_kernel[(M,)](
x_arg,
y_arg,
weight,
Expand All @@ -77,6 +77,75 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
return y


@triton.jit
def _rms_norm_kernel(
input,
weight,
output,
in_row_stride: tl.constexpr,
in_col_stride: tl.constexpr,
out_row_stride: tl.constexpr,
out_col_stride: tl.constexpr,
eps: tl.constexpr,
N_COLS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Rms norm kernel."""
prog_id = tl.program_id(0)
offsets = tl.arange(0, BLOCK_N)

w = tl.load(weight + offsets, mask=offsets < N_COLS, other=0.0)

x_ptr = input + prog_id * in_row_stride
x = tl.load(x_ptr + offsets * in_col_stride, mask=offsets < N_COLS, other=0.0)
xf = x.to(tl.float32)

var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
out = xf / tl.sqrt(var + eps)
out = (w * out).to(x.dtype)

out_ptr = output + prog_id * out_row_stride
tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS)


def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, out=None):
"""Rms norm."""

assert hidden_states.is_contiguous(), "hidden_states must be contiguous"

origin_shape = hidden_states.shape
hidden_dim = weight.shape[0]
assert hidden_dim == origin_shape[-1], f"hidden_dim {hidden_dim} != {origin_shape[-1]}"

rows = hidden_states.numel() // hidden_dim
if hidden_states.dim() == 3: # (bs, seq_len, hidden_dim)
hidden_states = hidden_states.view(rows, hidden_dim)

in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1)

BLOCK_N = triton.next_power_of_2(hidden_dim)

output = torch.empty_like(hidden_states) if out is None else out

out_row_stride, out_col_stride = output.stride(0), output.stride(1)
grid = (rows,)
_rms_norm_kernel[grid](
hidden_states,
weight,
output,
in_row_stride,
in_col_stride,
out_row_stride,
out_col_stride,
eps=eps,
N_COLS=hidden_dim,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=3,
)
return output.reshape(origin_shape)


def torch_rms_norm(x, weight, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight

Expand All @@ -89,10 +158,20 @@ def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"):
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
# forward pass
y_tri = rmsnorm_forward(x, weight, eps)
y_tri_1 = rms_norm(x, weight, eps)
y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)

# compare
print("type:", y_tri.dtype, y_ref.dtype)
print("type:", y_tri.dtype, y_ref.dtype, y_tri_1.dtype)
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
print("max delta:", torch.max(torch.abs(y_tri_1 - y_ref)))
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
return


use_high_acc = os.getenv("RMSNORM_HIGH_ACCURACY", "False").upper() in ["ON", "TRUE", "1"]

if use_high_acc:
rmsnorm_forward = rms_norm
else:
rmsnorm_forward = rmsnorm
3 changes: 1 addition & 2 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import triton

from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.models.llama.triton_kernel.rmsnorm import rms_norm
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


Expand Down
102 changes: 0 additions & 102 deletions lightllm/models/vit/triton_kernel/rms_norm_vit.py

This file was deleted.