Skip to content

ggml : support broadcast for ggml_soft_max_ext and ggml_flash_attn_ext #14435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 28, 2025

Extract broadcast changes from #14363 for ggml_soft_max_ext() and ggml_flash_attn_ext():

llama.cpp/ggml/include/ggml.h

Lines 1435 to 1451 in 236682a

// a [ne0, ne01, ne02, ne03]
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
//
// broadcast:
// ne02 % ne12 == 0
// ne03 % ne13 == 0
//
// fused soft_max(a*scale + mask*(ALiBi slope))
// max_bias = 0.0f for no ALiBi
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias);

llama.cpp/ggml/include/ggml.h

Lines 1876 to 1896 in 236682a

// q: [n_embd_k, n_batch, n_head, ne3]
// k: [n_embd_k, n_kv, n_head_kv, ne3]
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
// ne3 % ne32 == 0
//
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
float max_bias,
float logit_softcap);

Both changes should be quite simple. On master we have the assumption that the mask is a 2D matrix and we always broadcast it across the dim 2 (i.e. the heads) and dim 3. With this change, we allow to have separate masks - i.e. generalized broadcast.

Currently, I've added tests and implemented the CPU and Metal to support this. The rest of the backends will fallback to CPU, until this gets implemented:

  • CPU
  • Metal
  • CUDA
  • Vulkan
  • CANN

Fallback is okay for now since these extensions are not used at the moment by llama.cpp. This support will be needed later for the #14363 PR, although it's better to support this either way.

@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jun 28, 2025
@ggerganov ggerganov force-pushed the gg/ggml-batch-soft-max-ops branch from e6faa45 to 236682a Compare June 28, 2025 14:53
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend Ascend NPU issues specific to Ascend NPUs labels Jun 28, 2025
@ggerganov ggerganov force-pushed the gg/ggml-batch-soft-max-ops branch 3 times, most recently from 852529e to bdfd7b7 Compare June 28, 2025 15:39
@ggerganov ggerganov marked this pull request as ready for review June 28, 2025 15:39
@github-actions github-actions bot added the SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language label Jun 28, 2025
@ggerganov ggerganov force-pushed the gg/ggml-batch-soft-max-ops branch from bdfd7b7 to 461cb2f Compare June 29, 2025 06:48
@jeffbolznv
Copy link
Collaborator

I've started working on the Vulkan backend support for this.

@jeffbolznv
Copy link
Collaborator

Vulkan support is in #14449, targeted to this branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants