Skip to content

Commit 1089e80

Browse files
author
Sanggyu Lee
committed
Fuse LlamaAttention to attention (onert)
It fuses LlamaAttention from TinyLlama model. Fused attention works as onert attention op. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee <[email protected]>
1 parent fc4cb16 commit 1089e80

File tree

5 files changed

+228
-0
lines changed

5 files changed

+228
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DO NOT REMOVE THIS FILE
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# User input
2+
prompt = "Lily picked up a flower."
3+
model_name = "Maykeye/TinyLLama-v0"
4+
5+
# Tokenizer
6+
from transformers import AutoTokenizer
7+
8+
tokenizer = AutoTokenizer.from_pretrained(model_name)
9+
tokenizer.pad_token = tokenizer.eos_token
10+
tokenizer.padding_side = "right"
11+
inputs = tokenizer(
12+
prompt,
13+
return_tensors="pt",
14+
padding="max_length",
15+
max_length=30,
16+
truncation=True,
17+
)
18+
19+
# Generator
20+
import torch
21+
22+
from transformers import AutoModelForCausalLM
23+
24+
model = AutoModelForCausalLM.from_pretrained(model_name)
25+
model.eval()
26+
27+
from tico.utils.record_input import RecordingInput
28+
29+
# past_key_values
30+
# ---------------
31+
# During prefill, "past_key_values" not None, but an empty Cache instance.
32+
# Passing None makes torch.export happy.
33+
34+
35+
input_to_remove = [
36+
"attention_mask",
37+
# For left pad, [0, ⋯, 0, 1, ⋯, 1]
38+
# For right right pad, [1, ⋯, 1, 0, ⋯, 0]
39+
# ( 0 is pad-token )
40+
# This script uses right pad and pass all-1 attention mask (including pad).
41+
# Npu computes all positions whether it is pad or not.
42+
]
43+
condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0
44+
45+
with torch.no_grad(), RecordingInput(
46+
model, condition_fn, input_to_remove=input_to_remove
47+
) as rec:
48+
outputs = model.generate(
49+
**inputs,
50+
max_new_tokens=32,
51+
do_sample=False,
52+
pad_token_id=tokenizer.eos_token_id,
53+
)
54+
captured_input = rec.captured_input
55+
56+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
57+
print(generated_text)
58+
59+
# Tico
60+
import tico
61+
from tico.serialize.operators.adapters.onert.op_attention import (
62+
llama_attention_forward_adapter,
63+
)
64+
from transformers.models.llama.modeling_llama import LlamaAttention
65+
66+
LlamaAttention.forward = llama_attention_forward_adapter
67+
68+
model = AutoModelForCausalLM.from_pretrained(model_name)
69+
model.eval()
70+
circle_model = tico.convert(model, captured_input)
71+
circle_model.save(f"tinyllama.decode.circle")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
transformers>=4.50.1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DO NOT REMOVE THIS FILE
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch._ops
19+
import torch.fx
20+
import torch
21+
from circle_schema import circle
22+
23+
from torch.library import Library
24+
25+
from tico.serialize.circle_graph import CircleSubgraph
26+
from tico.serialize.operators.hashable_opcode import OpCode
27+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
29+
30+
lib = Library("circle", "DEF")
31+
lib.define(
32+
"""
33+
attention.llama(
34+
Tensor hidden_states,
35+
Tensor wq,
36+
Tensor wk,
37+
Tensor wv,
38+
Tensor wo,
39+
Tensor position_cos,
40+
Tensor position_sin,
41+
Tensor attention_mask,
42+
Tensor past_key,
43+
Tensor past_value,
44+
Tensor cache_position
45+
) -> Tensor
46+
"""
47+
)
48+
49+
# ATTENTION FUSER
50+
51+
52+
@torch.library.register_fake("circle::attention.llama")
53+
def attention_llama(*args, **kwargs):
54+
(
55+
hidden_states,
56+
q_proj,
57+
k_proj,
58+
v_proj,
59+
o_proj,
60+
position_cos,
61+
position_sin,
62+
attention_mask,
63+
past_key,
64+
past_value,
65+
cache_position,
66+
) = args
67+
return hidden_states
68+
69+
70+
from typing import List, Optional
71+
72+
from transformers.cache_utils import DynamicCache
73+
from transformers.models.llama.modeling_llama import LlamaAttention
74+
75+
76+
def llama_attention_forward_adapter(
77+
self: LlamaAttention,
78+
hidden_states: torch.Tensor,
79+
position_embeddings: List[torch.Tensor],
80+
attention_mask: torch.Tensor,
81+
past_key_value: DynamicCache,
82+
cache_position: torch.Tensor,
83+
**kwargs,
84+
):
85+
# past_key_value is a dict with key_cache and value_cache.
86+
# It needs to be decomposed for tico and circle which does not know dict.
87+
key_cache = past_key_value.key_cache # type: ignore[union-attr]
88+
value_cache = past_key_value.value_cache # type: ignore[union-attr]
89+
return (
90+
torch.ops.circle.attention.llama(
91+
hidden_states,
92+
self.q_proj.weight,
93+
self.k_proj.weight,
94+
self.v_proj.weight,
95+
self.o_proj.weight,
96+
position_embeddings[0], # cos
97+
position_embeddings[1], # sin
98+
attention_mask,
99+
# key_cache is a list of cache for each decoder layer.
100+
# Assumtion: key cache is continuous
101+
#
102+
# k_cache[0] | k_cache[1] | ... | k_cache[n]
103+
key_cache[self.layer_idx],
104+
value_cache[self.layer_idx], # Same to value_cache
105+
cache_position,
106+
),
107+
None,
108+
)
109+
110+
111+
@register_node_visitor
112+
class AttentionVisitor(NodeVisitor):
113+
target: List[torch._ops.OpOverload] = [
114+
torch.ops.circle.attention.llama,
115+
]
116+
117+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
118+
super().__init__(op_codes, graph)
119+
120+
def define_node(
121+
self,
122+
node: torch.fx.Node,
123+
) -> circle.Operator.OperatorT:
124+
(
125+
hidden_states,
126+
wq,
127+
wk,
128+
wv,
129+
wo,
130+
position_cos,
131+
position_sin,
132+
attention_mask,
133+
past_key,
134+
past_value,
135+
cache_position,
136+
) = node.args
137+
138+
op_index = get_op_index(
139+
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
140+
)
141+
142+
# remove last arg (= layer_idx) from inputs.
143+
# layer_idx is attention op's param, not input.
144+
inputs = node.args[:-1]
145+
outputs = [node]
146+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
147+
148+
# Op-specific option
149+
operator.builtinOptionsType = (
150+
circle.BuiltinOptions.BuiltinOptions.AttentionOptions
151+
)
152+
operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT()
153+
154+
return operator

0 commit comments

Comments
 (0)