Skip to content

Commit b582c81

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Optimize safetensors loading with 1D contiguous reads and ICI resharding.
PiperOrigin-RevId: 892406249
1 parent 31f0377 commit b582c81

File tree

1 file changed

+219
-119
lines changed

1 file changed

+219
-119
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py

Lines changed: 219 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,32 @@
1414

1515
"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""
1616

17+
import asyncio
1718
import collections
1819
import json
19-
from typing import Any, Awaitable, Sequence, cast
20+
import logging
21+
import time
22+
from typing import Any, Awaitable, Sequence
2023

2124
import jax
2225
import jax.numpy as jnp
2326
import numpy as np
24-
from orbax.checkpoint._src.arrays import numpy_utils
27+
from orbax.checkpoint._src.multihost import multihost as multihost_v0
2528
from orbax.checkpoint._src.path import async_path
2629
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
2730
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
2831
from orbax.checkpoint.experimental.v1._src.path import types
32+
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
2933

3034
CheckpointLayout = checkpoint_layout.CheckpointLayout
3135
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
3236
Path = types.Path
3337

3438
HEADER_NUM_BYTES = 8
3539
SAFETENSORS_SUFFIX = ".safetensors"
40+
MAX_GAP_SIZE_BYTES = (
41+
32 * 1024 * 1024
42+
) # 32 MB gap allowed between tensors in a coalesced read block
3643

3744

3845
def _get_dtypes() -> dict[str, Any]:
@@ -92,75 +99,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]:
9299
return shape, dtype
93100

94101

95-
async def _read_non_contiguous_slice(
96-
f: async_path.AsyncFile,
97-
idx: tuple[slice, ...],
98-
stored_shape: tuple[int, ...],
99-
stored_dtype: np.dtype,
100-
tensor_file_offset: int,
101-
) -> np.ndarray:
102-
"""Reads a slice of a tensor from a file.
103-
104-
This function solves the problem of reading a multi-dimensional slice from an
105-
array where the slice's data is not stored as a single, contiguous block in
106-
the file. It does so by recursively "walking" the dimensions of the slice.
107-
108-
Args:
109-
f: The asynchronous file object (binary read mode)
110-
idx: A tuple of slice objects representing the n-dimensional slice to
111-
read.
112-
stored_shape: The shape of the tensor.
113-
stored_dtype: The `dtype` of the tensor.
114-
tensor_file_offset: The starting byte offset of the tensor's data within
115-
the file.
116-
117-
Returns:
118-
The specific tensor slice.
119-
"""
120-
# Handle 0-d scalar case
121-
if not idx:
122-
await f.seek(tensor_file_offset)
123-
num_bytes = np.dtype(stored_dtype).itemsize
124-
scalar_bytes = await f.read(num_bytes)
125-
# Reshape to () to create a 0-D NumPy array.
126-
return np.frombuffer(scalar_bytes, dtype=stored_dtype).reshape(())
127-
128-
itemsize = np.dtype(stored_dtype).itemsize
129-
130-
# Calculate the byte strides for the full tensor. The stride for a
131-
# dimension is the number of bytes to "jump" to get to the next element
132-
# in that dimension while keeping all other indices the same.
133-
global_strides = [itemsize] * len(stored_shape)
134-
for i in range(len(stored_shape) - 2, -1, -1):
135-
global_strides[i] = global_strides[i + 1] * stored_shape[i + 1]
136-
137-
async def _read_slice_recursively(dim: int, base_offset: int) -> bytes:
138-
# TODO(b/438763866) - @zachmeyers to consider alternative methods.
139-
s = idx[dim] # The slice for the current dimension.
140-
141-
# If we are at the last dimension, the data is contiguous.
142-
if dim == len(stored_shape) - 1:
143-
start = base_offset + s.start * global_strides[dim]
144-
num_bytes = (s.stop - s.start) * itemsize
145-
await f.seek(tensor_file_offset + start)
146-
return cast(bytes, await f.read(num_bytes))
147-
148-
# For all other dimensions, iterate through the indices
149-
# of the slice and make a recursive call for the next dimension.
150-
chunks = []
151-
for i in range(s.start, s.stop):
152-
offset = base_offset + i * global_strides[dim]
153-
chunk = await _read_slice_recursively(dim + 1, offset)
154-
chunks.append(chunk)
155-
156-
return b"".join(chunks)
157-
158-
# Start the recursive reading process from the first dimension.
159-
slice_bytes = await _read_slice_recursively(dim=0, base_offset=0)
160-
shard_shape = numpy_utils.slice_shape(idx)
161-
return np.frombuffer(slice_bytes, dtype=stored_dtype).reshape(shard_shape)
162-
163-
164102
async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
165103
"""Loads tensors from a safetensors file into host NumPy arrays."""
166104
header, data_start_offset = await _read_safetensors_header(path)
@@ -179,65 +117,227 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
179117
return tensors
180118

181119

