7
7
from megatron .training .initialize import initialize_megatron
8
8
9
9
from swift .llm import ExportArguments , get_model_tokenizer
10
+ from .argument import MegatronArguments
10
11
from .model import get_megatron_model_meta
11
12
12
13
@@ -15,17 +16,13 @@ def convert_hf2megatron(args: ExportArguments) -> None:
15
16
kwargs ['torch_dtype' ] = torch .float32
16
17
hf_model , processor = get_model_tokenizer (** kwargs )
17
18
megatron_model_meta = get_megatron_model_meta (args .model )
18
- model_provider = megatron_model_meta .get_model_provider ()
19
- megatron_model_meta .load_config (hf_model .model_info )
19
+ mg_model = megatron_model_meta .get_model_provider ()()
20
+ kwargs = megatron_model_meta .load_config (hf_model .model_info )
21
+ megatron_args = MegatronArguments (kwargs )
22
+ extra_args = megatron_args .parse_to_megatron ()
20
23
21
24
initialize_megatron (args_defaults = extra_args )
22
- args = get_args ()
23
- model_provider , convert_module = get_megatron_model_convert (args .model_type )
24
- mg_model = model_provider ()
25
- convert_module .convert_checkpoint_from_transformers_to_megatron (hf_model , mg_model , args )
26
- if save_torch_dtype is not None :
27
- mg_model .to (save_torch_dtype )
28
- convert_module .save_mgmodel (mg_model , args )
25
+ megatron_model_meta .convert_hf2megatron (hf_model , mg_model )
29
26
30
27
31
28
def convert_megatron2hf (
0 commit comments