Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support megatron #2885

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Literal, Optional

import torch
import torch.distributed as dist

from swift.utils import get_logger
from .base_args import BaseArguments, to_abspath
Expand Down Expand Up @@ -44,6 +45,12 @@ class ExportArguments(MergeArguments, BaseArguments):
to_ollama: bool = False
gguf_file: Optional[str] = None

# megatron
to_megatron: bool = False
to_hf: bool = False
target_tensor_model_parallel_size: int = 1
target_pipeline_model_parallel_size: int = 1

# push to ms hub
push_to_hub: bool = False
# 'user_name/repo_name' or 'repo_name'
Expand All @@ -65,6 +72,10 @@ def _init_output_dir(self):
suffix = f'{self.quant_method}-int{self.quant_bits}'
elif self.to_ollama:
suffix = 'ollama'
elif self.to_megatron:
suffix = f'tp{self.target_tensor_model_parallel_size}-pp{self.target_pipeline_model_parallel_size}'
elif self.to_hf:
suffix = 'hf'
else:
return

Expand All @@ -81,6 +92,14 @@ def __post_init__(self):
raise ValueError('Please specify `--quant_bits`.')
if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.to_megatron or self.to_hf:
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
dist.init_process_group(backend='nccl')

BaseArguments.__post_init__(self)
self._init_output_dir()
Expand Down
6 changes: 6 additions & 0 deletions swift/llm/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def run(self):
quantize_model(args)
elif args.to_ollama:
export_to_ollama(args)
elif args.to_megatron:
from swift.megatron import convert_hf2megatron
convert_hf2megatron(args)
elif args.to_hf:
from swift.megatron import convert_megatron2hf
convert_megatron2hf(args)
elif args.push_to_hub:
model_dir = args.adapters and args.adapters[0] or args.model_dir
assert model_dir, f'model_dir: {model_dir}'
Expand Down
10 changes: 7 additions & 3 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def get_all_models() -> List[str]:
return models


def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:
def _get_matched_model_meta(model_id_or_path, model_mapping) -> Optional[ModelMeta]:
model_name = get_model_name(model_id_or_path).lower()
for model_type, model_meta in MODEL_MAPPING.items():
model_group = model_meta.get_matched_model_group(model_name)
for model_type, model_meta in model_mapping.items():
model_group = ModelMeta.get_matched_model_group(model_meta, model_name)
if model_group is not None:
model_meta = deepcopy(model_meta)
for k, v in asdict(model_group).items():
Expand All @@ -319,6 +319,10 @@ def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:
return model_meta


def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:
return _get_matched_model_meta(model_id_or_path, MODEL_MAPPING)


def _get_model_info(model_dir: str, model_type: Optional[str], quantization_config) -> ModelInfo:
config_dict = PretrainedConfig.get_config_dict(model_dir)[0]
if quantization_config is not None:
Expand Down
10 changes: 10 additions & 0 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

try:
from .init import init_megatron_env
init_megatron_env()
except Exception:
# allows lint pass.
raise

from .convert import convert_hf2megatron, convert_megatron2hf
142 changes: 142 additions & 0 deletions swift/megatron/argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple
import inspect

@dataclass
class ExtraMegatronArguments:
padded_vocab_size: Optional[int] = None

target_tensor_model_parallel_size: int = 1
target_pipeline_model_parallel_size: int = 1


@dataclass
class MegatronMixin:
num_layers: Optional[int] = None #
hidden_size: Optional[int] = None #
ffn_hidden_size: Optional[int] = None #
num_attention_heads: Optional[int] = None #
num_query_groups: Optional[int] = None #
group_query_attention: Optional[bool] = None #
max_position_embeddings: Optional[int] = None #
norm_epsilon: Optional[float] = None #
swiglu: Optional[bool] = None #
rotary_base: Optional[int] = None #
disable_bias_linear: bool = True #
add_qkv_bias: bool = True #

train_iters: Optional[int] = None #
lr_warmup_iters: Optional[int] = None #
eval_iters: Optional[int] = None #
lr_decay_iters: Optional[int] = None #
save: Optional[str] = None #
load: Optional[str] = None
tensorboard_dir: Optional[str] = None # !
log_interval: int = 10 #
log_throughput: bool = False
eval_interval: Optional[int] = None #
save_interval: int = 500 #

position_embedding_type: str = 'rope' #
rotary_percent: float = 1. #
rotary_seq_len_interpolation_factor: int = 1 #
no_bias_swiglu_fusion: bool = False #
attention_dropout: float = 0. #
hidden_dropout: float = 0. #

optimizer: str = 'adam'
weight_decay: float = 0.1 #
clip_grad: float = 1. #
adam_beta1: float = 0.9 #
adam_beta2: float = 0.95 #
adam_eps: float = 1e-8
init_method_std: float = 0.01 #
micro_batch_size: int = 1 #
global_batch_size: int = 16 #
recompute_method: Optional[str] = None
recompute_granularity: Optional[str] = 'selective'
no_rope_fusion: bool = False
use_flash_attn: bool = False
use_cpu_initialization: Optional[bool] = None

dataloader_type: str = 'cyclic'
lr: float = 1e-5 #
lr_decay_style: str = 'cosine' #
min_lr: int = 1e-6
fp16: bool = False
bf16: bool = False
tensor_model_parallel_size: int = 1 #
pipeline_model_parallel_size: int = 1 #
context_parallel_size: int = 1 #
seed: int = 42
sequence_parallel: bool = False
transformer_impl: str = 'transformer_engine'