120+
def _create_non_sharded_array(
121+
raw_data: memoryview | bytes,
122+
abstract_leaf: Any,
123+
stored_shape: tuple[int, ...],
124+
stored_dtype: Any,
125+
) -> jax.Array:
126+
"""Creates a non-sharded JAX array from raw bytes."""
127+
np_array = np.frombuffer(raw_data, dtype=stored_dtype).reshape(stored_shape)
128+
target_dtype = abstract_leaf.dtype
129+
if np_array.dtype != target_dtype:
130+
np_array = np_array.astype(target_dtype)
131+
return jax.device_put(np_array)
132+
133+
134+
def _create_sharded_array(
135+
raw_data: memoryview | bytes,
136+
abstract_leaf: Any,
137+
stored_shape: tuple[int, ...],
138+
stored_dtype: Any,
139+
num_hosts: int,
140+
host_id: int,
141+
flat_sharding: jax.sharding.NamedSharding,
142+
) -> jax.Array:
143+
"""Creates a sharded JAX array from raw bytes."""
144+
sharding = abstract_leaf.sharding
145+
target_dtype = abstract_leaf.dtype
146+
147+
# Use 1D flat contiguous read + reshard logic for maximum IO throughput.
148+
total_elements = int(np.prod(stored_shape)) if stored_shape else 1
149+
150+
# Calculate padding
151+
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
152+
padded_elements = elements_per_host * num_hosts
153+
154+
start_idx = host_id * elements_per_host
155+
end_idx = min((host_id + 1) * elements_per_host, total_elements)
156+
num_elements_to_read = max(0, end_idx - start_idx)
157+
158+
local_data = np.frombuffer(raw_data, dtype=stored_dtype)
159+
if local_data.dtype != target_dtype:
160+
local_data = local_data.astype(target_dtype)
161+
162+
if num_elements_to_read < elements_per_host:
163+
local_data = np.pad(
164+
local_data, (0, elements_per_host - num_elements_to_read)
165+
)
166+
167+
# Put local data on all addressable devices in the flat sharding
168+
put_start_time = time.time()
169+
local_arrays = [
170+
jax.device_put(local_data, d) for d in flat_sharding.addressable_devices
171+
]
172+
put_end_time = time.time()
173+
174+
logging.info(
175+
"[Host=%s] Put %s arrays in %s seconds",
176+
host_id,
177+
len(local_arrays),
178+
put_end_time - put_start_time,
179+
)
180+
181+
# Create the 1D sharded array
182+
flat_array = jax.make_array_from_single_device_arrays(
183+
(padded_elements,), flat_sharding, local_arrays
184+
)
185+
186+
# Slice off the padding and reshape
187+
if padded_elements > total_elements:
188+
flat_array = flat_array[:total_elements]
189+
190+
reshaped_array = flat_array.reshape(stored_shape)
191+
192+
# Reshard to the target sharding
193+
reshard_start_time = time.time()
194+
target_array = jax.device_put(reshaped_array, sharding)
195+
reshard_end_time = time.time()
196+
197+
logging.info(
198+
"[Host=%s] Resharded array in %s seconds",
199+
host_id,
200+
reshard_end_time - reshard_start_time,
201+
)
202+
203+
return target_array
204+
205+
206+
async def _load_non_sharded_array(
207+
path: Path,
208+
abstract_leaf: Any,
209+
header_info: dict[str, Any],
210+
data_start_offset: int,
211+
) -> jax.Array:
212+
"""Loads a single non-sharded array from a safetensors file."""
213+
stored_shape, stored_dtype = _get_array_properties(header_info)
214+
st_data_offsets = header_info["data_offsets"]
215+
216+
start_offset, end_offset = st_data_offsets
217+
num_bytes = end_offset - start_offset
218+
async with async_path.open_file(path, mode="rb") as f:
219+
await f.seek(data_start_offset + start_offset)
220+
tensor_bytes = await f.read(num_bytes)
221+
222+
return _create_non_sharded_array(
223+
tensor_bytes, abstract_leaf, stored_shape, stored_dtype
224+
)
225+
226+
227+
async def _load_sharded_array(
228+
path: Path,
229+
abstract_leaf: Any,
230+
header_info: dict[str, Any],
231+
data_start_offset: int,
232+
num_hosts: int,
233+
host_id: int,
234+
flat_sharding: jax.sharding.NamedSharding,
235+
) -> jax.Array:
236+
"""Loads a single sharded array from a safetensors file."""
237+
stored_shape, stored_dtype = _get_array_properties(header_info)
238+
st_data_offsets = header_info["data_offsets"]
239+
240+
total_elements = int(np.prod(stored_shape)) if stored_shape else 1
241+
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
242+
start_idx = host_id * elements_per_host
243+
end_idx = min((host_id + 1) * elements_per_host, total_elements)
244+
num_elements_to_read = max(0, end_idx - start_idx)
245+
itemsize = np.dtype(stored_dtype).itemsize
246+
247+
start_byte = st_data_offsets[0] + data_start_offset + start_idx * itemsize
248+
num_bytes = num_elements_to_read * itemsize
249+
250+
async with async_path.open_file(path, mode="rb") as f:
251+
await f.seek(start_byte)
252+
read_start_time = time.time()
253+
raw_data = await f.read(num_bytes)
254+
read_end_time = time.time()
255+
256+
logging.info(
257+
"[Host=%s] Read %s bytes in %s seconds",
258+
host_id,
259+
num_bytes,
260+
read_end_time - read_start_time,
261+
)
262+
263+
return _create_sharded_array(
264+
raw_data,
265+
abstract_leaf,
266+
stored_shape,
267+
stored_dtype,
268+
num_hosts,
269+
host_id,
270+
flat_sharding,
271+
)
272+
273+
182274
async def _load_safetensors_on_device(
183275
path: Path, abstract_pytree: dict[str, Any]
184276
) -> dict[str, jax.Array]:
185277
"""Loads tensors from a safetensors file into on-device JAX arrays."""
186278
header, data_start_offset = await _read_safetensors_header(path)
187279
restored_pytree = {}
188-
async with async_path.open_file(path, mode="rb") as f:
189-
for tensor_name, abstract_leaf in abstract_pytree.items():
190-
if tensor_name not in header:
191-
raise KeyError(
192-
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
193-
)
194280

