@@ -248,6 +248,9 @@ class ArrayRestoreArgs(RestoreArgs):
248248
249249 mesh: the device mesh that the array should be restored as. Cannot be None.
250250 mesh_axes: the mesh_axes that the array should be restored as. Cannot be None.
251+ sharding: jax.sharding.Sharding object which takes precedence over mesh and
252+ mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to
253+ construct a NamedSharding object.
251254 global_shapes: the global shape that the array should be restored into. If not
252255 provided, the shape will be restored as written. Presently, arbitrary shape
253256 transformations are not supported (for example, reshaping to different
@@ -259,6 +262,7 @@ class ArrayRestoreArgs(RestoreArgs):
259262 restore_type : Any = jax .Array
260263 mesh : Optional [Mesh ] = None
261264 mesh_axes : Optional [jax .sharding .PartitionSpec ] = None
265+ sharding : Optional [jax .sharding .Sharding ] = None
262266 global_shape : Optional [Tuple [int ]] = None
263267
264268
@@ -321,16 +325,19 @@ async def deserialize(self,
321325 if args is None :
322326 raise ValueError ('Must provide ArrayRestoreArgs to restore as jax.Array.' )
323327 args = cast (ArrayRestoreArgs , args )
324- if args .mesh is None or args .mesh_axes is None :
328+ if args .sharding is None and ( args . mesh is None or args .mesh_axes is None ) :
325329 raise ValueError (
326330 'Sharding of jax.Array cannot be None. Provide `mesh`'
327- ' and `mesh_axes`.'
331+ ' and `mesh_axes` OR `sharding` .'
328332 )
333+ if args .sharding is None :
334+ sharding = jax .sharding .NamedSharding (args .mesh , args .mesh_axes )
335+ else :
336+ sharding = args .sharding
329337 tspec = self ._get_json_tspec (info )
330338 tspec = _get_cast_tspec_deserialize (tspec , args )
331- s = jax .sharding .NamedSharding (args .mesh , args .mesh_axes )
332339 return await serialization .async_deserialize (
333- s ,
340+ sharding ,
334341 tspec ,
335342 global_shape = args .global_shape ,
336343 byte_limiter = info .byte_limiter ,
0 commit comments