diff --git a/swift/megatron/argument.py b/swift/megatron/argument.py index 891cd3fc3..d98248e76 100644 --- a/swift/megatron/argument.py +++ b/swift/megatron/argument.py @@ -2,7 +2,7 @@ import sys from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Literal, Optional, Tuple - +import inspect @dataclass class ExtraMegatronArguments: @@ -107,6 +107,15 @@ def __post_init__(self): if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None: self.lr_decay_iters = self.train_iters - self.lr_warmup_iters + def get_matched_kwargs(args): + args_dict = asdict(args) + parameters = inspect.signature(MegatronArguments.__init__).parameters + + for k in list(args_dict.keys()): + if k not in parameters: + args_dict.pop(k) + return args_dict + def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: new_args = [] args_dict = asdict(self) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index a33dc0600..a48b5b7d3 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -16,12 +16,14 @@ def convert_hf2megatron(args: ExportArguments) -> None: kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) megatron_model_meta = get_megatron_model_meta(args.model) - mg_model = megatron_model_meta.get_model_provider()() kwargs = megatron_model_meta.load_config(hf_model.model_info) - megatron_args = MegatronArguments(kwargs) + megatron_args = MegatronArguments(**kwargs, **MegatronArguments.get_matched_kwargs(args)) extra_args = megatron_args.parse_to_megatron() - initialize_megatron(args_defaults=extra_args) + + mg_model = megatron_model_meta.get_model_provider()() + + megatron_model_meta.convert_hf2megatron(hf_model, mg_model) diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index d96a70945..c78a7d034 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -49,4 +49,4 @@ def get_qwen_model_provider(): Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'), Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'), ]), - ], convert_megatron2hf, convert_hf2megatron, get_model_provider, load_qwen_config)) + ], convert_megatron2hf, convert_hf2megatron, get_qwen_model_provider, load_qwen_config))