Skip to content

Commit 14d5128

Browse files
angel-coreOrbax Authors
authored andcommitted
Update SingleDeviceShardingMetadata sharding metadata to handle CPU device naming string changes.
PiperOrigin-RevId: 893518418
1 parent 09d2982 commit 14d5128

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

checkpoint/orbax/checkpoint/_src/metadata/sharding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,13 @@ def __repr__(self):
325325
return f'SingleDeviceShardingMetadata(device_str={self.device_str})'
326326

327327
def __eq__(self, other):
328-
return self.device_str == other.device_str
328+
if not isinstance(other, SingleDeviceShardingMetadata):
329+
return False
330+
# JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0
331+
return (
332+
self.device_str.replace('TFRT_CPU_', 'cpu:')
333+
== other.device_str.replace('TFRT_CPU_', 'cpu:')
334+
)
329335

330336

331337
def from_jax_sharding(jax_sharding) -> Optional[ShardingMetadata]:

checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,7 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
9393
jax_sharding = jax.sharding.SingleDeviceSharding(
9494
jax.local_devices(backend="cpu")[0]
9595
)
96-
# JAX used to report its cpu devices as TFRT_CPU_0
9796
expected_single_device_sharding_metadata = (
98-
sharding_metadata.SingleDeviceShardingMetadata(device_str="TFRT_CPU_0")
99-
)
100-
# ... but now uses cpu:0
101-
expected_single_device_sharding_metadata2 = (
10297
sharding_metadata.SingleDeviceShardingMetadata(device_str="cpu:0")
10398
)
10499
converted_single_device_sharding_metadata = (
@@ -109,12 +104,9 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
109104
converted_single_device_sharding_metadata,
110105
sharding_metadata.SingleDeviceShardingMetadata,
111106
)
112-
self.assertIn(
107+
self.assertEqual(
113108
converted_single_device_sharding_metadata,
114-
[
115-
expected_single_device_sharding_metadata,
116-
expected_single_device_sharding_metadata2,
117-
],
109+
expected_single_device_sharding_metadata,
118110
)
119111

120112
# Convert from `SingleDeviceShardingMetadata` to

0 commit comments

Comments
 (0)