Skip to content

Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag #36835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

inf3rnus
Copy link
Contributor

@inf3rnus inf3rnus commented Mar 19, 2025

What does this PR do?

modeling_utils.py has been modified to allow for the parallelized loading of weights. Two env vars have been introduced, HF_ENABLE_PARALLEL_LOADING which allows weights to be loaded in a multiprocessing pool and HF_PARALLEL_LOADING_WORKERS which specifies how many child processes you want to spawn to handle loading.

Only works for sharded torch based weights and safe tensors for now, although the same pattern can be extended to other weights by other contributors.

This can significantly speed up loading large models on big hardware, ~50%.

e.g. facebook/opt-30b on an AWS EC2 g4dn.metal can be made to load in ~30s with this modification vs ~55s without it.

Using napkin math, this increase in efficiency should equate to thousands of dollars monthly if not more in cold boot compute expenses being saved globally for anyone using HF on the cloud.

Please let me know if anything else needs to be changed to move this forward!

Before submitting

@Rocketknight1 @gante

@github-actions github-actions bot marked this pull request as draft March 19, 2025 18:41
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@inf3rnus inf3rnus marked this pull request as ready for review March 19, 2025 18:44
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting!
Indeed we do this "sequentially" today. I think using an even bigger checkpoint like a 100B model or 400B would be nice.

Great splitting, having a separate function for load_shard_file!

  • missing some API simplification:
    multiprocessing.set_start_method("spawn", force=True) should be done by ourselves (because you should have access to the device you are loading on!)

Also needs to be guarded against TP (until stable / tested!)

Comment on lines 4985 to 4987
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")):
num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice BUT! I think we need some guards / good defaults:

  • depends on the number of shard files
  • should depend on the number of available threads
    This will help us finding good sweetspots!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree on point 1, we should just min() that on len(args_list).

Second point, I'd argue is an enhancement, for now the benefits are so great we can leave it to the user until itr 2.

