Skip to content

Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag #36835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
May 23, 2025

Conversation

inf3rnus
Copy link
Contributor

@inf3rnus inf3rnus commented Mar 19, 2025

What does this PR do?

modeling_utils.py has been modified to allow for the parallelized loading of weights. Two env vars have been introduced, HF_ENABLE_PARALLEL_LOADING which allows weights to be loaded in a multiprocessing pool and HF_PARALLEL_LOADING_WORKERS which specifies how many child processes you want to spawn to handle loading.

Only works for sharded torch based weights and safe tensors for now, although the same pattern can be extended to other weights by other contributors.

This can significantly speed up loading large models on big hardware, ~50%.

e.g. facebook/opt-30b on an AWS EC2 g4dn.metal can be made to load in ~30s with this modification vs ~55s without it.

Using napkin math, this increase in efficiency should equate to thousands of dollars monthly if not more in cold boot compute expenses being saved globally for anyone using HF on the cloud.

Please let me know if anything else needs to be changed to move this forward!

Before submitting

@Rocketknight1 @gante

@github-actions github-actions bot marked this pull request as draft March 19, 2025 18:41
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@inf3rnus inf3rnus marked this pull request as ready for review March 19, 2025 18:44
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting!
Indeed we do this "sequentially" today. I think using an even bigger checkpoint like a 100B model or 400B would be nice.

Great splitting, having a separate function for load_shard_file!

  • missing some API simplification:
    multiprocessing.set_start_method("spawn", force=True) should be done by ourselves (because you should have access to the device you are loading on!)

Also needs to be guarded against TP (until stable / tested!)

Comment on lines 4985 to 4987
if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")):
num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice BUT! I think we need some guards / good defaults:

  • depends on the number of shard files
  • should depend on the number of available threads
    This will help us finding good sweetspots!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree on point 1, we should just min() that on len(args_list).

Second point, I'd argue is an enhancement, for now the benefits are so great we can leave it to the user until itr 2.

