Skip to content

Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag #36835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8fb9b18
Get parallel loader working. Include tests.
inf3rnus Mar 19, 2025
27f36f2
Update the tests for parallel loading
inf3rnus Mar 19, 2025
7e5ecd8
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus Mar 19, 2025
e7c3ea5
Rename env variables.
inf3rnus Mar 19, 2025
7599fe2
Add docs for parallel model weight loading.
inf3rnus Mar 19, 2025
065e102
Touch up parallel model loading docs.
inf3rnus Mar 19, 2025
d31594a
Touch up parallel model loading docs again.
inf3rnus Mar 19, 2025
33b3e0f
Edit comment in test_modeling_utils_parallel_loading.py
inf3rnus Mar 19, 2025
3fb6b65
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus Mar 19, 2025
0e22c04
Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modelin…
inf3rnus Mar 19, 2025
904bdaf
Correct times for parallelized loading, previous times were for a "ho…
inf3rnus Mar 21, 2025
7e37ba4
Update parallel model loading so the spawn method is encapsulated. DR…
inf3rnus Mar 24, 2025
a203f6a
Update docs on model loading parallelism so that details on setting t…
inf3rnus Mar 24, 2025
14e9eef
Fix style on model loading parallelism changes.
inf3rnus Mar 24, 2025
fe1fc0c
Merge remote-tracking branch 'upstream/main' into 03-18-25-parallel-m…
inf3rnus Apr 8, 2025
d5637e8
Merge latest version of master's modeling_utils.
inf3rnus Apr 8, 2025
e0d37bb
Removed unused variable.
inf3rnus Apr 8, 2025
9b4165c
Fix argument packing for the parallel loader.
inf3rnus Apr 8, 2025
1085461
Fix state dict being undefined in the parallel model loader.
inf3rnus Apr 8, 2025
82ab2ec
Merge main.
inf3rnus Apr 29, 2025
7ae3db6
Rename variables used in parallel model loading for clarity. Use get_…
inf3rnus Apr 29, 2025
8d04325
Switch to the use of threads for parallel model loading.
inf3rnus Apr 29, 2025
674ec37
Update docs for parallel loading.
inf3rnus Apr 29, 2025
b8a1470
Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADI…
inf3rnus Apr 30, 2025
efb6605
Move parallelized shard loading into its own function.
inf3rnus Apr 30, 2025
c66daef
Remove use of is_true(). Favor checking env var true values for HF_EN…
inf3rnus May 1, 2025
4566c5c
Update copyright to 2025 in readme for paralell model loading.
inf3rnus May 15, 2025
610c5e3
Remove garbage collection line in load_shard_file, implicit garbage c…
inf3rnus May 15, 2025
a9cb54b
Run formatter on modeling_utils.py
inf3rnus May 15, 2025
fc76fbb
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus May 15, 2025
16f3751
Apply style fixes
github-actions[bot] May 22, 2025
cd0f42e
Merge main.
inf3rnus May 22, 2025
3b9f458
Delete tests/utils/test_modeling_utils_parallel_loading.py
inf3rnus May 22, 2025
b6bf421
Merge branch 'main' into 03-18-25-parallel-model-loading
Cyrilvallez May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1057,4 +1057,9 @@
- local: internal/time_series_utils
title: Utilities for Time Series
title: Internal helpers
- sections:
- local: reference/environment_variables
title: Environment Variables
title: Reference
title: API

66 changes: 66 additions & 0 deletions docs/source/en/reference/environment_variables.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Environment Variables

## HF_ENABLE_PARALLEL_LOADING

By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%.

Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`.

e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it.

Profile before committing to using this environment variable, this will not produce speed ups for smaller models.

NOTE, if you are not loading a model onto specifically the CPU, you must set `multiprocessing` to use the `spawn` start method like so:

```py
import os

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"

import multiprocessing
from transformers import pipeline

if __name__ == "__main__":
# NOTE if a model loads on CPU this is not required
multiprocessing.set_start_method("spawn", force=True)

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```

If loading onto a cuda device, the code will crash if multiprocessing.set_start_method("spawn", force=True) is not set.

## HF_PARALLEL_LOADING_WORKERS

Determines how many child processes should be used when parallel loading is enabled. Default is `8`. Tune as you see fit.

```py
import os

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4"

import multiprocessing
from transformers import pipeline

