Skip to content

Commit fd52647

Browse files
authored
[Model] Replace diffusers apply_rotary_emb with omni RotaryEmbedding (#496)
Signed-off-by: iwzbi <[email protected]>
1 parent 667f6cf commit fd52647

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
import torch.nn as nn
9-
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
9+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
1010
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1111
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
1212
from vllm.logger import init_logger
@@ -17,6 +17,7 @@
1717

1818
from vllm_omni.diffusion.attention.layer import Attention
1919
from vllm_omni.diffusion.data import OmniDiffusionConfig
20+
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
2021
from vllm_omni.utils.platform_utils import is_npu
2122

2223
logger = init_logger(__name__)
@@ -97,6 +98,7 @@ def __init__(
9798

9899
self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias)
99100

101+
self.rope = RotaryEmbedding(is_neox_style=False)
100102
self.attn = Attention(
101103
num_heads=heads,
102104
head_size=self.head_dim,
@@ -138,8 +140,11 @@ def forward(
138140
value = torch.cat([encoder_value, value], dim=1)
139141

140142
if image_rotary_emb is not None:
141-
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
142-
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
143+
cos, sin = image_rotary_emb # [S, D/2]
144+
cos = cos.to(query.dtype)
145+
sin = sin.to(query.dtype)
146+
query = self.rope(query, cos, sin)
147+
key = self.rope(key, cos, sin)
143148

144149
hidden_states = self.attn(
145150
query,
@@ -263,16 +268,15 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
263268
is_npu = ids.device.type == "npu"
264269
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
265270
for i in range(n_axes):
266-
cos, sin = get_1d_rotary_pos_embed(
271+
freqs_cis = get_1d_rotary_pos_embed(
267272
self.axes_dim[i],
268273
pos[:, i],
269274
theta=self.theta,
270-
repeat_interleave_real=True,
271-
use_real=True,
275+
use_real=False,
272276
freqs_dtype=freqs_dtype,
273277
)
274-
cos_out.append(cos)
275-
sin_out.append(sin)
278+
cos_out.append(freqs_cis.real)
279+
sin_out.append(freqs_cis.imag)
276280
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
277281
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
278282
return freqs_cos, freqs_sin

vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
import torch.nn as nn
2323
from diffusers.models.attention import FeedForward
24-
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
24+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
2525
from diffusers.models.modeling_outputs import Transformer2DModelOutput
2626
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
2727
from diffusers.utils import is_torch_npu_available
@@ -32,6 +32,7 @@
3232

3333
from vllm_omni.diffusion.attention.layer import Attention
3434
from vllm_omni.diffusion.data import OmniDiffusionConfig
35+
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
3536

3637
logger = init_logger(__name__)
3738

@@ -96,6 +97,7 @@ def __init__(
9697

9798
self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias)
9899

100+
self.rope = RotaryEmbedding(is_neox_style=False)
99101
self.attn = Attention(
100102
num_heads=heads,
101103
head_size=self.head_dim,
@@ -137,8 +139,11 @@ def forward(
137139
value = torch.cat([encoder_value, value], dim=1)
138140

139141
if image_rotary_emb is not None:
140-
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
141-
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
142+
cos, sin = image_rotary_emb # [S, D/2]
143+
cos = cos.to(query.dtype)
144+
sin = sin.to(query.dtype)
145+
query = self.rope(query, cos, sin)
146+
key = self.rope(key, cos, sin)
142147

143148
hidden_states = self.attn(
144149
query,
@@ -318,16 +323,15 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
318323
is_npu = ids.device.type == "npu"
319324
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
320325
for i in range(n_axes):
321-
cos, sin = get_1d_rotary_pos_embed(
326+
freqs_cis = get_1d_rotary_pos_embed(
322327
self.axes_dim[i],
323328
pos[:, i],
324329
theta=self.theta,
325-
repeat_interleave_real=True,
326-
use_real=True,
330+
use_real=False,
327331
freqs_dtype=freqs_dtype,
328332
)
329-
cos_out.append(cos)
330-
sin_out.append(sin)
333+
cos_out.append(freqs_cis.real)
334+
sin_out.append(freqs_cis.imag)
331335
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
332336
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
333337
return freqs_cos, freqs_sin

0 commit comments

Comments
 (0)