From d875242adcb2d420fd5be5e53e25d487c0490966 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Jan 2025 23:40:14 +0800 Subject: [PATCH 01/11] support megatron --- swift/llm/argument/export_args.py | 20 +++++++++++++++++++- swift/llm/export/export.py | 6 ++++++ swift/megatron/__init__.py | 2 ++ swift/megatron/convert.py | 2 ++ swift/megatron/utils.py | 2 ++ 5 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 swift/megatron/__init__.py create mode 100644 swift/megatron/convert.py create mode 100644 swift/megatron/utils.py diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index cc36b6450..a9e20ec59 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from typing import Literal, Optional - +import torch.distributed as dist import torch from swift.utils import get_logger @@ -44,6 +44,12 @@ class ExportArguments(MergeArguments, BaseArguments): to_ollama: bool = False gguf_file: Optional[str] = None + # megatron + to_megatron: bool = False + to_hf: bool = False + target_tensor_model_parallel_size: int = 1 + target_pipeline_model_parallel_size: int = 1 + # push to ms hub push_to_hub: bool = False # 'user_name/repo_name' or 'repo_name' @@ -65,6 +71,10 @@ def _init_output_dir(self): suffix = f'{self.quant_method}-int{self.quant_bits}' elif self.to_ollama: suffix = 'ollama' + elif self.to_megatron: + suffix = f'tp{self.target_tensor_model_parallel_size}-pp{self.target_pipeline_model_parallel_size}' + elif self.to_hf: + suffix = 'hf' else: return @@ -81,6 +91,14 @@ def __post_init__(self): raise ValueError('Please specify `--quant_bits`.') if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None: self.torch_dtype = torch.float16 + if self.to_megatron or self.to_hf: + os.environ['RANK'] = '0' + os.environ['LOCAL_RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + dist.init_process_group(backend='nccl') BaseArguments.__post_init__(self) self._init_output_dir() diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 1adb901de..fdb50b307 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -25,6 +25,12 @@ def run(self): quantize_model(args) elif args.to_ollama: export_to_ollama(args) + elif args.to_megatron: + from swift.megatron import convert_hf_to_megatron + convert_hf_to_megatron(args) + elif args.to_hf: + from swift.megatron import convert_megatron_to_hf + convert_megatron_to_hf(args) elif args.push_to_hub: model_dir = args.adapters and args.adapters[0] or args.model_dir assert model_dir, f'model_dir: {model_dir}' diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py new file mode 100644 index 000000000..1fd393c25 --- /dev/null +++ b/swift/megatron/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py new file mode 100644 index 000000000..1fd393c25 --- /dev/null +++ b/swift/megatron/convert.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py new file mode 100644 index 000000000..1fd393c25 --- /dev/null +++ b/swift/megatron/utils.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + From 13e4a65484394dfdf4a1ee105173e6b5d0ba9804 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 11:41:23 +0800 Subject: [PATCH 02/11] update --- swift/megatron/__init__.py | 4 + swift/megatron/argument.py | 133 ++++++++++++ swift/megatron/convert.py | 2 - swift/megatron/convert/__init__.py | 3 + swift/megatron/convert/hf2megatron.py | 32 +++ swift/megatron/convert/megatron2hf.py | 17 ++ swift/megatron/model/__init__.py | 6 + swift/megatron/model/config.py | 27 +++ swift/megatron/model/constant.py | 3 + swift/megatron/model/qwen.py | 37 ++++ swift/megatron/model/register.py | 30 +++ swift/megatron/model/utils.py | 27 +++ swift/megatron/utils.py | 288 ++++++++++++++++++++++++++ 13 files changed, 607 insertions(+), 2 deletions(-) create mode 100644 swift/megatron/argument.py delete mode 100644 swift/megatron/convert.py create mode 100644 swift/megatron/convert/__init__.py create mode 100644 swift/megatron/convert/hf2megatron.py create mode 100644 swift/megatron/convert/megatron2hf.py create mode 100644 swift/megatron/model/__init__.py create mode 100644 swift/megatron/model/config.py create mode 100644 swift/megatron/model/constant.py create mode 100644 swift/megatron/model/qwen.py create mode 100644 swift/megatron/model/register.py create mode 100644 swift/megatron/model/utils.py diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 1fd393c25..0d49bea96 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -1,2 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + +from .utils import init_megatron_env + +init_megatron_env() diff --git a/swift/megatron/argument.py b/swift/megatron/argument.py new file mode 100644 index 000000000..16c142ddf --- /dev/null +++ b/swift/megatron/argument.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field, asdict +from typing import Optional, Literal, Dict, Any, Tuple, List +import sys + + +@dataclass +class ExtraMegatronArguments: + padded_vocab_size: Optional[int] = None + + target_tensor_model_parallel_size: int = 1 + target_pipeline_model_parallel_size: int = 1 + + +@dataclass +class MegatronMixin: + num_layers: Optional[int] = None # + hidden_size: Optional[int] = None # + ffn_hidden_size: Optional[int] = None # + num_attention_heads: Optional[int] = None # + num_query_groups: Optional[int] = None # + group_query_attention: Optional[bool] = None # + max_position_embeddings: Optional[int] = None # + norm_epsilon: Optional[float] = None # + swiglu: Optional[bool] = None # + rotary_base: Optional[int] = None # + disable_bias_linear: bool = True # + add_qkv_bias: bool = True # + + train_iters: Optional[int] = None # + lr_warmup_iters: Optional[int] = None # + eval_iters: Optional[int] = None # + lr_decay_iters: Optional[int] = None # + save: Optional[str] = None # + load: Optional[str] = None + tensorboard_dir: Optional[str] = None # ! + log_interval: int = 10 # + log_throughput: bool = False + eval_interval: Optional[int] = None # + save_interval: int = 500 # + + position_embedding_type: str = 'rope' # + rotary_percent: float = 1. # + rotary_seq_len_interpolation_factor: int = 1 # + no_bias_swiglu_fusion: bool = False # + attention_dropout: float = 0. # + hidden_dropout: float = 0. # + + optimizer: str = 'adam' + weight_decay: float = 0.1 # + clip_grad: float = 1. # + adam_beta1: float = 0.9 # + adam_beta2: float = 0.95 # + adam_eps: float = 1e-8 + init_method_std: float = 0.01 # + micro_batch_size: int = 1 # + global_batch_size: int = 16 # + recompute_method: Optional[str] = None + recompute_granularity: Optional[str] = 'selective' + no_rope_fusion: bool = False + use_flash_attn: bool = False + use_cpu_initialization: Optional[bool] = None + + dataloader_type: str = 'cyclic' + lr: float = 1e-5 # + lr_decay_style: str = 'cosine' # + min_lr: int = 1e-6 + fp16: bool = False + bf16: bool = False + tensor_model_parallel_size: int = 1 # + pipeline_model_parallel_size: int = 1 # + context_parallel_size: int = 1 # + seed: int = 42 + sequence_parallel: bool = False + transformer_impl: str = 'transformer_engine' + + apply_query_key_layer_scaling: bool = False # fp16 + num_workers: int = 8 + + log_timers_to_tensorboard: bool = True # + log_batch_size_to_tensorboard: bool = True # + log_validation_ppl_to_tensorboard: bool = True # + log_memory_to_tensorboard: bool = True # + tensorboard_log_interval: int = 1 # + tensorboard_queue_size: int = 10 # + untie_embeddings_and_output_weights: bool = True + seq_length: Optional[int] = None # + + no_save_optim: bool = False # + no_save_rng: bool = False # + no_load_optim: bool = False # + no_load_rng: bool = False # + loss_scale: Optional[float] = None + use_distributed_optimizer: bool = True + normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' # + calculate_per_token_loss: bool = True + + +@dataclass +class MegatronArguments(ExtraMegatronArguments, MegatronMixin): + + def __post_init__(self): + if self.group_query_attention is None: + self.group_query_attention = True if self.num_query_groups > 1 else False + if self.eval_interval is None: + self.eval_interval = self.save_interval + if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None: + self.lr_decay_iters = self.train_iters - self.lr_warmup_iters + + def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: + new_args = [] + args_dict = asdict(self) + extra_args = {} + for k, value in args_dict.items(): + if k in ExtraMegatronArguments.__annotations__: + extra_args[k] = value + continue + if value is None or value is False: + continue + new_args.append(f"--{k.replace('_', '-')}") + if isinstance(value, list): + new_args += [str(v) for v in value] + elif value is not True: + new_args.append(str(value)) + + return new_args, extra_args + + def parse_to_megatron(self): + new_args, extra_args = self._args_to_argv() + sys._old_argv = sys.argv + sys.argv = sys.argv[:1] + new_args + + return extra_args diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py deleted file mode 100644 index 1fd393c25..000000000 --- a/swift/megatron/convert.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - diff --git a/swift/megatron/convert/__init__.py b/swift/megatron/convert/__init__.py new file mode 100644 index 000000000..c4175f177 --- /dev/null +++ b/swift/megatron/convert/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .hf2megatron import convert_hf2megatron +from .megatron2hf import convert_megatron2hf diff --git a/swift/megatron/convert/hf2megatron.py b/swift/megatron/convert/hf2megatron.py new file mode 100644 index 000000000..f369d7331 --- /dev/null +++ b/swift/megatron/convert/hf2megatron.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +from swift.llm import get_model_tokenizer, ExportArguments + +from ..model import get_megatron_model_meta + + +def convert_hf2megatron( + args: ExportArguments +) -> None: + + from megatron.training.initialize import initialize_megatron + from megatron.training import get_args + kwargs = args.get_model_kwargs() + kwargs['torch_dtype'] = torch.float32 + hf_model, processor = get_model_tokenizer(**kwargs) + megatron_model_meta = get_megatron_model_meta(args.model) + megatron_model_meta.get_model_provider() + megatron_model_meta.load_config(hf_model.model_info) + + + initialize_megatron(args_defaults=extra_args) + args = get_args() + model_provider, convert_module = get_megatron_model_convert(args.model_type) + mg_model = model_provider() + convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args) + if save_torch_dtype is not None: + mg_model.to(save_torch_dtype) + convert_module.save_mgmodel(mg_model, args) + + diff --git a/swift/megatron/convert/megatron2hf.py b/swift/megatron/convert/megatron2hf.py new file mode 100644 index 000000000..02711e6d1 --- /dev/null +++ b/swift/megatron/convert/megatron2hf.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + + +def convert_megatron2hf( + hf_model, + extra_args: Dict[str, Any], +) -> None: + from megatron.training.initialize import initialize_megatron + from megatron.training import get_args + initialize_megatron(args_defaults=extra_args) + args = get_args() + + model_provider, convert_module = get_megatron_model_convert(args.model_type) + convert_module.model_provider = model_provider + mg_model = convert_module.load_megatron_model(args) # no copy + convert_module.convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args) diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py new file mode 100644 index 000000000..573c9d4af --- /dev/null +++ b/swift/megatron/model/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .register import ( + register_megatron_model, get_megatron_model_meta, MegatronModelMeta +) + + diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py new file mode 100644 index 000000000..efb99c751 --- /dev/null +++ b/swift/megatron/model/config.py @@ -0,0 +1,27 @@ +from swift.llm import ModelInfo +from typing import Dict, Any + +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'] +} + +def load_config(model_info: ModelInfo) -> Dict[str, Any]: + model_config = model_info.config + megatron_config = {} + for k, value in config_mapping.items(): + for v in value: + assert hasattr(model_config, v) + if k == 'rotary_base': + megatron_config[k] = int(getattr(model_config, v)) + else: + megatron_config[k] = getattr(model_config, v) + return megatron_config diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py new file mode 100644 index 000000000..5c314ec1d --- /dev/null +++ b/swift/megatron/model/constant.py @@ -0,0 +1,3 @@ + +class MegatronModelType: + qwen = 'qwen' diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py new file mode 100644 index 000000000..61589e93b --- /dev/null +++ b/swift/megatron/model/qwen.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from swift.llm import ModelInfo, ModelGroup, Model +from .register import register_megatron_model, MegatronModelMeta +from .utils import get_model_provider +from .constant import MegatronModelType +from .config import load_config + +def load_qwen_config(model_info: ModelInfo): + args_config = load_config(model_info) + args_config['swiglu'] = True + return args_config + +def convert_megatron2hf(): + pass + +def convert_hf2megatron(): + pass + + +register_megatron_model(MegatronModelMeta( +MegatronModelType.qwen,[ + ModelGroup([ + Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'), + Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'), + Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'), + Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'), + Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'), + Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'), + Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'), + ]), + ], + convert_megatron2hf, + convert_hf2megatron, + get_model_provider, + load_qwen_config +)) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py new file mode 100644 index 000000000..39a01d814 --- /dev/null +++ b/swift/megatron/model/register.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Callable, List, Optional +from dataclasses import dataclass +from swift.llm import ModelGroup +from swift.llm.model.register import _get_matched_model_meta + +MEGATRON_MODEL_MAPPING = {} + + +@dataclass +class MegatronModelMeta: + megatron_model_type: Optional[str] + model_groups: List[ModelGroup] + + convert_megatron2hf: Callable + convert_hf2megatron: Callable + get_model_provider: Callable + load_config: Callable + +def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = False): + megatron_model_type = model_meta.megatron_model_type + if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: + raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') + + MEGATRON_MODEL_MAPPING[megatron_model_type] = model_meta + + +def get_megatron_model_meta(model_id_or_path: str) -> Optional[MegatronModelMeta]: + return _get_matched_model_meta(model_id_or_path, MEGATRON_MODEL_MAPPING) + diff --git a/swift/megatron/model/utils.py b/swift/megatron/model/utils.py new file mode 100644 index 000000000..4f1692d6a --- /dev/null +++ b/swift/megatron/model/utils.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module): + def model_provider(pre_process=True, post_process=True): + from megatron.training import get_args + from megatron.training.arguments import core_transformer_config_from_args + args = get_args() + config = core_transformer_config_from_args(args, transformer_config_cls) + transformer_layer_spec = layer_spec_module.get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + model = gpt_model_cls( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor) + return model + return model_provider + diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py index 1fd393c25..5270f5899 100644 --- a/swift/megatron/utils.py +++ b/swift/megatron/utils.py @@ -1,2 +1,290 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import sys +from functools import partial, wraps +from typing import Any, Dict, List, Mapping, Optional +import torch +import torch.distributed as dist + +from swift.llm import LazyLLMDataset, Template, git_clone_github +from swift.utils import ( + append_to_jsonl, get_dist_setting, get_logger, is_master, subprocess_run, +is_megatron_available, safe_ddp_context) + + +logger = get_logger() + + +def _rename_files(): + megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] + qwen_folders = ['toolkits/model_checkpoints_convertor/qwen'] + for folder in qwen_folders: + dir_path = os.path.join(megatron_patch_path, folder) + for fname in os.listdir(dir_path): + old_path = os.path.join(dir_path, fname) + fname = fname.replace('qwen1.', 'qwen1_') + fname = fname.replace('qwen2.', 'qwen2_') + new_path = os.path.join(dir_path, fname) + if old_path != new_path and os.path.exists(old_path): + shutil.move(old_path, new_path) + + +def init_megatron_env() -> None: + if 'MEGATRON_LM_PATH' not in os.environ: + megatron_path = git_clone_github( + 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') + else: + megatron_path = os.environ['MEGATRON_LM_PATH'] + if not is_megatron_available(): + subprocess_run(['pip', 'install', '-e', megatron_path]) + sys.path.append(megatron_path) + + if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: + megatron_patch_path = git_clone_github( + 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') + else: + megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] + sys.path.append(megatron_patch_path) + + # rename qwen1.5/2.5->qwen1_5/2_5 files + with safe_ddp_context(): + _rename_files() + +def patch_megatron(tokenizer): + + def build_tokenizer(args): + args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size + return tokenizer + + from megatron.training import get_args, training, initialize, global_vars + global_vars.build_tokenizer = build_tokenizer + + _old_initialize_distributed = initialize._initialize_distributed + + @wraps(_old_initialize_distributed) + def _initialize_distributed(*_args, **kwargs): + args = get_args() + if dist.is_initialized(): + args.rank, args.local_rank, args.world_size, args.local_world_size = get_dist_setting() + torch.cuda.set_device(args.local_rank) + return _old_initialize_distributed(*_args, **kwargs) + + initialize._initialize_distributed = _initialize_distributed + + _old_load_state_dict = torch.nn.Module.load_state_dict + + @wraps(_old_load_state_dict) + def _load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, *args, **kwargs): + if strict: + keys = self.state_dict().keys() ^ state_dict.keys() + new_keys = [k for k in keys if not k.endswith('_extra_state')] + if keys and not new_keys: + strict = False + return _old_load_state_dict(self, state_dict, strict, *args, **kwargs) + + torch.nn.Module.load_state_dict = _load_state_dict + + _old_training_log = training.training_log + + @wraps(_old_training_log) + def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, + report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs): + args = get_args() + if is_master() and iteration % args.log_interval == 0: + logging_path = os.path.join(args.save, 'logging.jsonl') + logs = {} + for k, v in loss_dict.items(): + if isinstance(v, torch.Tensor): + v = v.item() + logs[k] = round(v, 8) + logs['grad_norm'] = round(grad_norm, 8) + logs['learning_rate'] = round(learning_rate, 8) + logs['consumed_samples'] = args.consumed_train_samples + logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' + + append_to_jsonl(logging_path, logs) + return _old_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, + num_zeros_in_grad, *_args, **kwargs) + + training.training_log = training_log + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. copy from Pai-Megatron-Patch + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + """ + from megatron.training import get_args + from megatron.core import mpu + from megatron.training.utils import average_losses_across_data_parallel_group + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + if args.context_parallel_size > 1: + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) + dist.all_reduce(loss, group=mpu.get_context_parallel_group()) + loss = loss[0] / loss[1] + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = dist.get_rank() + assert not loss.isnan(), (f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}') + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss * args.context_parallel_size, {'loss': averaged_loss[0]} + + +def get_batch_on_this_tp_rank(data_iterator): + # copy from Megatron-LM and made some changes. + from megatron.training import get_args + from megatron.core import mpu + args = get_args() + + def _broadcast(item): + if item is not None: + dist.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + + if mpu.get_tensor_model_parallel_rank() == 0: + + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + args.seq_length = data['tokens'].shape[1] + _broadcast(torch.tensor(args.seq_length).cuda(non_blocking=True)) + batch = { + 'tokens': data['tokens'].cuda(non_blocking=True), + 'labels': data['labels'].cuda(non_blocking=True), + 'loss_mask': data['loss_mask'].cuda(non_blocking=True), + 'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), + 'position_ids': data['position_ids'].cuda(non_blocking=True) + } + + if args.pipeline_model_parallel_size == 1: + _broadcast(batch['tokens']) + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_first_stage(): + _broadcast(batch['tokens']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_last_stage(): + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + + else: + seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + _broadcast(seq_length) + args.seq_length = seq_length.item() + tokens = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + labels = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + loss_mask = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.float32, + device=torch.cuda.current_device()) + if args.create_attention_mask_in_dataloader: + attention_mask = torch.empty((args.micro_batch_size, 1, args.seq_length, args.seq_length), + dtype=torch.bool, + device=torch.cuda.current_device()) + else: + attention_mask = None + position_ids = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + + if args.pipeline_model_parallel_size == 1: + _broadcast(tokens) + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_first_stage(): + labels = None + loss_mask = None + + _broadcast(tokens) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_last_stage(): + tokens = None + position_ids = None + + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + + batch = { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'attention_mask': attention_mask, + 'position_ids': position_ids + } + + return batch + + +def forward_step(data_iterator, model): + from megatron.training.utils import get_batch_on_this_cp_rank + batch = get_batch_on_this_tp_rank(data_iterator) + batch = get_batch_on_this_cp_rank(batch) + tokens, labels, loss_mask, attention_mask, position_ids = batch.values() + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples, train_dataset: LazyLLMDataset, + val_dataset: LazyLLMDataset, template: Template): + # train_val_test_num_samples: ignored + from megatron.training import training + from megatron.training.utils import get_ltor_masks_and_position_ids + + assert not hasattr(training, '_old_build_pretraining_data_loader') + _old_build_pretraining_data_loader = training.build_pretraining_data_loader + + def data_collator(batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]: + res = template.data_collator(batch, padding_to) + labels = res['labels'] + new_labels = torch.zeros_like(labels) + new_labels[:, :-1] = labels[:, 1:] + new_labels[:, -1] = -100 + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(new_labels, -100, False, False, True) + return { + 'tokens': res['input_ids'], + 'labels': new_labels, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids + } + + @wraps(_old_build_pretraining_data_loader) + def build_pretraining_data_loader(*args, **kwargs): + res = _old_build_pretraining_data_loader(*args, **kwargs) + if res is not None: + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + training._old_build_pretraining_data_loader = _old_build_pretraining_data_loader + return train_dataset, val_dataset, None From f230e01875805b1aa5006b91b6a405bd4cdc7fbb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 11:41:51 +0800 Subject: [PATCH 03/11] update --- swift/llm/argument/export_args.py | 3 ++- swift/megatron/__init__.py | 3 +-- swift/megatron/argument.py | 4 ++-- swift/megatron/convert/hf2megatron.py | 9 ++------- swift/megatron/model/__init__.py | 6 +----- swift/megatron/model/config.py | 4 +++- swift/megatron/model/constant.py | 1 - swift/megatron/model/qwen.py | 22 ++++++++++------------ swift/megatron/model/register.py | 5 +++-- swift/megatron/model/utils.py | 4 +++- swift/megatron/utils.py | 13 +++++-------- 11 files changed, 32 insertions(+), 42 deletions(-) diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index a9e20ec59..9862c7842 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -2,8 +2,9 @@ import os from dataclasses import dataclass from typing import Literal, Optional -import torch.distributed as dist + import torch +import torch.distributed as dist from swift.utils import get_logger from .base_args import BaseArguments, to_abspath diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 0d49bea96..9dde29f3e 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from .utils import init_megatron_env +from .utils import init_megatron_env init_megatron_env() diff --git a/swift/megatron/argument.py b/swift/megatron/argument.py index 16c142ddf..891cd3fc3 100644 --- a/swift/megatron/argument.py +++ b/swift/megatron/argument.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from dataclasses import dataclass, field, asdict -from typing import Optional, Literal, Dict, Any, Tuple, List import sys +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Literal, Optional, Tuple @dataclass diff --git a/swift/megatron/convert/hf2megatron.py b/swift/megatron/convert/hf2megatron.py index f369d7331..d4ed29c48 100644 --- a/swift/megatron/convert/hf2megatron.py +++ b/swift/megatron/convert/hf2megatron.py @@ -1,14 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import torch -from swift.llm import get_model_tokenizer, ExportArguments +from swift.llm import ExportArguments, get_model_tokenizer from ..model import get_megatron_model_meta -def convert_hf2megatron( - args: ExportArguments -) -> None: +def convert_hf2megatron(args: ExportArguments) -> None: from megatron.training.initialize import initialize_megatron from megatron.training import get_args @@ -19,7 +17,6 @@ def convert_hf2megatron( megatron_model_meta.get_model_provider() megatron_model_meta.load_config(hf_model.model_info) - initialize_megatron(args_defaults=extra_args) args = get_args() model_provider, convert_module = get_megatron_model_convert(args.model_type) @@ -28,5 +25,3 @@ def convert_hf2megatron( if save_torch_dtype is not None: mg_model.to(save_torch_dtype) convert_module.save_mgmodel(mg_model, args) - - diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 573c9d4af..ca15b80b5 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,6 +1,2 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .register import ( - register_megatron_model, get_megatron_model_meta, MegatronModelMeta -) - - +from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py index efb99c751..83ae73153 100644 --- a/swift/megatron/model/config.py +++ b/swift/megatron/model/config.py @@ -1,5 +1,6 @@ +from typing import Any, Dict + from swift.llm import ModelInfo -from typing import Dict, Any config_mapping = { 'num_layers': ['num_hidden_layers'], @@ -14,6 +15,7 @@ 'attention_dropout': ['attention_dropout'] } + def load_config(model_info: ModelInfo) -> Dict[str, Any]: model_config = model_info.config megatron_config = {} diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py index 5c314ec1d..929eae691 100644 --- a/swift/megatron/model/constant.py +++ b/swift/megatron/model/constant.py @@ -1,3 +1,2 @@ - class MegatronModelType: qwen = 'qwen' diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index 61589e93b..c4ff24f30 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -1,25 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import ModelInfo, ModelGroup, Model -from .register import register_megatron_model, MegatronModelMeta -from .utils import get_model_provider -from .constant import MegatronModelType +from swift.llm import Model, ModelGroup, ModelInfo from .config import load_config +from .constant import MegatronModelType +from .register import MegatronModelMeta, register_megatron_model +from .utils import get_model_provider + def load_qwen_config(model_info: ModelInfo): args_config = load_config(model_info) args_config['swiglu'] = True return args_config + def convert_megatron2hf(): pass + def convert_hf2megatron(): pass -register_megatron_model(MegatronModelMeta( -MegatronModelType.qwen,[ +register_megatron_model( + MegatronModelMeta(MegatronModelType.qwen, [ ModelGroup([ Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'), Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'), @@ -29,9 +32,4 @@ def convert_hf2megatron(): Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'), Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'), ]), - ], - convert_megatron2hf, - convert_hf2megatron, - get_model_provider, - load_qwen_config -)) + ], convert_megatron2hf, convert_hf2megatron, get_model_provider, load_qwen_config)) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 39a01d814..a98ad3fa8 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Callable, List, Optional from dataclasses import dataclass +from typing import Callable, List, Optional + from swift.llm import ModelGroup from swift.llm.model.register import _get_matched_model_meta @@ -17,6 +18,7 @@ class MegatronModelMeta: get_model_provider: Callable load_config: Callable + def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = model_meta.megatron_model_type if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: @@ -27,4 +29,3 @@ def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = F def get_megatron_model_meta(model_id_or_path: str) -> Optional[MegatronModelMeta]: return _get_matched_model_meta(model_id_or_path, MEGATRON_MODEL_MAPPING) - diff --git a/swift/megatron/model/utils.py b/swift/megatron/model/utils.py index 4f1692d6a..6232ecb3f 100644 --- a/swift/megatron/model/utils.py +++ b/swift/megatron/model/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module): + def model_provider(pre_process=True, post_process=True): from megatron.training import get_args from megatron.training.arguments import core_transformer_config_from_args @@ -23,5 +25,5 @@ def model_provider(pre_process=True, post_process=True): rotary_base=args.rotary_base, seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor) return model - return model_provider + return model_provider diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py index 5270f5899..8c4ff82a4 100644 --- a/swift/megatron/utils.py +++ b/swift/megatron/utils.py @@ -9,10 +9,8 @@ import torch.distributed as dist from swift.llm import LazyLLMDataset, Template, git_clone_github -from swift.utils import ( - append_to_jsonl, get_dist_setting, get_logger, is_master, subprocess_run, -is_megatron_available, safe_ddp_context) - +from swift.utils import (append_to_jsonl, get_dist_setting, get_logger, is_master, is_megatron_available, + safe_ddp_context, subprocess_run) logger = get_logger() @@ -33,8 +31,7 @@ def _rename_files(): def init_megatron_env() -> None: if 'MEGATRON_LM_PATH' not in os.environ: - megatron_path = git_clone_github( - 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') + megatron_path = git_clone_github('https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') else: megatron_path = os.environ['MEGATRON_LM_PATH'] if not is_megatron_available(): @@ -42,8 +39,7 @@ def init_megatron_env() -> None: sys.path.append(megatron_path) if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: - megatron_patch_path = git_clone_github( - 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') + megatron_patch_path = git_clone_github('https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') else: megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] sys.path.append(megatron_patch_path) @@ -52,6 +48,7 @@ def init_megatron_env() -> None: with safe_ddp_context(): _rename_files() + def patch_megatron(tokenizer): def build_tokenizer(args): From 9b92ae037ef32596705cc4d73b75ec2f226671a0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 11:44:51 +0800 Subject: [PATCH 04/11] update --- swift/llm/model/register.py | 10 +++++++--- tests/megatron/test_export.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 tests/megatron/test_export.py diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 5ad49dc76..1ae2dbb2d 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -307,10 +307,10 @@ def get_all_models() -> List[str]: return models -def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: +def _get_matched_model_meta(model_id_or_path, model_mapping) -> Optional[ModelMeta]: model_name = get_model_name(model_id_or_path).lower() - for model_type, model_meta in MODEL_MAPPING.items(): - model_group = model_meta.get_matched_model_group(model_name) + for model_type, model_meta in model_mapping.items(): + model_group = ModelMeta.get_matched_model_group(model_meta, model_name) if model_group is not None: model_meta = deepcopy(model_meta) for k, v in asdict(model_group).items(): @@ -319,6 +319,10 @@ def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: return model_meta +def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: + return _get_matched_model_meta(model_id_or_path, MODEL_MAPPING) + + def _get_model_info(model_dir: str, model_type: Optional[str], quantization_config) -> ModelInfo: config_dict = PretrainedConfig.get_config_dict(model_dir)[0] if quantization_config is not None: diff --git a/tests/megatron/test_export.py b/tests/megatron/test_export.py new file mode 100644 index 000000000..b808a87f5 --- /dev/null +++ b/tests/megatron/test_export.py @@ -0,0 +1,16 @@ + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +def hf2megatron(): + from swift.llm import export_main, ExportArguments + export_main(ExportArguments(model='Qwen/Qwen2.5-7B-Instruct', to_megatron=True, + tensor_model_parallel_size=2, torch_dtype='bfloat16')) + + +def megatron2hf(): + from swift.llm import export_main, ExportArguments + + +if __name__ == '__main__': + hf2megatron() From e02c51914fa28a77d32fadd5dfcab4c30cfe7512 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 11:45:04 +0800 Subject: [PATCH 05/11] update --- tests/megatron/test_export.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/megatron/test_export.py b/tests/megatron/test_export.py index b808a87f5..d7399d8a2 100644 --- a/tests/megatron/test_export.py +++ b/tests/megatron/test_export.py @@ -1,11 +1,13 @@ - import os + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + def hf2megatron(): from swift.llm import export_main, ExportArguments - export_main(ExportArguments(model='Qwen/Qwen2.5-7B-Instruct', to_megatron=True, - tensor_model_parallel_size=2, torch_dtype='bfloat16')) + export_main( + ExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', to_megatron=True, tensor_model_parallel_size=2, torch_dtype='bfloat16')) def megatron2hf(): From b9b85e5ced74cd6d9f2d1ccc30be2f120f71f355 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 11:58:03 +0800 Subject: [PATCH 06/11] update --- swift/llm/export/export.py | 2 +- swift/megatron/__init__.py | 1 + .../{convert/hf2megatron.py => convert.py} | 15 +++++++++++++++ swift/megatron/convert/__init__.py | 3 --- swift/megatron/convert/megatron2hf.py | 17 ----------------- swift/megatron/utils.py | 18 ++++++++---------- tests/megatron/test_export.py | 5 ++++- 7 files changed, 29 insertions(+), 32 deletions(-) rename swift/megatron/{convert/hf2megatron.py => convert.py} (65%) delete mode 100644 swift/megatron/convert/__init__.py delete mode 100644 swift/megatron/convert/megatron2hf.py diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index fdb50b307..42f5fcea8 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -26,7 +26,7 @@ def run(self): elif args.to_ollama: export_to_ollama(args) elif args.to_megatron: - from swift.megatron import convert_hf_to_megatron + from swift.megatron import convert_hf2megatron convert_hf_to_megatron(args) elif args.to_hf: from swift.megatron import convert_megatron_to_hf diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 9dde29f3e..418fba95b 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .convert import convert_hf2megatron, convert_megatron2hf from .utils import init_megatron_env init_megatron_env() diff --git a/swift/megatron/convert/hf2megatron.py b/swift/megatron/convert.py similarity index 65% rename from swift/megatron/convert/hf2megatron.py rename to swift/megatron/convert.py index d4ed29c48..ac8ce307b 100644 --- a/swift/megatron/convert/hf2megatron.py +++ b/swift/megatron/convert.py @@ -25,3 +25,18 @@ def convert_hf2megatron(args: ExportArguments) -> None: if save_torch_dtype is not None: mg_model.to(save_torch_dtype) convert_module.save_mgmodel(mg_model, args) + + +def convert_megatron2hf( + hf_model, + extra_args: Dict[str, Any], +) -> None: + from megatron.training.initialize import initialize_megatron + from megatron.training import get_args + initialize_megatron(args_defaults=extra_args) + args = get_args() + + model_provider, convert_module = get_megatron_model_convert(args.model_type) + convert_module.model_provider = model_provider + mg_model = convert_module.load_megatron_model(args) # no copy + convert_module.convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args) diff --git a/swift/megatron/convert/__init__.py b/swift/megatron/convert/__init__.py deleted file mode 100644 index c4175f177..000000000 --- a/swift/megatron/convert/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from .hf2megatron import convert_hf2megatron -from .megatron2hf import convert_megatron2hf diff --git a/swift/megatron/convert/megatron2hf.py b/swift/megatron/convert/megatron2hf.py deleted file mode 100644 index 02711e6d1..000000000 --- a/swift/megatron/convert/megatron2hf.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict - - -def convert_megatron2hf( - hf_model, - extra_args: Dict[str, Any], -) -> None: - from megatron.training.initialize import initialize_megatron - from megatron.training import get_args - initialize_megatron(args_defaults=extra_args) - args = get_args() - - model_provider, convert_module = get_megatron_model_convert(args.model_type) - convert_module.model_provider = model_provider - mg_model = convert_module.load_megatron_model(args) # no copy - convert_module.convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args) diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py index 8c4ff82a4..af2b7364d 100644 --- a/swift/megatron/utils.py +++ b/swift/megatron/utils.py @@ -31,21 +31,19 @@ def _rename_files(): def init_megatron_env() -> None: if 'MEGATRON_LM_PATH' not in os.environ: - megatron_path = git_clone_github('https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') - else: - megatron_path = os.environ['MEGATRON_LM_PATH'] + os.environ['MEGATRON_LM_PATH'] = git_clone_github( + 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') if not is_megatron_available(): - subprocess_run(['pip', 'install', '-e', megatron_path]) - sys.path.append(megatron_path) + subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) + sys.path.append(os.environ['MEGATRON_LM_PATH']) if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: - megatron_patch_path = git_clone_github('https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') - else: - megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] - sys.path.append(megatron_patch_path) + os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github( + 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') + sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH']) # rename qwen1.5/2.5->qwen1_5/2_5 files - with safe_ddp_context(): + with safe_ddp_context('rename_files'): _rename_files() diff --git a/tests/megatron/test_export.py b/tests/megatron/test_export.py index d7399d8a2..5492e4fe6 100644 --- a/tests/megatron/test_export.py +++ b/tests/megatron/test_export.py @@ -7,7 +7,10 @@ def hf2megatron(): from swift.llm import export_main, ExportArguments export_main( ExportArguments( - model='Qwen/Qwen2.5-7B-Instruct', to_megatron=True, tensor_model_parallel_size=2, torch_dtype='bfloat16')) + model='Qwen/Qwen2.5-7B-Instruct', + to_megatron=True, + target_tensor_model_parallel_size=2, + torch_dtype='bfloat16')) def megatron2hf(): From 65fcd636852e494c1000c4cb5e0c56e1264aa701 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 13:00:08 +0800 Subject: [PATCH 07/11] update --- swift/llm/export/export.py | 6 +++--- swift/megatron/convert.py | 6 ++++-- swift/megatron/model/__init__.py | 1 + swift/megatron/model/qwen.py | 4 ++++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 42f5fcea8..3ef4b0ff6 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -27,10 +27,10 @@ def run(self): export_to_ollama(args) elif args.to_megatron: from swift.megatron import convert_hf2megatron - convert_hf_to_megatron(args) + convert_hf2megatron(args) elif args.to_hf: - from swift.megatron import convert_megatron_to_hf - convert_megatron_to_hf(args) + from swift.megatron import convert_megatron2hf + convert_megatron2hf(args) elif args.push_to_hub: model_dir = args.adapters and args.adapters[0] or args.model_dir assert model_dir, f'model_dir: {model_dir}' diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index ac8ce307b..78469b5ce 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -1,9 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + import torch from swift.llm import ExportArguments, get_model_tokenizer -from ..model import get_megatron_model_meta +from .model import get_megatron_model_meta def convert_hf2megatron(args: ExportArguments) -> None: @@ -14,7 +16,7 @@ def convert_hf2megatron(args: ExportArguments) -> None: kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) megatron_model_meta = get_megatron_model_meta(args.model) - megatron_model_meta.get_model_provider() + model_provider = megatron_model_meta.get_model_provider() megatron_model_meta.load_config(hf_model.model_info) initialize_megatron(args_defaults=extra_args) diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index ca15b80b5..371955130 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model +from . import qwen diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index c4ff24f30..71fed3065 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -21,6 +21,10 @@ def convert_hf2megatron(): pass +def get_qwen_model_provider(): + pass + + register_megatron_model( MegatronModelMeta(MegatronModelType.qwen, [ ModelGroup([ From bd7547caae761d2c907083a006baf8c2175a2eae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 13:35:13 +0800 Subject: [PATCH 08/11] update --- swift/megatron/convert.py | 7 ++----- swift/megatron/model/__init__.py | 2 +- swift/megatron/model/qwen.py | 18 ++++++++++++++---- swift/megatron/model/utils.py | 6 ++++-- swift/megatron/utils.py | 14 ++++---------- 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 78469b5ce..767484e6f 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -3,15 +3,14 @@ from typing import Any, Dict import torch +from megatron.training import get_args +from megatron.training.initialize import initialize_megatron from swift.llm import ExportArguments, get_model_tokenizer from .model import get_megatron_model_meta def convert_hf2megatron(args: ExportArguments) -> None: - - from megatron.training.initialize import initialize_megatron - from megatron.training import get_args kwargs = args.get_model_kwargs() kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) @@ -33,8 +32,6 @@ def convert_megatron2hf( hf_model, extra_args: Dict[str, Any], ) -> None: - from megatron.training.initialize import initialize_megatron - from megatron.training import get_args initialize_megatron(args_defaults=extra_args) args = get_args() diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 371955130..6fda16d80 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model from . import qwen +from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index 71fed3065..0df3277c2 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -1,5 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import importlib + +from megatron.training import get_args + from swift.llm import Model, ModelGroup, ModelInfo from .config import load_config from .constant import MegatronModelType @@ -13,16 +17,22 @@ def load_qwen_config(model_info: ModelInfo): return args_config -def convert_megatron2hf(): +def convert_megatron2hf(hf_model, mg_model): pass -def convert_hf2megatron(): - pass +def convert_hf2megatron(hf_model, mg_model): + args = get_args() def get_qwen_model_provider(): - pass + module_prefix = 'megatron_patch.model.qwen2' + gpt_model_cls = importlib.import_module(f'{module_prefix}.model').GPTModel + transformer_config_cls = getattr( + importlib.import_module(f'{module_prefix}.transformer_config'), 'Qwen2TransformerConfig') + layer_spec_module = importlib.import_module(f'{module_prefix}.layer_specs') + model_provider = get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module) + return model_provider register_megatron_model( diff --git a/swift/megatron/model/utils.py b/swift/megatron/model/utils.py index 6232ecb3f..7dd4d4a4a 100644 --- a/swift/megatron/model/utils.py +++ b/swift/megatron/model/utils.py @@ -1,11 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from megatron.training import get_args +from megatron.training.arguments import core_transformer_config_from_args + def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module): def model_provider(pre_process=True, post_process=True): - from megatron.training import get_args - from megatron.training.arguments import core_transformer_config_from_args + args = get_args() config = core_transformer_config_from_args(args, transformer_config_cls) transformer_layer_spec = layer_spec_module.get_gpt_layer_with_transformer_engine_spec( diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py index af2b7364d..0d99f1951 100644 --- a/swift/megatron/utils.py +++ b/swift/megatron/utils.py @@ -7,6 +7,10 @@ import torch import torch.distributed as dist +from megatron.core import mpu +from megatron.training import get_args, global_vars, initialize, training +from megatron.training.utils import (average_losses_across_data_parallel_group, get_batch_on_this_cp_rank, + get_ltor_masks_and_position_ids) from swift.llm import LazyLLMDataset, Template, git_clone_github from swift.utils import (append_to_jsonl, get_dist_setting, get_logger, is_master, is_megatron_available, @@ -53,7 +57,6 @@ def build_tokenizer(args): args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size return tokenizer - from megatron.training import get_args, training, initialize, global_vars global_vars.build_tokenizer = build_tokenizer _old_initialize_distributed = initialize._initialize_distributed @@ -114,9 +117,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): loss_mask (torch.Tensor): Used to mask out some portions of the loss output_tensor (torch.Tensor): The tensor with the losses """ - from megatron.training import get_args - from megatron.core import mpu - from megatron.training.utils import average_losses_across_data_parallel_group args = get_args() losses = output_tensor.float() @@ -142,8 +142,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): def get_batch_on_this_tp_rank(data_iterator): # copy from Megatron-LM and made some changes. - from megatron.training import get_args - from megatron.core import mpu args = get_args() def _broadcast(item): @@ -241,7 +239,6 @@ def _broadcast(item): def forward_step(data_iterator, model): - from megatron.training.utils import get_batch_on_this_cp_rank batch = get_batch_on_this_tp_rank(data_iterator) batch = get_batch_on_this_cp_rank(batch) tokens, labels, loss_mask, attention_mask, position_ids = batch.values() @@ -252,9 +249,6 @@ def forward_step(data_iterator, model): def train_valid_test_datasets_provider(train_val_test_num_samples, train_dataset: LazyLLMDataset, val_dataset: LazyLLMDataset, template: Template): # train_val_test_num_samples: ignored - from megatron.training import training - from megatron.training.utils import get_ltor_masks_and_position_ids - assert not hasattr(training, '_old_build_pretraining_data_loader') _old_build_pretraining_data_loader = training.build_pretraining_data_loader From 83dc3346c09e279a104a07eb035a8d11f15e680c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 14:07:47 +0800 Subject: [PATCH 09/11] update --- swift/megatron/convert.py | 15 ++++++--------- swift/megatron/model/qwen.py | 5 ++++- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 767484e6f..a33dc0600 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -7,6 +7,7 @@ from megatron.training.initialize import initialize_megatron from swift.llm import ExportArguments, get_model_tokenizer +from .argument import MegatronArguments from .model import get_megatron_model_meta @@ -15,17 +16,13 @@ def convert_hf2megatron(args: ExportArguments) -> None: kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) megatron_model_meta = get_megatron_model_meta(args.model) - model_provider = megatron_model_meta.get_model_provider() - megatron_model_meta.load_config(hf_model.model_info) + mg_model = megatron_model_meta.get_model_provider()() + kwargs = megatron_model_meta.load_config(hf_model.model_info) + megatron_args = MegatronArguments(kwargs) + extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) - args = get_args() - model_provider, convert_module = get_megatron_model_convert(args.model_type) - mg_model = model_provider() - convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args) - if save_torch_dtype is not None: - mg_model.to(save_torch_dtype) - convert_module.save_mgmodel(mg_model, args) + megatron_model_meta.convert_hf2megatron(hf_model, mg_model) def convert_megatron2hf( diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index 0df3277c2..d96a70945 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import importlib +from typing import Any, Dict from megatron.training import get_args @@ -11,7 +12,7 @@ from .utils import get_model_provider -def load_qwen_config(model_info: ModelInfo): +def load_qwen_config(model_info: ModelInfo) -> Dict[str, Any]: args_config = load_config(model_info) args_config['swiglu'] = True return args_config @@ -23,6 +24,8 @@ def convert_megatron2hf(hf_model, mg_model): def convert_hf2megatron(hf_model, mg_model): args = get_args() + mg_model.to(args.torch_dtype) + convert_module.save_mgmodel(mg_model, args) def get_qwen_model_provider(): From 9a8c458c5e7c57d272cf8f7f88f6e33c23fabb80 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 14:17:12 +0800 Subject: [PATCH 10/11] update --- swift/megatron/__init__.py | 10 +++++++--- swift/megatron/init.py | 38 ++++++++++++++++++++++++++++++++++++++ swift/megatron/utils.py | 32 -------------------------------- 3 files changed, 45 insertions(+), 35 deletions(-) create mode 100644 swift/megatron/init.py diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 418fba95b..0a75b056d 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .convert import convert_hf2megatron, convert_megatron2hf -from .utils import init_megatron_env +try: + from .init import init_megatron_env + init_megatron_env() +except Exception: + # allows lint pass. + raise -init_megatron_env() +from .convert import convert_hf2megatron, convert_megatron2hf diff --git a/swift/megatron/init.py b/swift/megatron/init.py new file mode 100644 index 000000000..4291ce62a --- /dev/null +++ b/swift/megatron/init.py @@ -0,0 +1,38 @@ +import os +import shutil +import sys + +from swift.llm import git_clone_github +from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run + + +def _rename_files(): + megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] + qwen_folders = ['toolkits/model_checkpoints_convertor/qwen'] + for folder in qwen_folders: + dir_path = os.path.join(megatron_patch_path, folder) + for fname in os.listdir(dir_path): + old_path = os.path.join(dir_path, fname) + fname = fname.replace('qwen1.', 'qwen1_') + fname = fname.replace('qwen2.', 'qwen2_') + new_path = os.path.join(dir_path, fname) + if old_path != new_path and os.path.exists(old_path): + shutil.move(old_path, new_path) + + +def init_megatron_env() -> None: + if 'MEGATRON_LM_PATH' not in os.environ: + os.environ['MEGATRON_LM_PATH'] = git_clone_github( + 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') + if not is_megatron_available(): + subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) + sys.path.append(os.environ['MEGATRON_LM_PATH']) + + if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: + os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github( + 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') + sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH']) + + # rename qwen1.5/2.5->qwen1_5/2_5 files + with safe_ddp_context('rename_files'): + _rename_files() diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py index 0d99f1951..6568974a4 100644 --- a/swift/megatron/utils.py +++ b/swift/megatron/utils.py @@ -19,38 +19,6 @@ logger = get_logger() -def _rename_files(): - megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] - qwen_folders = ['toolkits/model_checkpoints_convertor/qwen'] - for folder in qwen_folders: - dir_path = os.path.join(megatron_patch_path, folder) - for fname in os.listdir(dir_path): - old_path = os.path.join(dir_path, fname) - fname = fname.replace('qwen1.', 'qwen1_') - fname = fname.replace('qwen2.', 'qwen2_') - new_path = os.path.join(dir_path, fname) - if old_path != new_path and os.path.exists(old_path): - shutil.move(old_path, new_path) - - -def init_megatron_env() -> None: - if 'MEGATRON_LM_PATH' not in os.environ: - os.environ['MEGATRON_LM_PATH'] = git_clone_github( - 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') - if not is_megatron_available(): - subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) - sys.path.append(os.environ['MEGATRON_LM_PATH']) - - if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: - os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github( - 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') - sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH']) - - # rename qwen1.5/2.5->qwen1_5/2_5 files - with safe_ddp_context('rename_files'): - _rename_files() - - def patch_megatron(tokenizer): def build_tokenizer(args): From 836fbcfd6d24fd93769db62801001c9a055f70df Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Jan 2025 16:44:54 +0800 Subject: [PATCH 11/11] update --- swift/megatron/argument.py | 11 ++++++++++- swift/megatron/convert.py | 8 +++++--- swift/megatron/model/qwen.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/swift/megatron/argument.py b/swift/megatron/argument.py index 891cd3fc3..d98248e76 100644 --- a/swift/megatron/argument.py +++ b/swift/megatron/argument.py @@ -2,7 +2,7 @@ import sys from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Literal, Optional, Tuple - +import inspect @dataclass class ExtraMegatronArguments: @@ -107,6 +107,15 @@ def __post_init__(self): if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None: self.lr_decay_iters = self.train_iters - self.lr_warmup_iters + def get_matched_kwargs(args): + args_dict = asdict(args) + parameters = inspect.signature(MegatronArguments.__init__).parameters + + for k in list(args_dict.keys()): + if k not in parameters: + args_dict.pop(k) + return args_dict + def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: new_args = [] args_dict = asdict(self) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index a33dc0600..a48b5b7d3 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -16,12 +16,14 @@ def convert_hf2megatron(args: ExportArguments) -> None: kwargs['torch_dtype'] = torch.float32 hf_model, processor = get_model_tokenizer(**kwargs) megatron_model_meta = get_megatron_model_meta(args.model) - mg_model = megatron_model_meta.get_model_provider()() kwargs = megatron_model_meta.load_config(hf_model.model_info) - megatron_args = MegatronArguments(kwargs) + megatron_args = MegatronArguments(**kwargs, **MegatronArguments.get_matched_kwargs(args)) extra_args = megatron_args.parse_to_megatron() - initialize_megatron(args_defaults=extra_args) + + mg_model = megatron_model_meta.get_model_provider()() + + megatron_model_meta.convert_hf2megatron(hf_model, mg_model) diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py index d96a70945..c78a7d034 100644 --- a/swift/megatron/model/qwen.py +++ b/swift/megatron/model/qwen.py @@ -49,4 +49,4 @@ def get_qwen_model_provider(): Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'), Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'), ]), - ], convert_megatron2hf, convert_hf2megatron, get_model_provider, load_qwen_config)) + ], convert_megatron2hf, convert_hf2megatron, get_qwen_model_provider, load_qwen_config))