File tree Expand file tree Collapse file tree 4 files changed +17
-2
lines changed
Expand file tree Collapse file tree 4 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
88## [ Unreleased]
99
10+ ## [ 0.11.23] - 2025-08-18
11+
1012### Added
1113
1214- Add support for sharded loading from Safetensors checkpoints onto a JAX mesh.
Original file line number Diff line number Diff line change @@ -94,7 +94,16 @@ def _construct_maximal_sharding(
9494
9595
9696def construct_maximal_shardings (abstract_state : PyTree ) -> PyTree :
97- """Construct a sharding that partitions each array as much as possible."""
97+ """Construct a sharding that partitions each array as much as possible.
98+
99+ This method is subject to change and should not be considered stable.
100+
101+ Args:
102+ abstract_state: PyTree of jax.ShapeDtypeStruct.
103+
104+ Returns:
105+ PyTree of jax.sharding.Sharding.
106+ """
98107 shardings = jax .tree .map (_construct_maximal_sharding , abstract_state )
99108
100109 total_size = 0
Original file line number Diff line number Diff line change 2222 Index ,
2323 Shape ,
2424)
25+
26+ from orbax .checkpoint ._src .arrays .sharding import (
27+ construct_maximal_shardings ,
28+ )
Original file line number Diff line number Diff line change 1717# A new PyPI release will be pushed everytime `__version__` is increased.
1818# Also modify version and date in CHANGELOG.
1919# LINT.IfChange
20- __version__ = '0.11.22 '
20+ __version__ = '0.11.23 '
2121# LINT.ThenChange(//depot//orbax/checkpoint/CHANGELOG.md)
2222
2323
You can’t perform that action at this time.
0 commit comments