Skip to content

Commit 4df794b

Browse files
authored
[Bugfix] Solve Ulysses-SP sequence length not divisible by SP degree (using padding and attention mask) (#672)
Signed-off-by: Didan Deng <[email protected]>
1 parent 89f47c9 commit 4df794b

File tree

4 files changed

+244
-12
lines changed

4 files changed

+244
-12
lines changed

tests/e2e/offline_inference/test_sequence_parallel.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,114 @@ def test_sequence_parallel(
202202
f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}); "
203203
f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, dtype={dtype}"
204204
)
205+
206+
207+
@pytest.mark.parametrize("model_name", models)
208+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
209+
@pytest.mark.parametrize("attn_backend", ["sdpa"])
210+
def test_sequence_parallel_ulysses_sp_only(
211+
model_name: str,
212+
dtype: torch.dtype,
213+
attn_backend: str,
214+
):
215+
"""Test sequence parallel with ulysses_degree=4, ring_degree=1, and the image size (332x332) where the sequence length is NOT divisible by sp_size."""
216+
ulysses_degree = 4
217+
ring_degree = 1
218+
219+
# Skip if not enough GPUs available for SP run
220+
if device_count() < ulysses_degree * ring_degree:
221+
pytest.skip(f"Test requires {ulysses_degree * ring_degree} GPUs but only {device_count()} available")
222+
223+
# (272/8) * (272/8) = 17 * 17 = 289 Not divisible by sp_size=4
224+
height = 272
225+
width = 272
226+
num_inference_steps = 4 # Minimal steps for fast test
227+
seed = 42
228+
229+
# Step 1: Baseline (no Ulysses sequence parallel)
230+
baseline_parallel_config = DiffusionParallelConfig(ulysses_degree=1, ring_degree=1)
231+
baseline = Omni(
232+
model=model_name,
233+
parallel_config=baseline_parallel_config,
234+
dtype=dtype,
235+
attention_backend=attn_backend,
236+
)
237+
try:
238+
outputs = baseline.generate(
239+
PROMPT,
240+
height=height,
241+
width=width,
242+
num_inference_steps=num_inference_steps,
243+
guidance_scale=0.0,
244+
generator=torch.Generator(get_device_name()).manual_seed(seed),
245+
num_outputs_per_prompt=1,
246+
)
247+
baseline_images = outputs[0].request_output[0].images
248+
finally:
249+
baseline.close()
250+
if dist.is_initialized():
251+
dist.destroy_process_group()
252+
for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]:
253+
os.environ.pop(key, None)
254+
time.sleep(2) # Wait for resources to release
255+
256+
assert baseline_images is not None
257+
assert len(baseline_images) == 1
258+
assert baseline_images[0].width == width
259+
assert baseline_images[0].height == height
260+
261+
# Step 2: SP (Ulysses-SP + Ring-SP)
262+
sp_parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree)
263+
sp = Omni(
264+
model=model_name,
265+
parallel_config=sp_parallel_config,
266+
dtype=dtype,
267+
attention_backend=attn_backend,
268+
)
269+
try:
270+
outputs = sp.generate(
271+
PROMPT,
272+
height=height,
273+
width=width,
274+
num_inference_steps=num_inference_steps,
275+
guidance_scale=0.0,
276+
generator=torch.Generator(get_device_name()).manual_seed(seed),
277+
num_outputs_per_prompt=1,
278+
)
279+
sp_images = outputs[0].request_output[0].images
280+
finally:
281+
sp.close()
282+
if dist.is_initialized():
283+
dist.destroy_process_group()
284+
for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK"]:
285+
os.environ.pop(key, None)
286+
time.sleep(2)
287+
288+
assert sp_images is not None
289+
assert len(sp_images) == 1
290+
assert sp_images[0].width == width
291+
assert sp_images[0].height == height
292+
293+
# Step 3: Compare outputs
294+
mean_abs_diff, max_abs_diff = _diff_metrics(baseline_images[0], sp_images[0])
295+
296+
# FP16/BF16 may differ slightly due to different computation order under parallelism.
297+
if dtype in (torch.float16, torch.bfloat16):
298+
mean_threshold = 2e-2
299+
max_threshold = 2e-1
300+
else:
301+
mean_threshold = 1e-2
302+
max_threshold = 1e-1
303+
304+
print(
305+
"Image diff stats (baseline ulysses_degree*ring_degree=1 vs SP): "
306+
f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; "
307+
f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; "
308+
f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, dtype={dtype}"
309+
)
310+
311+
assert mean_abs_diff <= mean_threshold and max_abs_diff <= max_threshold, (
312+
f"Image diff exceeded threshold: mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e} "
313+
f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}); "
314+
f"ulysses_degree={ulysses_degree}, ring_degree={ring_degree}, dtype={dtype}"
315+
)

