|
4 | 4 | from transformers import AutoModelForCausalLM
|
5 | 5 | from pathlib import Path
|
6 | 6 | import tqdm
|
| 7 | +import functools |
7 | 8 | import glob
|
8 | 9 | import json
|
9 | 10 | import contextlib
|
@@ -46,6 +47,14 @@ def no_init_weights(attrs: list = None):
|
46 | 47 | if old_attr[idx] is not None:
|
47 | 48 | setattr(torch.Tensor, attr, old_attr[idx])
|
48 | 49 |
|
| 50 | +@contextlib.contextmanager |
| 51 | +def patch_cache_file_in_parallel(): |
| 52 | + old_cached_file_func = transformers.utils.hub.cached_file |
| 53 | + old_get_checkpoint_shard_files_func = transformers.utils.hub.get_checkpoint_shard_files |
| 54 | + yield |
| 55 | + transformers.utils.hub.cached_file = old_cached_file_func |
| 56 | + transformers.utils.hub.get_checkpoint_shard_files = old_get_checkpoint_shard_files_func |
| 57 | + |
49 | 58 | def get_no_split_layer_type_name(model:torch.nn.Module):
|
50 | 59 | try:
|
51 | 60 | return model._get_no_split_modules("auto")
|
@@ -92,6 +101,16 @@ def _get_resolved_weight_or_index_file(model_name_or_path):
|
92 | 101 | return str(weight_or_index_file)
|
93 | 102 |
|
94 | 103 |
|
| 104 | +def parallel_download_decorator(task_func_shard, *args, **kwargs): |
| 105 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 106 | + def cached_file_func_in_thread(task_func, *args, **kwargs): |
| 107 | + return executor.submit(task_func, *args, **kwargs) |
| 108 | + transformers.utils.hub.cached_file = functools.partial(cached_file_func_in_thread, transformers.utils.hub.cached_file) |
| 109 | + result = task_func_shard(*args, **kwargs) |
| 110 | + result_0 = [future.result() for future in result[0]] |
| 111 | + return result_0, result[1] |
| 112 | + |
| 113 | + |
95 | 114 | def _load_check_point(model, model_name_or_path, get_keys_only: bool = False):
|
96 | 115 | weight_or_index_file = _get_resolved_weight_or_index_file(model_name_or_path)
|
97 | 116 | all_keys = set()
|
@@ -183,16 +202,20 @@ def from_pretrained(
|
183 | 202 |
|
184 | 203 | torch_dtype = kwargs.pop("torch_dtype", auto_conf.torch_dtype)
|
185 | 204 |
|
186 |
| - llm = AutoModelForCausalLM.from_pretrained( |
187 |
| - pretrained_model_name_or_path, |
188 |
| - torch_dtype=torch_dtype, |
189 |
| - trust_remote_code=trust_remote_code, |
190 |
| - attn_implementation=attn_implementation, |
191 |
| - # device_map="auto", |
192 |
| - # low_cpu_mem_usage=True, |
193 |
| - # max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024}, |
194 |
| - # offload_folder="/tmp/a2" |
195 |
| - ) |
| 205 | + with patch_cache_file_in_parallel(): |
| 206 | + transformers.utils.hub.get_checkpoint_shard_files = functools.partial( |
| 207 | + parallel_download_decorator, transformers.utils.hub.get_checkpoint_shard_files |
| 208 | + ) |
| 209 | + llm = AutoModelForCausalLM.from_pretrained( |
| 210 | + pretrained_model_name_or_path, |
| 211 | + torch_dtype=torch_dtype, |
| 212 | + trust_remote_code=trust_remote_code, |
| 213 | + attn_implementation=attn_implementation, |
| 214 | + # device_map="auto", |
| 215 | + # low_cpu_mem_usage=True, |
| 216 | + # max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024}, |
| 217 | + # offload_folder="/tmp/a2" |
| 218 | + ) |
196 | 219 | return llm
|
197 | 220 |
|
198 | 221 | @classmethod
|
|
0 commit comments