Skip to content

Commit 2e3f137

Browse files
authored
Merge pull request #234 from google/test_513911122
Add optional sharding option for ArrayRestoreArgs.
2 parents c58635b + 58e73ed commit 2e3f137

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.1.3] - 2022-03-03
11+
12+
### Added
13+
- `sharding` option on `ArrayRestoreArgs
14+
1015
## [0.1.2] - 2022-02-17
1116

1217
### Added

orbax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
"""Orbax API."""
1616

1717
# A new PyPI release will be pushed everytime `__version__` is increased.
18-
__version__ = '0.1.2'
18+
__version__ = '0.1.3'

orbax/checkpoint/type_handlers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ class ArrayRestoreArgs(RestoreArgs):
248248
249249
mesh: the device mesh that the array should be restored as. Cannot be None.
250250
mesh_axes: the mesh_axes that the array should be restored as. Cannot be None.
251+
sharding: jax.sharding.Sharding object which takes precedence over mesh and
252+
mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to
253+
construct a NamedSharding object.
251254
global_shapes: the global shape that the array should be restored into. If not
252255
provided, the shape will be restored as written. Presently, arbitrary shape
253256
transformations are not supported (for example, reshaping to different
@@ -259,6 +262,7 @@ class ArrayRestoreArgs(RestoreArgs):
259262
restore_type: Any = jax.Array
260263
mesh: Optional[Mesh] = None
261264
mesh_axes: Optional[jax.sharding.PartitionSpec] = None
265+
sharding: Optional[jax.sharding.Sharding] = None
262266
global_shape: Optional[Tuple[int]] = None
263267

264268

@@ -321,16 +325,19 @@ async def deserialize(self,
321325
if args is None:
322326
raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.')
323327
args = cast(ArrayRestoreArgs, args)
324-
if args.mesh is None or args.mesh_axes is None:
328+
if args.sharding is None and (args.mesh is None or args.mesh_axes is None):
325329
raise ValueError(
326330
'Sharding of jax.Array cannot be None. Provide `mesh`'
327-
' and `mesh_axes`.'
331+
' and `mesh_axes` OR `sharding`.'
328332
)
333+
if args.sharding is None:
334+
sharding = jax.sharding.NamedSharding(args.mesh, args.mesh_axes)
335+
else:
336+
sharding = args.sharding
329337
tspec = self._get_json_tspec(info)
330338
tspec = _get_cast_tspec_deserialize(tspec, args)
331-
s = jax.sharding.NamedSharding(args.mesh, args.mesh_axes)
332339
return await serialization.async_deserialize(
333-
s,
340+
sharding,
334341
tspec,
335342
global_shape=args.global_shape,
336343
byte_limiter=info.byte_limiter,

0 commit comments

Comments
 (0)