A library for incremental loading of large PyTorch checkpoints
Read a blogpost introduction by yours truly
import torch
import zipslicer
# Could be a private custom recurrent sentient transformer
# instead of a garden variety resnet
my_complicated_network = torch.hub.load(
"pytorch/vision:v0.10.0", "resnet18", pretrained=True
)
s_dict = my_complicated_network.state_dict()
torch.save(s_dict, "my_network_checkpoint_v123.pth")
del my_complicated_network
# Later, on a smaller unrelated machine you load a "LazyStateDict"
# Which is just like a regular state dict, but it loads tensors only when it has to
lazy_s_dict = zipslicer.load("my_network_checkpoint_v123.pth")
layer3_tensors = {}
for k in lazy_s_dict.keys():
if k.startswith("layer3"):
layer3_tensors[k] = lazy_s_dict[k]
# Now you have layer3's tensors and you can analyze them without breaking your RAM.
# Or you can instantiate the layers' classes in sequence and compute the whole
# network's output for a given input by threading the activations through them.
# But we will just print the tensors instead:
print(layer3_tensors)
Run this example and unit-tests:
python examples/example_resnet18.py
pytest -o log_cli=true --capture=tee-sys -p no:asyncio
Test your checkpoint for compatibility:
python tests/test_checkpoint_readonly.py your_magnificent_checkpoint.pth
If it's all green, it will work.
- Supported python and torch versions:
python-3.10 + torch-(1.11,1.12,stable)
python-3.11 + torch:stable
- Generally,
zipslicer
should work with modern enough install of PyTorch - use included safe test to check for compatibility ofzipslicer
with your PyTorch and your checkpoint. This is a pure Python library, so specific CPU architecture shouldn't matter. - A checkpoint produced by saving your model's
state_dict
via vanilla torch.save(...) - default settings should suffice, as Torch doesn't use ZIP compression. - An application that can take advantage of incrementally-loaded checkpoint - i.e. if your app just loads all
state_dict.items()
in a loop right away it doesn't make much sense to use this library. Make sure your code readsstate_dict.keys()
(andstate_dict.get_meta(k)
if necessary) and uses these intelligently to work on a subset ofstate_dict[k]
tensors at a time. For general inspiration you might read this (HF) and this (arxiv). With some additional engineering it should be possible to run Large Language Models like BLOOM-176B or FLAN-T5-XXL on a single mid-range GPU at home - if you are willing to wait for a night's worth of time. In the large batch regime this might even make some practical sense, for example to process a set of documents into embeddings.
Generally, copying the zipslicer/zipslicer
directory into your project's source tree is enough.
If you are a fan of official ceremony-driven install processes for executable modules of dubious provenance, soon there will be a possibility of installing this boutique software module via pip: pip install zipslicer
- This library is only for reading pytorch tensors from checkpoints. We leave writing for future work.
- Writing to loaded
state_dict
is frowned upon, but it will work - though you should avoid doing this while iterating over keys for now and expecting the keys to reflect this update. - Perhaps more importantly, general-purpose pickles are not supported - the design of this library doesn't allow you to load whole neural network class instances. Usually this isn't necessary, and pytorch official documentation recommends you to use
state_dict
for model serialization. We supportstate_dict
's. - Some rare tensor types (i.e: pytorch quantized tensors - not to be confused with integer tensors which work fine) are not yet supported. If this bothers you, share your experience in issues.
- We say "Hi" to HF
safetensors
project, but note that in comparison to theirs, our approach doesn't require checkpoint conversion which takes significant time and storage. In fact, both approaches could be complementary, as you will have to load tensors from the pytorch checkpoint somehow to convert it tosafetensors
- and the default loading mechanism is constrained by available RAM.
If you are interested in some of these features, consider creating an issue:
- Effective loading of tensor slices - to implement tensor parallelism in sharded deployments
- Accessing the source checkpoint over a network
- Writing to a checkpoint in-place
- Incremental conversion to other checkpoint formats