(I'll be cooking this up real soon!)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha okay 🤗

Comment on lines 5013 to 5019
splits = tensor_name.split(".")
module = model_to_load

for split in splits[:-1]:
module = getattr(module, split)

last_key = splits.pop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a util for that, see the tensor_parallel integration that also needs to get the module from the key name 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate it! Will make the adjustment :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to improve!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this file! I'll see to add some tests myeelf once this has been merged!

@ArthurZucker ArthurZucker added Core: Modeling Internals of the library; Models. from_pretrained labels Mar 20, 2025
@inf3rnus
Copy link
Contributor Author

inf3rnus commented Mar 20, 2025

Also, quick Q, is it normal for some of the CI tests to fail with the 4.50.0 branch, or is my code running amok?

I checked the outputs and I don't think it's my code.

Also, couldn't find a place where the model_utils.py run, I assumed anything with the test prefix would run...

Locally, wrt to just the tests for model_utils (including the parallel tests), they are passing, so I'm inclined to believe this code is stable, although it could certainly benefit from additional tests.

It's setup to run optionally until we suss out any remaining edge cases.

Also curious to hear what your thoughts are on how to improve the parallel tests, I can dream up some ideas, but I don't want to do anything that will result in waste.

@inf3rnus
Copy link
Contributor Author

inf3rnus commented Mar 20, 2025

No longer needed, I've encapsulated this into the model loading parallelism if block

Any ideas on where we could put this?

missing some API simplification:
multiprocessing.set_start_method("spawn", force=True) should be done by ourselves (because you should have access to the device you are loading on!)

I could run some experiments, but if you know of a great location where env vars are reduced into global configuration settings, that's probably where we want to do it, if there's some place where a bunch of package wide initialization occurs.

@inf3rnus
Copy link
Contributor Author

@ArthurZucker Alright, I've made the requested changes, pls lmk if anything else needs touching up!

Notable changes based on the feedback provided:

  • The multiprocessing start method (spawn) is now handled in a try finally in the if block that controls model parallelism.
  • I've incorporated the use of model.get_submodule() for updating the meta model after process workers load weights.
  • Docs have been updated to remove mention of the spawn method for multiprocessing now that it's handled internally.
  • Manually setting the start method to spawn in the model loading parallelism test file has been removed.

@inf3rnus inf3rnus requested a review from ArthurZucker March 25, 2025 17:57
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 89 to 93
def is_true(value: Optional[str]) -> bool:
if value is None:
return False
return value.upper() in ENV_VARS_TRUE_VALUES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a super fan of this, let's just do this:

os.environ.get("MY_VAR", "").upper() in ENV_VARS_TRUE_VALUES````



# Declare the normal model_utils.py test as a sideffect of importing the module
from .test_modeling_utils import ModelUtilsTest # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need it to run on all models not even sure this tests it all?

@inf3rnus
Copy link
Contributor Author

inf3rnus commented May 7, 2025

@Rocketknight1 bumping for visibility 🙏

@inf3rnus
Copy link
Contributor Author

Hey all, just checking in again, figure you guys are slammed. Very appreciative of the time you've spent reviewing this 🙏

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez Tagging you again for a bump, it's right at the finish line. Very much appreciate everyone's eyes on this so far

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALright, sorry for the delay! This in indeed much cleaner, and I agree that using threads is much better as we share the objects by default. This should be thread safe as the different state dicts are orthogonal to each other, but did you check further by any chance? I.e. in the case of quantization, I did not check that we are not mutating other objects in ways that could go through race conditions, did you make sure?

Also, interested to see how much faster this is, given the GIL 🤔 Not entirely sure if the new thread would trigger new GPU streams, and if those would be run concurrently as well I must say, so hard to tell - did you benchmark a bit? 😁🤗

@Cyrilvallez
Copy link
Member

Oh also, please run make style with ruff==0.11.2 to pass code quality!

@inf3rnus
Copy link
Contributor Author

inf3rnus commented May 15, 2025

@Cyrilvallez

Alright all changes requested have been made!

ALright, sorry for the delay! This in indeed much cleaner, and I agree that using threads is much better as we share the objects by default. This should be thread safe as the different state dicts are orthogonal to each other, but did you check further by any chance? I.e. in the case of quantization, I did not check that we are not mutating other objects in ways that could go through race conditions, did you make sure?

I just did a check on each of the quantizers, taking a look at _load_state_dict_into_meta_model and create_quantized_param for each quantizer method and we should be thread safe as long as each shard is indeed responsible for a unique set of layers. Which the only doubt in my mind there would be with tying weights...

You'd know best, but if we can guarantee that tying of weights only happens to the layers in a given shard, then we should be safe. Otherwise yeah, we're going to need to introduce locks.

And it's really only one of the quantizers that I can see this happening in, it may be happening in the others, but if it is, I'm ignorant to it.

Quantizer of concern: quantizer_torchao

I don't think this should be a problem though because after looking at the code, if I'm not mistaken models always have loaded their input and output embeddings / encoder and decoder weights before loading the rest of the weights? It also only ties weights for the shard containing the module for the input embeddings.

But outside of tying weights, I don't believe that there are any intersections in terms of assignment operations with respect to state, each quantization method only sets new weight values per module on a given layer via the state_dict that is constructed per thread, so each thread's set of layers of the model's overall state should be modified in isolation. They also all do this synchronously, so there should be no unexpected changes to a given module's state due to concurrent access.

quantizer_torchao below for reference (other comments here after this codeblock fyi):

    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        state_dict: Dict[str, Any],
        unexpected_keys: List[str],
    ):
        """
        Each nn.Linear layer that needs to be quantized is processed here.
        First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
        """
        if self.quantization_config.quant_type == "autoquant":
            return

        from torchao.quantization import quantize_

        module, tensor_name = get_module_from_name(model, param_name)
        if self.pre_quantized:
            module._parameters[tensor_name] = torch.nn.Parameter(
                param_value.to(device=target_device), requires_grad=param_value.requires_grad
            )
            if isinstance(module, nn.Linear):
                module.extra_repr = types.MethodType(_linear_extra_repr, module)
        else:
            assert isinstance(self.quantization_config, TorchAoConfig)
            module._parameters[tensor_name] = torch.nn.Parameter(
                param_value, requires_grad=param_value.requires_grad
            ).to(device=target_device)
            # if we are quantizing tied parameters, to avoid tying the quantized weights
            # the correct order to do it is
            # 1. load the weight to model
            # 2. run tie_weights to populate the weights
            # 3. quantize
            input_embed = model.get_input_embeddings()
            if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
                model.tie_weights()
                setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)

            # handle AOPerModuleConfig, introduced in torchao 0.11.0+
            if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
                from torchao.quantization import AOPerModuleConfig

                config = self.quantization_config.get_apply_tensor_subclass()
                if isinstance(config, AOPerModuleConfig):
                    module_fqn, _ = param_name.rsplit(".", 1)
                    c = None
                    if module_fqn in config.module_fqn_to_config:
                        c = config.module_fqn_to_config[module_fqn]
                    else:
                        c = config.module_fqn_to_config.get("_default", None)
                    if c is not None:
                        # filter_fn: not filtering out any modules
                        quantize_(module, c, filter_fn=lambda x, fqn: True)
                    return

            quantize_(module, self.quantization_config.get_apply_tensor_subclass())

Also, interested to see how much faster this is, given the GIL 🤔 Not entirely sure if the new thread would trigger new GPU streams, and if those would be run concurrently as well I must say, so hard to tell - did you benchmark a bit? 😁🤗

Surprisingly, and maybe not so surprisingly the performance is even better. I did do a run on facebook/opt-30b again and you get about the same performance in terms of the time it takes to load the files. However, you see between a 5-7s speed up in not having to clone processes, so that's 7s saved!

Presumably what's happening is a lot of this work's IO bound, the only major part AFAIK that isn't is the deserialization of the weights into system RAM. So while only one thread may deserialize weights, the rest are free to do work while that guy sits on the main thread.

If the disk can provide the throughput and you have vCPUs at your disposal, you will see speed up benefits!

The same can be said about offloading weights from RAM into the GPU, that's all likely handled by system calls and driver code that's divorced from any work that needs to actually happen within the python process itself.

That's my theory at least, if we infinite time, we could certainly drill into it 😄...

Prev results were around 30s to load facebook/opt-30b

New results are around 24sish, although this is from memory as I did not put those numbers somewhere at the time. All I know is it was faster than the previous time of 30s by about 6 seconds.

@inf3rnus
Copy link
Contributor Author

@Cyrilvallez @ArthurZucker Pinging again, I appreciate your patience 🙏

@Cyrilvallez
Copy link
Member

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry, things were a bit crazy lately 😬 I was able to run make style for you thanks to our new bot, however could you rebase/merge very quickly to take care of the tiny conflict in the imports? 🤗
Also, let's remove the additional test file as I pointed out, I'll see to add some myself after I merge!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this file! I'll see to add some tests myeelf once this has been merged!

@inf3rnus
Copy link
Contributor Author

inf3rnus commented May 22, 2025

@Cyrilvallez All changes have been made! Big time thank you for looking at this. I figure you guys are absolutely slammed right now

@Cyrilvallez
Copy link
Member

Cyrilvallez commented May 23, 2025

Alright, merging! Thanks for the great contribution and for bearing with us 🙏🤗

@Cyrilvallez Cyrilvallez enabled auto-merge (squash) May 23, 2025 16:27
@Cyrilvallez Cyrilvallez merged commit d5f992f into huggingface:main May 23, 2025
20 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@inf3rnus
Copy link
Contributor Author

Thanks folks, viva la Huggingface! 🤗

redmoe-moutain pushed a commit to redmoe-moutain/transformers that referenced this pull request Jun 10, 2025
huggingface#36835)

* Get parallel loader working. Include tests.

* Update the tests for parallel loading

* Rename env variables.

* Add docs for parallel model weight loading.

* Touch up parallel model loading docs.

* Touch up parallel model loading docs again.

* Edit comment in test_modeling_utils_parallel_loading.py

* Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py

* Correct times for parallelized loading, previous times were for a "hot" filesystem

* Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule.

* Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally.

* Fix style on model loading parallelism changes.

* Merge latest version of master's modeling_utils.

* Removed unused variable.

* Fix argument packing for the parallel loader.

* Fix state dict being undefined in the parallel model loader.

* Rename variables used in parallel model loading for clarity. Use get_module_from_name().

* Switch to the use of threads for parallel model loading.

* Update docs for parallel loading.

* Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting.

* Move parallelized shard loading into its own function.

* Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING.

* Update copyright to 2025 in readme for paralell model loading.

* Remove garbage collection line in load_shard_file, implicit garbage collection already occurs.

* Run formatter on modeling_utils.py

* Apply style fixes

* Delete tests/utils/test_modeling_utils_parallel_loading.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Cyril Vallez <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models. from_pretrained
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants