Skip to content

[Feature][Kernel][DSR1]: Makefused_grouped_topk more fused (integrate TRT-LLM kernel) #28086

@robertgshaw2-redhat

Description

@robertgshaw2-redhat

🚀 The feature, motivation and pitch

Right now, this operation kernel has multiple steps (

def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
topk_values, topk_indices = ops.grouped_topk(
scores,
scores_with_bias.to(scores.dtype),
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
)

def fused_grouped_topk(
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
    e_score_correction_bias: torch.Tensor,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"

    if scoring_func == "softmax":
        scores = torch.softmax(gating_output, dim=-1)
    elif scoring_func == "sigmoid":
        scores = gating_output.sigmoid()
    else:
        raise ValueError(f"Unsupported scoring function: {scoring_func}")

    scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
    topk_values, topk_indices = ops.grouped_topk(
        scores,
        scores_with_bias.to(scores.dtype),
        num_expert_group,
        topk_group,
        topk,
        renormalize,
        routed_scaling_factor,
    )
    return topk_values.to(torch.float32), topk_indices.to(torch.int32)

We should make this one single kernel to do the:

  • sigmoid
  • addition
  • output types

Alternatives

See below. we should pull in the kernel from trt-llm

Additional context

See below. we should pull in the kernel from trt-llm

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions