diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py index 89061dae3..697e4b4cd 100755 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py @@ -10,7 +10,6 @@ from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index f5a6ef4b8..a724d5668 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -10,6 +10,7 @@ from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import ( @@ -297,3 +298,27 @@ def _init_config(self): if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size return + + +@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class InternVLQwen3MOETpPartModel(Qwen3MOEModel): + # weight class + pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight + + # infer class + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file)["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/models/llama/triton_kernel/rmsnorm.py index da0e1369f..2dc52d728 100644 --- a/lightllm/models/llama/triton_kernel/rmsnorm.py +++ b/lightllm/models/llama/triton_kernel/rmsnorm.py @@ -1,11 +1,11 @@ +import os import torch - import triton import triton.language as tl @triton.jit -def _rms_norm_fwd_fused( +def _rmsnorm_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights @@ -41,7 +41,7 @@ def _rms_norm_fwd_fused( tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) -def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): +def rmsnorm(x: torch.Tensor, weight, eps, out=None): # allocate output y = torch.empty_like(x) if out is None else out # reshape input data into 2D tensor @@ -61,7 +61,7 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): if BLOCK_SIZE > 16384: BLOCK_SIZE = 16384 # enqueue kernel - _rms_norm_fwd_fused[(M,)]( + _rmsnorm_kernel[(M,)]( x_arg, y_arg, weight, @@ -77,6 +77,75 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None): return y +@triton.jit +def _rms_norm_kernel( + input, + weight, + output, + in_row_stride: tl.constexpr, + in_col_stride: tl.constexpr, + out_row_stride: tl.constexpr, + out_col_stride: tl.constexpr, + eps: tl.constexpr, + N_COLS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + + w = tl.load(weight + offsets, mask=offsets < N_COLS, other=0.0) + + x_ptr = input + prog_id * in_row_stride + x = tl.load(x_ptr + offsets * in_col_stride, mask=offsets < N_COLS, other=0.0) + xf = x.to(tl.float32) + + var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) + out = xf / tl.sqrt(var + eps) + out = (w * out).to(x.dtype) + + out_ptr = output + prog_id * out_row_stride + tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS) + + +def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5, out=None): + """Rms norm.""" + + assert hidden_states.is_contiguous(), "hidden_states must be contiguous" + + origin_shape = hidden_states.shape + hidden_dim = weight.shape[0] + assert hidden_dim == origin_shape[-1], f"hidden_dim {hidden_dim} != {origin_shape[-1]}" + + rows = hidden_states.numel() // hidden_dim + if hidden_states.dim() == 3: # (bs, seq_len, hidden_dim) + hidden_states = hidden_states.view(rows, hidden_dim) + + in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1) + + BLOCK_N = triton.next_power_of_2(hidden_dim) + + output = torch.empty_like(hidden_states) if out is None else out + + out_row_stride, out_col_stride = output.stride(0), output.stride(1) + grid = (rows,) + _rms_norm_kernel[grid]( + hidden_states, + weight, + output, + in_row_stride, + in_col_stride, + out_row_stride, + out_col_stride, + eps=eps, + N_COLS=hidden_dim, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=3, + ) + return output.reshape(origin_shape) + + def torch_rms_norm(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight @@ -89,10 +158,20 @@ def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") # forward pass y_tri = rmsnorm_forward(x, weight, eps) + y_tri_1 = rms_norm(x, weight, eps) y_ref = torch_rms_norm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype) # compare - print("type:", y_tri.dtype, y_ref.dtype) + print("type:", y_tri.dtype, y_ref.dtype, y_tri_1.dtype) print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + print("max delta:", torch.max(torch.abs(y_tri_1 - y_ref))) assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) return + + +use_high_acc = os.getenv("RMSNORM_HIGH_ACCURACY", "False").upper() in ["ON", "TRUE", "1"] + +if use_high_acc: + rmsnorm_forward = rms_norm +else: + rmsnorm_forward = rmsnorm diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 14ba9cfed..085e16e9d 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -7,11 +7,10 @@ import triton from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd -from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm +from lightllm.models.llama.triton_kernel.rmsnorm import rms_norm from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager diff --git a/lightllm/models/vit/triton_kernel/rms_norm_vit.py b/lightllm/models/vit/triton_kernel/rms_norm_vit.py deleted file mode 100644 index 387bedfcf..000000000 --- a/lightllm/models/vit/triton_kernel/rms_norm_vit.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import triton -import triton.language as tl -from torch import Tensor -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - -@triton.jit -def rms_norm_kernel( - input, - weight, - output, - in_row_stride: tl.constexpr, - in_col_stride: tl.constexpr, - out_row_stride: tl.constexpr, - out_col_stride: tl.constexpr, - eps: tl.constexpr, - N_COLS: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """Rms norm kernel.""" - prog_id = tl.program_id(0) - offsets = tl.arange(0, BLOCK_N) - - w = tl.load(weight + offsets, mask=offsets < N_COLS, other=0.0) - - x_ptr = input + prog_id * in_row_stride - x = tl.load(x_ptr + offsets * in_col_stride, mask=offsets < N_COLS, other=0.0) - xf = x.to(tl.float32) - - var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) - out = xf / tl.sqrt(var + eps) - out = (w * out).to(x.dtype) - - out_ptr = output + prog_id * out_row_stride - tl.store(out_ptr + offsets * out_col_stride, out, mask=offsets < N_COLS) - - -def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False): - """Rms norm.""" - - assert hidden_states.is_contiguous(), "hidden_states must be contiguous" - - origin_shape = hidden_states.shape - hidden_dim = weight.shape[0] - assert hidden_dim == origin_shape[-1], f"hidden_dim {hidden_dim} != {origin_shape[-1]}" - - rows = hidden_states.numel() // hidden_dim - if hidden_states.dim() == 3: # (bs, seq_len, hidden_dim) - hidden_states = hidden_states.view(rows, hidden_dim) - - in_row_stride, in_col_stride = hidden_states.stride(0), hidden_states.stride(1) - - BLOCK_N = triton.next_power_of_2(hidden_dim) - if use_custom_tensor_mananger: - shape = hidden_states.shape - dtype = hidden_states.dtype - device = hidden_states.device - output = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - output = torch.empty_like(hidden_states) - - out_row_stride, out_col_stride = output.stride(0), output.stride(1) - grid = (rows,) - rms_norm_kernel[grid]( - hidden_states, - weight, - output, - in_row_stride, - in_col_stride, - out_row_stride, - out_col_stride, - eps=eps, - N_COLS=hidden_dim, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=3, - ) - return output.reshape(origin_shape) - - -def test(): - def _rms_norm_ref(x: torch.Tensor, weight: torch.Tensor, eps: float): - var = (x.float() ** 2).mean(dim=-1, keepdim=True) - y = x.float() / torch.sqrt(var + eps) - return (y * weight).to(x.dtype) - - torch.manual_seed(0) - device, dtype = "cuda", torch.float16 - bs, seq_len, hidden = 3, 1025, 3200 - eps = 1e-5 - weight = torch.randn(hidden, device=device, dtype=dtype) - - # 2-D contiguous - x2 = torch.randn(seq_len, hidden, device=device, dtype=dtype).contiguous() - assert torch.allclose(rms_norm(x2, weight, eps), _rms_norm_ref(x2, weight, eps), atol=1e-3, rtol=1e-3) - - # 3-D contiguous - x3 = torch.randn(bs, seq_len, hidden, device=device, dtype=dtype).contiguous() - assert torch.allclose(rms_norm(x3, weight, eps), _rms_norm_ref(x3, weight, eps), atol=1e-3, rtol=1e-3) - - print("all tests pass")