Skip to content

Commit cfe527c

Browse files
authored
fix qwen2_5-omni (#3716)
1 parent 99df181 commit cfe527c

File tree

9 files changed

+100
-8
lines changed

9 files changed

+100
-8
lines changed

docs/source/BestPractices/GRPO多模态训练.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ register_dataset(
3636
```json
3737
{
3838
'images': [{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xe0\x00\x00\x01@\x08\x06\x00\x00\x00d\xc8\xafB`\x82 ...', 'path': 'CLEVR_trainA_000000.png'}],
39-
'messages': [{'role': 'user', 'content': 'How many items are there in the image? Output the thinking process in <think> </think> and\n final answer (number) in <answer> </answer> tags.'}, {'role': 'assistant', 'content': '<answer> 3 </answer>'}],
39+
'messages': [{'role': 'user', 'content': 'How many items are there in the image? Output the thinking process in <think> </think> and\n final answer (number) in <answer> </answer> tags.'}],
4040
'solution': '<answer> 3 </answer>'
4141
}
4242
```

docs/source_en/BestPractices/GRPO-Multi-Modal-Training.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ The purpose of redefining the dataset preprocessor here is to modify the query.
4040
```json
4141
{
4242
'images': [{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xe0\x00\x00\x01@\x08\x06\x00\x00\x00d\xc8\xafB`\x82 ...', 'path': 'CLEVR_trainA_000000.png'}],
43-
'messages': [{'role': 'user', 'content': 'How many items are there in the image? Output the thinking process in <think> </think> and\n final answer (number) in <answer> </answer> tags.'}, {'role': 'assistant', 'content': '<answer> 3 </answer>'}],
43+
'messages': [{'role': 'user', 'content': 'How many items are there in the image? Output the thinking process in <think> </think> and\n final answer (number) in <answer> </answer> tags.'}],
4444
'solution': '<answer> 3 </answer>'
4545
}
4646
```
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
VIDEO_MAX_PIXELS=50176 \
3+
FPS_MAX_FRAMES=12 \
4+
MAX_PIXELS=1003520 \
5+
swift infer \
6+
--adapters output/vx-xxx/checkpoint-xxx \
7+
--stream true \
8+
--load_data_args true \
9+
--max_new_tokens 2048

examples/train/multimodal/omni/sft.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# 4*25GB
2+
# A demo for four modalities that can be run directly
3+
nproc_per_node=4
4+
5+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
6+
NPROC_PER_NODE=$nproc_per_node \
7+
VIDEO_MAX_PIXELS=50176 \
8+
FPS_MAX_FRAMES=12 \
9+
MAX_PIXELS=1003520 \
10+
swift sft \
11+
--model Qwen/Qwen2.5-Omni-7B \
12+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
13+
'AI-ModelScope/LaTeX_OCR:human_handwrite#2000' \
14+
'speech_asr/speech_asr_aishell1_trainsets:validation#2000' \
15+
'swift/VideoChatGPT:all#2000' \
16+
--train_type lora \
17+
--torch_dtype bfloat16 \
18+
--num_train_epochs 1 \
19+
--per_device_train_batch_size 1 \
20+
--per_device_eval_batch_size 1 \
21+
--learning_rate 1e-4 \
22+
--lora_rank 8 \
23+
--lora_alpha 32 \
24+
--target_modules all-linear \
25+
--freeze_vit true \
26+
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
27+
--eval_steps 50 \
28+
--save_steps 50 \
29+
--save_total_limit 5 \
30+
--logging_steps 5 \
31+
--max_length 2048 \
32+
--output_dir output \
33+
--warmup_ratio 0.05 \
34+
--dataloader_num_workers 4 \
35+
--deepspeed zero2

swift/llm/model/model/qwen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
644644
requires=['transformers>=4.50', 'soundfile', 'qwen_omni_utils', 'decord'],
645645
tags=['vision', 'video', 'audio'],
646646
additional_saved_files=['spk_dict.pt'],
647+
ignore_patterns=[],
647648
))
648649

649650

swift/llm/model/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,10 @@ def safe_snapshot_download(model_id_or_path: str,
237237
logger.info(f'Loading the model using local model_dir: {model_dir}')
238238
return model_dir
239239
if ignore_patterns is None:
240-
ignore_patterns = []
241-
ignore_patterns += [
242-
'*.zip', '*.gguf', '*.pth', '*.pt', 'consolidated*', 'onnx/*', '*.safetensors.md', '*.msgpack', '*.onnx',
243-
'*.ot', '*.h5'
244-
]
240+
ignore_patterns = [
241+
'*.zip', '*.gguf', '*.pth', '*.pt', 'consolidated*', 'onnx/*', '*.safetensors.md', '*.msgpack', '*.onnx',
242+
'*.ot', '*.h5'
243+
]
245244
if not download_model:
246245
ignore_patterns += ['*.bin', '*.safetensors']
247246
hub = get_hub(use_hf)

swift/llm/template/template/qwen.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ class Qwen2_5VLTemplate(Qwen2VLTemplate):
341341

342342

343343
class Qwen2_5OmniTemplate(Template):
344+
placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
344345

345346
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
346347
inputs: StdTemplateInputs) -> List[Context]:
@@ -376,15 +377,61 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
376377
media_inputs.pop('input_ids')
377378
media_inputs.pop('attention_mask')
378379
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
380+
input_ids = encoded['input_ids']
381+
labels = encoded['labels']
382+
for media_type in ['image', 'video']:
383+
token = f'<|{media_type.upper()}|>'
384+
token_id = self._tokenize(token)
385+
idx_list = findall(input_ids, token_id)
386+
if idx_list:
387+
merge_length = self.processor.omni_processor.merge_size**2
388+
media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
389+
390+
def _get_new_tokens(i):
391+
token_len = (media_grid_thw[i].prod() // merge_length)
392+
return [token_id] * token_len
393+
394+
_, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
395+
# audio
396+
feature_attention_mask = media_inputs.get('feature_attention_mask')
397+
if feature_attention_mask is not None:
398+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1).tolist()
399+
token_id = self._tokenize('<|AUDIO|>')
400+
idx_list = findall(input_ids, token_id)
401+
402+
def _get_new_tokens(i):
403+
place_num = ((audio_feature_lengths[i] - 1) // 2 + 1 - 2) // 2 + 1
404+
return [token_id] * place_num
405+
406+
_, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
407+
408+
encoded['labels'] = labels
379409
encoded.update(media_inputs)
380410
return encoded
381411

382412
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
383413
if self.is_training:
384414
feature_attention_mask = inputs.get('feature_attention_mask')
385415
if feature_attention_mask is not None:
416+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
386417
inputs['input_features'] = inputs['input_features'].permute(0, 2, 1)[feature_attention_mask.bool()]
387418
inputs['input_features'] = inputs['input_features'].permute(1, 0)
419+
else:
420+
audio_feature_lengths = None
421+
use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
422+
video_second_per_grid = inputs.pop('video_second_per_grid', None)
423+
position_ids, _, input_ids, attention_mask = model.thinker.get_rope_index(
424+
inputs.get('input_ids'),
425+
inputs.get('image_grid_thw'),
426+
inputs.get('video_grid_thw'),
427+
inputs.get('attention_mask'),
428+
use_audio_in_video,
429+
audio_feature_lengths,
430+
video_second_per_grid,
431+
)
432+
inputs['input_ids'] = input_ids
433+
inputs['attention_mask'] = attention_mask
434+
inputs['position_ids'] = position_ids
388435
return inputs
389436

390437
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:

tests/test_align/test_template/test_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def _infer_model(pt_engine, system=None, messages=None, audios=None):
1313
messages += [{'role': 'user', 'content': '你好'}]
1414
resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
1515
response = resp[0].choices[0].message.content
16-
messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '<audio>这段语音说了什么'}]
16+
messages += [{'role': 'assistant', 'content': response}]
17+
messages += [{'role': 'user', 'content': '<audio>这段语音说了什么'}]
1718
else:
1819
messages = messages.copy()
1920
if audios is None:

0 commit comments

Comments
 (0)