Skip to content

Commit 4140cfb

Browse files
authored
[megatron] support vit_lr aligner_lr (modelscope#6469)
1 parent 9238661 commit 4140cfb

File tree

4 files changed

+163
-2
lines changed

4 files changed

+163
-2
lines changed

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
283283
- 注意:**Megatron-SWIFT训练特性优先支持padding_free格式**,若非特殊情况,请勿修改该值。
284284
- mlp_padding_free: 默认为False。用于padding_free设置为false时,对mlp进行padding_free优化。这可以在自定义attention_mask的同时,提升训练速度和减少显存占用。
285285
- vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为True。(**Megatron-SWIFT的vit实现使用transformers实现**
286+
- vit_lr: 当训练多模态大模型时,该参数指定vit的学习率,默认为None,等于learning_rate。
287+
- 通常与`--freeze_vit false``--freeze_aligner false`参数结合使用。
288+
- aligner_lr: 当训练多模态大模型时,该参数指定aligner的学习率,默认为None,等于learning_rate。
286289
- gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。
287290
- 🔥packing: 是否使用序列packing提升计算效率(不同节点与进程更负载均衡,GPU利用率更高;但需要额外的预处理时间)并稳定显存占用,默认为False。当前支持CPT/SFT/DPO/KTO/RM。
288291
- 注意:**同一batch的不同序列之间依旧是不可见的**,除了Qwen3-Next。

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ Megatron training parameters are inherited from Megatron parameters and basic pa
300300
- Note: **The Megatron-SWIFT training feature prioritizes support for the padding-free format**. Unless under special circumstances, please do not modify this value.
301301
- mlp_padding_free: The default is False. This is used for applying padding-free optimization to the MLP when padding_free is set to false. It allows for improved training speed and reduced memory usage while customizing the attention_mask.
302302
- vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT (Vision Transformer) component during multimodal model training. Defaults to `True`. (**The ViT implementation in Megatron-SWIFT uses the Hugging Face `transformers` library.**)
303+
- vit_lr: Specifies the learning rate for the ViT module when training multimodal models. Default is `None`, same as `learning_rate`.
304+
- Typically used together with `--freeze_vit false` and `--freeze_aligner false`.
305+
- aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`.
303306
- gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled.
304307
- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM.
305308
- Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next.

swift/megatron/argument/megatron_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
127127

128128
# visual
129129
vit_gradient_checkpointing: bool = True
130+
vit_lr: Optional[float] = None
131+
aligner_lr: Optional[float] = None
130132
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
131133
# qwen3_next
132134
linear_num_value_heads: Optional[int] = None

swift/megatron/trainers/base.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from abc import ABC, abstractmethod
77
from contextlib import contextmanager
88
from datetime import datetime
9-
from typing import Dict, Literal
9+
from typing import Callable, Dict, List, Literal, Optional
1010

1111
import megatron.core
1212
import torch
1313
import torch.nn
1414
from megatron.core import mpu
1515
from megatron.core.enums import ModelType
1616
from megatron.core.num_microbatches_calculator import get_num_microbatches
17+
from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups, param_group_identifier_keys
1718
from megatron.core.pipeline_parallel import get_forward_backward_func
1819
from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine
20+
from megatron.core.transformer.module import MegatronModule
1921
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
2022
from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper
2123
from megatron.core.utils import StragglerDetector
@@ -235,6 +237,157 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs):
235237
args.no_load_rng = origin_no_load_rng
236238
args.finetune = origin_finetune
237239

240+
# Code borrowed from Megatron-LM
241+
def _get_param_groups(
242+
self,
243+
model_chunks: List[MegatronModule],
244+
no_weight_decay_cond: Optional[Callable],
245+
scale_lr_cond: Optional[Callable],
246+
lr_mult: float,
247+
lr: float,
248+
min_lr: float,
249+
decoupled_lr: Optional[float],
250+
decoupled_min_lr: Optional[float],
251+
default_skip_embedding_weight_decay: bool = False,
252+
) -> List[Dict]:
253+
"""Create parameter groups for optimizer.
254+
255+
Creates parameter groups based on weight decay condition (regularized vs
256+
non regularized), learning rate scale condition (lr vs lr_mult * lr),
257+
and whether it is expert parameters. scale_lr_cond is used during finetuning
258+
where head of the network requires a scaled version of the base learning rate.
259+
260+
Args:
261+
model_chunks (List[MegatronModule]): model chunks to create parameter
262+
groups for.
263+
no_weight_decay_cond (func, optional): function to determine whether a
264+
parameter should not perform weight decay.
265+
scale_lr_cond (func, optional): function to determine whether a parameter
266+
should have a scaled learning rate.
267+
lr_mult (float): learning rate multiplier for parameters that
268+
satisfy scale_lr_cond.
269+
lr (float): learning rate.
270+
min_lr (float): minimum learning rate.
271+
decoupled_lr (Optional[float]): optional decoupled learning rate.
272+
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
273+
default_skip_embedding_weight_decay (bool): whether to skip weight decay for embedding
274+
parameters by default, if no_weight_decay_cond is not provided.
275+
276+
Returns:
277+
List of parameter groups.
278+
"""
279+
280+
use_decoupled_learning_rate = decoupled_lr is not None
281+
282+
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
283+
params_map = {}
284+
for model_chunk in model_chunks:
285+
visual = model_chunk.module.module.visual
286+
for name, param in model_chunk.named_parameters():
287+
if not param.requires_grad:
288+
continue
289+
290+
is_expert_parallel = not getattr(param, 'allreduce', True)
291+
292+
if no_weight_decay_cond is not None:
293+
no_wd: bool = no_weight_decay_cond(name, param)
294+
else:
295+
# Do not regularize biases and norm parameters.
296+
# optionally, also skip weight decay for embedding parameters if requested
297+
# (useful if you do not want embeddings to shrink to zero in training
298+
# https://arxiv.org/abs/2312.16903)
299+
no_wd = (
300+
name.endswith('.bias') or len(param.shape) == 1
301+
or (default_skip_embedding_weight_decay and 'embedding' in name))
302+
_lr_mult = lr_mult
303+
if scale_lr_cond is not None:
304+
scale_lr = scale_lr_cond(name, param)
305+
else:
306+
scale_lr = False
307+
# Handling multimodal models: vit_lr, aligner_lr
308+
unwrapped_name = name.removeprefix('module.').removeprefix('module.')
309+
is_aligner = any(unwrapped_name.startswith(f'visual.{k}') for k in visual._aligner)
310+
is_vit = any(unwrapped_name.startswith(f'visual.{k}')
311+
for k in visual._vision_tower) and not is_aligner
312+
if is_vit and self.args.vit_lr:
313+
scale_lr = True
314+
_lr_mult = self.args.vit_lr / lr
315+
elif is_aligner and self.args.aligner_lr:
316+
scale_lr = True
317+
_lr_mult = self.args.aligner_lr / lr
318+
319+
if not no_wd and not scale_lr:
320+
wd_mult, _lr_mult = 1.0, 1.0
321+
elif not no_wd and scale_lr:
322+
wd_mult, _lr_mult = 1.0, _lr_mult
323+
elif no_wd and not scale_lr:
324+
wd_mult, _lr_mult = 0.0, 1.0
325+
else:
326+
wd_mult, _lr_mult = 0.0, _lr_mult
327+
328+
is_decoupled_lr = False
329+
# For input/embedding and output layer: embedding.word_embeddings.weight /
330+
# output_layer.weight.
331+
if use_decoupled_learning_rate and getattr(param, 'is_embedding_or_output_parameter', False):
332+
is_decoupled_lr = True
333+
334+
key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
335+
if key not in params_map:
336+
params_map[key] = []
337+
params_map[key].append(param)
338+
339+
# Distributed checkpoint requires all ranks to have the same param groups,
340+
# so we need to align the param groups across ranks, otherwise we may have
341+
# runtime error when loading the checkpoint or numerical error when resuming training.
342+
params_key = list(params_map.keys())
343+
gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
344+
torch.distributed.all_gather_object(gathered_params_key, params_key)
345+
for keys in gathered_params_key:
346+
for key in keys:
347+
if key not in params_key:
348+
params_key.append(key)
349+
350+
param_groups = []
351+
for key in params_key:
352+
wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr = key
353+
params = params_map[key] if key in params_map else []
354+
param_group = {
355+
'params': params,
356+
'wd_mult': wd_mult,
357+
'lr_mult': _lr_mult,
358+
'is_expert_parallel': is_expert_parallel,
359+
'is_decoupled_lr': is_decoupled_lr,
360+
}
361+
# Ensure param_group has required keys for matching when loading optimizer state
362+
# See MegatronOptimizer._filter_and_reorder_param_groups.
363+
assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'}
364+
param_groups.append(param_group)
365+
366+
param_groups = _update_min_and_max_lr_in_param_groups(
367+
param_groups,
368+
lr=lr,
369+
min_lr=min_lr,
370+
decoupled_lr=decoupled_lr,
371+
decoupled_min_lr=decoupled_min_lr,
372+
)
373+
374+
return param_groups
375+
376+
@contextmanager
377+
def _patch_get_param_groups(self):
378+
if not self.args.megatron_model_meta.is_multimodal or (self.args.vit_lr is None
379+
and self.args.aligner_lr is None):
380+
yield
381+
return
382+
from megatron.core import optimizer
383+
384+
_get_param_groups = optimizer._get_param_groups
385+
optimizer._get_param_groups = self._get_param_groups
386+
try:
387+
yield
388+
finally:
389+
optimizer._get_param_groups = _get_param_groups
390+
238391
def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs):
239392

240393
args = get_args()
@@ -254,7 +407,7 @@ def new_model_provider_func(*_args, **kwargs):
254407
return model
255408

256409
self._init_multimodal_full()
257-
with self._patch_load_state_dict(self._load_base_checkpoint):
410+
with self._patch_load_state_dict(self._load_base_checkpoint), self._patch_get_param_groups():
258411
model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer(
259412
new_model_provider_func, model_type, *_args, **kwargs)
260413
if args.initialize_embedding:

0 commit comments

Comments
 (0)