Skip to content

Commit bd2ea41

Browse files
committed
replace with custom sdpa before export
1 parent 0028ccf commit bd2ea41

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

export_et.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# XnnpackDynamicallyQuantizedPartitioner,
2525
# )
2626
from executorch_portable_utils import export_to_edge
27+
from export_et_util import replace_attention_with_sdpa_attention
2728

2829
from quantize import get_precision
2930
from torch._export import capture_pre_autograd_graph
@@ -106,6 +107,9 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
106107
else:
107108
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
108109

110+
111+
112+
replace_attention_with_sdpa_attention(export_model)
109113
with torch.nn.attention.sdpa_kernel(
110114
[torch.nn.attention.SDPBackend.MATH]
111115
), torch.no_grad():

export_et_util.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
22
from build.model import Attention
3+
from torch import nn
34

4-
class AttentionWithSDPA(nn.Module):
5+
class SDPAAttention(nn.Module):
56
def __init__(self, attention: Attention):
67
super().__init__()
78

@@ -51,3 +52,11 @@ def forward(
5152
)
5253
output = output.view(bsz, seqlen, self.dim)
5354
return self.wo(output)
55+
56+
57+
def replace_attention_with_sdpa_attention(module: nn.Module):
58+
for name, child in module.named_children():
59+
if isinstance(child, Attention):
60+
setattr(module, name, SDPAAttention(child))
61+
else:
62+
replace_attention_with_sdpa_attention(child)

0 commit comments

Comments
 (0)