diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index cc36b6450..9862c7842 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -4,6 +4,7 @@ from typing import Literal, Optional import torch +import torch.distributed as dist from swift.utils import get_logger from .base_args import BaseArguments, to_abspath @@ -44,6 +45,12 @@ class ExportArguments(MergeArguments, BaseArguments): to_ollama: bool = False gguf_file: Optional[str] = None + # megatron + to_megatron: bool = False + to_hf: bool = False + target_tensor_model_parallel_size: int = 1 + target_pipeline_model_parallel_size: int = 1 + # push to ms hub push_to_hub: bool = False # 'user_name/repo_name' or 'repo_name' @@ -65,6 +72,10 @@ def _init_output_dir(self): suffix = f'{self.quant_method}-int{self.quant_bits}' elif self.to_ollama: suffix = 'ollama' + elif self.to_megatron: + suffix = f'tp{self.target_tensor_model_parallel_size}-pp{self.target_pipeline_model_parallel_size}' + elif self.to_hf: + suffix = 'hf' else: return @@ -81,6 +92,14 @@ def __post_init__(self): raise ValueError('Please specify `--quant_bits`.') if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None: self.torch_dtype = torch.float16 + if self.to_megatron or self.to_hf: + os.environ['RANK'] = '0' + os.environ['LOCAL_RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + dist.init_process_group(backend='nccl') BaseArguments.__post_init__(self) self._init_output_dir() diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 1adb901de..3ef4b0ff6 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -25,6 +25,12 @@ def run(self): quantize_model(args) elif args.to_ollama: export_to_ollama(args) + elif args.to_megatron: + from swift.megatron import convert_hf2megatron + convert_hf2megatron(args) + elif args.to_hf: + from swift.megatron import convert_megatron2hf + convert_megatron2hf(args) elif args.push_to_hub: model_dir = args.adapters and args.adapters[0] or args.model_dir assert model_dir, f'model_dir: {model_dir}' diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 5ad49dc76..1ae2dbb2d 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -307,10 +307,10 @@ def get_all_models() -> List[str]: return models -def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: +def _get_matched_model_meta(model_id_or_path, model_mapping) -> Optional[ModelMeta]: model_name = get_model_name(model_id_or_path).lower() - for model_type, model_meta in MODEL_MAPPING.items(): - model_group = model_meta.get_matched_model_group(model_name) + for model_type, model_meta in model_mapping.items(): + model_group = ModelMeta.get_matched_model_group(model_meta, model_name) if model_group is not None: model_meta = deepcopy(model_meta) for k, v in asdict(model_group).items(): @@ -319,6 +319,10 @@ def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: return model_meta +def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: + return _get_matched_model_meta(model_id_or_path, MODEL_MAPPING) + + def _get_model_info(model_dir: str, model_type: Optional[str], quantization_config) -> ModelInfo: config_dict = PretrainedConfig.get_config_dict(model_dir)[0] if quantization_config is not None: diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py new file mode 100644 index 000000000..0a75b056d --- /dev/null +++ b/swift/megatron/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +try: + from .init import init_megatron_env + init_megatron_env() +except Exception: + # allows lint pass. + raise + +from .convert import convert_hf2megatron, convert_megatron2hf diff --git a/swift/megatron/argument.py b/swift/megatron/argument.py new file mode 100644 index 000000000..d98248e76 --- /dev/null +++ b/swift/megatron/argument.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sys +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Literal, Optional, Tuple +import inspect + +@dataclass +class ExtraMegatronArguments: + padded_vocab_size: Optional[int] = None + + target_tensor_model_parallel_size: int = 1 + target_pipeline_model_parallel_size: int = 1 + + +@dataclass +class MegatronMixin: + num_layers: Optional[int] = None # + hidden_size: Optional[int] = None # + ffn_hidden_size: Optional[int] = None # + num_attention_heads: Optional[int] = None # + num_query_groups: Optional[int] = None # + group_query_attention: Optional[bool] = None # + max_position_embeddings: Optional[int] = None # + norm_epsilon: Optional[float] = None # + swiglu: Optional[bool] = None # + rotary_base: Optional[int] = None # + disable_bias_linear: bool = True # + add_qkv_bias: bool = True # + + train_iters: Optional[int] = None # + lr_warmup_iters: Optional[int] = None # + eval_iters: Optional[int] = None # + lr_decay_iters: Optional[int] = None # + save: Optional[str] = None # + load: Optional[str] = None + tensorboard_dir: Optional[str] = None # ! + log_interval: int = 10 # + log_throughput: bool = False + eval_interval: Optional[int] = None # + save_interval: int = 500 # + + position_embedding_type: str = 'rope' # + rotary_percent: float = 1. # + rotary_seq_len_interpolation_factor: int = 1 # + no_bias_swiglu_fusion: bool = False # + attention_dropout: float = 0. # + hidden_dropout: float = 0. # + + optimizer: str = 'adam' + weight_decay: float = 0.1 # + clip_grad: float = 1. # + adam_beta1: float = 0.9 # + adam_beta2: float = 0.95 # + adam_eps: float = 1e-8 + init_method_std: float = 0.01 # + micro_batch_size: int = 1 # + global_batch_size: int = 16 # + recompute_method: Optional[str] = None + recompute_granularity: Optional[str] = 'selective' + no_rope_fusion: bool = False + use_flash_attn: bool = False + use_cpu_initialization: Optional[bool] = None + + dataloader_type: str = 'cyclic' + lr: float = 1e-5 # + lr_decay_style: str = 'cosine' # + min_lr: int = 1e-6 + fp16: bool = False + bf16: bool = False + tensor_model_parallel_size: int = 1 # + pipeline_model_parallel_size: int = 1 # + context_parallel_size: int = 1 # + seed: int = 42 + sequence_parallel: bool = False + transformer_impl: str = 'transformer_engine' + + apply_query_key_layer_scaling: bool = False # fp16 + num_workers: int = 8 + + log_timers_to_tensorboard: bool = True # + log_batch_size_to_tensorboard: bool = True # + log_validation_ppl_to_tensorboard: bool = True # + log_memory_to_tensorboard: bool = True # + tensorboard_log_interval: int = 1 # + tensorboard_queue_size: int = 10 # + untie_embeddings_and_output_weights: bool = True + seq_length: Optional[int] = None # + + no_save_optim: bool = False # + no_save_rng: bool = False # + no_load_optim: bool = False # + no_load_rng: bool = False # + loss_scale: Optional[float] = None + use_distributed_optimizer: bool = True + normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' # + calculate_per_token_loss: bool = True + + +@dataclass +class MegatronArguments(ExtraMegatronArguments, MegatronMixin): + + def __post_init__(self): + if self.group_query_attention is None: + self.group_query_attention = True if self.num_query_groups > 1 else False + if self.eval_interval is None: + self.eval_interval = self.save_interval + if self.lr_decay_iters is None and self.train_iters is not None and self.lr_warmup_iters is not None: + self.lr_decay_iters = self.train_iters - self.lr_warmup_iters + + def get_matched_kwargs(args): + args_dict = asdict(args) + parameters = inspect.signature(MegatronArguments.__init__).parameters + + for k in list(args_dict.keys()): + if k not in parameters: + args_dict.pop(k) + return args_dict + + def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]: + new_args = [] + args_dict = asdict(self) + extra_args = {} + for k, value in args_dict.items(): + if k in ExtraMegatronArguments.__annotations__: + extra_args[k] = value + continue + if value is None or value is False: + continue + new_args.append(f"--{k.replace('_', '-')}") + if isinstance(value, list): + new_args += [str(v) for v in value] + elif value is not True: + new_args.append(str(value)) + + return new_args, extra_args + + def parse_to_megatron(self): + new_args, extra_args = self._args_to_argv() + sys._old_argv = sys.argv + sys.argv = sys.argv[:1] + new_args + + return extra_args diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py new file mode 100644 index 000000000..a48b5b7d3 --- /dev/null +++ b/swift/megatron/convert.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch +from megatron.training import get_args +from megatron.training.initialize import initialize_megatron + +from swift.llm import ExportArguments, get_model_tokenizer +from .argument import MegatronArguments +from .model import get_megatron_model_meta + + +def convert_hf2megatron(args: ExportArguments) -> None: + kwargs = args.get_model_kwargs() + kwargs['torch_dtype'] = torch.float32 + hf_model, processor = get_model_tokenizer(**kwargs) + megatron_model_meta = get_megatron_model_meta(args.model) + kwargs = megatron_model_meta.load_config(hf_model.model_info) + megatron_args = MegatronArguments(**kwargs, **MegatronArguments.get_matched_kwargs(args)) + extra_args = megatron_args.parse_to_megatron() + initialize_megatron(args_defaults=extra_args) + + mg_model = megatron_model_meta.get_model_provider()() + + + megatron_model_meta.convert_hf2megatron(hf_model, mg_model) + + +def convert_megatron2hf( + hf_model, + extra_args: Dict[str, Any], +) -> None: + initialize_megatron(args_defaults=extra_args) + args = get_args() + + model_provider, convert_module = get_megatron_model_convert(args.model_type) + convert_module.model_provider = model_provider + mg_model = convert_module.load_megatron_model(args) # no copy + convert_module.convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args) diff --git a/swift/megatron/init.py b/swift/megatron/init.py new file mode 100644 index 000000000..4291ce62a --- /dev/null +++ b/swift/megatron/init.py @@ -0,0 +1,38 @@ +import os +import shutil +import sys + +from swift.llm import git_clone_github +from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run + + +def _rename_files(): + megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH'] + qwen_folders = ['toolkits/model_checkpoints_convertor/qwen'] + for folder in qwen_folders: + dir_path = os.path.join(megatron_patch_path, folder) + for fname in os.listdir(dir_path): + old_path = os.path.join(dir_path, fname) + fname = fname.replace('qwen1.', 'qwen1_') + fname = fname.replace('qwen2.', 'qwen2_') + new_path = os.path.join(dir_path, fname) + if old_path != new_path and os.path.exists(old_path): + shutil.move(old_path, new_path) + + +def init_megatron_env() -> None: + if 'MEGATRON_LM_PATH' not in os.environ: + os.environ['MEGATRON_LM_PATH'] = git_clone_github( + 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0') + if not is_megatron_available(): + subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) + sys.path.append(os.environ['MEGATRON_LM_PATH']) + + if 'PAI_MEGATRON_PATCH_PATH' not in os.environ: + os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github( + 'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1') + sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH']) + + # rename qwen1.5/2.5->qwen1_5/2_5 files + with safe_ddp_context('rename_files'): + _rename_files() diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py new file mode 100644 index 000000000..6fda16d80 --- /dev/null +++ b/swift/megatron/model/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import qwen +from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py new file mode 100644 index 000000000..83ae73153 --- /dev/null +++ b/swift/megatron/model/config.py @@ -0,0 +1,29 @@ +from typing import Any, Dict + +from swift.llm import ModelInfo + +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'] +} + + +def load_config(model_info: ModelInfo) -> Dict[str, Any]: + model_config = model_info.config + megatron_config = {} + for k, value in config_mapping.items(): + for v in value: + assert hasattr(model_config, v) + if k == 'rotary_base': + megatron_config[k] = int(getattr(model_config, v)) + else: + megatron_config[k] = getattr(model_config, v) + return megatron_config diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py new file mode 100644 index 000000000..929eae691 --- /dev/null +++ b/swift/megatron/model/constant.py @@ -0,0 +1,2 @@ +class MegatronModelType: + qwen = 'qwen' diff --git a/swift/megatron/model/qwen.py b/swift/megatron/model/qwen.py new file mode 100644 index 000000000..c78a7d034 --- /dev/null +++ b/swift/megatron/model/qwen.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import importlib +from typing import Any, Dict + +from megatron.training import get_args + +from swift.llm import Model, ModelGroup, ModelInfo +from .config import load_config +from .constant import MegatronModelType +from .register import MegatronModelMeta, register_megatron_model +from .utils import get_model_provider + + +def load_qwen_config(model_info: ModelInfo) -> Dict[str, Any]: + args_config = load_config(model_info) + args_config['swiglu'] = True + return args_config + + +def convert_megatron2hf(hf_model, mg_model): + pass + + +def convert_hf2megatron(hf_model, mg_model): + args = get_args() + mg_model.to(args.torch_dtype) + convert_module.save_mgmodel(mg_model, args) + + +def get_qwen_model_provider(): + module_prefix = 'megatron_patch.model.qwen2' + gpt_model_cls = importlib.import_module(f'{module_prefix}.model').GPTModel + transformer_config_cls = getattr( + importlib.import_module(f'{module_prefix}.transformer_config'), 'Qwen2TransformerConfig') + layer_spec_module = importlib.import_module(f'{module_prefix}.layer_specs') + model_provider = get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module) + return model_provider + + +register_megatron_model( + MegatronModelMeta(MegatronModelType.qwen, [ + ModelGroup([ + Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'), + Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'), + Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'), + Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'), + Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'), + Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'), + Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'), + ]), + ], convert_megatron2hf, convert_hf2megatron, get_qwen_model_provider, load_qwen_config)) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py new file mode 100644 index 000000000..a98ad3fa8 --- /dev/null +++ b/swift/megatron/model/register.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass +from typing import Callable, List, Optional + +from swift.llm import ModelGroup +from swift.llm.model.register import _get_matched_model_meta + +MEGATRON_MODEL_MAPPING = {} + + +@dataclass +class MegatronModelMeta: + megatron_model_type: Optional[str] + model_groups: List[ModelGroup] + + convert_megatron2hf: Callable + convert_hf2megatron: Callable + get_model_provider: Callable + load_config: Callable + + +def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = False): + megatron_model_type = model_meta.megatron_model_type + if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: + raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') + + MEGATRON_MODEL_MAPPING[megatron_model_type] = model_meta + + +def get_megatron_model_meta(model_id_or_path: str) -> Optional[MegatronModelMeta]: + return _get_matched_model_meta(model_id_or_path, MEGATRON_MODEL_MAPPING) diff --git a/swift/megatron/model/utils.py b/swift/megatron/model/utils.py new file mode 100644 index 000000000..7dd4d4a4a --- /dev/null +++ b/swift/megatron/model/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from megatron.training import get_args +from megatron.training.arguments import core_transformer_config_from_args + + +def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module): + + def model_provider(pre_process=True, post_process=True): + + args = get_args() + config = core_transformer_config_from_args(args, transformer_config_cls) + transformer_layer_spec = layer_spec_module.get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + model = gpt_model_cls( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor) + return model + + return model_provider diff --git a/swift/megatron/utils.py b/swift/megatron/utils.py new file mode 100644 index 000000000..6568974a4 --- /dev/null +++ b/swift/megatron/utils.py @@ -0,0 +1,247 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import sys +from functools import partial, wraps +from typing import Any, Dict, List, Mapping, Optional + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.training import get_args, global_vars, initialize, training +from megatron.training.utils import (average_losses_across_data_parallel_group, get_batch_on_this_cp_rank, + get_ltor_masks_and_position_ids) + +from swift.llm import LazyLLMDataset, Template, git_clone_github +from swift.utils import (append_to_jsonl, get_dist_setting, get_logger, is_master, is_megatron_available, + safe_ddp_context, subprocess_run) + +logger = get_logger() + + +def patch_megatron(tokenizer): + + def build_tokenizer(args): + args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size + return tokenizer + + global_vars.build_tokenizer = build_tokenizer + + _old_initialize_distributed = initialize._initialize_distributed + + @wraps(_old_initialize_distributed) + def _initialize_distributed(*_args, **kwargs): + args = get_args() + if dist.is_initialized(): + args.rank, args.local_rank, args.world_size, args.local_world_size = get_dist_setting() + torch.cuda.set_device(args.local_rank) + return _old_initialize_distributed(*_args, **kwargs) + + initialize._initialize_distributed = _initialize_distributed + + _old_load_state_dict = torch.nn.Module.load_state_dict + + @wraps(_old_load_state_dict) + def _load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, *args, **kwargs): + if strict: + keys = self.state_dict().keys() ^ state_dict.keys() + new_keys = [k for k in keys if not k.endswith('_extra_state')] + if keys and not new_keys: + strict = False + return _old_load_state_dict(self, state_dict, strict, *args, **kwargs) + + torch.nn.Module.load_state_dict = _load_state_dict + + _old_training_log = training.training_log + + @wraps(_old_training_log) + def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, + report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs): + args = get_args() + if is_master() and iteration % args.log_interval == 0: + logging_path = os.path.join(args.save, 'logging.jsonl') + logs = {} + for k, v in loss_dict.items(): + if isinstance(v, torch.Tensor): + v = v.item() + logs[k] = round(v, 8) + logs['grad_norm'] = round(grad_norm, 8) + logs['learning_rate'] = round(learning_rate, 8) + logs['consumed_samples'] = args.consumed_train_samples + logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' + + append_to_jsonl(logging_path, logs) + return _old_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, + num_zeros_in_grad, *_args, **kwargs) + + training.training_log = training_log + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. copy from Pai-Megatron-Patch + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + if args.context_parallel_size > 1: + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) + dist.all_reduce(loss, group=mpu.get_context_parallel_group()) + loss = loss[0] / loss[1] + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = dist.get_rank() + assert not loss.isnan(), (f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}') + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss * args.context_parallel_size, {'loss': averaged_loss[0]} + + +def get_batch_on_this_tp_rank(data_iterator): + # copy from Megatron-LM and made some changes. + args = get_args() + + def _broadcast(item): + if item is not None: + dist.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + + if mpu.get_tensor_model_parallel_rank() == 0: + + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + args.seq_length = data['tokens'].shape[1] + _broadcast(torch.tensor(args.seq_length).cuda(non_blocking=True)) + batch = { + 'tokens': data['tokens'].cuda(non_blocking=True), + 'labels': data['labels'].cuda(non_blocking=True), + 'loss_mask': data['loss_mask'].cuda(non_blocking=True), + 'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), + 'position_ids': data['position_ids'].cuda(non_blocking=True) + } + + if args.pipeline_model_parallel_size == 1: + _broadcast(batch['tokens']) + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_first_stage(): + _broadcast(batch['tokens']) + _broadcast(batch['attention_mask']) + _broadcast(batch['position_ids']) + + elif mpu.is_pipeline_last_stage(): + _broadcast(batch['labels']) + _broadcast(batch['loss_mask']) + _broadcast(batch['attention_mask']) + + else: + seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + _broadcast(seq_length) + args.seq_length = seq_length.item() + tokens = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + labels = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + loss_mask = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.float32, + device=torch.cuda.current_device()) + if args.create_attention_mask_in_dataloader: + attention_mask = torch.empty((args.micro_batch_size, 1, args.seq_length, args.seq_length), + dtype=torch.bool, + device=torch.cuda.current_device()) + else: + attention_mask = None + position_ids = torch.empty((args.micro_batch_size, args.seq_length), + dtype=torch.int64, + device=torch.cuda.current_device()) + + if args.pipeline_model_parallel_size == 1: + _broadcast(tokens) + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_first_stage(): + labels = None + loss_mask = None + + _broadcast(tokens) + _broadcast(attention_mask) + _broadcast(position_ids) + + elif mpu.is_pipeline_last_stage(): + tokens = None + position_ids = None + + _broadcast(labels) + _broadcast(loss_mask) + _broadcast(attention_mask) + + batch = { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'attention_mask': attention_mask, + 'position_ids': position_ids + } + + return batch + + +def forward_step(data_iterator, model): + batch = get_batch_on_this_tp_rank(data_iterator) + batch = get_batch_on_this_cp_rank(batch) + tokens, labels, loss_mask, attention_mask, position_ids = batch.values() + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples, train_dataset: LazyLLMDataset, + val_dataset: LazyLLMDataset, template: Template): + # train_val_test_num_samples: ignored + assert not hasattr(training, '_old_build_pretraining_data_loader') + _old_build_pretraining_data_loader = training.build_pretraining_data_loader + + def data_collator(batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]: + res = template.data_collator(batch, padding_to) + labels = res['labels'] + new_labels = torch.zeros_like(labels) + new_labels[:, :-1] = labels[:, 1:] + new_labels[:, -1] = -100 + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(new_labels, -100, False, False, True) + return { + 'tokens': res['input_ids'], + 'labels': new_labels, + 'attention_mask': attention_mask, + 'loss_mask': loss_mask, + 'position_ids': position_ids + } + + @wraps(_old_build_pretraining_data_loader) + def build_pretraining_data_loader(*args, **kwargs): + res = _old_build_pretraining_data_loader(*args, **kwargs) + if res is not None: + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + training._old_build_pretraining_data_loader = _old_build_pretraining_data_loader + return train_dataset, val_dataset, None diff --git a/tests/megatron/test_export.py b/tests/megatron/test_export.py new file mode 100644 index 000000000..5492e4fe6 --- /dev/null +++ b/tests/megatron/test_export.py @@ -0,0 +1,21 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def hf2megatron(): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + to_megatron=True, + target_tensor_model_parallel_size=2, + torch_dtype='bfloat16')) + + +def megatron2hf(): + from swift.llm import export_main, ExportArguments + + +if __name__ == '__main__': + hf2megatron()