-
Notifications
You must be signed in to change notification settings - Fork 29.5k
Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag #36835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
8fb9b18
27f36f2
7e5ecd8
e7c3ea5
7599fe2
065e102
d31594a
33b3e0f
3fb6b65
0e22c04
904bdaf
7e37ba4
a203f6a
14e9eef
fe1fc0c
d5637e8
e0d37bb
9b4165c
1085461
82ab2ec
7ae3db6
8d04325
674ec37
b8a1470
efb6605
c66daef
4566c5c
610c5e3
a9cb54b
fc76fbb
16f3751
cd0f42e
3b9f458
b6bf421
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
<!--Copyright 2020 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# Environment Variables | ||
|
||
## HF_ENABLE_PARALLEL_LOADING | ||
|
||
By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. | ||
|
||
Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. | ||
|
||
e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. | ||
|
||
Profile before committing to using this environment variable, this will not produce speed ups for smaller models. | ||
|
||
NOTE, if you are not loading a model onto specifically the CPU, you must set `multiprocessing` to use the `spawn` start method like so: | ||
|
||
```py | ||
import os | ||
|
||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" | ||
|
||
import multiprocessing | ||
from transformers import pipeline | ||
|
||
if __name__ == "__main__": | ||
# NOTE if a model loads on CPU this is not required | ||
multiprocessing.set_start_method("spawn", force=True) | ||
|
||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
``` | ||
|
||
If loading onto a cuda device, the code will crash if multiprocessing.set_start_method("spawn", force=True) is not set. | ||
|
||
## HF_PARALLEL_LOADING_WORKERS | ||
|
||
Determines how many child processes should be used when parallel loading is enabled. Default is `8`. Tune as you see fit. | ||
|
||
```py | ||
import os | ||
|
||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" | ||
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" | ||
|
||
import multiprocessing | ||
from transformers import pipeline | ||
|
||
if __name__ == "__main__": | ||
# NOTE if a model loads on CPU this is not required | ||
multiprocessing.set_start_method("spawn", force=True) | ||
|
||
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -874,6 +874,128 @@ def _load_state_dict_into_meta_model( | |
return disk_offload_index, cpu_offload_index | ||
|
||
|
||
def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): | ||
state_dict_modules = {} | ||
|
||
for tensor_name in state_dict.keys(): | ||
if tensor_name not in expected_keys: | ||
continue | ||
|
||
splits = tensor_name.split(".") | ||
module = model_to_load | ||
for split in splits: | ||
try: | ||
module = getattr(module, split) | ||
except Exception as exception: | ||
print(exception) | ||
pass | ||
inf3rnus marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
state_dict_modules[tensor_name] = module | ||
|
||
return state_dict_modules | ||
|
||
|
||
# This function is in global scope so it's picklable for multiprocessing | ||
def load_shard_file(args): | ||
( | ||
state_dict, | ||
shard_file, | ||
disk_only_shard_files, | ||
low_cpu_mem_usage, | ||
is_quantized, | ||
device_map, | ||
hf_quantizer, | ||
key_renaming_mapping, | ||
weights_only, | ||
model_to_load, | ||
ignore_mismatched_sizes, | ||
prefix, | ||
loading_base_model_from_task_state_dict, | ||
expected_keys, | ||
reverse_key_renaming_mapping, | ||
disk_offload_folder, | ||
disk_offload_index, | ||
cpu_offload_folder, | ||
cpu_offload_index, | ||
is_offloaded_safetensors, | ||
keep_in_fp32_modules, | ||
unexpected_keys, | ||
device_mesh, | ||
) = args | ||
# Skip the load for shards that only contain disk-offloaded weights | ||
if shard_file in disk_only_shard_files: | ||
return [], [], disk_offload_index, cpu_offload_index, {} | ||
|
||
map_location = "cpu" | ||
if low_cpu_mem_usage: | ||
if shard_file.endswith(".safetensors") and not is_quantized: | ||
map_location = "meta" | ||
elif ( | ||
device_map is not None | ||
and hf_quantizer is not None | ||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
): | ||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
|
||
# If shard_file is "", we use the existing state_dict instead of loading it | ||
if shard_file != "": | ||
state_dict = load_state_dict( | ||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only | ||
) | ||
|
||
error_msgs = [] | ||
mismatched_keys = [] | ||
|
||
# Fix the key names | ||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||
|
||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
# matching the weights in the model. | ||
mismatched_keys += _find_mismatched_keys( | ||
model_to_load, | ||
state_dict, | ||
ignore_mismatched_sizes, | ||
prefix if loading_base_model_from_task_state_dict else "", | ||
) | ||
|
||
if low_cpu_mem_usage: | ||
# Skip it with fsdp on ranks other than 0 | ||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): | ||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( | ||
model_to_load, | ||
state_dict, | ||
shard_file, | ||
expected_keys, | ||
reverse_key_renaming_mapping, | ||
device_map=device_map, | ||
disk_offload_folder=disk_offload_folder, | ||
disk_offload_index=disk_offload_index, | ||
cpu_offload_folder=cpu_offload_folder, | ||
cpu_offload_index=cpu_offload_index, | ||
hf_quantizer=hf_quantizer, | ||
is_safetensors=is_offloaded_safetensors, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
unexpected_keys=unexpected_keys, | ||
device_mesh=device_mesh, | ||
) | ||
else: | ||
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) | ||
if is_deepspeed_zero3_enabled(): | ||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) | ||
else: | ||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) | ||
|
||
# We now figure out what in the state dict changed and store the module used for each layer, this will contain the device | ||
# information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model | ||
state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to have this function here (it's very inefficient in case of NOT using multiprocessing, as it's basically useless) -> only the keys of the state_dict are enough information to then do it after the loading in the multiprocessing case |
||
|
||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop | ||
del state_dict | ||
|
||
return mismatched_keys, error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules | ||
|
||
|
||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: | ||
if variant is not None: | ||
path, name = weights_name.rsplit(".", 1) | ||
|
@@ -4810,9 +4932,6 @@ def _load_pretrained_model( | |
cpu_offload_folder = tempfile.mkdtemp() | ||
cpu_offload_index = {} | ||
|
||
# For nice tqdm bars | ||
if checkpoint_files is not None and len(checkpoint_files) > 1: | ||
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") | ||
# To be able to iterate, even if we don't use it if the state_dict is already provided | ||
elif state_dict is not None: | ||
checkpoint_files = [""] | ||
|
@@ -4827,73 +4946,99 @@ def _load_pretrained_model( | |
expanded_device_map = expand_device_map(device_map, expected_keys) | ||
caching_allocator_warmup(model_to_load, expanded_device_map) | ||
|
||
error_msgs = [] | ||
from multiprocessing import Pool | ||
|
||
# Prepare arguments for multiprocessing | ||
args_list = [ | ||
( | ||
state_dict, | ||
shard_file, | ||
disk_only_shard_files, | ||
low_cpu_mem_usage, | ||
is_quantized, | ||
device_map, | ||
hf_quantizer, | ||
key_renaming_mapping, | ||
weights_only, | ||
model_to_load, | ||
ignore_mismatched_sizes, | ||
prefix, | ||
loading_base_model_from_task_state_dict, | ||
expected_keys, | ||
reverse_key_renaming_mapping, | ||
disk_offload_folder, | ||
disk_offload_index, | ||
cpu_offload_folder, | ||
cpu_offload_index, | ||
is_offloaded_safetensors, | ||
keep_in_fp32_modules, | ||
unexpected_keys, | ||
device_mesh, | ||
) | ||
for shard_file in checkpoint_files | ||
] | ||
|
||
mismatched_keys = [] | ||
# Iterate on all the shards to load the weights | ||
for shard_file in checkpoint_files: | ||
# Skip the load for shards that only contain disk-offloaded weights | ||
if shard_file in disk_only_shard_files: | ||
continue | ||
error_msgs = [] | ||
|
||
map_location = "cpu" | ||
if low_cpu_mem_usage: | ||
if shard_file.endswith(".safetensors") and not is_quantized: | ||
map_location = "meta" | ||
elif ( | ||
device_map is not None | ||
and hf_quantizer is not None | ||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
): | ||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
# Use multiprocessing Pool for parallel execution, off by default | ||
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be guarded against deepspeed here, as deepspeed is the only remaining path where the model in not on meta -> it will lead to exploding the memory as each process copy the model |
||
num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) | ||
logger.info(f"Loading model weights in parallel with {num_workers} workers...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is nice BUT! I think we need some guards / good defaults:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fully agree on point 1, we should just min() that on len(args_list). Second point, I'd argue is an enhancement, for now the benefits are so great we can leave it to the user until itr 2. (I'll be cooking this up real soon!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha okay 🤗 |
||
state_dict_modules_list = [] | ||
|
||
with Pool(processes=num_workers) as pool: | ||
# For nice tqdm bars | ||
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: | ||
# NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model | ||
for result in pool.imap_unordered(load_shard_file, args_list): | ||
_mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( | ||
result | ||
) | ||
|
||
# If shard_file is "", we use the existing state_dict instead of loading it | ||
if shard_file != "": | ||
state_dict = load_state_dict( | ||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only | ||
) | ||
mismatched_keys += _mismatched_keys | ||
error_msgs += _error_msgs | ||
|
||
# Fix the key names | ||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||
state_dict_modules_list.append(state_dict_modules) | ||
|
||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
# matching the weights in the model. | ||
mismatched_keys += _find_mismatched_keys( | ||
model_to_load, | ||
state_dict, | ||
ignore_mismatched_sizes, | ||
prefix if loading_base_model_from_task_state_dict else "", | ||
) | ||
pbar.update(1) | ||
|
||
if low_cpu_mem_usage: | ||
# Skip it with fsdp on ranks other than 0 | ||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): | ||
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( | ||
model_to_load, | ||
state_dict, | ||
shard_file, | ||
expected_keys, | ||
reverse_key_renaming_mapping, | ||
device_map=device_map, | ||
disk_offload_folder=disk_offload_folder, | ||
disk_offload_index=disk_offload_index, | ||
cpu_offload_folder=cpu_offload_folder, | ||
cpu_offload_index=cpu_offload_index, | ||
hf_quantizer=hf_quantizer, | ||
is_safetensors=is_offloaded_safetensors, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
unexpected_keys=unexpected_keys, | ||
device_mesh=device_mesh, | ||
) | ||
else: | ||
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) | ||
if is_deepspeed_zero3_enabled(): | ||
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) | ||
else: | ||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) | ||
# We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker | ||
# We are transferring that state into the orginal ref (model_to_load) here | ||
# This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors | ||
# You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust | ||
# in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct | ||
for state_dict_modules in state_dict_modules_list: | ||
for tensor_name in state_dict_modules.keys(): | ||
splits = tensor_name.split(".") | ||
module = model_to_load | ||
|
||
for split in splits[:-1]: | ||
module = getattr(module, split) | ||
|
||
last_key = splits.pop() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have a util for that, see the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Appreciate it! Will make the adjustment :) |
||
|
||
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop | ||
del state_dict | ||
tensor_ref = state_dict_modules[tensor_name] | ||
|
||
setattr(module, last_key, tensor_ref) | ||
|
||
del state_dict_modules_list | ||
gc.collect() | ||
else: | ||
if len(args_list) > 1: | ||
# For nice tqdm bars | ||
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") | ||
|
||
for args in args_list: | ||
_mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( | ||
load_shard_file(args) | ||
) | ||
|
||
mismatched_keys += _mismatched_keys | ||
error_msgs += _error_msgs | ||
|
||
del state_dict_modules | ||
gc.collect() | ||
|
||
# Adjust offloaded weights name and save if needed | ||
if disk_offload_index is not None and len(disk_offload_index) > 0: | ||
|
Uh oh!
There was an error while loading. Please reload this page.