apply_query_key_layer_scaling: bool = False # fp16
num_workers: int = 8

log_timers_to_tensorboard: bool = True #
log_batch_size_to_tensorboard: bool = True #
log_validation_ppl_to_tensorboard: bool = True #
log_memory_to_tensorboard: bool = True #
tensorboard_log_interval: int = 1 #
tensorboard_queue_size: int = 10 #
untie_embeddings_and_output_weights: bool = True
seq_length: Optional[int] = None #

no_save_optim: bool = False #
no_save_rng: bool = False #
no_load_optim: bool = False #
no_load_rng: bool = False #
loss_scale: Optional[float] = None
use_distributed_optimizer: bool = True
normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' #
calculate_per_token_loss: bool = True


@dataclass
class MegatronArguments(ExtraMegatronArguments, MegatronMixin):

def __post_init__(self):
if self.group_query_attention is None:
self.group_query_attention = True if self.num_query_groups > 1 else False
if self.eval_interval is None:
self.eval_interval = self.save_interval
if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None:
self.lr_decay_iters = self.train_iters - self.lr_warmup_iters

def get_matched_kwargs(args):
args_dict = asdict(args)
parameters = inspect.signature(MegatronArguments.__init__).parameters

for k in list(args_dict.keys()):
if k not in parameters:
args_dict.pop(k)
return args_dict

def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
new_args = []
args_dict = asdict(self)
extra_args = {}
for k, value in args_dict.items():
if k in ExtraMegatronArguments.__annotations__:
extra_args[k] = value
continue
if value is None or value is False:
continue
new_args.append(f"--{k.replace('_', '-')}")
if isinstance(value, list):
new_args += [str(v) for v in value]
elif value is not True:
new_args.append(str(value))

return new_args, extra_args

def parse_to_megatron(self):
new_args, extra_args = self._args_to_argv()
sys._old_argv = sys.argv
sys.argv = sys.argv[:1] + new_args

return extra_args
40 changes: 40 additions & 0 deletions swift/megatron/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

import torch
from megatron.training import get_args
from megatron.training.initialize import initialize_megatron

from swift.llm import ExportArguments, get_model_tokenizer
from .argument import MegatronArguments
from .model import get_megatron_model_meta


def convert_hf2megatron(args: ExportArguments) -> None:
kwargs = args.get_model_kwargs()
kwargs['torch_dtype'] = torch.float32
hf_model, processor = get_model_tokenizer(**kwargs)
megatron_model_meta = get_megatron_model_meta(args.model)
kwargs = megatron_model_meta.load_config(hf_model.model_info)
megatron_args = MegatronArguments(**kwargs, **MegatronArguments.get_matched_kwargs(args))
extra_args = megatron_args.parse_to_megatron()
initialize_megatron(args_defaults=extra_args)

mg_model = megatron_model_meta.get_model_provider()()


megatron_model_meta.convert_hf2megatron(hf_model, mg_model)


def convert_megatron2hf(
hf_model,
extra_args: Dict[str, Any],
) -> None:
initialize_megatron(args_defaults=extra_args)
args = get_args()

model_provider, convert_module = get_megatron_model_convert(args.model_type)
convert_module.model_provider = model_provider
mg_model = convert_module.load_megatron_model(args) # no copy
convert_module.convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args)
38 changes: 38 additions & 0 deletions swift/megatron/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import shutil
import sys

from swift.llm import git_clone_github
from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run


def _rename_files():
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
qwen_folders = ['toolkits/model_checkpoints_convertor/qwen']
for folder in qwen_folders:
dir_path = os.path.join(megatron_patch_path, folder)
for fname in os.listdir(dir_path):
old_path = os.path.join(dir_path, fname)
fname = fname.replace('qwen1.', 'qwen1_')
fname = fname.replace('qwen2.', 'qwen2_')
new_path = os.path.join(dir_path, fname)
if old_path != new_path and os.path.exists(old_path):
shutil.move(old_path, new_path)


def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
if not is_megatron_available():
subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
sys.path.append(os.environ['MEGATRON_LM_PATH'])

if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github(
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])

# rename qwen1.5/2.5->qwen1_5/2_5 files
with safe_ddp_context('rename_files'):
_rename_files()
3 changes: 3 additions & 0 deletions swift/megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import qwen
from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
29 changes: 29 additions & 0 deletions swift/megatron/model/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, Dict

from swift.llm import ModelInfo

config_mapping = {
'num_layers': ['num_hidden_layers'],
'hidden_size': ['hidden_size'],
'ffn_hidden_size': ['intermediate_size'],
'num_attention_heads': ['num_attention_heads'],
'num_query_groups': ['num_key_value_heads'],
'max_position_embeddings': ['max_position_embeddings'],
'norm_epsilon': ['rms_norm_eps'],
'rotary_base': ['rope_theta'],
'padded_vocab_size': ['vocab_size'],
'attention_dropout': ['attention_dropout']
}


def load_config(model_info: ModelInfo) -> Dict[str, Any]:
model_config = model_info.config
megatron_config = {}
for k, value in config_mapping.items():
for v in value:
assert hasattr(model_config, v)
if k == 'rotary_base':
megatron_config[k] = int(getattr(model_config, v))
else:
megatron_config[k] = getattr(model_config, v)
return megatron_config
2 changes: 2 additions & 0 deletions swift/megatron/model/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class MegatronModelType:
qwen = 'qwen'
Loading
Loading