Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add LBR_GRU to cell_fusion kernel for MTL dispatch #2868

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

h-sadia
Copy link
Contributor

@h-sadia h-sadia commented Mar 12, 2025

Description

This WIP PR has a few things left:

  • Figure out correction issues for LBR GRU
  • Consolidate common code between cell_common and cell_gru_lbr and put it in a common header file
  • Consolidate lbr gru code in the kernel and create a function to remove duplicate code
  • Ideally a new lcm function can be created for k_limit
    Fixes # (MFDNN-12712)

@h-sadia h-sadia requested a review from a team as a code owner March 12, 2025 19:50
@github-actions github-actions bot added the platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel label Mar 12, 2025
int eu_count = device_info.eu_count();
int ideal_k_block = graph::utils::lcm(
eu_count, (int)device_info.min_subgroup_size());
int ideal_k_limit = graph::utils::lcm(ideal_k_block, (int)rnn.sic);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.rnn.bias[off_ker_bias(dims.dhc, 0, c)],
ctx.rnn.alpha, ctx.rnn.tm_scales);
store_vanilla_rnn(gates.ptr, gates.strides.mb, states.ptr,
states.strides.mb, dims.dhc, n, c, g);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the runtime if statements so that we get compiler tests on all cases during development.

* TO_REF(
ctx.lbr_gru.hidden_state_iter[cell_ws_state(
states.strides.mb, n, c)])
+ (1 - G0) * G2;
Copy link
Contributor

@rjoursler rjoursler Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wrap the Ht calculation into compute_gates_lbr_gru., we can then just get rid of all these intermediate variables.

Additionally, please update elemwise_fwd to use compute_gates_lbr_gru to avoid code duplication.

}

status_t compute_cell_fwd(const exec_ctx_t &ctx,
const compute::kernel_t &kernel, dim_t lay, dim_t dir, dim_t iter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than duplicate this function, can we reasonably create a shared implementation between cell_common and this one.?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants