File tree Expand file tree Collapse file tree 2 files changed +9
-11
lines changed
checkpoint/orbax/checkpoint/_src/metadata Expand file tree Collapse file tree 2 files changed +9
-11
lines changed Original file line number Diff line number Diff line change @@ -325,7 +325,13 @@ def __repr__(self):
325325 return f'SingleDeviceShardingMetadata(device_str={ self .device_str } )'
326326
327327 def __eq__ (self , other ):
328- return self .device_str == other .device_str
328+ if not isinstance (other , SingleDeviceShardingMetadata ):
329+ return False
330+ # JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0
331+ return (
332+ self .device_str .replace ('TFRT_CPU_' , 'cpu:' )
333+ == other .device_str .replace ('TFRT_CPU_' , 'cpu:' )
334+ )
329335
330336
331337def from_jax_sharding (jax_sharding ) -> Optional [ShardingMetadata ]:
Original file line number Diff line number Diff line change @@ -93,12 +93,7 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
9393 jax_sharding = jax .sharding .SingleDeviceSharding (
9494 jax .local_devices (backend = "cpu" )[0 ]
9595 )
96- # JAX used to report its cpu devices as TFRT_CPU_0
9796 expected_single_device_sharding_metadata = (
98- sharding_metadata .SingleDeviceShardingMetadata (device_str = "TFRT_CPU_0" )
99- )
100- # ... but now uses cpu:0
101- expected_single_device_sharding_metadata2 = (
10297 sharding_metadata .SingleDeviceShardingMetadata (device_str = "cpu:0" )
10398 )
10499 converted_single_device_sharding_metadata = (
@@ -109,12 +104,9 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
109104 converted_single_device_sharding_metadata ,
110105 sharding_metadata .SingleDeviceShardingMetadata ,
111106 )
112- self .assertIn (
107+ self .assertEqual (
113108 converted_single_device_sharding_metadata ,
114- [
115- expected_single_device_sharding_metadata ,
116- expected_single_device_sharding_metadata2 ,
117- ],
109+ expected_single_device_sharding_metadata ,
118110 )
119111
120112 # Convert from `SingleDeviceShardingMetadata` to
You can’t perform that action at this time.
0 commit comments