Skip to content

Commit bd9b156

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Release new version. Add construct_maximal_shardings to public interface.
PiperOrigin-RevId: 796546347
1 parent fe44336 commit bd9b156

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.23] - 2025-08-18
11+
1012
### Added
1113

1214
- Add support for sharded loading from Safetensors checkpoints onto a JAX mesh.

checkpoint/orbax/checkpoint/_src/arrays/sharding.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,16 @@ def _construct_maximal_sharding(
9494

9595

9696
def construct_maximal_shardings(abstract_state: PyTree) -> PyTree:
97-
"""Construct a sharding that partitions each array as much as possible."""
97+
"""Construct a sharding that partitions each array as much as possible.
98+
99+
This method is subject to change and should not be considered stable.
100+
101+
Args:
102+
abstract_state: PyTree of jax.ShapeDtypeStruct.
103+
104+
Returns:
105+
PyTree of jax.sharding.Sharding.
106+
"""
98107
shardings = jax.tree.map(_construct_maximal_sharding, abstract_state)
99108

100109
total_size = 0

checkpoint/orbax/checkpoint/arrays.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@
2222
Index,
2323
Shape,
2424
)
25+
26+
from orbax.checkpoint._src.arrays.sharding import (
27+
construct_maximal_shardings,
28+
)

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
1919
# LINT.IfChange
20-
__version__ = '0.11.22'
20+
__version__ = '0.11.23'
2121
# LINT.ThenChange(//depot//orbax/checkpoint/CHANGELOG.md)
2222

2323

0 commit comments

Comments
 (0)