Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jan 9, 2025
1 parent 9a8c458 commit 836fbcf
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
11 changes: 10 additions & 1 deletion swift/megatron/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple

import inspect

@dataclass
class ExtraMegatronArguments:
Expand Down Expand Up @@ -107,6 +107,15 @@ def __post_init__(self):
if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None:
self.lr_decay_iters = self.train_iters - self.lr_warmup_iters

def get_matched_kwargs(args):
args_dict = asdict(args)
parameters = inspect.signature(MegatronArguments.__init__).parameters

for k in list(args_dict.keys()):
if k not in parameters:
args_dict.pop(k)
return args_dict

def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
new_args = []
args_dict = asdict(self)
Expand Down
8 changes: 5 additions & 3 deletions swift/megatron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def convert_hf2megatron(args: ExportArguments) -> None:
kwargs['torch_dtype'] = torch.float32
hf_model, processor = get_model_tokenizer(**kwargs)
megatron_model_meta = get_megatron_model_meta(args.model)
mg_model = megatron_model_meta.get_model_provider()()
kwargs = megatron_model_meta.load_config(hf_model.model_info)
megatron_args = MegatronArguments(kwargs)
megatron_args = MegatronArguments(**kwargs, **MegatronArguments.get_matched_kwargs(args))
extra_args = megatron_args.parse_to_megatron()

initialize_megatron(args_defaults=extra_args)

mg_model = megatron_model_meta.get_model_provider()()


megatron_model_meta.convert_hf2megatron(hf_model, mg_model)


Expand Down
2 changes: 1 addition & 1 deletion swift/megatron/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def get_qwen_model_provider():
Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
]),
], convert_megatron2hf, convert_hf2megatron, get_model_provider, load_qwen_config))
], convert_megatron2hf, convert_hf2megatron, get_qwen_model_provider, load_qwen_config))

0 comments on commit 836fbcf

Please sign in to comment.