Skip to content

🐛 [Bug] Wan2.1 SDPA in Torch-TRT is slower than SDPA in ONNX as batch size (num_frames) grow #3695

@cehongwang

Description

@cehongwang

We found that for 1-layer transformer and full transformer of Wan2.1, the gap between Torch-TensorRT and Onnx-TensorRT becomes larger when the batch_size (num_frames) grows. When num_frames=1, the difference is within 5%. But when num_frames is large, the difference goes to more than 50%. It is mainly caused by mha kernel because mha kernels take the majority of time when num_frames is large.

num_frames=1
Torch_TensorRT: 7.344 , { "name" : "_gemm_mha_v2_myl3_43", "timeMs" : 523.523, "averageMs" : 0.934863, "medianMs" : 0.93472, "percentage" : 12.7283 }



ONNX_TensorRT: 7.292ms, { "name" : "_gemm_mha_v2_myl3_44", "timeMs" : 257.947, "averageMs" : 0.445505, "medianMs" : 0.445504, "percentage" : 6.10986 }



num_frames=81
Torch_TensorRT: 483.11ms, { "name" : "_gemm_mha_v2_myl3_40", "timeMs" : 4023.78, "averageMs" : 365.798, "medianMs" : 365.339, "percentage" : 75.7176 }

ONNX_TensorRT: 314.05ms, { "name" : "_gemm_mha_v2_myl3_45", "timeMs" : 2487.11, "averageMs" : 177.651, "medianMs" : 177.962, "percentage" : 56.5674 }

Moreover, for the full model, Torch-TensorRT engine size of fp16 is 36GB and ONNX-TensorRT 26GB. We think it might be constant_folding issue

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions