1414
1515"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""
1616
17+ import asyncio
1718import collections
1819import json
19- from typing import Any , Awaitable , Sequence , cast
20+ import logging
21+ import time
22+ from typing import Any , Awaitable , Sequence
2023
2124import jax
2225import jax .numpy as jnp
2326import numpy as np
24- from orbax .checkpoint ._src .arrays import numpy_utils
27+ from orbax .checkpoint ._src .multihost import multihost as multihost_v0
2528from orbax .checkpoint ._src .path import async_path
2629from orbax .checkpoint .experimental .v1 ._src .layout import checkpoint_layout
2730from orbax .checkpoint .experimental .v1 ._src .metadata import types as metadata_types
2831from orbax .checkpoint .experimental .v1 ._src .path import types
32+ from orbax .checkpoint .experimental .v1 ._src .synchronization import multihost
2933
3034CheckpointLayout = checkpoint_layout .CheckpointLayout
3135InvalidLayoutError = checkpoint_layout .InvalidLayoutError
3236Path = types .Path
3337
3438HEADER_NUM_BYTES = 8
3539SAFETENSORS_SUFFIX = ".safetensors"
40+ MAX_GAP_SIZE_BYTES = (
41+ 32 * 1024 * 1024
42+ ) # 32 MB gap allowed between tensors in a coalesced read block
3643
3744
3845def _get_dtypes () -> dict [str , Any ]:
@@ -92,75 +99,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]:
9299 return shape , dtype
93100
94101
95- async def _read_non_contiguous_slice (
96- f : async_path .AsyncFile ,
97- idx : tuple [slice , ...],
98- stored_shape : tuple [int , ...],
99- stored_dtype : np .dtype ,
100- tensor_file_offset : int ,
101- ) -> np .ndarray :
102- """Reads a slice of a tensor from a file.
103-
104- This function solves the problem of reading a multi-dimensional slice from an
105- array where the slice's data is not stored as a single, contiguous block in
106- the file. It does so by recursively "walking" the dimensions of the slice.
107-
108- Args:
109- f: The asynchronous file object (binary read mode)
110- idx: A tuple of slice objects representing the n-dimensional slice to
111- read.
112- stored_shape: The shape of the tensor.
113- stored_dtype: The `dtype` of the tensor.
114- tensor_file_offset: The starting byte offset of the tensor's data within
115- the file.
116-
117- Returns:
118- The specific tensor slice.
119- """
120- # Handle 0-d scalar case
121- if not idx :
122- await f .seek (tensor_file_offset )
123- num_bytes = np .dtype (stored_dtype ).itemsize
124- scalar_bytes = await f .read (num_bytes )
125- # Reshape to () to create a 0-D NumPy array.
126- return np .frombuffer (scalar_bytes , dtype = stored_dtype ).reshape (())
127-
128- itemsize = np .dtype (stored_dtype ).itemsize
129-
130- # Calculate the byte strides for the full tensor. The stride for a
131- # dimension is the number of bytes to "jump" to get to the next element
132- # in that dimension while keeping all other indices the same.
133- global_strides = [itemsize ] * len (stored_shape )
134- for i in range (len (stored_shape ) - 2 , - 1 , - 1 ):
135- global_strides [i ] = global_strides [i + 1 ] * stored_shape [i + 1 ]
136-
137- async def _read_slice_recursively (dim : int , base_offset : int ) -> bytes :
138- # TODO(b/438763866) - @zachmeyers to consider alternative methods.
139- s = idx [dim ] # The slice for the current dimension.
140-
141- # If we are at the last dimension, the data is contiguous.
142- if dim == len (stored_shape ) - 1 :
143- start = base_offset + s .start * global_strides [dim ]
144- num_bytes = (s .stop - s .start ) * itemsize
145- await f .seek (tensor_file_offset + start )
146- return cast (bytes , await f .read (num_bytes ))
147-
148- # For all other dimensions, iterate through the indices
149- # of the slice and make a recursive call for the next dimension.
150- chunks = []
151- for i in range (s .start , s .stop ):
152- offset = base_offset + i * global_strides [dim ]
153- chunk = await _read_slice_recursively (dim + 1 , offset )
154- chunks .append (chunk )
155-
156- return b"" .join (chunks )
157-
158- # Start the recursive reading process from the first dimension.
159- slice_bytes = await _read_slice_recursively (dim = 0 , base_offset = 0 )
160- shard_shape = numpy_utils .slice_shape (idx )
161- return np .frombuffer (slice_bytes , dtype = stored_dtype ).reshape (shard_shape )
162-
163-
164102async def _load_safetensors_as_numpy (path : Path ) -> dict [str , np .ndarray ]:
165103 """Loads tensors from a safetensors file into host NumPy arrays."""
166104 header , data_start_offset = await _read_safetensors_header (path )
@@ -179,65 +117,227 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
179117 return tensors
180118
181119
120+ def _create_non_sharded_array (
121+ raw_data : memoryview | bytes ,
122+ abstract_leaf : Any ,
123+ stored_shape : tuple [int , ...],
124+ stored_dtype : Any ,
125+ ) -> jax .Array :
126+ """Creates a non-sharded JAX array from raw bytes."""
127+ np_array = np .frombuffer (raw_data , dtype = stored_dtype ).reshape (stored_shape )
128+ target_dtype = abstract_leaf .dtype
129+ if np_array .dtype != target_dtype :
130+ np_array = np_array .astype (target_dtype )
131+ return jax .device_put (np_array )
132+
133+
134+ def _create_sharded_array (
135+ raw_data : memoryview | bytes ,
136+ abstract_leaf : Any ,
137+ stored_shape : tuple [int , ...],
138+ stored_dtype : Any ,
139+ num_hosts : int ,
140+ host_id : int ,
141+ flat_sharding : jax .sharding .NamedSharding ,
142+ ) -> jax .Array :
143+ """Creates a sharded JAX array from raw bytes."""
144+ sharding = abstract_leaf .sharding
145+ target_dtype = abstract_leaf .dtype
146+
147+ # Use 1D flat contiguous read + reshard logic for maximum IO throughput.
148+ total_elements = int (np .prod (stored_shape )) if stored_shape else 1
149+
150+ # Calculate padding
151+ elements_per_host = (total_elements + num_hosts - 1 ) // num_hosts
152+ padded_elements = elements_per_host * num_hosts
153+
154+ start_idx = host_id * elements_per_host
155+ end_idx = min ((host_id + 1 ) * elements_per_host , total_elements )
156+ num_elements_to_read = max (0 , end_idx - start_idx )
157+
158+ local_data = np .frombuffer (raw_data , dtype = stored_dtype )
159+ if local_data .dtype != target_dtype :
160+ local_data = local_data .astype (target_dtype )
161+
162+ if num_elements_to_read < elements_per_host :
163+ local_data = np .pad (
164+ local_data , (0 , elements_per_host - num_elements_to_read )
165+ )
166+
167+ # Put local data on all addressable devices in the flat sharding
168+ put_start_time = time .time ()
169+ local_arrays = [
170+ jax .device_put (local_data , d ) for d in flat_sharding .addressable_devices
171+ ]
172+ put_end_time = time .time ()
173+
174+ logging .info (
175+ "[Host=%s] Put %s arrays in %s seconds" ,
176+ host_id ,
177+ len (local_arrays ),
178+ put_end_time - put_start_time ,
179+ )
180+
181+ # Create the 1D sharded array
182+ flat_array = jax .make_array_from_single_device_arrays (
183+ (padded_elements ,), flat_sharding , local_arrays
184+ )
185+
186+ # Slice off the padding and reshape
187+ if padded_elements > total_elements :
188+ flat_array = flat_array [:total_elements ]
189+
190+ reshaped_array = flat_array .reshape (stored_shape )
191+
192+ # Reshard to the target sharding
193+ reshard_start_time = time .time ()
194+ target_array = jax .device_put (reshaped_array , sharding )
195+ reshard_end_time = time .time ()
196+
197+ logging .info (
198+ "[Host=%s] Resharded array in %s seconds" ,
199+ host_id ,
200+ reshard_end_time - reshard_start_time ,
201+ )
202+
203+ return target_array
204+
205+
206+ async def _load_non_sharded_array (
207+ path : Path ,
208+ abstract_leaf : Any ,
209+ header_info : dict [str , Any ],
210+ data_start_offset : int ,
211+ ) -> jax .Array :
212+ """Loads a single non-sharded array from a safetensors file."""
213+ stored_shape , stored_dtype = _get_array_properties (header_info )
214+ st_data_offsets = header_info ["data_offsets" ]
215+
216+ start_offset , end_offset = st_data_offsets
217+ num_bytes = end_offset - start_offset
218+ async with async_path .open_file (path , mode = "rb" ) as f :
219+ await f .seek (data_start_offset + start_offset )
220+ tensor_bytes = await f .read (num_bytes )
221+
222+ return _create_non_sharded_array (
223+ tensor_bytes , abstract_leaf , stored_shape , stored_dtype
224+ )
225+
226+
227+ async def _load_sharded_array (
228+ path : Path ,
229+ abstract_leaf : Any ,
230+ header_info : dict [str , Any ],
231+ data_start_offset : int ,
232+ num_hosts : int ,
233+ host_id : int ,
234+ flat_sharding : jax .sharding .NamedSharding ,
235+ ) -> jax .Array :
236+ """Loads a single sharded array from a safetensors file."""
237+ stored_shape , stored_dtype = _get_array_properties (header_info )
238+ st_data_offsets = header_info ["data_offsets" ]
239+
240+ total_elements = int (np .prod (stored_shape )) if stored_shape else 1
241+ elements_per_host = (total_elements + num_hosts - 1 ) // num_hosts
242+ start_idx = host_id * elements_per_host
243+ end_idx = min ((host_id + 1 ) * elements_per_host , total_elements )
244+ num_elements_to_read = max (0 , end_idx - start_idx )
245+ itemsize = np .dtype (stored_dtype ).itemsize
246+
247+ start_byte = st_data_offsets [0 ] + data_start_offset + start_idx * itemsize
248+ num_bytes = num_elements_to_read * itemsize
249+
250+ async with async_path .open_file (path , mode = "rb" ) as f :
251+ await f .seek (start_byte )
252+ read_start_time = time .time ()
253+ raw_data = await f .read (num_bytes )
254+ read_end_time = time .time ()
255+
256+ logging .info (
257+ "[Host=%s] Read %s bytes in %s seconds" ,
258+ host_id ,
259+ num_bytes ,
260+ read_end_time - read_start_time ,
261+ )
262+
263+ return _create_sharded_array (
264+ raw_data ,
265+ abstract_leaf ,
266+ stored_shape ,
267+ stored_dtype ,
268+ num_hosts ,
269+ host_id ,
270+ flat_sharding ,
271+ )
272+
273+
182274async def _load_safetensors_on_device (
183275 path : Path , abstract_pytree : dict [str , Any ]
184276) -> dict [str , jax .Array ]:
185277 """Loads tensors from a safetensors file into on-device JAX arrays."""
186278 header , data_start_offset = await _read_safetensors_header (path )
187279 restored_pytree = {}
188- async with async_path .open_file (path , mode = "rb" ) as f :
189- for tensor_name , abstract_leaf in abstract_pytree .items ():
190- if tensor_name not in header :
191- raise KeyError (
192- f"Tensor '{ tensor_name } ' not found in safetensors header of { path } ."
193- )
194280
195- stored_shape , stored_dtype = _get_array_properties (header [tensor_name ])
196- st_data_offsets = header [tensor_name ]["data_offsets" ]
197- sharding = abstract_leaf .sharding
198- target_shape = abstract_leaf .shape
199- target_dtype = abstract_leaf .dtype
200-
201- if sharding is None :
202- start_offset , end_offset = st_data_offsets
203- num_bytes = end_offset - start_offset
204- await f .seek (data_start_offset + start_offset )
205- tensor_bytes = await f .read (num_bytes )
206- np_array = np .frombuffer (tensor_bytes , dtype = stored_dtype ).reshape (
207- stored_shape
208- )
209- if np_array .dtype != target_dtype :
210- np_array = np_array .astype (target_dtype )
211- restored_pytree [tensor_name ] = jax .device_put (np_array )
212- continue
213-
214- device_indices_map = sharding .addressable_devices_indices_map (
215- target_shape
281+ num_hosts = multihost .process_count ()
282+ host_id = jax .process_index ()
283+
284+ # Build an initial mesh grouping all global devices by host
285+ devices_by_host = []
286+ for i in range (num_hosts ):
287+ devices_by_host .append ([
288+ d
289+ for d in jax .devices ()
290+ if multihost_v0 .process_index_from_device (d ) == i
291+ ])
292+
293+ # Ensure uniform mesh shape (in case of uneven device counts, which is rare)
294+ num_devices_per_host = len (devices_by_host [0 ])
295+ for d in devices_by_host :
296+ if len (d ) != num_devices_per_host :
297+ raise ValueError ("Number of devices must be the same across all hosts." )
298+
299+ initial_mesh = jax .sharding .Mesh (
300+ np .array (devices_by_host ), ("hosts" , "devices" )
301+ )
302+ flat_sharding = jax .sharding .NamedSharding (
303+ initial_mesh , jax .sharding .PartitionSpec ("hosts" )
304+ )
305+
306+ async def _load_tensor (
307+ tensor_name : str , abstract_leaf : Any
308+ ) -> tuple [str , jax .Array ]:
309+ if abstract_leaf .sharding is None :
310+ tensor = await _load_non_sharded_array (
311+ path ,
312+ abstract_leaf ,
313+ header [tensor_name ],
314+ data_start_offset ,
216315 )
316+ else :
317+ # We have a target sharding.
318+ tensor = await _load_sharded_array (
319+ path ,
320+ abstract_leaf ,
321+ header [tensor_name ],
322+ data_start_offset ,
323+ num_hosts ,
324+ host_id ,
325+ flat_sharding ,
326+ )
327+ return tensor_name , tensor
217328
218- device_map = []
219- for device in device_indices_map :
220- idx = device_indices_map [device ]
221- resolved_idx = numpy_utils .resolve_slice (idx , stored_shape )
222- shard_shape = numpy_utils .slice_shape (resolved_idx )
223-
224- shard_np = await _read_non_contiguous_slice (
225- f ,
226- resolved_idx ,
227- stored_shape ,
228- stored_dtype ,
229- st_data_offsets [0 ] + data_start_offset ,
230- )
231- shard_np = shard_np .reshape (shard_shape ) # pytype: disable=attribute-error
232-
233- if shard_np .dtype != target_dtype :
234- shard_np = shard_np .astype (target_dtype )
329+ tasks = []
330+ for tensor_name , abstract_leaf in abstract_pytree .items ():
331+ if tensor_name not in header :
332+ raise KeyError (
333+ f"Tensor '{ tensor_name } ' not found in safetensors header of { path } ."
334+ )
335+ tasks .append (_load_tensor (tensor_name , abstract_leaf ))
235336
236- device_map .append (jax .device_put (shard_np , device ))
337+ results = await asyncio .gather (* tasks )
338+ for tensor_name , tensor in results :
339+ restored_pytree [tensor_name ] = tensor
237340
238- restored_pytree [tensor_name ] = jax .make_array_from_single_device_arrays (
239- target_shape , sharding , device_map
240- )
241341 return restored_pytree
242342
243343
0 commit comments