Skip to content

Matrix multiplication not delegated to backend during export (Whisper) #15914

@IgorSwat

Description

@IgorSwat

🐛 Describe the bug

During the export of whisper-tiny.en model to ExecuTorch format, I encountered the following problem: one of the matrix multiplication operations is not being delegated to the XNNPACK backend, causing a significant inference speed slowdown (up to 50% of the whole inference time).

The operation is located within the forward method of Decoder module from openai-whisper package:

def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            # ISSUE: the below matrix multiplication is not being delegated
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

Steps to reproduce

I use the following code to export the module with static shapes:

import torch
import whisper
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower

model = whisper.load_model("tiny.en").decoder
model.eval()

inputs = (
  torch.randint(1, 100, (1, 128), dtype=torch.int),
  torch.randn(1, 1500, 384, dtype=torch.float32)
)

# Export
exported_program = torch.export.export(model, inputs)

executorch_program = to_edge_transform_and_lower(
    exported_program,
    partitioner = [XnnpackPartitioner()]
).to_executorch()

with open("whisper/exported/decoder.pte", "wb") as file:
    executorch_program.write_to_file(file)

To profile the exported model, I follow the instructions from ExecuTorch docs.

Actual behavior

A single native_call_mm.out operation not delegated to the XNNPACK backend is responsible for approximately 50% of the inference time.
The profiling results are available here.

What I tried

  • Using other matrix multiplication function (matmul() and mm()) - no effect
  • Replacing matrix multiplication with an equivalent nn.Linear call with static shapes - resolves the delegation issue, and produces a speedup from approximately 145 ms to 80 ms inference time (~45%).

Versions

Collecting environment information...
PyTorch version: 2.10.0.dev20250916
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 4.1.2
Libc version: N/A

Python version: 3.12.11 (main, Jun 3 2025, 15:41:47) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.7.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] executorch==1.0.0.dev20250916
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.4
[pip3] optimum-executorch==0.2.0.dev0
[pip3] pytorch_tokenizers==0.1.0
[pip3] torch==2.10.0.dev20250916
[pip3] torchao==0.14.0.dev20250916+cpu
[pip3] torchaudio==2.8.0.dev20250916
[pip3] torchvision==0.25.0.dev20250916
[pip3] openai-whisper==20250625

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions