Skip to content

Commit 1803922

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Update benchmark utility to support replicated arrays
PiperOrigin-RevId: 852469369
1 parent 39262e1 commit 1803922

14 files changed

+730
-172
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B slice 16"
3+
num_repeats: 1
4+
5+
mesh_config:
6+
mesh_axes: ["replica", "model"]
7+
# Should match reference_sharding_path.
8+
ici_parallelism: {"replica": 1, "model": 64}
9+
dcn_parallelism: {"replica": 16}
10+
11+
# Note: checkpoint_config field not specified.
12+
13+
benchmarks:
14+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
15+
options:
16+
# --- Generator Options ---
17+
# These keys must match the attributes of the `V1BenchmarkOptions` class
18+
# associated with the `V1Benchmark` generator.
19+
async_enabled: true
20+
use_ocdbt: true
21+
use_zarr3: true
22+
use_replica_parallel: [true, false]
23+
use_compression: true
24+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
25+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
26+
use_load_and_broadcast: true # speed up the loading
27+
num_of_repeats: 20
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B slice 2"
3+
num_repeats: 20
4+
5+
mesh_config:
6+
mesh_axes: ["replica", "model"]
7+
# Should match reference_sharding_path.
8+
ici_parallelism: {"replica": 1, "model": 64}
9+
dcn_parallelism: {"replica": 2}
10+
11+
# Note: checkpoint_config field not specified.
12+
13+
benchmarks:
14+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
15+
options:
16+
# --- Generator Options ---
17+
# These keys must match the attributes of the `V1BenchmarkOptions` class
18+
# associated with the `V1Benchmark` generator.
19+
async_enabled: true
20+
use_ocdbt: true
21+
use_zarr3: true
22+
use_replica_parallel: [true, false]
23+
use_compression: true
24+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
25+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
26+
use_load_and_broadcast: true # speed up the loading
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B slice 32"
3+
num_repeats: 1
4+
5+
mesh_config:
6+
mesh_axes: ["replica", "model"]
7+
# Should match reference_sharding_path.
8+
ici_parallelism: {"replica": 1, "model": 64}
9+
dcn_parallelism: {"replica": 32}
10+
11+
# Note: checkpoint_config field not specified.
12+
13+
benchmarks:
14+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
15+
options:
16+
# --- Generator Options ---
17+
# These keys must match the attributes of the `V1BenchmarkOptions` class
18+
# associated with the `V1Benchmark` generator.
19+
async_enabled: true
20+
use_ocdbt: true
21+
use_zarr3: true
22+
use_replica_parallel: [true, false]
23+
use_compression: true
24+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
25+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
26+
use_load_and_broadcast: true # speed up the loading
27+
num_of_repeats: 20
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B slice 4"
3+
num_repeats: 20
4+
5+
mesh_config:
6+
mesh_axes: ["replica", "model"]
7+
# Should match reference_sharding_path.
8+
ici_parallelism: {"replica": 1, "model": 64}
9+
dcn_parallelism: {"replica": 4}
10+
11+
# Note: checkpoint_config field not specified.
12+
13+
benchmarks:
14+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
15+
options:
16+
# --- Generator Options ---
17+
# These keys must match the attributes of the `V1BenchmarkOptions` class
18+
# associated with the `V1Benchmark` generator.
19+
async_enabled: true
20+
use_ocdbt: true
21+
use_zarr3: true
22+
use_replica_parallel: [true, false]
23+
use_compression: true
24+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
25+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
26+
use_load_and_broadcast: true # speed up the loading
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# The name for the entire test suite run.
2+
suite_name: "Llama 3.1 70B slice 8"
3+
num_repeats: 1 # depends on the number of repeats in the benchmark.
4+
5+
mesh_config:
6+
mesh_axes: ["replica", "model"]
7+
# Should match reference_sharding_path.
8+
ici_parallelism: {"replica": 1, "model": 64}
9+
dcn_parallelism: {"replica": 8}
10+
11+
# Note: checkpoint_config field not specified.
12+
13+
benchmarks:
14+
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.replica_parallel_multislice_benchmark.ReplicaParallelMultislice"
15+
options:
16+
# --- Generator Options ---
17+
# These keys must match the attributes of the `V1BenchmarkOptions` class
18+
# associated with the `V1Benchmark` generator.
19+
async_enabled: true
20+
use_ocdbt: true
21+
use_zarr3: true
22+
use_replica_parallel: [true, false]
23+
use_compression: true
24+
reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt"
25+
reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json"
26+
use_load_and_broadcast: true # speed up the loading
27+
num_of_repeats: 20
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for multi-slice benchmarks."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any
20+
21+
from absl import logging
22+
from etils import epath
23+
import jax
24+
from orbax.checkpoint import v1 as ocp
25+
from orbax.checkpoint._src.multihost import multislice
26+
from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation
27+
28+
29+
def get_multi_slice_abstract_state(
30+
context: ocp.Context,
31+
global_mesh: jax.sharding.Mesh,
32+
*,
33+
reference_checkpoint_path: epath.Path,
34+
reference_sharding_path: epath.Path,
35+
) -> Any:
36+
"""Returns the abstract state for all replicas."""
37+
with ocp.Context(context=context):
38+
metadata = ocp.pytree_metadata(reference_checkpoint_path)
39+
# Abstract tree has shardings on a single replica.
40+
single_replica_abstract_state = (
41+
checkpoint_generation.get_abstract_state_from_sharding_config(
42+
reference_sharding_path,
43+
metadata.metadata,
44+
devices=multislice.replica_devices(
45+
global_mesh, replica_id=0, replica_axis_index=0
46+
).tolist(),
47+
)
48+
)
49+
50+
# Blow shardings up to all replicas.
51+
def _multi_replica_sharding(abstract_arr: jax.ShapeDtypeStruct):
52+
logging.info(
53+
"Original (single-replica) sharding: %s", abstract_arr.sharding
54+
)
55+
assert isinstance(abstract_arr.sharding, jax.sharding.NamedSharding)
56+
single_replica_mesh = abstract_arr.sharding.mesh
57+
single_replica_partition_spec = abstract_arr.sharding.spec
58+
multi_replica_sharding = jax.sharding.NamedSharding(
59+
jax.sharding.Mesh(
60+
devices=global_mesh.devices.reshape(
61+
-1, *single_replica_mesh.devices.shape
62+
),
63+
axis_names=["replica", *single_replica_mesh.axis_names],
64+
),
65+
spec=jax.sharding.PartitionSpec(*single_replica_partition_spec),
66+
)
67+
logging.info("Multi-replica sharding: %s", multi_replica_sharding)
68+
return jax.ShapeDtypeStruct(
69+
shape=abstract_arr.shape,
70+
dtype=abstract_arr.dtype,
71+
sharding=multi_replica_sharding,
72+
)
73+
74+
return jax.tree.map(
75+
_multi_replica_sharding,
76+
single_replica_abstract_state,
77+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
18+
from absl.testing import absltest
19+
from absl.testing import parameterized
20+
from etils import epath
21+
import jax
22+
import jax.numpy as jnp
23+
import numpy as np
24+
from orbax.checkpoint import v1 as ocp
25+
from orbax.checkpoint._src.testing.benchmarks.v1 import multi_slice_util
26+
27+
28+
_REQUIRED_DEVICE_COUNT = 16
29+
30+
31+
class MultiSliceUtilTest(parameterized.TestCase):
32+
33+
def setUp(self):
34+
self._prev_xla_flags = os.environ.get('XLA_FLAGS')
35+
os.environ['XLA_FLAGS'] = (
36+
self._prev_xla_flags or ''
37+
) + ' --xla_force_host_platform_device_count=16'
38+
super().setUp()
39+
if jax.local_device_count() != _REQUIRED_DEVICE_COUNT:
40+
self.skipTest(
41+
f'Test requires {_REQUIRED_DEVICE_COUNT} local devices, but only'
42+
f' {jax.local_device_count()} are available. Set XLA_FLAGS='
43+
f'"--xla_force_host_platform_device_count={_REQUIRED_DEVICE_COUNT}"'
44+
' before JAX initializes.'
45+
)
46+
self.directory = epath.Path(self.create_tempdir().full_path)
47+
48+
def tearDown(self):
49+
if self._prev_xla_flags is None:
50+
os.environ.pop('XLA_FLAGS', None)
51+
else:
52+
os.environ['XLA_FLAGS'] = self._prev_xla_flags
53+
super().tearDown()
54+
55+
def test_get_multi_slice_abstract_state(self):
56+
# Setup real checkpoint and sharding config
57+
pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}}
58+
ref_ckpt_path = self.directory / 'ref_ckpt'
59+
ocp.save_pytree(ref_ckpt_path, pytree)
60+
61+
sharding_config = {
62+
'a': {
63+
'shape': [32],
64+
'dtype': 'int32',
65+
'sharding': {
66+
'mesh': {'shape': [4], 'axes': ['model']},
67+
'spec': ['model'],
68+
},
69+
},
70+
'b.c': {
71+
'shape': [8, 8],
72+
'dtype': 'float32',
73+
'sharding': {
74+
'mesh': {'shape': [4], 'axes': ['model']},
75+
'spec': [None, 'model'],
76+
},
77+
},
78+
}
79+
sharding_config_path = self.directory / 'sharding_config.json'
80+
sharding_config_path.write_text(json.dumps(sharding_config))
81+
global_mesh = jax.sharding.Mesh(
82+
np.array(jax.devices()).reshape((4, 4)), ('replica', 'model')
83+
)
84+
85+
abstract_pytree = multi_slice_util.get_multi_slice_abstract_state(
86+
context=ocp.Context(),
87+
global_mesh=global_mesh,
88+
reference_checkpoint_path=ref_ckpt_path,
89+
reference_sharding_path=sharding_config_path,
90+
)
91+
self.assertEqual(
92+
{'replica': 4, 'model': 4}, abstract_pytree['a'].sharding.mesh.shape
93+
)
94+
self.assertEqual(
95+
jax.sharding.PartitionSpec('model'), abstract_pytree['a'].sharding.spec
96+
)
97+
self.assertEqual(
98+
{'replica': 4, 'model': 4},
99+
abstract_pytree['b']['c'].sharding.mesh.shape,
100+
)
101+
self.assertEqual(
102+
jax.sharding.PartitionSpec(None, 'model'),
103+
abstract_pytree['b']['c'].sharding.spec,
104+
)
105+
106+
107+
if __name__ == '__main__':
108+
absltest.main()

0 commit comments

Comments
 (0)