-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
1、背景:
主仓支持传递基础类型变量到kernel侧:
eg:tilelang-main\examples\flash_attention\example_mha_fwd_varlen.py
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32, --传递基础类型
Output_unpad: T.Tensor(o_shape, dtype),
):
2、期望:
tilelangascend也能支持,这样可以直接传递变量到kernel侧,实现与原算子语义流程一致(当前func中传过来的变量,翻译后其实是常量)
Metadata
Metadata
Assignees
Labels
No labels