Skip to content

Commit 99c515b

Browse files
authored
patch parallel download shard (#159)
1 parent aa99dcb commit 99c515b

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

qllm/modeling/base.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers import AutoModelForCausalLM
55
from pathlib import Path
66
import tqdm
7+
import functools
78
import glob
89
import json
910
import contextlib
@@ -46,6 +47,14 @@ def no_init_weights(attrs: list = None):
4647
if old_attr[idx] is not None:
4748
setattr(torch.Tensor, attr, old_attr[idx])
4849

50+
@contextlib.contextmanager
51+
def patch_cache_file_in_parallel():
52+
old_cached_file_func = transformers.utils.hub.cached_file
53+
old_get_checkpoint_shard_files_func = transformers.utils.hub.get_checkpoint_shard_files
54+
yield
55+
transformers.utils.hub.cached_file = old_cached_file_func
56+
transformers.utils.hub.get_checkpoint_shard_files = old_get_checkpoint_shard_files_func
57+
4958
def get_no_split_layer_type_name(model:torch.nn.Module):
5059
try:
5160
return model._get_no_split_modules("auto")
@@ -92,6 +101,16 @@ def _get_resolved_weight_or_index_file(model_name_or_path):
92101
return str(weight_or_index_file)
93102

94103

104+
def parallel_download_decorator(task_func_shard, *args, **kwargs):
105+
with concurrent.futures.ThreadPoolExecutor() as executor:
106+
def cached_file_func_in_thread(task_func, *args, **kwargs):
107+
return executor.submit(task_func, *args, **kwargs)
108+
transformers.utils.hub.cached_file = functools.partial(cached_file_func_in_thread, transformers.utils.hub.cached_file)
109+
result = task_func_shard(*args, **kwargs)
110+
result_0 = [future.result() for future in result[0]]
111+
return result_0, result[1]
112+
113+
95114
def _load_check_point(model, model_name_or_path, get_keys_only: bool = False):
96115
weight_or_index_file = _get_resolved_weight_or_index_file(model_name_or_path)
97116
all_keys = set()
@@ -183,16 +202,20 @@ def from_pretrained(
183202

184203
torch_dtype = kwargs.pop("torch_dtype", auto_conf.torch_dtype)
185204

186-
llm = AutoModelForCausalLM.from_pretrained(
187-
pretrained_model_name_or_path,
188-
torch_dtype=torch_dtype,
189-
trust_remote_code=trust_remote_code,
190-
attn_implementation=attn_implementation,
191-
# device_map="auto",
192-
# low_cpu_mem_usage=True,
193-
# max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024},
194-
# offload_folder="/tmp/a2"
195-
)
205+
with patch_cache_file_in_parallel():
206+
transformers.utils.hub.get_checkpoint_shard_files = functools.partial(
207+
parallel_download_decorator, transformers.utils.hub.get_checkpoint_shard_files
208+
)
209+
llm = AutoModelForCausalLM.from_pretrained(
210+
pretrained_model_name_or_path,
211+
torch_dtype=torch_dtype,
212+
trust_remote_code=trust_remote_code,
213+
attn_implementation=attn_implementation,
214+
# device_map="auto",
215+
# low_cpu_mem_usage=True,
216+
# max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024},
217+
# offload_folder="/tmp/a2"
218+
)
196219
return llm
197220

198221
@classmethod

0 commit comments

Comments
 (0)