-
Notifications
You must be signed in to change notification settings - Fork 265
fix rms_norm #945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix rms_norm #945
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
from einops import rearrange | ||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm | ||
from lightllm.common.basemodel import PostLayerInferTpl | ||
from lightllm.utils.infer_utils import mark_cost_time | ||
from lightllm.distributed.communication_op import all_gather | ||
|
@@ -25,7 +25,7 @@ def __init__(self, network_config, mode): | |
return | ||
|
||
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: | ||
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) | ||
return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo): | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,7 +14,7 @@ | |||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k | ||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd | ||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v | ||||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||||||
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import path seems incorrect. This change imports
Suggested change
|
||||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||||||
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd | ||||||
|
||||||
|
@@ -135,14 +135,14 @@ def _att_norm( | |||||
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight | ||||||
) -> torch.Tensor: | ||||||
out = self.alloc_tensor(input.shape, input.dtype) | ||||||
rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) | ||||||
rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
return out | ||||||
|
||||||
def _ffn_norm( | ||||||
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight | ||||||
) -> torch.Tensor: | ||||||
out = self.alloc_tensor(input.shape, input.dtype) | ||||||
rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) | ||||||
rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
return out | ||||||
|
||||||
def _get_qkv( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import path seems incorrect. This change imports
rms_norm
from the vit model, but the file name indicates it should be for llama. Double check this import.