2323 AttentionMetadata ,
2424)
2525from vllm_omni .diffusion .attention .layer import Attention
26+ from vllm_omni .diffusion .attention .selector import get_attn_backend
2627from vllm_omni .diffusion .cache .base import CachedTransformer
2728from vllm_omni .diffusion .data import OmniDiffusionConfig
2829from vllm_omni .diffusion .distributed .parallel_state import (
30+ get_ring_parallel_world_size ,
2931 get_sequence_parallel_rank ,
3032 get_sequence_parallel_world_size ,
3133 get_sp_group ,
@@ -373,7 +375,14 @@ def forward(
373375 encoder_hidden_states : torch .Tensor ,
374376 vid_freqs : torch .Tensor ,
375377 txt_freqs : torch .Tensor ,
378+ hidden_states_mask : torch .Tensor | None = None ,
379+ encoder_hidden_states_mask : torch .Tensor | None = None ,
376380 ):
381+ # if mask is all true, set it to None
382+ if hidden_states_mask is not None and hidden_states_mask .all ():
383+ hidden_states_mask = None
384+ if encoder_hidden_states_mask is not None and encoder_hidden_states_mask .all ():
385+ encoder_hidden_states_mask = None
377386 seq_len_txt = encoder_hidden_states .shape [1 ]
378387
379388 # Compute QKV for image stream (sample projections)
@@ -416,30 +425,63 @@ def forward(
416425 joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
417426
418427 # Compute joint attention
419-
420428 if (
421429 self .parallel_config is not None
422430 and self .parallel_config .sequence_parallel_size > 1
423431 and not get_forward_context ().split_text_embed_in_sp
424432 ):
425433 # if using sequence parallel, but not splitting text embed,
426434 # we need to pass text embedding to attention layer as joint qkv
435+ attn_metadata = AttentionMetadata (
436+ joint_query = txt_query ,
437+ joint_key = txt_key ,
438+ joint_value = txt_value ,
439+ joint_strategy = "front" ,
440+ )
441+ if hidden_states_mask is not None :
442+ attn_metadata .attn_mask = hidden_states_mask
443+ if encoder_hidden_states_mask is not None :
444+ attn_metadata .joint_attn_mask = encoder_hidden_states_mask
445+
427446 joint_hidden_states = self .attn (
428447 img_query ,
429448 img_key ,
430449 img_value ,
431- AttentionMetadata (
432- joint_query = txt_query ,
433- joint_key = txt_key ,
434- joint_value = txt_value ,
435- joint_strategy = "front" ,
436- ),
450+ attn_metadata ,
437451 )
438452 else :
453+ attn_metadata = None
454+ if hidden_states_mask is not None or encoder_hidden_states_mask is not None :
455+ mask_list = []
456+ if encoder_hidden_states_mask is not None :
457+ mask_list .append (encoder_hidden_states_mask )
458+ else :
459+ mask_list .append (
460+ torch .ones (
461+ [encoder_hidden_states .shape [0 ], encoder_hidden_states .shape [1 ]],
462+ dtype = torch .bool ,
463+ device = encoder_hidden_states .device ,
464+ )
465+ )
466+ if hidden_states_mask is not None :
467+ mask_list .append (hidden_states_mask )
468+ else :
469+ mask_list .append (
470+ torch .ones (
471+ [hidden_states .shape [0 ], hidden_states .shape [1 ]],
472+ dtype = torch .bool ,
473+ device = hidden_states .device ,
474+ )
475+ )
476+ joint_mask = (
477+ None if len (mask_list ) == 0 else torch .cat (mask_list , dim = 1 ) if len (mask_list ) > 1 else mask_list [0 ]
478+ )
479+ attn_metadata = AttentionMetadata (attn_mask = joint_mask )
439480 joint_hidden_states = self .attn (
440481 joint_query ,
441482 joint_key ,
442483 joint_value ,
484+ attn_metadata ,
443485 )
444486 joint_hidden_states = joint_hidden_states .flatten (2 , 3 )
445487 joint_hidden_states = joint_hidden_states .to (joint_query .dtype )
@@ -547,6 +589,7 @@ def forward(
547589 image_rotary_emb : tuple [torch .Tensor , torch .Tensor ],
548590 joint_attention_kwargs : dict [str , Any ] | None = None ,
549591 modulate_index : list [int ] | None = None ,
592+ hidden_states_mask : torch .Tensor | None = None ,
550593 ) -> tuple [torch .Tensor , torch .Tensor ]:
551594 # Get modulation parameters for both streams
552595 img_mod_params = self .img_mod (temb ) # [B, 6*dim]
@@ -577,6 +620,8 @@ def forward(
577620 encoder_hidden_states = txt_modulated , # Text stream (will be processed as "context")
578621 vid_freqs = image_rotary_emb [0 ],
579622 txt_freqs = image_rotary_emb [1 ],
623+ hidden_states_mask = hidden_states_mask ,
624+ encoder_hidden_states_mask = encoder_hidden_states_mask ,
580625 )
581626
582627 # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
@@ -732,14 +777,48 @@ def forward(
732777 # else:
733778 # lora_scale = 1.0
734779
780+ original_seq_len = None
781+ seq_padding = 0
782+ hidden_states_mask = None
783+
735784 if self .parallel_config .sequence_parallel_size > 1 :
736- hidden_states = torch .chunk (hidden_states , get_sequence_parallel_world_size (), dim = - 2 )[
737- get_sequence_parallel_rank ()
738- ]
785+ batch_size , seq_len , channels = hidden_states .shape
786+ sp_size = get_sequence_parallel_world_size ()
787+
788+ if seq_len % sp_size != 0 :
789+ # flash_attn, ring_attn, sage_attn do not support attention_mask
790+ if get_attn_backend (- 1 ).get_name () != "SDPA" and get_attn_backend (- 1 ).get_name () != "ASCEND" :
791+ raise ValueError (
792+ f"When generating image shape that the sequence length is NOT divisible by sp_size={ sp_size } ,"
793+ f"cannot use { get_attn_backend (- 1 ).get_name ()} which does not support attention_mask."
794+ f"Please switch to SDPA or Ascend attention backend."
795+ )
796+ # ring attention does not support attention_mask
797+ if get_ring_parallel_world_size () > 1 :
798+ raise ValueError (
799+ f"When generating image shape that the sequence length is NOT divisible by sp_size={ sp_size } ,"
800+ f"cannot use ring attention which does not support attention_mask."
801+ f"Please switch to Ulysses SP only."
802+ )
803+
804+ seq_padding = sp_size - (seq_len % sp_size )
805+ original_seq_len = seq_len
806+
807+ hidden_states_mask = torch .ones (
808+ batch_size , seq_len + seq_padding , dtype = torch .bool , device = hidden_states .device
809+ )
810+ hidden_states_mask [:, seq_len :] = False
811+ padding_tensor = torch .zeros (
812+ batch_size , seq_padding , channels , dtype = hidden_states .dtype , device = hidden_states .device
813+ )
814+ hidden_states = torch .cat ([hidden_states , padding_tensor ], dim = 1 )
815+
816+ hidden_states = torch .chunk (hidden_states , sp_size , dim = - 2 )[get_sequence_parallel_rank ()]
739817 # NOTE:
740818 # QwenImage uses *dual-stream* (text + image) and runs a *joint attention*.
741819 # text embeddings to be replicated across SP ranks for correctness.
742820 get_forward_context ().split_text_embed_in_sp = False
821+
743822 hidden_states = self .img_in (hidden_states )
744823
745824 # Ensure timestep tensor is on the same device and dtype as hidden_states
@@ -769,13 +848,17 @@ def forward(
769848
770849 image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
771850
772- def get_rotary_emb_chunk (freqs ):
851+ def get_rotary_emb_chunk (freqs , padding = 0 ):
852+ # Pad rotary embeddings if needed
853+ if padding > 0 :
854+ padding_tensor = torch .zeros (padding , freqs .shape [- 1 ], dtype = freqs .dtype , device = freqs .device )
855+ freqs = torch .cat ([freqs , padding_tensor ], dim = 0 )
773856 freqs = torch .chunk (freqs , get_sequence_parallel_world_size (), dim = 0 )[get_sequence_parallel_rank ()]
774857 return freqs
775858
776859 if self .parallel_config .sequence_parallel_size > 1 :
777860 img_freqs , txt_freqs = image_rotary_emb
778- img_freqs = get_rotary_emb_chunk (img_freqs )
861+ img_freqs = get_rotary_emb_chunk (img_freqs , seq_padding )
779862 if get_forward_context ().split_text_embed_in_sp :
780863 txt_freqs = get_rotary_emb_chunk (txt_freqs )
781864 image_rotary_emb = (img_freqs , txt_freqs )
@@ -789,6 +872,7 @@ def get_rotary_emb_chunk(freqs):
789872 image_rotary_emb = image_rotary_emb ,
790873 joint_attention_kwargs = attention_kwargs ,
791874 modulate_index = modulate_index ,
875+ hidden_states_mask = hidden_states_mask ,
792876 )
793877
794878 if self .zero_cond_t :
@@ -799,6 +883,11 @@ def get_rotary_emb_chunk(freqs):
799883
800884 if self .parallel_config .sequence_parallel_size > 1 :
801885 output = get_sp_group ().all_gather (output , dim = - 2 )
886+
887+ # Remove padding if it was added
888+ if original_seq_len is not None :
889+ output = output [:, :original_seq_len , :]
890+
802891 return Transformer2DModelOutput (sample = output )
803892
804893 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
0 commit comments