Skip to content

Commit 9a8c458

Browse files
committed
update
1 parent 83dc334 commit 9a8c458

File tree

3 files changed

+45
-35
lines changed

3 files changed

+45
-35
lines changed

swift/megatron/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

3-
from .convert import convert_hf2megatron, convert_megatron2hf
4-
from .utils import init_megatron_env
3+
try:
4+
from .init import init_megatron_env
5+
init_megatron_env()
6+
except Exception:
7+
# allows lint pass.
8+
raise
59

6-
init_megatron_env()
10+
from .convert import convert_hf2megatron, convert_megatron2hf

swift/megatron/init.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import shutil
3+
import sys
4+
5+
from swift.llm import git_clone_github
6+
from swift.utils import is_megatron_available, safe_ddp_context, subprocess_run
7+
8+
9+
def _rename_files():
10+
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
11+
qwen_folders = ['toolkits/model_checkpoints_convertor/qwen']
12+
for folder in qwen_folders:
13+
dir_path = os.path.join(megatron_patch_path, folder)
14+
for fname in os.listdir(dir_path):
15+
old_path = os.path.join(dir_path, fname)
16+
fname = fname.replace('qwen1.', 'qwen1_')
17+
fname = fname.replace('qwen2.', 'qwen2_')
18+
new_path = os.path.join(dir_path, fname)
19+
if old_path != new_path and os.path.exists(old_path):
20+
shutil.move(old_path, new_path)
21+
22+
23+
def init_megatron_env() -> None:
24+
if 'MEGATRON_LM_PATH' not in os.environ:
25+
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
26+
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
27+
if not is_megatron_available():
28+
subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
29+
sys.path.append(os.environ['MEGATRON_LM_PATH'])
30+
31+
if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
32+
os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github(
33+
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
34+
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])
35+
36+
# rename qwen1.5/2.5->qwen1_5/2_5 files
37+
with safe_ddp_context('rename_files'):
38+
_rename_files()

swift/megatron/utils.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,6 @@
1919
logger = get_logger()
2020

2121

22-
def _rename_files():
23-
megatron_patch_path = os.environ['PAI_MEGATRON_PATCH_PATH']
24-
qwen_folders = ['toolkits/model_checkpoints_convertor/qwen']
25-
for folder in qwen_folders:
26-
dir_path = os.path.join(megatron_patch_path, folder)
27-
for fname in os.listdir(dir_path):
28-
old_path = os.path.join(dir_path, fname)
29-
fname = fname.replace('qwen1.', 'qwen1_')
30-
fname = fname.replace('qwen2.', 'qwen2_')
31-
new_path = os.path.join(dir_path, fname)
32-
if old_path != new_path and os.path.exists(old_path):
33-
shutil.move(old_path, new_path)
34-
35-
36-
def init_megatron_env() -> None:
37-
if 'MEGATRON_LM_PATH' not in os.environ:
38-
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
39-
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.10.0')
40-
if not is_megatron_available():
41-
subprocess_run(['pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
42-
sys.path.append(os.environ['MEGATRON_LM_PATH'])
43-
44-
if 'PAI_MEGATRON_PATCH_PATH' not in os.environ:
45-
os.environ['PAI_MEGATRON_PATCH_PATH'] = git_clone_github(
46-
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='v0.10.1')
47-
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])
48-
49-
# rename qwen1.5/2.5->qwen1_5/2_5 files
50-
with safe_ddp_context('rename_files'):
51-
_rename_files()
52-
53-
5422
def patch_megatron(tokenizer):
5523

5624
def build_tokenizer(args):

0 commit comments

Comments
 (0)