vllm_omni/diffusion/attention/backends/abstract.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def supports_head_size(cls, head_size: int) -> bool:
4848
@dataclass
4949
class AttentionMetadata:
5050
attn_mask: torch.Tensor | None = None
51+
joint_attn_mask: torch.Tensor | None = None
52+
# a joint mask for the joint query, key, and value, depends the joint_strategy
5153
joint_query: torch.Tensor | None = None
5254
# a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy
5355
joint_key: torch.Tensor | None = None

vllm_omni/diffusion/attention/parallel/ulysses.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,36 @@ def pre_attention(
163163
joint_len=joint_len,
164164
joint_strategy=joint_strategy,
165165
)
166+
167+
if attn_metadata is not None:
168+
if is_joint:
169+
if attn_metadata.joint_attn_mask is None and attn_metadata.attn_mask is None:
170+
attn_metadata.attn_mask = None
171+
else:
172+
if attn_metadata.attn_mask is None:
173+
attn_metadata.attn_mask = torch.ones(
174+
[query.shape[0], query.shape[1] - attn_metadata.joint_attn_mask.shape[1]],
175+
dtype=torch.bool,
176+
device=query.device,
177+
)
178+
elif attn_metadata.joint_attn_mask is None:
179+
attn_metadata.joint_attn_mask = torch.ones(
180+
[query.shape[0], query.shape[1] - attn_metadata.attn_mask.shape[1]],
181+
dtype=torch.bool,
182+
device=query.device,
183+
)
184+
attn_metadata.attn_mask = (
185+
torch.cat([attn_metadata.joint_attn_mask, attn_metadata.attn_mask], dim=1)
186+
if joint_strategy == "front"
187+
else torch.cat([attn_metadata.attn_mask, attn_metadata.joint_attn_mask], dim=1)
188+
)
189+
190+
if attn_metadata.attn_mask is not None:
191+
# the final attn_mask is ready, the length should be aligedn with query length
192+
assert attn_metadata.attn_mask.shape[1] == query.shape[1], (
193+
f"attn_mask length: {attn_metadata.attn_mask.shape[1]} != query length: {query.shape[1]}"
194+
)
195+
attn_metadata.attn_mask = attn_metadata.attn_mask.bool().contiguous()
166196
return query, key, value, attn_metadata, ctx
167197

168198
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:

vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
AttentionMetadata,
2424
)
2525
from vllm_omni.diffusion.attention.layer import Attention
26+
from vllm_omni.diffusion.attention.selector import get_attn_backend
2627
from vllm_omni.diffusion.cache.base import CachedTransformer
2728
from vllm_omni.diffusion.data import OmniDiffusionConfig
2829
from vllm_omni.diffusion.distributed.parallel_state import (
30+
get_ring_parallel_world_size,
2931
get_sequence_parallel_rank,
3032
get_sequence_parallel_world_size,
3133
get_sp_group,
@@ -373,7 +375,14 @@ def forward(
373375
encoder_hidden_states: torch.Tensor,
374376
vid_freqs: torch.Tensor,
375377
txt_freqs: torch.Tensor,
378+
hidden_states_mask: torch.Tensor | None = None,
379+
encoder_hidden_states_mask: torch.Tensor | None = None,
376380
):
381+
# if mask is all true, set it to None
382+
if hidden_states_mask is not None and hidden_states_mask.all():
383+
hidden_states_mask = None
384+
if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all():
385+
encoder_hidden_states_mask = None
377386
seq_len_txt = encoder_hidden_states.shape[1]
378387

379388
# Compute QKV for image stream (sample projections)
@@ -416,30 +425,63 @@ def forward(
416425
joint_value = torch.cat([txt_value, img_value], dim=1)
417426

418427
# Compute joint attention
419-
420428
if (
421429
self.parallel_config is not None
422430
and self.parallel_config.sequence_parallel_size > 1
423431
and not get_forward_context().split_text_embed_in_sp
424432
):
425433
# if using sequence parallel, but not splitting text embed,
426434
# we need to pass text embedding to attention layer as joint qkv
435+
attn_metadata = AttentionMetadata(
436+
joint_query=txt_query,
437+
joint_key=txt_key,
438+
joint_value=txt_value,
439+
joint_strategy="front",
440+
)
441+
if hidden_states_mask is not None:
442+
attn_metadata.attn_mask = hidden_states_mask
443+
if encoder_hidden_states_mask is not None:
444+
attn_metadata.joint_attn_mask = encoder_hidden_states_mask
445+
427446
joint_hidden_states = self.attn(
428447
img_query,
429448
img_key,
430449
img_value,
431-
AttentionMetadata(
432-
joint_query=txt_query,
433-
joint_key=txt_key,
434-
joint_value=txt_value,
435-
joint_strategy="front",
436-
),
450+
attn_metadata,
437451
)
438452
else:
453+
attn_metadata = None
454+
if hidden_states_mask is not None or encoder_hidden_states_mask is not None:
455+
mask_list = []
456+
if encoder_hidden_states_mask is not None:
457+
mask_list.append(encoder_hidden_states_mask)
458+
else:
459+
mask_list.append(
460+
torch.ones(
461+
[encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]],
462+
dtype=torch.bool,
463+
device=encoder_hidden_states.device,
464+
)
465+
)
466+
if hidden_states_mask is not None:
467+
mask_list.append(hidden_states_mask)
468+
else:
469+
mask_list.append(
470+
torch.ones(
471+
[hidden_states.shape[0], hidden_states.shape[1]],
472+
dtype=torch.bool,
473+
device=hidden_states.device,
474+
)
475+
)
476+
joint_mask = (
477+
None if len(mask_list) == 0 else torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0]
478+
)
479+
attn_metadata = AttentionMetadata(attn_mask=joint_mask)
439480
joint_hidden_states = self.attn(
440481
joint_query,
441482
joint_key,
442483
joint_value,
484+
attn_metadata,
443485
)
444486
joint_hidden_states = joint_hidden_states.flatten(2, 3)
445487
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
@@ -547,6 +589,7 @@ def forward(
547589
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
548590
joint_attention_kwargs: dict[str, Any] | None = None,
549591
modulate_index: list[int] | None = None,
592+
hidden_states_mask: torch.Tensor | None = None,
550593
) -> tuple[torch.Tensor, torch.Tensor]:
551594
# Get modulation parameters for both streams
552595
img_mod_params = self.img_mod(temb) # [B, 6*dim]
@@ -577,6 +620,8 @@ def forward(
577620
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
578621
vid_freqs=image_rotary_emb[0],
579622
txt_freqs=image_rotary_emb[1],
623+
hidden_states_mask=hidden_states_mask,
624+
encoder_hidden_states_mask=encoder_hidden_states_mask,
580625
)
581626