if __name__ == "__main__":
# NOTE if a model loads on CPU this is not required
multiprocessing.set_start_method("spawn", force=True)

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```
271 changes: 208 additions & 63 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,128 @@ def _load_state_dict_into_meta_model(
return disk_offload_index, cpu_offload_index


def resolve_state_dict_modules(model_to_load, state_dict, expected_keys):
state_dict_modules = {}

for tensor_name in state_dict.keys():
if tensor_name not in expected_keys:
continue

splits = tensor_name.split(".")
module = model_to_load
for split in splits:
try:
module = getattr(module, split)
except Exception as exception:
print(exception)
pass

state_dict_modules[tensor_name] = module

return state_dict_modules


# This function is in global scope so it's picklable for multiprocessing
def load_shard_file(args):
(
state_dict,
shard_file,
disk_only_shard_files,
low_cpu_mem_usage,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
ignore_mismatched_sizes,
prefix,
loading_base_model_from_task_state_dict,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_modules,
unexpected_keys,
device_mesh,
) = args
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
return [], [], disk_offload_index, cpu_offload_index, {}

map_location = "cpu"
if low_cpu_mem_usage:
if shard_file.endswith(".safetensors") and not is_quantized:
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])

# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)

error_msgs = []
mismatched_keys = []

# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
model_to_load,
state_dict,
ignore_mismatched_sizes,
prefix if loading_base_model_from_task_state_dict else "",
)

if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
else:
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params)
else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)

# We now figure out what in the state dict changed and store the module used for each layer, this will contain the device
# information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model
state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to have this function here (it's very inefficient in case of NOT using multiprocessing, as it's basically useless) -> only the keys of the state_dict are enough information to then do it after the loading in the multiprocessing case


# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict

return mismatched_keys, error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules


def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
path, name = weights_name.rsplit(".", 1)
Expand Down Expand Up @@ -4810,9 +4932,6 @@ def _load_pretrained_model(
cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {}

# For nice tqdm bars
if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
Expand All @@ -4827,73 +4946,99 @@ def _load_pretrained_model(
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map)

error_msgs = []
from multiprocessing import Pool

# Prepare arguments for multiprocessing
args_list = [
(
state_dict,
shard_file,
disk_only_shard_files,
low_cpu_mem_usage,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
ignore_mismatched_sizes,
prefix,
loading_base_model_from_task_state_dict,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_modules,
unexpected_keys,
device_mesh,
)
for shard_file in checkpoint_files
]

mismatched_keys = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
continue
error_msgs = []

map_location = "cpu"
if low_cpu_mem_usage:
if shard_file.endswith(".safetensors") and not is_quantized:
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# Use multiprocessing Pool for parallel execution, off by default
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded against deepspeed here, as deepspeed is the only remaining path where the model in not on meta -> it will lead to exploding the memory as each process copy the model

num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice BUT! I think we need some guards / good defaults:

  • depends on the number of shard files
  • should depend on the number of available threads
    This will help us finding good sweetspots!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree on point 1, we should just min() that on len(args_list).

Second point, I'd argue is an enhancement, for now the benefits are so great we can leave it to the user until itr 2.

(I'll be cooking this up real soon!)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha okay 🤗

state_dict_modules_list = []

with Pool(processes=num_workers) as pool:
# For nice tqdm bars
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
# NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model
for result in pool.imap_unordered(load_shard_file, args_list):
_mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = (
result
)

# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
mismatched_keys += _mismatched_keys
error_msgs += _error_msgs

# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
state_dict_modules_list.append(state_dict_modules)

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
model_to_load,
state_dict,
ignore_mismatched_sizes,
prefix if loading_base_model_from_task_state_dict else "",
)
pbar.update(1)

if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
else:
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
if is_deepspeed_zero3_enabled():
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params)
else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
# We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker
# We are transferring that state into the orginal ref (model_to_load) here
# This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors
# You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust
# in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct
for state_dict_modules in state_dict_modules_list:
for tensor_name in state_dict_modules.keys():
splits = tensor_name.split(".")
module = model_to_load

for split in splits[:-1]:
module = getattr(module, split)

last_key = splits.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a util for that, see the tensor_parallel integration that also needs to get the module from the key name 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate it! Will make the adjustment :)


# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
tensor_ref = state_dict_modules[tensor_name]

setattr(module, last_key, tensor_ref)

del state_dict_modules_list
gc.collect()
else:
if len(args_list) > 1:
# For nice tqdm bars
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")

for args in args_list:
_mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = (
load_shard_file(args)
)

mismatched_keys += _mismatched_keys
error_msgs += _error_msgs

del state_dict_modules
gc.collect()

# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
Expand Down
Loading