@@ -185,6 +185,26 @@ async def _load_safetensors_on_device(
185185 """Loads tensors from a safetensors file into on-device JAX arrays."""
186186 header , data_start_offset = await _read_safetensors_header (path )
187187 restored_pytree = {}
188+
189+ num_hosts = jax .process_count ()
190+ host_id = jax .process_index ()
191+
192+ # Build an initial mesh grouping all global devices by host
193+ devices_by_host = []
194+ for i in range (num_hosts ):
195+ devices_by_host .append ([d for d in jax .devices () if d .process_index == i ])
196+
197+ # Ensure uniform mesh shape (in case of uneven device counts, which is rare)
198+ min_devices = min (len (d ) for d in devices_by_host )
199+ devices_by_host = [d [:min_devices ] for d in devices_by_host ]
200+
201+ initial_mesh = jax .sharding .Mesh (
202+ np .array (devices_by_host ), ("hosts" , "devices" )
203+ )
204+ flat_sharding = jax .sharding .NamedSharding (
205+ initial_mesh , jax .sharding .PartitionSpec ("hosts" )
206+ )
207+
188208 async with async_path .open_file (path , mode = "rb" ) as f :
189209 for tensor_name , abstract_leaf in abstract_pytree .items ():
190210 if tensor_name not in header :
@@ -195,7 +215,6 @@ async def _load_safetensors_on_device(
195215 stored_shape , stored_dtype = _get_array_properties (header [tensor_name ])
196216 st_data_offsets = header [tensor_name ]["data_offsets" ]
197217 sharding = abstract_leaf .sharding
198- target_shape = abstract_leaf .shape
199218 target_dtype = abstract_leaf .dtype
200219
201220 if sharding is None :
@@ -211,33 +230,54 @@ async def _load_safetensors_on_device(
211230 restored_pytree [tensor_name ] = jax .device_put (np_array )
212231 continue
213232
214- device_indices_map = sharding .addressable_devices_indices_map (
215- target_shape
216- )
233+ # We have a target sharding.
234+ # Use 1D flat contiguous read + reshard logic for maximum IO throughput.
235+ total_elements = int ( np . prod ( stored_shape )) if stored_shape else 1
217236
218- device_map = []
219- for device in device_indices_map :
220- idx = device_indices_map [device ]
221- resolved_idx = numpy_utils .resolve_slice (idx , stored_shape )
222- shard_shape = numpy_utils .slice_shape (resolved_idx )
223-
224- shard_np = await _read_non_contiguous_slice (
225- f ,
226- resolved_idx ,
227- stored_shape ,
228- stored_dtype ,
229- st_data_offsets [0 ] + data_start_offset ,
230- )
231- shard_np = shard_np .reshape (shard_shape ) # pytype: disable=attribute-error
237+ # Calculate padding
238+ elements_per_host = (total_elements + num_hosts - 1 ) // num_hosts
239+ padded_elements = elements_per_host * num_hosts
240+
241+ # Calculate what this host needs to read
242+ start_idx = host_id * elements_per_host
243+ end_idx = min ((host_id + 1 ) * elements_per_host , total_elements )
244+ num_elements_to_read = max (0 , end_idx - start_idx )
245+ itemsize = np .dtype (stored_dtype ).itemsize
232246
233- if shard_np . dtype != target_dtype :
234- shard_np = shard_np . astype ( target_dtype )
247+ start_byte = st_data_offsets [ 0 ] + data_start_offset + start_idx * itemsize
248+ num_bytes = num_elements_to_read * itemsize
235249
236- device_map .append (jax .device_put (shard_np , device ))
250+ await f .seek (start_byte )
251+ raw_data = await f .read (num_bytes )
237252
238- restored_pytree [tensor_name ] = jax .make_array_from_single_device_arrays (
239- target_shape , sharding , device_map
253+ local_data = np .frombuffer (raw_data , dtype = stored_dtype )
254+ if local_data .dtype != target_dtype :
255+ local_data = local_data .astype (target_dtype )
256+
257+ if num_elements_to_read < elements_per_host :
258+ local_data = np .pad (
259+ local_data , (0 , elements_per_host - num_elements_to_read )
260+ )
261+
262+ # Put local data on all addressable devices in the flat sharding
263+ local_arrays = [
264+ jax .device_put (local_data , d )
265+ for d in flat_sharding .addressable_devices
266+ ]
267+
268+ # Create the 1D sharded array
269+ flat_array = jax .make_array_from_single_device_arrays (
270+ (padded_elements ,), flat_sharding , local_arrays
240271 )
272+
273+ # Slice off the padding and reshape
274+ if padded_elements > total_elements :
275+ flat_array = flat_array [:total_elements ]
276+
277+ reshaped_array = flat_array .reshape (stored_shape )
278+
279+ # Reshard to the target sharding
280+ restored_pytree [tensor_name ] = jax .device_put (reshaped_array , sharding )
241281 return restored_pytree
242282
243283
0 commit comments