195-
stored_shape, stored_dtype = _get_array_properties(header[tensor_name])
196-
st_data_offsets = header[tensor_name]["data_offsets"]
197-
sharding = abstract_leaf.sharding
198-
target_shape = abstract_leaf.shape
199-
target_dtype = abstract_leaf.dtype
200-
201-
if sharding is None:
202-
start_offset, end_offset = st_data_offsets
203-
num_bytes = end_offset - start_offset
204-
await f.seek(data_start_offset + start_offset)
205-
tensor_bytes = await f.read(num_bytes)
206-
np_array = np.frombuffer(tensor_bytes, dtype=stored_dtype).reshape(
207-
stored_shape
208-
)
209-
if np_array.dtype != target_dtype:
210-
np_array = np_array.astype(target_dtype)
211-
restored_pytree[tensor_name] = jax.device_put(np_array)
212-
continue
213-
214-
device_indices_map = sharding.addressable_devices_indices_map(
215-
target_shape
281+
num_hosts = multihost.process_count()
282+
host_id = jax.process_index()
283+
284+
# Build an initial mesh grouping all global devices by host
285+
devices_by_host = []
286+
for i in range(num_hosts):
287+
devices_by_host.append([
288+
d
289+
for d in jax.devices()
290+
if multihost_v0.process_index_from_device(d) == i
291+
])
292+
293+
# Ensure uniform mesh shape (in case of uneven device counts, which is rare)
294+
num_devices_per_host = len(devices_by_host[0])
295+
for d in devices_by_host:
296+
if len(d) != num_devices_per_host:
297+
raise ValueError("Number of devices must be the same across all hosts.")
298+
299+
initial_mesh = jax.sharding.Mesh(
300+
np.array(devices_by_host), ("hosts", "devices")
301+
)
302+
flat_sharding = jax.sharding.NamedSharding(
303+
initial_mesh, jax.sharding.PartitionSpec("hosts")
304+
)
305+
306+
async def _load_tensor(
307+
tensor_name: str, abstract_leaf: Any
308+
) -> tuple[str, jax.Array]:
309+
if abstract_leaf.sharding is None:
310+
tensor = await _load_non_sharded_array(
311+
path,
312+
abstract_leaf,
313+
header[tensor_name],
314+
data_start_offset,
216315
)
316+
else:
317+
# We have a target sharding.
318+
tensor = await _load_sharded_array(
319+
path,
320+
abstract_leaf,
321+
header[tensor_name],
322+
data_start_offset,
323+
num_hosts,
324+
host_id,
325+
flat_sharding,
326+
)
327+
return tensor_name, tensor
217328

218-
device_map = []
219-
for device in device_indices_map:
220-
idx = device_indices_map[device]
221-
resolved_idx = numpy_utils.resolve_slice(idx, stored_shape)
222-
shard_shape = numpy_utils.slice_shape(resolved_idx)
223-
224-
shard_np = await _read_non_contiguous_slice(
225-
f,
226-
resolved_idx,
227-
stored_shape,
228-
stored_dtype,
229-
st_data_offsets[0] + data_start_offset,
230-
)
231-
shard_np = shard_np.reshape(shard_shape) # pytype: disable=attribute-error
232-
233-
if shard_np.dtype != target_dtype:
234-
shard_np = shard_np.astype(target_dtype)
329+
tasks = []
330+
for tensor_name, abstract_leaf in abstract_pytree.items():
331+
if tensor_name not in header:
332+
raise KeyError(
333+
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
334+
)
335+
tasks.append(_load_tensor(tensor_name, abstract_leaf))
235336

236-
device_map.append(jax.device_put(shard_np, device))
337+
results = await asyncio.gather(*tasks)
338+
for tensor_name, tensor in results:
339+
restored_pytree[tensor_name] = tensor
237340

238-
restored_pytree[tensor_name] = jax.make_array_from_single_device_arrays(
239-
target_shape, sharding, device_map
240-
)
241341
return restored_pytree
242342

243343

0 commit comments

Comments
 (0)