Skip to content

fix rms_norm #945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from einops import rearrange
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import path seems incorrect. This change imports rms_norm from the vit model, but the file name indicates it should be for llama. Double check this import.

Suggested change
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward

from lightllm.common.basemodel import PostLayerInferTpl
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.distributed.communication_op import all_gather
Expand All @@ -25,7 +25,7 @@ def __init__(self, network_config, mode):
return

def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider whether use_custom_tensor_mananger=True should be a configurable parameter or determined based on the execution environment, rather than being hardcoded.

return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_)


def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):

Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import path seems incorrect. This change imports rms_norm from the vit model, but the file name indicates it should be for llama. Double check this import.

Suggested change
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward

from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd

Expand Down Expand Up @@ -135,14 +135,14 @@ def _att_norm(
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out)
rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider whether use_custom_tensor_mananger=True should be a configurable parameter or determined based on the execution environment, rather than being hardcoded.

rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)

return out

def _ffn_norm(
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out)
rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider whether use_custom_tensor_mananger=True should be a configurable parameter or determined based on the execution environment, rather than being hardcoded.

rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)

return out

def _get_qkv(
Expand Down