582627
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
@@ -732,14 +777,48 @@ def forward(
732777
# else:
733778
# lora_scale = 1.0
734779

780+
original_seq_len = None
781+
seq_padding = 0
782+
hidden_states_mask = None
783+
735784
if self.parallel_config.sequence_parallel_size > 1:
736-
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[
737-
get_sequence_parallel_rank()
738-
]
785+
batch_size, seq_len, channels = hidden_states.shape
786+
sp_size = get_sequence_parallel_world_size()
787+
788+
if seq_len % sp_size != 0:
789+
# flash_attn, ring_attn, sage_attn do not support attention_mask
790+
if get_attn_backend(-1).get_name() != "SDPA" and get_attn_backend(-1).get_name() != "ASCEND":
791+
raise ValueError(
792+
f"When generating image shape that the sequence length is NOT divisible by sp_size={sp_size},"
793+
f"cannot use {get_attn_backend(-1).get_name()} which does not support attention_mask."
794+
f"Please switch to SDPA or Ascend attention backend."
795+
)
796+
# ring attention does not support attention_mask
797+
if get_ring_parallel_world_size() > 1:
798+
raise ValueError(
799+
f"When generating image shape that the sequence length is NOT divisible by sp_size={sp_size},"
800+
f"cannot use ring attention which does not support attention_mask."
801+
f"Please switch to Ulysses SP only."
802+
)
803+
804+
seq_padding = sp_size - (seq_len % sp_size)
805+
original_seq_len = seq_len
806+
807+
hidden_states_mask = torch.ones(
808+
batch_size, seq_len + seq_padding, dtype=torch.bool, device=hidden_states.device
809+
)
810+
hidden_states_mask[:, seq_len:] = False
811+
padding_tensor = torch.zeros(
812+
batch_size, seq_padding, channels, dtype=hidden_states.dtype, device=hidden_states.device
813+
)
814+
hidden_states = torch.cat([hidden_states, padding_tensor], dim=1)
815+
816+
hidden_states = torch.chunk(hidden_states, sp_size, dim=-2)[get_sequence_parallel_rank()]
739817
# NOTE:
740818
# QwenImage uses *dual-stream* (text + image) and runs a *joint attention*.
741819
# text embeddings to be replicated across SP ranks for correctness.
742820
get_forward_context().split_text_embed_in_sp = False
821+
743822
hidden_states = self.img_in(hidden_states)
744823

745824
# Ensure timestep tensor is on the same device and dtype as hidden_states
@@ -769,13 +848,17 @@ def forward(
769848

770849
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
771850

772-
def get_rotary_emb_chunk(freqs):
851+
def get_rotary_emb_chunk(freqs, padding=0):
852+
# Pad rotary embeddings if needed
853+
if padding > 0:
854+
padding_tensor = torch.zeros(padding, freqs.shape[-1], dtype=freqs.dtype, device=freqs.device)
855+
freqs = torch.cat([freqs, padding_tensor], dim=0)
773856
freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
774857
return freqs
775858

776859
if self.parallel_config.sequence_parallel_size > 1:
777860
img_freqs, txt_freqs = image_rotary_emb
778-
img_freqs = get_rotary_emb_chunk(img_freqs)
861+
img_freqs = get_rotary_emb_chunk(img_freqs, seq_padding)
779862
if get_forward_context().split_text_embed_in_sp:
780863
txt_freqs = get_rotary_emb_chunk(txt_freqs)
781864
image_rotary_emb = (img_freqs, txt_freqs)
@@ -789,6 +872,7 @@ def get_rotary_emb_chunk(freqs):
789872
image_rotary_emb=image_rotary_emb,
790873
joint_attention_kwargs=attention_kwargs,
791874
modulate_index=modulate_index,
875+
hidden_states_mask=hidden_states_mask,
792876
)
793877

794878
if self.zero_cond_t:
@@ -799,6 +883,11 @@ def get_rotary_emb_chunk(freqs):
799883

800884
if self.parallel_config.sequence_parallel_size > 1:
801885
output = get_sp_group().all_gather(output, dim=-2)
886+
887+
# Remove padding if it was added
888+
if original_seq_len is not None:
889+
output = output[:, :original_seq_len, :]
890+
802891
return Transformer2DModelOutput(sample=output)
803892

804893
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

0 commit comments

Comments
 (0)