Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jan 9, 2025
1 parent 83dc334 commit 9a8c458
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 35 deletions.
10 changes: 7 additions & 3 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .convert import convert_hf2megatron, convert_megatron2hf
from .utils import init_megatron_env
try:
from .init import init_megatron_env
init_megatron_env()
except Exception:
# allows lint pass.
raise

init_megatron_env()
from .convert import convert_hf2megatron, convert_megatron2hf
38 changes: 38 additions & 0 deletions swift/megatron/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import shutil
import sys

from swift.llm import git_clone_github
from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run


def _rename_files():
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
qwen_folders = ['toolkits/model_checkpoints_convertor/qwen']
for folder in qwen_folders:
dir_path = os.path.join(megatron_patch_path, folder)
for fname in os.listdir(dir_path):
old_path = os.path.join(dir_path, fname)
fname = fname.replace('qwen1.', 'qwen1_')
fname = fname.replace('qwen2.', 'qwen2_')
new_path = os.path.join(dir_path, fname)
if old_path != new_path and os.path.exists(old_path):
shutil.move(old_path, new_path)


def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
if not is_megatron_available():
subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
sys.path.append(os.environ['MEGATRON_LM_PATH'])

if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github(
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])

# rename qwen1.5/2.5->qwen1_5/2_5 files
with safe_ddp_context('rename_files'):
_rename_files()
32 changes: 0 additions & 32 deletions swift/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,6 @@
logger = get_logger()


def _rename_files():
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
qwen_folders = ['toolkits/model_checkpoints_convertor/qwen']
for folder in qwen_folders:
dir_path = os.path.join(megatron_patch_path, folder)
for fname in os.listdir(dir_path):
old_path = os.path.join(dir_path, fname)
fname = fname.replace('qwen1.', 'qwen1_')
fname = fname.replace('qwen2.', 'qwen2_')
new_path = os.path.join(dir_path, fname)
if old_path != new_path and os.path.exists(old_path):
shutil.move(old_path, new_path)


def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
if not is_megatron_available():
subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
sys.path.append(os.environ['MEGATRON_LM_PATH'])

if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github(
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])

# rename qwen1.5/2.5->qwen1_5/2_5 files
with safe_ddp_context('rename_files'):
_rename_files()


def patch_megatron(tokenizer):

def build_tokenizer(args):
Expand Down

0 comments on commit 9a8c458

Please sign in to comment.