Skip to content

Commit 83dc334

Browse files
committed
update
1 parent bd7547c commit 83dc334

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

swift/megatron/convert.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from megatron.training.initialize import initialize_megatron
88

99
from swift.llm import ExportArguments, get_model_tokenizer
10+
from .argument import MegatronArguments
1011
from .model import get_megatron_model_meta
1112

1213

@@ -15,17 +16,13 @@ def convert_hf2megatron(args: ExportArguments) -> None:
1516
kwargs['torch_dtype'] = torch.float32
1617
hf_model, processor = get_model_tokenizer(**kwargs)
1718
megatron_model_meta = get_megatron_model_meta(args.model)
18-
model_provider = megatron_model_meta.get_model_provider()
19-
megatron_model_meta.load_config(hf_model.model_info)
19+
mg_model = megatron_model_meta.get_model_provider()()
20+
kwargs = megatron_model_meta.load_config(hf_model.model_info)
21+
megatron_args = MegatronArguments(kwargs)
22+
extra_args = megatron_args.parse_to_megatron()
2023

2124
initialize_megatron(args_defaults=extra_args)
22-
args = get_args()
23-
model_provider, convert_module = get_megatron_model_convert(args.model_type)
24-
mg_model = model_provider()
25-
convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args)
26-
if save_torch_dtype is not None:
27-
mg_model.to(save_torch_dtype)
28-
convert_module.save_mgmodel(mg_model, args)
25+
megatron_model_meta.convert_hf2megatron(hf_model, mg_model)
2926

3027

3128
def convert_megatron2hf(

swift/megatron/model/qwen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

33
import importlib
4+
from typing import Any, Dict
45

56
from megatron.training import get_args
67

@@ -11,7 +12,7 @@
1112
from .utils import get_model_provider
1213

1314

14-
def load_qwen_config(model_info: ModelInfo):
15+
def load_qwen_config(model_info: ModelInfo) -> Dict[str, Any]:
1516
args_config = load_config(model_info)
1617
args_config['swiglu'] = True
1718
return args_config
@@ -23,6 +24,8 @@ def convert_megatron2hf(hf_model, mg_model):
2324

2425
def convert_hf2megatron(hf_model, mg_model):
2526
args = get_args()
27+
mg_model.to(args.torch_dtype)
28+
convert_module.save_mgmodel(mg_model, args)
2629

2730

2831
def get_qwen_model_provider():

0 commit comments

Comments
 (0)