diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 767484e6f..a33dc0600 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -7,6 +7,7 @@ from megatron.training.initialize import initialize_megatron from swift.llm import ExportArguments, get_model_tokenizer +from .argument import MegatronArguments from .model import get_megatron_model_meta @@ -15,17 +16,13 @@ def convert_hf2megatron(args: ExportArguments) -> None: kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) megatron_model_meta = get_megatron_model_meta(args.model) - model_provider = megatron_model_meta.get_model_provider() - megatron_model_meta.load_config(hf_model.model_info) + mg_model = megatron_model_meta.get_model_provider()() + kwargs = megatron_model_meta.load_config(hf_model.model_info) + megatron_args = MegatronArguments(kwargs) + extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) - args = get_args() - model_provider, convert_module = get_megatron_model_convert(args.model_type) - mg_model = model_provider() - convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args) - if save_torch_dtype is not None: - mg_model.to(save_torch_dtype) - convert_module.save_mgmodel(mg_model, args) + megatron_model_meta.convert_hf2megatron(hf_model, mg_model) def convert_megatron2hf( diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index 0df3277c2..d96a70945 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import importlib +from typing import Any, Dict from megatron.training import get_args @@ -11,7 +12,7 @@ from .utils import get_model_provider -def load_qwen_config(model_info: ModelInfo): +def load_qwen_config(model_info: ModelInfo) -> Dict[str, Any]: args_config = load_config(model_info) args_config['swiglu'] = True return args_config @@ -23,6 +24,8 @@ def convert_megatron2hf(hf_model, mg_model): def convert_hf2megatron(hf_model, mg_model): args = get_args() + mg_model.to(args.torch_dtype) + convert_module.save_mgmodel(mg_model, args) def get_qwen_model_provider():