(I'll be cooking this up real soon!)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha okay 🤗

Comment on lines 5013 to 5019
splits = tensor_name.split(".")
module = model_to_load

for split in splits[:-1]:
module = getattr(module, split)

last_key = splits.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a util for that, see the tensor_parallel integration that also needs to get the module from the key name 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate it! Will make the adjustment :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to improve!

@ArthurZucker ArthurZucker added Core: Modeling Internals of the library; Models. from_pretrained labels Mar 20, 2025
@inf3rnus
Copy link
Contributor Author

inf3rnus commented Mar 20, 2025

Also, quick Q, is it normal for some of the CI tests to fail with the 4.50.0 branch, or is my code running amok?

I checked the outputs and I don't think it's my code.

Also, couldn't find a place where the model_utils.py run, I assumed anything with the test prefix would run...

Locally, wrt to just the tests for model_utils (including the parallel tests), they are passing, so I'm inclined to believe this code is stable, although it could certainly benefit from additional tests.

It's setup to run optionally until we suss out any remaining edge cases.

Also curious to hear what your thoughts are on how to improve the parallel tests, I can dream up some ideas, but I don't want to do anything that will result in waste.

@inf3rnus
Copy link
Contributor Author

inf3rnus commented Mar 20, 2025

No longer needed, I've encapsulated this into the model loading parallelism if block

Any ideas on where we could put this?

missing some API simplification:
multiprocessing.set_start_method("spawn", force=True) should be done by ourselves (because you should have access to the device you are loading on!)

I could run some experiments, but if you know of a great location where env vars are reduced into global configuration settings, that's probably where we want to do it, if there's some place where a bunch of package wide initialization occurs.

@inf3rnus
Copy link
Contributor Author

@ArthurZucker Alright, I've made the requested changes, pls lmk if anything else needs touching up!

Notable changes based on the feedback provided:

  • The multiprocessing start method (spawn) is now handled in a try finally in the if block that controls model parallelism.
  • I've incorporated the use of model.get_submodule() for updating the meta model after process workers load weights.
  • Docs have been updated to remove mention of the spawn method for multiprocessing now that it's handled internally.
  • Manually setting the start method to spawn in the model loading parallelism test file has been removed.

@inf3rnus inf3rnus requested a review from ArthurZucker March 25, 2025 17:57
@ArthurZucker
Copy link
Collaborator

Hey! sorry I got caught up here and there will come back to this but no rush!

@ArthurZucker
Copy link
Collaborator

We need to test on different settings, cluster no cluster, etc to make sure we are not introducing something that does not work well

@inf3rnus
Copy link
Contributor Author

inf3rnus commented Apr 9, 2025

Np! More than appreciative you're assisting me with this :)

Just give me a list of todo items and I'll execute, dying to merge this puppy!

@inf3rnus
Copy link
Contributor Author

Hey Arthur, just bumping... Dying to make the changes needed to get this over the line

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @inf3rnus! Nice work, and super sorry for the wait on this one! You can see my review comments attached! 🤗

Basically, we need to be mostly careful about sharing the model to each child process. Since the recent changes, models are always loaded on meta (except with deepspeed when non-quantized, see comments), so it won't result in too much memory overhead, but the best would still be to be able to share the model between processes if possible. Most notably because the buffers are not on meta, so a model with some large buffers could see its memory requirements surge a lot with this (both on gpus devices and cpu, depending on the situation).

Also, could you make sure to rebase on latest main, and make sure that the functions you copy/pasted to be external functions are still similar to what is being done on main (because it won't be flagged as a merge conflict here) 🤗

Comment on lines 4925 to 4936

# We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker
# We are transferring that state into the orginal ref (model_to_load) here
# This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors
# You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust
# in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct
for state_dict_modules in state_dict_modules_list:
for full_name, param in state_dict_modules.items():
*module_path, attr_name = full_name.split(".")
module_path = ".".join(module_path)
module = model_to_load.get_submodule(module_path)
setattr(module, attr_name, param)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we could instead correctly share the model between processes here?

Comment on lines 913 to 914
# information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model
state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to have this function here (it's very inefficient in case of NOT using multiprocessing, as it's basically useless) -> only the keys of the state_dict are enough information to then do it after the loading in the multiprocessing case

Comment on lines 4892 to 4893
# Use multiprocessing Pool for parallel execution, off by default
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded against deepspeed here, as deepspeed is the only remaining path where the model in not on meta -> it will lead to exploding the memory as each process copy the model

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Apr 25, 2025

Also, something very concerning to me is that in the current implementation, I think that if we have more state_dicts than num_workers, the second time a given worker starts processing a state_dict, the model it will copy already has some params loaded, which will blow up the memory very quickly (memory usage will be multiplied by a bit more than num_workers) no? As we use the reference to a single model_to_load for all arg_list. Did you try on a different checkpoint than facebook/opt-30b, which has exactly 8 state_dicts, so the issue won't be visible for it

A simple workaround could be to deepcopy the model when instantiating the arg_list though

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez Thanks Cyril, really appreciate your eyes on this 🙏

Moving through your feedback today, hope to have all concerns addressed by EOTD

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez

I've made your requested changes. I was able to greatly simplify things by just using threads to do the parallel loading. Which I when I created this PR was not possible.

All tests pass for modeling_utils.py locally except for one (test_safetensors_torch_from_flax), which also fails with the latest commit on main, so I believe that's something the HF team will need to handle.

Error for that is:

ERROR: test_safetensors_torch_from_flax (tests.utils.test_modeling_utils.ModelUtilsTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/inf3rnus/Development/transformers/tests/utils/test_modeling_utils.py", line 1581, in test_safetensors_torch_from_flax
    self.assertTrue(torch.equal(p1, p2))
NotImplementedError: aten::equal: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

You left this comment:

Also, something very concerning to me is that in the current implementation, I think that if we have more state_dicts than num_workers, the second time a given worker starts processing a state_dict, the model it will copy already has some params loaded, which will blow up the memory very quickly (memory usage will be multiplied by a bit more than num_workers) no? As we use the reference to a single model_to_load for all arg_list. Did you try on a different checkpoint than facebook/opt-30b, which has exactly 8 state_dicts, so the issue won't be visible for it

A simple workaround could be to deepcopy the model when instantiating the arg_list though

I believe this is solved by using threads now, but please lmk if there's something I'm missing.

Please let me know if there is anything else you need from me, would love to get this merged asap! 🙏

@inf3rnus
Copy link
Contributor Author

Oh yeah, one other thing, I went to run make style but it reformatted a bunch of files, so it's going to fail the format check.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience! Its much better / simpler than before let's get this merge ! 🤗

):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# Use multiprocessing Pool for parallel execution, off by default
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to use json loads here? a simple cast should work!

Comment on lines 5038 to 5051
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
(
_error_msgs,
disk_offload_index,
cpu_offload_index,
) = result

# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
error_msgs += _error_msgs

pbar.update(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put that in an external function?
Just to not make from_pretrained bigger than it already is!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 89 to 93
def is_true(value: Optional[str]) -> bool:
if value is None:
return False
return value.upper() in ENV_VARS_TRUE_VALUES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a super fan of this, let's just do this:

os.environ.get("MY_VAR", "").upper() in ENV_VARS_TRUE_VALUES````



# Declare the normal model_utils.py test as a sideffect of importing the module
from .test_modeling_utils import ModelUtilsTest # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need it to run on all models not even sure this tests it all?

@inf3rnus
Copy link
Contributor Author

inf3rnus commented May 7, 2025

@Rocketknight1 bumping for visibility 🙏

@inf3rnus
Copy link
Contributor Author

Hey all, just checking in again, figure you guys are slammed. Very appreciative of the time you've spent reviewing this 🙏

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez Tagging you again for a bump, it's right at the finish line. Very much appreciate everyone's eyes on this so far

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALright, sorry for the delay! This in indeed much cleaner, and I agree that using threads is much better as we share the objects by default. This should be thread safe as the different state dicts are orthogonal to each other, but did you check further by any chance? I.e. in the case of quantization, I did not check that we are not mutating other objects in ways that could go through race conditions, did you make sure?

Also, interested to see how much faster this is, given the GIL 🤔 Not entirely sure if the new thread would trigger new GPU streams, and if those would be run concurrently as well I must say, so hard to tell - did you benchmark a bit? 😁🤗

@Cyrilvallez
Copy link
Member

Oh also, please run make style with ruff==0.11.2 to pass code quality!

@inf3rnus
Copy link
Contributor Author

inf3rnus commented May 15, 2025

@Cyrilvallez

Alright all changes requested have been made!

ALright, sorry for the delay! This in indeed much cleaner, and I agree that using threads is much better as we share the objects by default. This should be thread safe as the different state dicts are orthogonal to each other, but did you check further by any chance? I.e. in the case of quantization, I did not check that we are not mutating other objects in ways that could go through race conditions, did you make sure?

I just did a check on each of the quantizers, taking a look at _load_state_dict_into_meta_model and create_quantized_param for each quantizer method and we should be thread safe as long as each shard is indeed responsible for a unique set of layers. Which the only doubt in my mind there would be with tying weights...

You'd know best, but if we can guarantee that tying of weights only happens to the layers in a given shard, then we should be safe. Otherwise yeah, we're going to need to introduce locks.

And it's really only one of the quantizers that I can see this happening in, it may be happening in the others, but if it is, I'm ignorant to it.

Quantizer of concern: quantizer_torchao

I don't think this should be a problem though because after looking at the code, if I'm not mistaken models always have loaded their input and output embeddings / encoder and decoder weights before loading the rest of the weights? It also only ties weights for the shard containing the module for the input embeddings.

But outside of tying weights, I don't believe that there are any intersections in terms of assignment operations with respect to state, each quantization method only sets new weight values per module on a given layer via the state_dict that is constructed per thread, so each thread's set of layers of the model's overall state should be modified in isolation. They also all do this synchronously, so there should be no unexpected changes to a given module's state due to concurrent access.

quantizer_torchao below for reference (other comments here after this codeblock fyi):

    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        state_dict: Dict[str, Any],
        unexpected_keys: List[str],
    ):
        """
        Each nn.Linear layer that needs to be quantized is processed here.
        First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
        """
        if self.quantization_config.quant_type == "autoquant":
            return

        from torchao.quantization import quantize_

        module, tensor_name = get_module_from_name(model, param_name)
        if self.pre_quantized:
            module._parameters[tensor_name] = torch.nn.Parameter(
                param_value.to(device=target_device), requires_grad=param_value.requires_grad
            )
            if isinstance(module, nn.Linear):
                module.extra_repr = types.MethodType(_linear_extra_repr, module)
        else:
            assert isinstance(self.quantization_config, TorchAoConfig)
            module._parameters[tensor_name] = torch.nn.Parameter(
                param_value, requires_grad=param_value.requires_grad
            ).to(device=target_device)
            # if we are quantizing tied parameters, to avoid tying the quantized weights
            # the correct order to do it is
            # 1. load the weight to model
            # 2. run tie_weights to populate the weights
            # 3. quantize
            input_embed = model.get_input_embeddings()
            if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
                model.tie_weights()
                setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)

            # handle AOPerModuleConfig, introduced in torchao 0.11.0+
            if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
                from torchao.quantization import AOPerModuleConfig

                config = self.quantization_config.get_apply_tensor_subclass()
                if isinstance(config, AOPerModuleConfig):
                    module_fqn, _ = param_name.rsplit(".", 1)
                    c = None
                    if module_fqn in config.module_fqn_to_config:
                        c = config.module_fqn_to_config[module_fqn]
                    else:
                        c = config.module_fqn_to_config.get("_default", None)
                    if c is not None:
                        # filter_fn: not filtering out any modules
                        quantize_(module, c, filter_fn=lambda x, fqn: True)
                    return

            quantize_(module, self.quantization_config.get_apply_tensor_subclass())

Also, interested to see how much faster this is, given the GIL 🤔 Not entirely sure if the new thread would trigger new GPU streams, and if those would be run concurrently as well I must say, so hard to tell - did you benchmark a bit? 😁🤗

Surprisingly, and maybe not so surprisingly the performance is even better. I did do a run on facebook/opt-30b again and you get about the same performance in terms of the time it takes to load the files. However, you see between a 5-7s speed up in not having to clone processes, so that's 7s saved!

Presumably what's happening is a lot of this work's IO bound, the only major part AFAIK that isn't is the deserialization of the weights into system RAM. So while only one thread may deserialize weights, the rest are free to do work while that guy sits on the main thread.

If the disk can provide the throughput and you have vCPUs at your disposal, you will see speed up benefits!

The same can be said about offloading weights from RAM into the GPU, that's all likely handled by system calls and driver code that's divorced from any work that needs to actually happen within the python process itself.

That's my theory at least, if we infinite time, we could certainly drill into it 😄...

Prev results were around 30s to load facebook/opt-30b

New results are around 24sish, although this is from memory as I did not put those numbers somewhere at the time. All I know is it was faster than the previous time of 30s by about 6 seconds.

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez @ArthurZucker Pinging again, I appreciate your patience 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models. from_pretrained
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants