Skip to content

Commit 3a9ba8f

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Optimize safetensors loading with 1D contiguous reads and ICI resharding.
PiperOrigin-RevId: 892406249
1 parent 21f35b3 commit 3a9ba8f

File tree

1 file changed

+63
-23
lines changed

1 file changed

+63
-23
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,26 @@ async def _load_safetensors_on_device(
185185
"""Loads tensors from a safetensors file into on-device JAX arrays."""
186186
header, data_start_offset = await _read_safetensors_header(path)
187187
restored_pytree = {}
188+
189+
num_hosts = jax.process_count()
190+
host_id = jax.process_index()
191+
192+
# Build an initial mesh grouping all global devices by host
193+
devices_by_host = []
194+
for i in range(num_hosts):
195+
devices_by_host.append([d for d in jax.devices() if d.process_index == i])
196+
197+
# Ensure uniform mesh shape (in case of uneven device counts, which is rare)
198+
min_devices = min(len(d) for d in devices_by_host)
199+
devices_by_host = [d[:min_devices] for d in devices_by_host]
200+
201+
initial_mesh = jax.sharding.Mesh(
202+
np.array(devices_by_host), ("hosts", "devices")
203+
)
204+
flat_sharding = jax.sharding.NamedSharding(
205+
initial_mesh, jax.sharding.PartitionSpec("hosts")
206+
)
207+
188208
async with async_path.open_file(path, mode="rb") as f:
189209
for tensor_name, abstract_leaf in abstract_pytree.items():
190210
if tensor_name not in header:
@@ -195,7 +215,6 @@ async def _load_safetensors_on_device(
195215
stored_shape, stored_dtype = _get_array_properties(header[tensor_name])
196216
st_data_offsets = header[tensor_name]["data_offsets"]
197217
sharding = abstract_leaf.sharding
198-
target_shape = abstract_leaf.shape
199218
target_dtype = abstract_leaf.dtype
200219

201220
if sharding is None:
@@ -211,33 +230,54 @@ async def _load_safetensors_on_device(
211230
restored_pytree[tensor_name] = jax.device_put(np_array)
212231
continue
213232

214-
device_indices_map = sharding.addressable_devices_indices_map(
215-
target_shape
216-
)
233+
# We have a target sharding.
234+
# Use 1D flat contiguous read + reshard logic for maximum IO throughput.
235+
total_elements = int(np.prod(stored_shape)) if stored_shape else 1
217236

218-
device_map = []
219-
for device in device_indices_map:
220-
idx = device_indices_map[device]
221-
resolved_idx = numpy_utils.resolve_slice(idx, stored_shape)
222-
shard_shape = numpy_utils.slice_shape(resolved_idx)
223-
224-
shard_np = await _read_non_contiguous_slice(
225-
f,
226-
resolved_idx,
227-
stored_shape,
228-
stored_dtype,
229-
st_data_offsets[0] + data_start_offset,
230-
)
231-
shard_np = shard_np.reshape(shard_shape) # pytype: disable=attribute-error
237+
# Calculate padding
238+
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
239+
padded_elements = elements_per_host * num_hosts
240+
241+
# Calculate what this host needs to read
242+
start_idx = host_id * elements_per_host
243+
end_idx = min((host_id + 1) * elements_per_host, total_elements)
244+
num_elements_to_read = max(0, end_idx - start_idx)
245+
itemsize = np.dtype(stored_dtype).itemsize
232246

233-
if shard_np.dtype != target_dtype:
234-
shard_np = shard_np.astype(target_dtype)
247+
start_byte = st_data_offsets[0] + data_start_offset + start_idx * itemsize
248+
num_bytes = num_elements_to_read * itemsize
235249

236-
device_map.append(jax.device_put(shard_np, device))
250+
await f.seek(start_byte)
251+
raw_data = await f.read(num_bytes)
237252

238-
restored_pytree[tensor_name] = jax.make_array_from_single_device_arrays(
239-
target_shape, sharding, device_map
253+
local_data = np.frombuffer(raw_data, dtype=stored_dtype)
254+
if local_data.dtype != target_dtype:
255+
local_data = local_data.astype(target_dtype)
256+
257+
if num_elements_to_read < elements_per_host:
258+
local_data = np.pad(
259+
local_data, (0, elements_per_host - num_elements_to_read)
260+
)
261+
262+
# Put local data on all addressable devices in the flat sharding
263+
local_arrays = [
264+
jax.device_put(local_data, d)
265+
for d in flat_sharding.addressable_devices
266+
]
267+
268+
# Create the 1D sharded array
269+
flat_array = jax.make_array_from_single_device_arrays(
270+
(padded_elements,), flat_sharding, local_arrays
240271
)
272+
273+
# Slice off the padding and reshape
274+
if padded_elements > total_elements:
275+
flat_array = flat_array[:total_elements]
276+
277+
reshaped_array = flat_array.reshape(stored_shape)
278+
279+
# Reshard to the target sharding
280+
restored_pytree[tensor_name] = jax.device_put(reshaped_array, sharding)
241281
return restored_pytree
242282

243283

0 commit comments

Comments
 (0)