Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jan 9, 2025
1 parent bd7547c commit 83dc334
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
15 changes: 6 additions & 9 deletions swift/megatron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from megatron.training.initialize import initialize_megatron

from swift.llm import ExportArguments, get_model_tokenizer
from .argument import MegatronArguments
from .model import get_megatron_model_meta


Expand All @@ -15,17 +16,13 @@ def convert_hf2megatron(args: ExportArguments) -> None:
kwargs['torch_dtype'] = torch.float32
hf_model, processor = get_model_tokenizer(**kwargs)
megatron_model_meta = get_megatron_model_meta(args.model)
model_provider = megatron_model_meta.get_model_provider()
megatron_model_meta.load_config(hf_model.model_info)
mg_model = megatron_model_meta.get_model_provider()()
kwargs = megatron_model_meta.load_config(hf_model.model_info)
megatron_args = MegatronArguments(kwargs)
extra_args = megatron_args.parse_to_megatron()

initialize_megatron(args_defaults=extra_args)
args = get_args()
model_provider, convert_module = get_megatron_model_convert(args.model_type)
mg_model = model_provider()
convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args)
if save_torch_dtype is not None:
mg_model.to(save_torch_dtype)
convert_module.save_mgmodel(mg_model, args)
megatron_model_meta.convert_hf2megatron(hf_model, mg_model)


def convert_megatron2hf(
Expand Down
5 changes: 4 additions & 1 deletion swift/megatron/model/qwen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import importlib
from typing import Any, Dict

from megatron.training import get_args

Expand All @@ -11,7 +12,7 @@
from .utils import get_model_provider


def load_qwen_config(model_info: ModelInfo):
def load_qwen_config(model_info: ModelInfo) -> Dict[str, Any]:
args_config = load_config(model_info)
args_config['swiglu'] = True
return args_config
Expand All @@ -23,6 +24,8 @@ def convert_megatron2hf(hf_model, mg_model):

def convert_hf2megatron(hf_model, mg_model):
args = get_args()
mg_model.to(args.torch_dtype)
convert_module.save_mgmodel(mg_model, args)


def get_qwen_model_provider():
Expand Down

0 comments on commit 83dc334

Please sign in to comment.