Skip to content

Commit 7ddb629

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Fix pinned_host loading by ensuring memory_kind is carried through when creating in-memory buffers.
PiperOrigin-RevId: 738822867
1 parent d3e8781 commit 7ddb629

File tree

5 files changed

+56
-5
lines changed

5 files changed

+56
-5
lines changed

checkpoint/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.10] - 2025-03-20
11+
1012
### Added
1113

1214
- Add `fallback_sharding` option to `StandardRestoreArgs` to support restoring
1315
on different topologies easily.
1416

17+
### Fixed
18+
19+
- Fix pinned_host loading by ensuring `memory_kind` is carried through when creating in-memory buffers.
20+
1521
## [0.11.9] - 2025-03-17
1622

1723
### Added

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,3 +2378,24 @@ def test_save_restore_random_keys(self, use_ocdbt: bool):
23782378
self.directory, args=PyTreeRestoreArgs(pytree, restore_args)
23792379
)
23802380
test_utils.assert_tree_equal(self, pytree, restored)
2381+
2382+
def test_pinned_host_loading(self):
2383+
if multihost.is_pathways_backend():
2384+
# TODO(b/404915487): Reenable when possible.
2385+
self.skipTest('Disabled due to b/404915487.')
2386+
pytree = dict(arr=np.ones((1024, 512)))
2387+
self.handler.save(self.directory, args=PyTreeSaveArgs(pytree))
2388+
2389+
mesh = jax.sharding.Mesh(
2390+
np.asarray(jax.devices()).reshape((1, len(jax.devices()))), ('x', 'y')
2391+
)
2392+
sharding = jax.sharding.NamedSharding(
2393+
mesh, jax.sharding.PartitionSpec('x', 'y')
2394+
).with_memory_kind('pinned_host')
2395+
2396+
restore_args = dict(arr=ArrayRestoreArgs(sharding=sharding))
2397+
restored = self.handler.restore(
2398+
self.directory, args=PyTreeRestoreArgs(restore_args=restore_args)
2399+
)
2400+
expected = dict(arr=jax.device_put(np.ones((1024, 512)), sharding))
2401+
self.validate_restore(expected, restored)

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ async def _read_array_index_and_device_put(
401401
byte_limiter: ByteLimiter,
402402
strict: bool,
403403
dll: Optional[layout.DeviceLocalLayout],
404+
memory_kind: Optional[str],
404405
) -> list[jax.Array]:
405406
"""Callback that reads an array index and places on the devices."""
406407
for sl in index:
@@ -450,11 +451,10 @@ async def _read_array_index_and_device_put(
450451
f' TensorStore details: {t.spec}.'
451452
) from e
452453
for device in devices:
453-
result.append(
454-
jax.device_put(
455-
shard, Layout(dll, jax.sharding.SingleDeviceSharding(device))
456-
)
454+
sharding = jax.sharding.SingleDeviceSharding(
455+
device, memory_kind=memory_kind
457456
)
457+
result.append(jax.device_put(shard, Layout(dll, sharding)))
458458
return result
459459

460460

@@ -496,6 +496,7 @@ async def read_and_create_array(
496496
byte_limiter=byte_limiter,
497497
strict=strict,
498498
dll=dll,
499+
memory_kind=sharding.memory_kind,
499500
)
500501
for idx, devices in local_indices_devices_map.items()
501502
]

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,3 +1879,26 @@ def test_save_restore_random_keys(self, use_ocdbt: bool):
18791879
) as load_handler:
18801880
restored = load_handler.load(self.directory)
18811881
test_utils.assert_tree_equal(self, pytree, restored)
1882+
1883+
def test_pinned_host_loading(self):
1884+
if multihost.is_pathways_backend():
1885+
# TODO(b/404915487): Reenable when possible.
1886+
self.skipTest('Disabled due to b/404915487.')
1887+
pytree = dict(arr=np.ones((1024, 512)))
1888+
self.handler.save(self.directory, pytree)
1889+
1890+
mesh = jax.sharding.Mesh(
1891+
np.asarray(jax.devices()).reshape((1, len(jax.devices()))), ('x', 'y')
1892+
)
1893+
sharding = jax.sharding.NamedSharding(
1894+
mesh, jax.sharding.PartitionSpec('x', 'y')
1895+
).with_memory_kind('pinned_host')
1896+
1897+
abstract_pytree = dict(
1898+
arr=jax.ShapeDtypeStruct(
1899+
pytree['arr'].shape, pytree['arr'].dtype, sharding=sharding
1900+
)
1901+
)
1902+
restored = self.handler.load(self.directory, abstract_pytree)
1903+
expected = dict(arr=jax.device_put(np.ones((1024, 512)), sharding))
1904+
test_utils.assert_tree_equal(self, expected, restored)

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
1919
# LINT.IfChange
20-
__version__ = '0.11.9'
20+
__version__ = '0.11.10'
2121
# LINT.ThenChange(//depot//orbax/checkpoint/CHANGELOG.md)
2222

2323

0 commit comments

Comments
 (0)